fix(library/class_instance_resolution): transitive instances in the new type class resolution procedure

This commit is contained in:
Leonardo de Moura 2015-10-22 15:11:56 -07:00
parent 5468076400
commit 8d8e43abfd

View file

@ -112,10 +112,25 @@ struct cienv {
unsigned m_next_uvar_idx; unsigned m_next_uvar_idx;
unsigned m_next_mvar_idx; unsigned m_next_mvar_idx;
struct stack_entry {
// We only use transitive instances when we can solve the problem in a single step.
// That is, the transitive instance does not have any instance argument, OR
// it uses local instances to fill them.
// We accomplish that by not considering global instances when solving
// transitive instance subproblems.
expr m_mvar;
unsigned m_depth;
bool m_trans_inst_subproblem;
stack_entry(expr const & m, unsigned d, bool s = false):
m_mvar(m), m_depth(d), m_trans_inst_subproblem(s) {}
};
struct state { struct state {
list<pair<unsigned, expr>> m_stack; // stack of meta-variables that need to be synthesized; bool m_trans_inst_subproblem;
uassignment m_uassignment; list<stack_entry> m_stack; // stack of meta-variables that need to be synthesized;
eassignment m_eassignment; uassignment m_uassignment;
eassignment m_eassignment;
state():m_trans_inst_subproblem(false) {}
}; };
state m_state; // active state state m_state; // active state
@ -609,9 +624,12 @@ struct cienv {
out << mvar << " : " << instantiate_uvars_mvars(mvar_type) << " := " << r << endl; out << mvar << " : " << instantiate_uvars_mvars(mvar_type) << " := " << r << endl;
} }
bool try_instance(unsigned depth, expr const & mvar, expr const & inst, expr const & inst_type) { // Try to synthesize e.m_mvar using instance inst : inst_type.
// trans_inst is true if inst is a transitive instance.
bool try_instance(stack_entry const & e, expr const & inst, expr const & inst_type, bool trans_inst) {
try { try {
buffer<expr> locals; buffer<expr> locals;
expr const & mvar = e.m_mvar;
expr mvar_type = mlocal_type(mvar); expr mvar_type = mlocal_type(mvar);
while (true) { while (true) {
mvar_type = whnf(mvar_type); mvar_type = whnf(mvar_type);
@ -637,7 +655,7 @@ struct cienv {
type = instantiate(binding_body(type), new_arg); type = instantiate(binding_body(type), new_arg);
} }
if (m_trace_instances) { if (m_trace_instances) {
trace(depth, mk_app(mvar, locals), mvar_type, r); trace(e.m_depth, mk_app(mvar, locals), mvar_type, r);
} }
if (!is_def_eq(mvar_type, type)) { if (!is_def_eq(mvar_type, type)) {
return false; return false;
@ -656,7 +674,7 @@ struct cienv {
unsigned i = new_inst_mvars.size(); unsigned i = new_inst_mvars.size();
while (i > 0) { while (i > 0) {
--i; --i;
m_state.m_stack = cons(mk_pair(depth+1, new_inst_mvars[i]), m_state.m_stack); m_state.m_stack = cons(stack_entry(new_inst_mvars[i], e.m_depth+1, trans_inst), m_state.m_stack);
} }
return true; return true;
} catch (exception &) { } catch (exception &) {
@ -664,7 +682,7 @@ struct cienv {
} }
} }
bool try_instance(unsigned depth, expr const & mvar, name const & inst_name) { bool try_instance(stack_entry const & e, name const & inst_name, bool trans_inst) {
if (auto decl = m_env.find(inst_name)) { if (auto decl = m_env.find(inst_name)) {
buffer<level> ls_buffer; buffer<level> ls_buffer;
unsigned num_univ_ps = decl->get_num_univ_params(); unsigned num_univ_ps = decl->get_num_univ_params();
@ -673,7 +691,7 @@ struct cienv {
levels ls = to_list(ls_buffer.begin(), ls_buffer.end()); levels ls = to_list(ls_buffer.begin(), ls_buffer.end());
expr inst_cnst = mk_constant(inst_name, ls); expr inst_cnst = mk_constant(inst_name, ls);
expr inst_type = instantiate_type_univ_params(*decl, ls); expr inst_type = instantiate_type_univ_params(*decl, ls);
return try_instance(depth, mvar, inst_cnst, inst_type); return try_instance(e, inst_cnst, inst_type, trans_inst);
} else { } else {
return false; return false;
} }
@ -725,47 +743,49 @@ struct cienv {
return true; return true;
} }
bool process_next_alt_core(unsigned depth, expr const & mvar, list<expr> & insts) { bool process_next_alt_core(stack_entry const & e, list<expr> & insts) {
while (!empty(insts)) { while (!empty(insts)) {
expr inst = head(insts); expr inst = head(insts);
insts = tail(insts); insts = tail(insts);
expr inst_type = infer_type(inst); expr inst_type = infer_type(inst);
if (try_instance(depth, mvar, inst, inst_type)) bool trans_inst = false;
if (try_instance(e, inst, inst_type, trans_inst))
return true; return true;
} }
return false; return false;
} }
bool process_next_alt_core(unsigned depth, expr const & mvar, list<name> & inst_names) { bool process_next_alt_core(stack_entry const & e, list<name> & inst_names, bool trans_inst) {
while (!empty(inst_names)) { while (!empty(inst_names)) {
name inst_name = head(inst_names); name inst_name = head(inst_names);
inst_names = tail(inst_names); inst_names = tail(inst_names);
if (try_instance(depth, mvar, inst_name)) if (try_instance(e, inst_name, trans_inst))
return true; return true;
} }
return false; return false;
} }
bool process_next_alt(unsigned depth, expr const & mvar) { bool process_next_alt(stack_entry const & e) {
lean_assert(!m_choices.empty()); lean_assert(!m_choices.empty());
choice & c = m_choices.back(); choice & c = m_choices.back();
if (process_next_alt_core(depth, mvar, c.m_local_instances)) if (process_next_alt_core(e, c.m_local_instances))
return true;
if (process_next_alt_core(depth, mvar, c.m_trans_instances))
return true;
if (process_next_alt_core(depth, mvar, c.m_instances))
return true; return true;
if (!e.m_trans_inst_subproblem) {
if (process_next_alt_core(e, c.m_trans_instances, true))
return true;
if (process_next_alt_core(e, c.m_instances, false))
return true;
}
return false; return false;
} }
bool process_next_mvar() { bool process_next_mvar() {
lean_assert(!is_done()); lean_assert(!is_done());
unsigned depth = head(m_state.m_stack).first; stack_entry e = head(m_state.m_stack);
expr mvar = head(m_state.m_stack).second; if (!mk_choice_point(e.m_mvar))
if (!mk_choice_point(mvar))
return false; return false;
m_state.m_stack = tail(m_state.m_stack); m_state.m_stack = tail(m_state.m_stack);
return process_next_alt(depth, mvar); return process_next_alt(e);
} }
bool backtrack() { bool backtrack() {
@ -776,10 +796,9 @@ struct cienv {
if (m_choices.empty()) if (m_choices.empty())
return false; return false;
m_state = m_choices.back().m_state; m_state = m_choices.back().m_state;
unsigned depth = head(m_state.m_stack).first; stack_entry e = head(m_state.m_stack);
expr mvar = head(m_state.m_stack).second;
m_state.m_stack = tail(m_state.m_stack); m_state.m_stack = tail(m_state.m_stack);
if (process_next_alt(depth, mvar)) if (process_next_alt(e))
return true; return true;
} }
} }
@ -798,10 +817,9 @@ struct cienv {
if (m_choices.empty()) if (m_choices.empty())
return none_expr(); return none_expr();
m_state = m_choices.back().m_state; m_state = m_choices.back().m_state;
unsigned depth = head(m_state.m_stack).first; stack_entry e = head(m_state.m_stack);
expr mvar = head(m_state.m_stack).second;
m_state.m_stack = tail(m_state.m_stack); m_state.m_stack = tail(m_state.m_stack);
if (process_next_alt(depth, mvar)) if (process_next_alt(e))
return search(); return search();
else if (backtrack()) else if (backtrack())
return search(); return search();
@ -812,7 +830,7 @@ struct cienv {
void init_search(expr const & type) { void init_search(expr const & type) {
m_state = state(); m_state = state();
m_main_mvar = mk_mvar(type); m_main_mvar = mk_mvar(type);
m_state.m_stack = cons(mk_pair(0u, m_main_mvar), m_state.m_stack); m_state.m_stack = cons(stack_entry(m_main_mvar, 0), m_state.m_stack);
m_choices.clear(); m_choices.clear();
} }