diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 598090a99..09cd648e6 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -1855,6 +1855,10 @@ expr congruence_closure::get_next(name const & R, expr const & e) const { } } +bool congruence_closure::eq_class_heterogeneous(expr const & e) const { + return has_heq_proofs(get_root(get_eq_name(), e)); +} + unsigned congruence_closure::get_mt(name const & R, expr const & e) const { if (auto n = m_entries.find(eqc_key(R, e))) { return n->m_mt; diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index b4e30894f..68a1cffef 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -252,6 +252,8 @@ public: expr get_root(name const & R, expr const & e) const; expr get_next(name const & R, expr const & e) const; + bool eq_class_heterogeneous(expr const & e) const; + /** \brief Mark the root of each equivalence class as an "abstract value" After this method is invoked, proof production is disabled. Moreover, merging two different partitions will trigger an inconsistency. */ diff --git a/src/library/blast/forward/ematch.cpp b/src/library/blast/forward/ematch.cpp index d575c8ee1..ebda5ecc8 100644 --- a/src/library/blast/forward/ematch.cpp +++ b/src/library/blast/forward/ematch.cpp @@ -252,7 +252,7 @@ struct ematch_fn { blast_tmp_type_context m_ctx; congruence_closure & m_cc; - enum frame_kind { DefEqOnly, Match, MatchSS /* match subsingleton */, Continue }; + enum frame_kind { DefEqOnly, EqvOnly, Match, MatchSS /* match subsingleton */, Continue }; typedef std::tuple entry; typedef list state; @@ -347,30 +347,64 @@ struct ematch_fn { return true; } + /* If the eq equivalence class of `t` is heterogeneous, then even though + `t` may fail to match because of its type, another element that is + heterogeneously equal to `t` but that has a different type may match + successfully. */ + bool match_leaf(name const & R, expr const & p, expr const & t) { + if (R == get_eq_name() && m_cc.eq_class_heterogeneous(t)) { + lean_trace_debug_ematch(tout() << "match_leaf with heq\n";); + buffer new_states; + expr it = t; + do { + expr_set types_seen; + expr it_type = m_ctx->infer(it); + if (types_seen.find(it_type)) continue; + types_seen.insert(it_type); + new_states.emplace_back(cons(entry(get_eq_name(), EqvOnly, p, it), m_state)); + it = m_cc.get_next(R, it); + } while (it != t); + push_states(new_states); + return true; + } else { + lean_trace_debug_ematch(tout() << "match_leaf no heq\n";); + return is_eqv(R, p, t); + } + } + + void push_states(buffer & new_states) { + if (new_states.size() == 1) { + lean_trace_debug_ematch(tout() << "(only one match)\n";); + m_state = new_states[0]; + } else { + lean_trace_debug_ematch(tout() << "# matches: " << new_states.size() << "\n";); + m_state = new_states.back(); + new_states.pop_back(); + choice c = to_list(new_states); + m_choice_stack.push_back(c); + m_ctx->push(); + } + } + bool process_match(name const & R, expr const & p, expr const & t) { lean_trace_debug_ematch(tout() << "try process_match: " << ppb(p) << " <=?=> " << ppb(t) << "\n";); if (!is_app(p)) { - bool success = is_eqv(R, p, t); - lean_trace_debug_ematch( - expr new_p = m_ctx->instantiate_uvars_mvars(p); - expr new_p_type = m_ctx->instantiate_uvars_mvars(m_ctx->infer(p)); - expr t_type = m_ctx->infer(t); - tout() << "is_eqv " << ppb(new_p) << " : " << ppb(new_p_type) - << " <- " << ppb(t) << " : " << ppb(t_type) << " ... " << (success ? "succeeded" : "failed") << "\n";); + bool success = match_leaf(R, p, t); return success; } buffer p_args; expr const & fn = get_app_args(p, p_args); - if (m_ctx->is_mvar(fn)) - return is_eqv(R, p, t); + if (m_ctx->is_mvar(fn)) { + return match_leaf(R, p, t); + } buffer candidates; expr t_fn; expr it = t; do { expr const & it_fn = get_app_fn(it); bool ok = false; - if (m_cc.is_congr_root(R, it) && m_ctx->is_def_eq(it_fn, fn) && + if ((m_cc.is_congr_root(R, it) || m_cc.eq_class_heterogeneous(it)) && m_ctx->is_def_eq(it_fn, fn) && get_app_num_args(it) == p_args.size()) { t_fn = it_fn; ok = true; @@ -391,19 +425,8 @@ struct ematch_fn { new_states.push_back(new_state); } } - if (new_states.size() == 1) { - lean_trace_debug_ematch(tout() << "(only one match)\n";); - m_state = new_states[0]; - return true; - } else { - lean_trace_debug_ematch(tout() << "# matches: " << new_states.size() << "\n";); - m_state = new_states.back(); - new_states.pop_back(); - choice c = to_list(new_states); - m_choice_stack.push_back(c); - m_ctx->push(); - return true; - } + push_states(new_states); + return true; } bool process_continue(name const & R, expr const & p) { @@ -413,7 +436,7 @@ struct ematch_fn { buffer new_states; if (auto s = m_inst_ext.get_apps().find(head_index(f))) { s->for_each([&](expr const & t) { - if (m_cc.is_congr_root(R, t)) { + if (m_cc.is_congr_root(R, t) || m_cc.eq_class_heterogeneous(t)) { state new_state = m_state; if (match_args(new_state, R, p_args, t)) new_states.push_back(new_state); @@ -469,13 +492,29 @@ struct ematch_fn { std::tie(R, kind, p, t) = head(m_state); m_state = tail(m_state); // diagnostic(env(), ios()) << ">> " << R << ", " << ppb(p) << " =?= " << ppb(t) << "\n"; + bool success; switch (kind) { case DefEqOnly: - lean_trace_debug_ematch(tout() << "must be def-eq: " - << ppb(p) << " <=?=> " << ppb(t) << "\n";); - return m_ctx->is_def_eq(p, t); + success = m_ctx->is_def_eq(p, t); + lean_trace_debug_ematch( + expr new_p = m_ctx->instantiate_uvars_mvars(p); + expr new_p_type = m_ctx->instantiate_uvars_mvars(m_ctx->infer(p)); + expr t_type = m_ctx->infer(t); + tout() << "must be def-eq: " << ppb(new_p) << " : " << ppb(new_p_type) + << " =?= " << ppb(t) << " : " << ppb(t_type) + << " ... " << (success ? "succeeded" : "failed") << "\n";); + return success; case Match: return process_match(R, p, t); + case EqvOnly: + success = is_eqv(R, p, t); + lean_trace_debug_ematch( + expr new_p = m_ctx->instantiate_uvars_mvars(p); + expr new_p_type = m_ctx->instantiate_uvars_mvars(m_ctx->infer(p)); + expr t_type = m_ctx->infer(t); + tout() << "must be eqv: " << ppb(new_p) << " : " << ppb(new_p_type) << " =?= " + << ppb(t) << " : " << ppb(t_type) << " ... " << (success ? "succeeded" : "failed") << "\n";); + return success; case MatchSS: return process_matchss(p, t); case Continue: @@ -555,7 +594,7 @@ struct ematch_fn { unsigned gmt = m_cc.get_gmt(); if (auto s = m_inst_ext.get_apps().find(head_index(f))) { s->for_each([&](expr const & t) { - if (m_cc.is_congr_root(R, t) && (!filter || m_cc.get_mt(R, t) == gmt)) { + if ((m_cc.is_congr_root(R, t) || m_cc.eq_class_heterogeneous(t)) && (!filter || m_cc.get_mt(R, t) == gmt)) { lean_trace_debug_ematch(tout() << "ematch " << ppb(get_app_fn(lemma.m_proof)) << " [using] " << ppb(t) << "\n";); m_ctx->clear(); m_ctx->set_next_uvar_idx(lemma.m_num_uvars); diff --git a/tests/lean/run/blast_vector_test.lean b/tests/lean/run/blast_vector_test.lean index fd30a79a2..6ce58bdb8 100644 --- a/tests/lean/run/blast_vector_test.lean +++ b/tests/lean/run/blast_vector_test.lean @@ -91,6 +91,6 @@ lemma vplus.def2 [simp] {n : ℕ} (v₁ v₂ : vector ℕ n) (a₁ a₂ : ℕ) : lemma vplus_weird {n₁ n₂ : ℕ} (v₁ : vector ℕ n₁) (v₂ : vector ℕ n₂) (a b : ℕ) : vplus (a :: append v₁ v₂) ⟨b :: append v₂ v₁⟩ == (a + b) :: vplus (append v₁ v₂) ⟨append v₂ v₁⟩ := -sorry -- TODO need to traverse equivalence class when matching against a meta-variable + by inst_simp end vector