diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index 4ca7158db..59ae8e962 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -810,6 +810,48 @@ static bool is_logical_app(expr const & n) { (const_name(fn) == get_ite_name() && is_prop(n))); } +/* This method is invoked during internalization and eagerly apply basic equivalences for term \c e + Examples: + - If e := cast H e', then it merges the equivalence classes of (cast H e') and e' + + In principle, we could mark theorems such as cast_eq as simplification rules, but this creates + problems with the builtin support for cast-introduction in the ematching module. + + Eagerly merging the equivalence classes is also more efficient. */ +void congruence_closure::apply_simple_eqvs(expr const & e) { + if (g_heq_based) { + /* equivalences when == support is enabled */ + if (is_app_of(e, get_cast_name(), 4)) { + /* cast H a == a + + theorem cast_heq : ∀ {A B : Type.{l_1}} (H : A = B) (a : A), @cast.{l_1} A B H a == a + */ + buffer args; + expr const & cast = get_app_args(e, args); + expr const & a = args[3]; + expr proof = mk_app(mk_constant(get_cast_heq_name(), const_levels(cast)), args); + bool heq_proof = true; + push_todo(get_eq_name(), e, a, proof, heq_proof); + } + + if (is_app_of(e, get_eq_rec_name(), 6)) { + /* eq.rec p H == p + + theorem eq_rec_heq : ∀ {A : Type.{l_1}} {P : A → Type.{l_2}} {a a' : A} (H : a = a') (p : P a), @eq.rec.{l_2 l_1} A a P p a' H == p + */ + buffer args; + expr const & eq_rec = get_app_args(e, args); + expr A = args[0]; expr a = args[1]; expr P = args[2]; expr p = args[3]; + expr a_prime = args[4]; expr H = args[5]; + level l_2 = head(const_levels(eq_rec)); + level l_1 = head(tail(const_levels(eq_rec))); + expr proof = mk_app({mk_constant(get_eq_rec_heq_name(), {l_1, l_2}), A, P, a, a_prime, H, p}); + bool heq_proof = true; + push_todo(get_eq_name(), e, p, proof, heq_proof); + } + } +} + void congruence_closure::internalize_core(name R, expr const & e, bool toplevel, bool to_propagate) { lean_assert(closed(e)); if (g_heq_based && R == get_heq_name()) @@ -878,6 +920,7 @@ void congruence_closure::internalize_core(name R, expr const & e, bool toplevel, } add_congruence_table(*lemma, e); } + apply_simple_eqvs(e); break; }} } diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index ef6c681d1..06c513d93 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -122,6 +122,7 @@ class congruence_closure { unsigned m_gmt{0}; void update_non_eq_relations(name const & R); + void apply_simple_eqvs(expr const & e); void register_to_propagate(expr const & e); void internalize_core(name R, expr const & e, bool toplevel, bool to_propagate); void process_todo(optional const & added_prop); diff --git a/src/library/constants.cpp b/src/library/constants.cpp index e5c40541a..7a234e40c 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -19,6 +19,8 @@ name const * g_bool = nullptr; name const * g_bool_ff = nullptr; name const * g_bool_tt = nullptr; name const * g_cast = nullptr; +name const * g_cast_eq = nullptr; +name const * g_cast_heq = nullptr; name const * g_char = nullptr; name const * g_char_mk = nullptr; name const * g_classical = nullptr; @@ -46,6 +48,7 @@ name const * g_eq_subst = nullptr; name const * g_eq_symm = nullptr; name const * g_eq_trans = nullptr; name const * g_eq_of_heq = nullptr; +name const * g_eq_rec_heq = nullptr; name const * g_exists_elim = nullptr; name const * g_false = nullptr; name const * g_false_of_true_iff_false = nullptr; @@ -284,6 +287,8 @@ void initialize_constants() { g_bool_ff = new name{"bool", "ff"}; g_bool_tt = new name{"bool", "tt"}; g_cast = new name{"cast"}; + g_cast_eq = new name{"cast_eq"}; + g_cast_heq = new name{"cast_heq"}; g_char = new name{"char"}; g_char_mk = new name{"char", "mk"}; g_classical = new name{"classical"}; @@ -311,6 +316,7 @@ void initialize_constants() { g_eq_symm = new name{"eq", "symm"}; g_eq_trans = new name{"eq", "trans"}; g_eq_of_heq = new name{"eq_of_heq"}; + g_eq_rec_heq = new name{"eq_rec_heq"}; g_exists_elim = new name{"exists", "elim"}; g_false = new name{"false"}; g_false_of_true_iff_false = new name{"false_of_true_iff_false"}; @@ -550,6 +556,8 @@ void finalize_constants() { delete g_bool_ff; delete g_bool_tt; delete g_cast; + delete g_cast_eq; + delete g_cast_heq; delete g_char; delete g_char_mk; delete g_classical; @@ -577,6 +585,7 @@ void finalize_constants() { delete g_eq_symm; delete g_eq_trans; delete g_eq_of_heq; + delete g_eq_rec_heq; delete g_exists_elim; delete g_false; delete g_false_of_true_iff_false; @@ -815,6 +824,8 @@ name const & get_bool_name() { return *g_bool; } name const & get_bool_ff_name() { return *g_bool_ff; } name const & get_bool_tt_name() { return *g_bool_tt; } name const & get_cast_name() { return *g_cast; } +name const & get_cast_eq_name() { return *g_cast_eq; } +name const & get_cast_heq_name() { return *g_cast_heq; } name const & get_char_name() { return *g_char; } name const & get_char_mk_name() { return *g_char_mk; } name const & get_classical_name() { return *g_classical; } @@ -842,6 +853,7 @@ name const & get_eq_subst_name() { return *g_eq_subst; } name const & get_eq_symm_name() { return *g_eq_symm; } name const & get_eq_trans_name() { return *g_eq_trans; } name const & get_eq_of_heq_name() { return *g_eq_of_heq; } +name const & get_eq_rec_heq_name() { return *g_eq_rec_heq; } name const & get_exists_elim_name() { return *g_exists_elim; } name const & get_false_name() { return *g_false; } name const & get_false_of_true_iff_false_name() { return *g_false_of_true_iff_false; } diff --git a/src/library/constants.h b/src/library/constants.h index 8a231b673..0872a185e 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -21,6 +21,8 @@ name const & get_bool_name(); name const & get_bool_ff_name(); name const & get_bool_tt_name(); name const & get_cast_name(); +name const & get_cast_eq_name(); +name const & get_cast_heq_name(); name const & get_char_name(); name const & get_char_mk_name(); name const & get_classical_name(); @@ -48,6 +50,7 @@ name const & get_eq_subst_name(); name const & get_eq_symm_name(); name const & get_eq_trans_name(); name const & get_eq_of_heq_name(); +name const & get_eq_rec_heq_name(); name const & get_exists_elim_name(); name const & get_false_name(); name const & get_false_of_true_iff_false_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index 53dae95b8..d71ac62ec 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -14,6 +14,8 @@ bool bool.ff bool.tt cast +cast_eq +cast_heq char char.mk classical @@ -41,6 +43,7 @@ eq.subst eq.symm eq.trans eq_of_heq +eq_rec_heq exists.elim false false_of_true_iff_false