feat(library/blast/forward/ematch): basic support for heq classes

This commit is contained in:
Daniel Selsam 2016-01-15 20:34:28 -08:00
parent f2eef7aa1b
commit a883101a3b
4 changed files with 75 additions and 30 deletions

View file

@ -1855,6 +1855,10 @@ expr congruence_closure::get_next(name const & R, expr const & e) const {
} }
} }
bool congruence_closure::eq_class_heterogeneous(expr const & e) const {
return has_heq_proofs(get_root(get_eq_name(), e));
}
unsigned congruence_closure::get_mt(name const & R, expr const & e) const { unsigned congruence_closure::get_mt(name const & R, expr const & e) const {
if (auto n = m_entries.find(eqc_key(R, e))) { if (auto n = m_entries.find(eqc_key(R, e))) {
return n->m_mt; return n->m_mt;

View file

@ -252,6 +252,8 @@ public:
expr get_root(name const & R, expr const & e) const; expr get_root(name const & R, expr const & e) const;
expr get_next(name const & R, expr const & e) const; expr get_next(name const & R, expr const & e) const;
bool eq_class_heterogeneous(expr const & e) const;
/** \brief Mark the root of each equivalence class as an "abstract value" /** \brief Mark the root of each equivalence class as an "abstract value"
After this method is invoked, proof production is disabled. Moreover, After this method is invoked, proof production is disabled. Moreover,
merging two different partitions will trigger an inconsistency. */ merging two different partitions will trigger an inconsistency. */

View file

@ -252,7 +252,7 @@ struct ematch_fn {
blast_tmp_type_context m_ctx; blast_tmp_type_context m_ctx;
congruence_closure & m_cc; congruence_closure & m_cc;
enum frame_kind { DefEqOnly, Match, MatchSS /* match subsingleton */, Continue }; enum frame_kind { DefEqOnly, EqvOnly, Match, MatchSS /* match subsingleton */, Continue };
typedef std::tuple<name, frame_kind, expr, expr> entry; typedef std::tuple<name, frame_kind, expr, expr> entry;
typedef list<entry> state; typedef list<entry> state;
@ -347,30 +347,64 @@ struct ematch_fn {
return true; return true;
} }
/* If the eq equivalence class of `t` is heterogeneous, then even though
`t` may fail to match because of its type, another element that is
heterogeneously equal to `t` but that has a different type may match
successfully. */
bool match_leaf(name const & R, expr const & p, expr const & t) {
if (R == get_eq_name() && m_cc.eq_class_heterogeneous(t)) {
lean_trace_debug_ematch(tout() << "match_leaf with heq\n";);
buffer<state> new_states;
expr it = t;
do {
expr_set types_seen;
expr it_type = m_ctx->infer(it);
if (types_seen.find(it_type)) continue;
types_seen.insert(it_type);
new_states.emplace_back(cons(entry(get_eq_name(), EqvOnly, p, it), m_state));
it = m_cc.get_next(R, it);
} while (it != t);
push_states(new_states);
return true;
} else {
lean_trace_debug_ematch(tout() << "match_leaf no heq\n";);
return is_eqv(R, p, t);
}
}
void push_states(buffer<state> & new_states) {
if (new_states.size() == 1) {
lean_trace_debug_ematch(tout() << "(only one match)\n";);
m_state = new_states[0];
} else {
lean_trace_debug_ematch(tout() << "# matches: " << new_states.size() << "\n";);
m_state = new_states.back();
new_states.pop_back();
choice c = to_list(new_states);
m_choice_stack.push_back(c);
m_ctx->push();
}
}
bool process_match(name const & R, expr const & p, expr const & t) { bool process_match(name const & R, expr const & p, expr const & t) {
lean_trace_debug_ematch(tout() << "try process_match: " lean_trace_debug_ematch(tout() << "try process_match: "
<< ppb(p) << " <=?=> " << ppb(t) << "\n";); << ppb(p) << " <=?=> " << ppb(t) << "\n";);
if (!is_app(p)) { if (!is_app(p)) {
bool success = is_eqv(R, p, t); bool success = match_leaf(R, p, t);
lean_trace_debug_ematch(
expr new_p = m_ctx->instantiate_uvars_mvars(p);
expr new_p_type = m_ctx->instantiate_uvars_mvars(m_ctx->infer(p));
expr t_type = m_ctx->infer(t);
tout() << "is_eqv " << ppb(new_p) << " : " << ppb(new_p_type)
<< " <- " << ppb(t) << " : " << ppb(t_type) << " ... " << (success ? "succeeded" : "failed") << "\n";);
return success; return success;
} }
buffer<expr> p_args; buffer<expr> p_args;
expr const & fn = get_app_args(p, p_args); expr const & fn = get_app_args(p, p_args);
if (m_ctx->is_mvar(fn)) if (m_ctx->is_mvar(fn)) {
return is_eqv(R, p, t); return match_leaf(R, p, t);
}
buffer<expr> candidates; buffer<expr> candidates;
expr t_fn; expr t_fn;
expr it = t; expr it = t;
do { do {
expr const & it_fn = get_app_fn(it); expr const & it_fn = get_app_fn(it);
bool ok = false; bool ok = false;
if (m_cc.is_congr_root(R, it) && m_ctx->is_def_eq(it_fn, fn) && if ((m_cc.is_congr_root(R, it) || m_cc.eq_class_heterogeneous(it)) && m_ctx->is_def_eq(it_fn, fn) &&
get_app_num_args(it) == p_args.size()) { get_app_num_args(it) == p_args.size()) {
t_fn = it_fn; t_fn = it_fn;
ok = true; ok = true;
@ -391,19 +425,8 @@ struct ematch_fn {
new_states.push_back(new_state); new_states.push_back(new_state);
} }
} }
if (new_states.size() == 1) { push_states(new_states);
lean_trace_debug_ematch(tout() << "(only one match)\n";); return true;
m_state = new_states[0];
return true;
} else {
lean_trace_debug_ematch(tout() << "# matches: " << new_states.size() << "\n";);
m_state = new_states.back();
new_states.pop_back();
choice c = to_list(new_states);
m_choice_stack.push_back(c);
m_ctx->push();
return true;
}
} }
bool process_continue(name const & R, expr const & p) { bool process_continue(name const & R, expr const & p) {
@ -413,7 +436,7 @@ struct ematch_fn {
buffer<state> new_states; buffer<state> new_states;
if (auto s = m_inst_ext.get_apps().find(head_index(f))) { if (auto s = m_inst_ext.get_apps().find(head_index(f))) {
s->for_each([&](expr const & t) { s->for_each([&](expr const & t) {
if (m_cc.is_congr_root(R, t)) { if (m_cc.is_congr_root(R, t) || m_cc.eq_class_heterogeneous(t)) {
state new_state = m_state; state new_state = m_state;
if (match_args(new_state, R, p_args, t)) if (match_args(new_state, R, p_args, t))
new_states.push_back(new_state); new_states.push_back(new_state);
@ -469,13 +492,29 @@ struct ematch_fn {
std::tie(R, kind, p, t) = head(m_state); std::tie(R, kind, p, t) = head(m_state);
m_state = tail(m_state); m_state = tail(m_state);
// diagnostic(env(), ios()) << ">> " << R << ", " << ppb(p) << " =?= " << ppb(t) << "\n"; // diagnostic(env(), ios()) << ">> " << R << ", " << ppb(p) << " =?= " << ppb(t) << "\n";
bool success;
switch (kind) { switch (kind) {
case DefEqOnly: case DefEqOnly:
lean_trace_debug_ematch(tout() << "must be def-eq: " success = m_ctx->is_def_eq(p, t);
<< ppb(p) << " <=?=> " << ppb(t) << "\n";); lean_trace_debug_ematch(
return m_ctx->is_def_eq(p, t); expr new_p = m_ctx->instantiate_uvars_mvars(p);
expr new_p_type = m_ctx->instantiate_uvars_mvars(m_ctx->infer(p));
expr t_type = m_ctx->infer(t);
tout() << "must be def-eq: " << ppb(new_p) << " : " << ppb(new_p_type)
<< " =?= " << ppb(t) << " : " << ppb(t_type)
<< " ... " << (success ? "succeeded" : "failed") << "\n";);
return success;
case Match: case Match:
return process_match(R, p, t); return process_match(R, p, t);
case EqvOnly:
success = is_eqv(R, p, t);
lean_trace_debug_ematch(
expr new_p = m_ctx->instantiate_uvars_mvars(p);
expr new_p_type = m_ctx->instantiate_uvars_mvars(m_ctx->infer(p));
expr t_type = m_ctx->infer(t);
tout() << "must be eqv: " << ppb(new_p) << " : " << ppb(new_p_type) << " =?= "
<< ppb(t) << " : " << ppb(t_type) << " ... " << (success ? "succeeded" : "failed") << "\n";);
return success;
case MatchSS: case MatchSS:
return process_matchss(p, t); return process_matchss(p, t);
case Continue: case Continue:
@ -555,7 +594,7 @@ struct ematch_fn {
unsigned gmt = m_cc.get_gmt(); unsigned gmt = m_cc.get_gmt();
if (auto s = m_inst_ext.get_apps().find(head_index(f))) { if (auto s = m_inst_ext.get_apps().find(head_index(f))) {
s->for_each([&](expr const & t) { s->for_each([&](expr const & t) {
if (m_cc.is_congr_root(R, t) && (!filter || m_cc.get_mt(R, t) == gmt)) { if ((m_cc.is_congr_root(R, t) || m_cc.eq_class_heterogeneous(t)) && (!filter || m_cc.get_mt(R, t) == gmt)) {
lean_trace_debug_ematch(tout() << "ematch " << ppb(get_app_fn(lemma.m_proof)) << " [using] " << ppb(t) << "\n";); lean_trace_debug_ematch(tout() << "ematch " << ppb(get_app_fn(lemma.m_proof)) << " [using] " << ppb(t) << "\n";);
m_ctx->clear(); m_ctx->clear();
m_ctx->set_next_uvar_idx(lemma.m_num_uvars); m_ctx->set_next_uvar_idx(lemma.m_num_uvars);

View file

@ -91,6 +91,6 @@ lemma vplus.def2 [simp] {n : } (v₁ v₂ : vector n) (a₁ a₂ : ) :
lemma vplus_weird {n₁ n₂ : } (v₁ : vector n₁) (v₂ : vector n₂) (a b : ) : lemma vplus_weird {n₁ n₂ : } (v₁ : vector n₁) (v₂ : vector n₂) (a b : ) :
vplus (a :: append v₁ v₂) ⟨b :: append v₂ v₁⟩ == (a + b) :: vplus (append v₁ v₂) ⟨append v₂ v₁⟩ := vplus (a :: append v₁ v₂) ⟨b :: append v₂ v₁⟩ == (a + b) :: vplus (append v₁ v₂) ⟨append v₂ v₁⟩ :=
sorry -- TODO need to traverse equivalence class when matching against a meta-variable by inst_simp
end vector end vector