feat(library/blast/forward/ematch): use type class resolution to infer missing arguments

This commit is contained in:
Leonardo de Moura 2015-11-30 08:44:42 -07:00
parent 37ad850455
commit 4d63a27f13
3 changed files with 33 additions and 12 deletions

View file

@ -243,7 +243,19 @@ struct ematch_fn {
}
void instantiate(hi_lemma const & lemma) {
std::cout << "FOUND\n";
list<bool> const * it = &lemma.m_is_inst_implicit;
for (expr const & mvar : lemma.m_mvars) {
if (!m_ctx->get_assignment(mvar)) {
if (!head(*it))
return; // fail, argument is not instance implicit
auto new_val = m_ctx->mk_class_instance(m_ctx->infer(mvar));
if (!new_val)
return; // fail, instance could not be generated
if (!m_ctx->assign(mvar, *new_val))
return; // fail, type error
}
it = &tail(*it);
}
for (expr const & mvar : lemma.m_mvars) {
diagnostic(env(), ios()) << "[" << mvar << "] := " << ppb(m_ctx->instantiate_uvars_mvars(mvar)) << "\n";
}

View file

@ -212,6 +212,7 @@ struct hi_config {
hi_lemma const & l = e.m_lemma;
s << l.m_num_uvars << l.m_num_mvars << l.m_priority << l.m_prop << l.m_proof;
write_list(s, l.m_mvars);
write_list(s, l.m_is_inst_implicit);
write_list(s, l.m_multi_patterns);
}
}
@ -222,8 +223,9 @@ struct hi_config {
if (!e.m_no_pattern) {
hi_lemma & l = e.m_lemma;
d >> l.m_num_uvars >> l.m_num_mvars >> l.m_priority >> l.m_prop >> l.m_proof;
l.m_mvars = read_list<expr>(d);
l.m_multi_patterns = read_list<multi_pattern>(d);
l.m_mvars = read_list<expr>(d);
l.m_is_inst_implicit = read_list<bool>(d);
l.m_multi_patterns = read_list<multi_pattern>(d);
}
return e;
}
@ -259,7 +261,9 @@ static bool is_higher_order(tmp_type_context & ctx, expr const & e) {
and store in trackable and residue the subsets of these meta-variables as
described in the beginning of this file. Then returns B (instantiated with the new meta-variables) */
expr extract_trackable(tmp_type_context & ctx, expr const & type,
buffer<expr> & mvars, idx_metavar_set & trackable, idx_metavar_set & residue) {
buffer<expr> & mvars,
buffer<bool> & inst_implicit_flags,
idx_metavar_set & trackable, idx_metavar_set & residue) {
// 1. Create mvars and initialize trackable and residue sets
expr it = type;
while (true) {
@ -274,6 +278,7 @@ expr extract_trackable(tmp_type_context & ctx, expr const & type,
lean_assert(is_idx_metavar(new_mvar));
mvars.push_back(new_mvar);
bool is_inst_implicit = binding_info(it).is_inst_implicit();
inst_implicit_flags.push_back(is_inst_implicit);
bool is_prop = ctx.is_prop(binding_domain(it));
if (!is_inst_implicit) {
unsigned midx = to_meta_idx(new_mvar);
@ -586,7 +591,9 @@ struct mk_hi_lemma_fn {
hi_lemma operator()() {
expr H_type = m_ctx.infer(m_H);
expr B = extract_trackable(m_ctx, H_type, m_mvars, m_trackable, m_residue);
buffer<bool> inst_implicit_flags;
expr B = extract_trackable(m_ctx, H_type, m_mvars, inst_implicit_flags, m_trackable, m_residue);
lean_assert(m_mvars.size() == inst_implicit_flags.size());
buffer<expr> subst;
buffer<expr> residue_locals;
expr proof = mk_proof(residue_locals, subst);
@ -619,13 +626,14 @@ struct mk_hi_lemma_fn {
"(solution: provide pattern hints using the notation '(: t :)' )");
}
hi_lemma r;
r.m_num_uvars = m_num_uvars;
r.m_num_mvars = m_mvars.size();
r.m_priority = m_priority;
r.m_multi_patterns = mps;
r.m_mvars = to_list(m_mvars);
r.m_prop = m_ctx.infer(proof);
r.m_proof = proof;
r.m_num_uvars = m_num_uvars;
r.m_num_mvars = m_mvars.size();
r.m_priority = m_priority;
r.m_multi_patterns = mps;
r.m_mvars = to_list(m_mvars);
r.m_is_inst_implicit = to_list(inst_implicit_flags);
r.m_prop = m_ctx.infer(proof);
r.m_proof = proof;
return r;
}
};

View file

@ -40,6 +40,7 @@ struct hi_lemma {
unsigned m_num_mvars;
unsigned m_priority;
list<multi_pattern> m_multi_patterns;
list<bool> m_is_inst_implicit;
list<expr> m_mvars;
expr m_prop;
expr m_proof;