feat(library/blast/congruence_closure): add support for specialized congr lemmas in the congruence closure module

This commit is contained in:
Leonardo de Moura 2016-01-06 17:30:20 -08:00
parent ef691d6cf5
commit cb02d1deae
7 changed files with 114 additions and 26 deletions

View file

@ -755,6 +755,10 @@ public:
return m_fun_info_manager.get_specialized(a);
}
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) {
return m_fun_info_manager.get_specialization_prefix_size(fn, nargs);
}
unsigned abstract_hash(expr const & e) {
return m_abstract_expr_manager.hash(e);
}
@ -1140,6 +1144,11 @@ fun_info get_specialized_fun_info(expr const & a) {
return g_blastenv->get_specialized_fun_info(a);
}
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs) {
lean_assert(g_blastenv);
return g_blastenv->get_specialization_prefix_size(fn, nargs);
}
unsigned abstract_hash(expr const & e) {
lean_assert(g_blastenv);
return g_blastenv->abstract_hash(e);

View file

@ -160,6 +160,8 @@ fun_info get_fun_info(expr const & fn, unsigned nargs);
taking into account the actual arguments.
\pre is_app(a) */
fun_info get_specialized_fun_info(expr const & a);
/** \brief Return the given function specialization prefix size. */
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
/** \brief Hash and equality test for abstract expressions */
unsigned abstract_hash(expr const & e);

View file

@ -192,7 +192,6 @@ static optional<ext_congr_lemma> to_ext_congr_lemma(name const & R, expr const &
Rcs.resize(lhs_args.size(), optional<name>());
r_hyps.resize(lhs_args.size(), none_expr());
// Set Fixed args
// TODO(Leo): handle FixedNoParam case?
for (unsigned i = 0; i < lhs_args.size(); i++) {
if (lhs_args[i] == rhs_args[i])
kinds[i] = congr_arg_kind::Fixed;
@ -243,7 +242,8 @@ static optional<ext_congr_lemma> to_ext_congr_lemma(name const & R, expr const &
}
switch (kinds[i]) {
case congr_arg_kind::FixedNoParam:
// TODO(Leo): revise this code
// User defined congruence rules do not use FixedNoParam
lean_unreachable();
break;
case congr_arg_kind::Fixed:
break;
@ -283,7 +283,7 @@ static optional<ext_congr_lemma> to_ext_congr_lemma(name const & R, expr const &
return optional<ext_congr_lemma>(R, new_lemma, to_list(Rcs), lift_needed);
}
static optional<ext_congr_lemma> mk_ext_congr_lemma_core(name const & R, expr const & fn, unsigned nargs) {
static optional<ext_congr_lemma> mk_ext_user_congr_lemma(name const & R, expr const & fn, unsigned nargs) {
simp_lemmas_for const * sr = get_simp_lemmas().find(R);
if (sr) {
list<user_congr_lemma> const * crs = sr->find_congr(fn);
@ -294,8 +294,11 @@ static optional<ext_congr_lemma> mk_ext_congr_lemma_core(name const & R, expr co
}
}
}
return optional<ext_congr_lemma>();
}
// Automatically generated lemma for equivalence relation over iff/eq
/* Automatically generated lemma for equivalence relation over iff/eq. */
static optional<ext_congr_lemma> mk_relation_congr_lemma(name const & R, expr const & fn, unsigned nargs) {
if (auto info = is_relation(fn)) {
if (info->get_arity() == nargs) {
if (R == get_iff_name()) {
@ -311,9 +314,13 @@ static optional<ext_congr_lemma> mk_ext_congr_lemma_core(name const & R, expr co
}
}
}
return optional<ext_congr_lemma>();
}
// Automatically generated lemma
optional<congr_lemma> eq_congr = mk_congr_lemma(fn, nargs);
/* Automatically generated lemma for function application \c e. The lemma is specialized using the
specialization prefix for \c e. */
static optional<ext_congr_lemma> mk_ext_specialized_congr_lemma(name const & R, expr const & e) {
optional<congr_lemma> eq_congr = mk_specialized_congr_lemma(e);
if (!eq_congr)
return optional<ext_congr_lemma>();
ext_congr_lemma res1(*eq_congr);
@ -326,14 +333,45 @@ static optional<ext_congr_lemma> mk_ext_congr_lemma_core(name const & R, expr co
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed);
}
optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs) {
congr_lemma_key key(R, fn, nargs);
auto it = g_congr_cache->find(key);
if (it != g_congr_cache->end())
return it->second;
auto r = mk_ext_congr_lemma_core(R, fn, nargs);
g_congr_cache->insert(mk_pair(key, r));
return r;
optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e) {
expr const & fn = get_app_fn(e);
unsigned nargs = get_app_num_args(e);
/* Check if (R, fn, nargs) is in the cache */
congr_lemma_key key1(R, fn, nargs);
auto it1 = g_congr_cache->find(key1);
if (it1 != g_congr_cache->end())
return it1->second;
/* Check if (g := fn+specialization prefix) is in the cache */
unsigned prefix_sz = get_specialization_prefix_size(fn, nargs);
unsigned rest_nargs = nargs - prefix_sz;
expr g = e;
for (unsigned i = 0; i < rest_nargs; i++) g = app_fn(g);
congr_lemma_key key2(R, g, rest_nargs);
auto it2 = g_congr_cache->find(key2);
if (it2 != g_congr_cache->end())
return it2->second;
/* Check if there is user defined lemma for (R, fn, nargs).
Remark: specialization prefix is irrelevan for used defined congruence lemmas. */
if (auto lemma = mk_ext_user_congr_lemma(R, fn, nargs)) {
g_congr_cache->insert(mk_pair(key1, lemma));
return lemma;
}
/* Try automatically generated lemma for equivalence relation over iff/eq */
if (auto lemma = mk_relation_congr_lemma(R, fn, nargs)) {
g_congr_cache->insert(mk_pair(key1, lemma));
return lemma;
}
/* Try automatically generated specialized congruence lemma */
if (auto lemma = mk_ext_specialized_congr_lemma(R, e)) {
if (prefix_sz == 0)
g_congr_cache->insert(mk_pair(key1, lemma));
else
g_congr_cache->insert(mk_pair(key2, lemma));
return lemma;
}
/* cache failure */
g_congr_cache->insert(mk_pair(key1, optional<ext_congr_lemma>()));
return optional<ext_congr_lemma>();
}
void congruence_closure::update_non_eq_relations(name const & R) {
@ -412,7 +450,7 @@ int congruence_closure::congr_key_cmp::operator()(congr_key const & k1, congr_ke
expr const & fn2 = get_app_args(k2.m_expr, args2);
if (args1.size() != args2.size())
return unsigned_cmp()(args1.size(), args2.size());
auto lemma = mk_ext_congr_lemma(k1.m_R, fn1, args1.size());
auto lemma = mk_ext_congr_lemma(k1.m_R, k1.m_expr);
lean_assert(lemma);
if (!lemma->m_fixed_fun) {
int r = g_cc->compare_root(get_eq_name(), fn1, fn2);
@ -552,6 +590,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
check_iff_true(k);
}
// TODO(Leo): this should not be hard-coded
static bool is_logical_app(expr const & n) {
if (!is_app(n)) return false;
expr const & fn = get_app_fn(n);
@ -614,7 +653,7 @@ void congruence_closure::internalize_core(name const & R, expr const & e, bool t
} else {
to_propagate = false;
}
if (auto lemma = mk_ext_congr_lemma(R, fn, args.size())) {
if (auto lemma = mk_ext_congr_lemma(R, e)) {
list<optional<name>> const * it = &(lemma->m_rel_names);
for (expr const & arg : args) {
lean_assert(*it);
@ -681,9 +720,7 @@ void congruence_closure::remove_parents(name const & R, expr const & e) {
auto ps = m_parents.find(child_key(R, e));
if (!ps) return;
ps->for_each([&](parent_occ const & p) {
expr const & fn = get_app_fn(p.m_expr);
unsigned nargs = get_app_num_args(p.m_expr);
auto lemma = mk_ext_congr_lemma(p.m_R, fn, nargs);
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
lean_assert(lemma);
congr_key k = mk_congr_key(*lemma, p.m_expr);
m_congruences.erase(k);
@ -694,9 +731,7 @@ void congruence_closure::reinsert_parents(name const & R, expr const & e) {
auto ps = m_parents.find(child_key(R, e));
if (!ps) return;
ps->for_each([&](parent_occ const & p) {
expr const & fn = get_app_fn(p.m_expr);
unsigned nargs = get_app_num_args(p.m_expr);
auto lemma = mk_ext_congr_lemma(p.m_R, fn, nargs);
auto lemma = mk_ext_congr_lemma(p.m_R, p.m_expr);
lean_assert(lemma);
add_congruence_table(*lemma, p.m_expr);
});
@ -955,7 +990,7 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
expr const & lhs_fn = get_app_args(lhs, lhs_args);
expr const & rhs_fn = get_app_args(rhs, rhs_args);
lean_assert(lhs_args.size() == rhs_args.size());
auto lemma = mk_ext_congr_lemma(R, lhs_fn, lhs_args.size());
auto lemma = mk_ext_congr_lemma(R, lhs);
lean_assert(lemma);
if (lemma->m_fixed_fun) {
list<optional<name>> const * it1 = &lemma->m_rel_names;

View file

@ -261,9 +261,9 @@ struct ext_congr_lemma {
list<optional<name>> const & get_arg_rel_names() const { return m_rel_names; }
};
/** \brief Build an extended congruence lemma for function \c fn with \c nargs expected arguments over relation \c R.
/** \brief Build an extended congruence lemma for function the function application \c e over relation \c R.
A subset of user-defined congruence lemmas is considered by this procedure. */
optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs);
optional<ext_congr_lemma> mk_ext_congr_lemma(name const & R, expr const & e);
void initialize_congruence_closure();
void finalize_congruence_closure();

View file

@ -293,7 +293,7 @@ struct ematch_fn {
}
bool match_args(state & s, name const & R, buffer<expr> const & p_args, expr const & t) {
optional<ext_congr_lemma> cg_lemma = mk_ext_congr_lemma(R, get_app_fn(t), p_args.size());
optional<ext_congr_lemma> cg_lemma = mk_ext_congr_lemma(R, t);
if (!cg_lemma)
return false;
buffer<expr> t_args;

View file

@ -0,0 +1,28 @@
import data.unit
open nat unit
constant f {A : Type} (a : A) {B : Type} (b : B) : nat
constant g : unit → nat
set_option blast.strategy "cc"
example (a b : unit) : g a = g b :=
by blast
example (a c : unit) (b d : nat) : b = d → f a b = f c d :=
by blast
constant h {A B : Type} : A → B → nat
example (a b c d : unit) : h a b = h c d :=
by blast
definition C [reducible] : nat → Type₁
| nat.zero := unit
| (nat.succ a) := nat
constant g₂ : Π {n : nat}, C n → nat → nat
example (a b : C zero) (c d : nat) : c = d → g₂ a c = g₂ b d :=
by blast

View file

@ -0,0 +1,14 @@
import data.unit
open nat unit
set_option blast.strategy "cc"
constant r {A B : Type} : A → B → A
definition ex1 (a b c d : unit) : r a b = r c d :=
by blast
-- The congruence closure module does not automatically merge subsingleton equivalence classes.
--
-- example (a b : unit) : a = b :=
-- by blast