From 8d8e43abfdc989b29038330d59118701bcfec5d7 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 22 Oct 2015 15:11:56 -0700 Subject: [PATCH] fix(library/class_instance_resolution): transitive instances in the new type class resolution procedure --- src/library/class_instance_resolution.cpp | 76 ++++++++++++++--------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/src/library/class_instance_resolution.cpp b/src/library/class_instance_resolution.cpp index 8de4402fc..c3344264b 100644 --- a/src/library/class_instance_resolution.cpp +++ b/src/library/class_instance_resolution.cpp @@ -112,10 +112,25 @@ struct cienv { unsigned m_next_uvar_idx; unsigned m_next_mvar_idx; + struct stack_entry { + // We only use transitive instances when we can solve the problem in a single step. + // That is, the transitive instance does not have any instance argument, OR + // it uses local instances to fill them. + // We accomplish that by not considering global instances when solving + // transitive instance subproblems. + expr m_mvar; + unsigned m_depth; + bool m_trans_inst_subproblem; + stack_entry(expr const & m, unsigned d, bool s = false): + m_mvar(m), m_depth(d), m_trans_inst_subproblem(s) {} + }; + struct state { - list> m_stack; // stack of meta-variables that need to be synthesized; - uassignment m_uassignment; - eassignment m_eassignment; + bool m_trans_inst_subproblem; + list m_stack; // stack of meta-variables that need to be synthesized; + uassignment m_uassignment; + eassignment m_eassignment; + state():m_trans_inst_subproblem(false) {} }; state m_state; // active state @@ -609,9 +624,12 @@ struct cienv { out << mvar << " : " << instantiate_uvars_mvars(mvar_type) << " := " << r << endl; } - bool try_instance(unsigned depth, expr const & mvar, expr const & inst, expr const & inst_type) { + // Try to synthesize e.m_mvar using instance inst : inst_type. + // trans_inst is true if inst is a transitive instance. + bool try_instance(stack_entry const & e, expr const & inst, expr const & inst_type, bool trans_inst) { try { buffer locals; + expr const & mvar = e.m_mvar; expr mvar_type = mlocal_type(mvar); while (true) { mvar_type = whnf(mvar_type); @@ -637,7 +655,7 @@ struct cienv { type = instantiate(binding_body(type), new_arg); } if (m_trace_instances) { - trace(depth, mk_app(mvar, locals), mvar_type, r); + trace(e.m_depth, mk_app(mvar, locals), mvar_type, r); } if (!is_def_eq(mvar_type, type)) { return false; @@ -656,7 +674,7 @@ struct cienv { unsigned i = new_inst_mvars.size(); while (i > 0) { --i; - m_state.m_stack = cons(mk_pair(depth+1, new_inst_mvars[i]), m_state.m_stack); + m_state.m_stack = cons(stack_entry(new_inst_mvars[i], e.m_depth+1, trans_inst), m_state.m_stack); } return true; } catch (exception &) { @@ -664,7 +682,7 @@ struct cienv { } } - bool try_instance(unsigned depth, expr const & mvar, name const & inst_name) { + bool try_instance(stack_entry const & e, name const & inst_name, bool trans_inst) { if (auto decl = m_env.find(inst_name)) { buffer ls_buffer; unsigned num_univ_ps = decl->get_num_univ_params(); @@ -673,7 +691,7 @@ struct cienv { levels ls = to_list(ls_buffer.begin(), ls_buffer.end()); expr inst_cnst = mk_constant(inst_name, ls); expr inst_type = instantiate_type_univ_params(*decl, ls); - return try_instance(depth, mvar, inst_cnst, inst_type); + return try_instance(e, inst_cnst, inst_type, trans_inst); } else { return false; } @@ -725,47 +743,49 @@ struct cienv { return true; } - bool process_next_alt_core(unsigned depth, expr const & mvar, list & insts) { + bool process_next_alt_core(stack_entry const & e, list & insts) { while (!empty(insts)) { expr inst = head(insts); insts = tail(insts); expr inst_type = infer_type(inst); - if (try_instance(depth, mvar, inst, inst_type)) + bool trans_inst = false; + if (try_instance(e, inst, inst_type, trans_inst)) return true; } return false; } - bool process_next_alt_core(unsigned depth, expr const & mvar, list & inst_names) { + bool process_next_alt_core(stack_entry const & e, list & inst_names, bool trans_inst) { while (!empty(inst_names)) { name inst_name = head(inst_names); inst_names = tail(inst_names); - if (try_instance(depth, mvar, inst_name)) + if (try_instance(e, inst_name, trans_inst)) return true; } return false; } - bool process_next_alt(unsigned depth, expr const & mvar) { + bool process_next_alt(stack_entry const & e) { lean_assert(!m_choices.empty()); choice & c = m_choices.back(); - if (process_next_alt_core(depth, mvar, c.m_local_instances)) - return true; - if (process_next_alt_core(depth, mvar, c.m_trans_instances)) - return true; - if (process_next_alt_core(depth, mvar, c.m_instances)) + if (process_next_alt_core(e, c.m_local_instances)) return true; + if (!e.m_trans_inst_subproblem) { + if (process_next_alt_core(e, c.m_trans_instances, true)) + return true; + if (process_next_alt_core(e, c.m_instances, false)) + return true; + } return false; } bool process_next_mvar() { lean_assert(!is_done()); - unsigned depth = head(m_state.m_stack).first; - expr mvar = head(m_state.m_stack).second; - if (!mk_choice_point(mvar)) + stack_entry e = head(m_state.m_stack); + if (!mk_choice_point(e.m_mvar)) return false; m_state.m_stack = tail(m_state.m_stack); - return process_next_alt(depth, mvar); + return process_next_alt(e); } bool backtrack() { @@ -776,10 +796,9 @@ struct cienv { if (m_choices.empty()) return false; m_state = m_choices.back().m_state; - unsigned depth = head(m_state.m_stack).first; - expr mvar = head(m_state.m_stack).second; + stack_entry e = head(m_state.m_stack); m_state.m_stack = tail(m_state.m_stack); - if (process_next_alt(depth, mvar)) + if (process_next_alt(e)) return true; } } @@ -798,10 +817,9 @@ struct cienv { if (m_choices.empty()) return none_expr(); m_state = m_choices.back().m_state; - unsigned depth = head(m_state.m_stack).first; - expr mvar = head(m_state.m_stack).second; + stack_entry e = head(m_state.m_stack); m_state.m_stack = tail(m_state.m_stack); - if (process_next_alt(depth, mvar)) + if (process_next_alt(e)) return search(); else if (backtrack()) return search(); @@ -812,7 +830,7 @@ struct cienv { void init_search(expr const & type) { m_state = state(); m_main_mvar = mk_mvar(type); - m_state.m_stack = cons(mk_pair(0u, m_main_mvar), m_state.m_stack); + m_state.m_stack = cons(stack_entry(m_main_mvar, 0), m_state.m_stack); m_choices.clear(); }