diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index bf6a2ac7e..6b9795a02 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -21,14 +21,23 @@ Author: Leonardo de Moura #define LEAN_DEFAULT_BLAST_CC_HEQ false #endif +#ifndef LEAN_DEFAULT_BLAST_CC_SUBSINGLETON +#define LEAN_DEFAULT_BLAST_CC_SUBSINGLETON false +#endif + namespace lean { namespace blast { -static name * g_blast_cc_heq = nullptr; +static name * g_blast_cc_heq = nullptr; +static name * g_blast_cc_subsingleton = nullptr; bool get_blast_cc_heq(options const & o) { return o.get_bool(*g_blast_cc_heq, LEAN_DEFAULT_BLAST_CC_HEQ); } +bool get_blast_cc_subsingleton(options const & o) { + return o.get_bool(*g_blast_cc_subsingleton, LEAN_DEFAULT_BLAST_CC_SUBSINGLETON); +} + /* Not all user-defined congruence lemmas can be use by this module We cache the ones that can be used. */ struct congr_lemma_key { @@ -52,6 +61,7 @@ struct congr_lemma_key_eq_fn { }; LEAN_THREAD_VALUE(bool, g_heq_based, false); +LEAN_THREAD_VALUE(bool, g_propagate_subsingletons, false); static list> rel_names_from_arg_kinds(list const & kinds, name const & R) { return map2>(kinds, [&](congr_arg_kind k) { @@ -125,7 +135,8 @@ static void push_todo(name const & R, expr const & lhs, expr const & rhs, expr c scope_congruence_closure::scope_congruence_closure(): m_old_cache(g_congr_cache) { g_congr_cache = new congr_cache(); - g_heq_based = is_standard(env()) && get_blast_cc_heq(ios().get_options()); + g_heq_based = is_standard(env()) && get_blast_cc_heq(ios().get_options()); + g_propagate_subsingletons = get_blast_cc_subsingleton(ios().get_options()); } scope_congruence_closure::~scope_congruence_closure() { @@ -148,6 +159,64 @@ void congruence_closure::initialize() { add_eqv(get_eq_name(), nat_zero, zero_nat, b.mk_eq_refl(nat_zero)); } +void congruence_closure::push_subsingleton_eq(expr const & a, expr const & b) { + expr A = infer_type(a); + expr B = infer_type(b); + if (is_def_eq(A, B)) { + // TODO(Leo): to improve performance we can create the following proof lazily + bool heq_proof = false; + expr proof = get_app_builder().mk_app(get_subsingleton_elim_name(), a, b); + push_todo(get_eq_name(), a, b, proof, heq_proof); + } else if (g_heq_based) { + bool heq_proof = true; + expr A_eq_B = *get_eqv_proof(get_eq_name(), A, B); + expr proof = get_app_builder().mk_app(get_subsingleton_helim_name(), A_eq_B, a, b); + push_todo(get_eq_name(), a, b, proof, heq_proof); + } +} + +void congruence_closure::check_new_subsingleton_eq(expr const & old_root, expr const & new_root) { + lean_assert(is_eqv(get_eq_name(), old_root, new_root)); + lean_assert(get_root(get_eq_name(), old_root) == new_root); + auto it1 = m_subsingleton_reprs.find(old_root); + if (!it1) return; + if (auto it2 = m_subsingleton_reprs.find(new_root)) { + push_subsingleton_eq(*it1, *it2); + } else { + m_subsingleton_reprs.insert(new_root, *it1); + } +} + +/* If \c typeof(e) is a subsingleton, then try to propagate equality */ +void congruence_closure::process_subsingleton_elem(expr const & e) { + if (!g_propagate_subsingletons) + return; + expr type = infer_type(e); + optional ss = mk_subsingleton_instance(type); + if (!ss) + return; /* type is not a subsingleton */ + /* Make sure type has been internalized */ + bool toplevel = true; + bool propagate = false; + internalize_core(get_eq_name(), type, toplevel, propagate); + /* Try to find representative */ + if (auto it = m_subsingleton_reprs.find(type)) { + push_subsingleton_eq(e, *it); + } else { + m_subsingleton_reprs.insert(type, e); + } + if (!g_heq_based) + return; + expr type_root = get_root(get_eq_name(), type); + if (type_root == type) + return; + if (auto it2 = m_subsingleton_reprs.find(type_root)) { + push_subsingleton_eq(e, *it2); + } else { + m_subsingleton_reprs.insert(type_root, e); + } +} + void congruence_closure::mk_entry_core(name const & R, expr const & e, bool to_propagate, bool interpreted, bool constructor) { lean_assert(!m_entries.find(eqc_key(R, e))); entry n; @@ -180,6 +249,7 @@ void congruence_closure::mk_entry_core(name const & R, expr const & e, bool to_p } } } + process_subsingleton_elem(e); } void congruence_closure::mk_entry_core(name const & R, expr const & e, bool to_propagate) { @@ -738,7 +808,6 @@ void congruence_closure::internalize_core(name R, expr const & e, bool toplevel, lean_assert(closed(e)); if (g_heq_based && R == get_heq_name()) R = get_eq_name(); - // we allow metavariables after partitions have been frozen if (has_expr_metavar(e) && !m_froze_partitions) return; @@ -1055,7 +1124,9 @@ void congruence_closure::add_eqv_step(name const & R, expr e1, expr e2, expr con } update_mt(R, e2_root); - + if (R == get_eq_name()) { + check_new_subsingleton_eq(e1_root, e2_root); + } lean_trace(name({"cc", "merge"}), tout() << ppb(e1_root) << " [" << R << "] " << ppb(e2_root) << "\n";); lean_trace(name({"cc", "state"}), trace();); } @@ -1711,16 +1782,21 @@ void initialize_congruence_closure() { g_iff_true_mark = new expr(mk_constant(name(prefix, "[iff-true]"))); g_lift_mark = new expr(mk_constant(name(prefix, "[lift]"))); - g_blast_cc_heq = new name{"blast", "cc", "heq"}; + g_blast_cc_heq = new name{"blast", "cc", "heq"}; + g_blast_cc_subsingleton = new name{"blast", "cc", "subsingleton"}; register_bool_option(*g_blast_cc_heq, LEAN_DEFAULT_BLAST_CC_HEQ, "(blast) enable support for heterogeneous equality " "and more general congruence lemmas in the congruence closure module " "(this option is ignore in HoTT mode)"); + + register_bool_option(*g_blast_cc_subsingleton, LEAN_DEFAULT_BLAST_CC_SUBSINGLETON, + "(blast) enable support for subsingleton equality propagation in congruence closure module"); } void finalize_congruence_closure() { delete g_blast_cc_heq; + delete g_blast_cc_subsingleton; delete g_congr_mark; delete g_iff_true_mark; delete g_lift_mark; diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index ab76d27f9..ef6c681d1 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -106,10 +106,12 @@ class congruence_closure { typedef rb_tree parent_occ_set; typedef rb_map parents; typedef rb_tree congruences; - - entries m_entries; - parents m_parents; - congruences m_congruences; + typedef rb_map subsingleton_reprs; + entries m_entries; + parents m_parents; + congruences m_congruences; + /** The following mapping store a representative for each subsingleton type */ + subsingleton_reprs m_subsingleton_reprs; list m_non_eq_relations; /** The congruence closure module has a mode where the root of each equivalence class is marked as an interpreted/abstract @@ -130,6 +132,9 @@ class congruence_closure { congr_key mk_congr_key(ext_congr_lemma const & lemma, expr const & e) const; void check_iff_true(congr_key const & k); + void push_subsingleton_eq(expr const & a, expr const & b); + void check_new_subsingleton_eq(expr const & old_root, expr const & new_root); + void process_subsingleton_elem(expr const & e); void mk_entry_core(name const & R, expr const & e, bool to_propagate, bool interpreted, bool constructor); void mk_entry_core(name const & R, expr const & e, bool to_propagate); void mk_entry(name const & R, expr const & e, bool to_propagate); diff --git a/tests/lean/run/blast_cc_heq6.lean b/tests/lean/run/blast_cc_heq6.lean new file mode 100644 index 000000000..94e7d6c74 --- /dev/null +++ b/tests/lean/run/blast_cc_heq6.lean @@ -0,0 +1,18 @@ +import data.unit +open unit + +set_option blast.strategy "cc" +set_option blast.cc.subsingleton true +set_option blast.cc.heq true + +example (a b : unit) : a = b := +by blast + +example (a b : nat) (h₁ : a = 0) (h₂ : b = 0) : a = b → h₁ == h₂ := +by blast + +definition inv' : ∀ (a : nat), a ≠ 0 → nat := +sorry + +example (a b : nat) (h₁ : a ≠ 0) (h₂ : b ≠ 0) : a = b → inv' a h₁ = inv' b h₂ := +by blast