feat(library/blast/congruence_closure): support for congruence lemmas that use heterogeneous equality

This commit is contained in:
Leonardo de Moura 2016-01-10 13:45:40 -08:00
parent 934f3b67ff
commit ea7da31bba
4 changed files with 99 additions and 33 deletions

View file

@ -79,21 +79,21 @@ ext_congr_lemma::ext_congr_lemma(congr_lemma const & H):
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
m_lift_needed(false),
m_fixed_fun(true),
m_heq_based(false) {}
m_heq_result(false) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based):
m_R(R),
m_congr_lemma(H),
m_rel_names(rel_names_from_arg_kinds(H.get_arg_kinds(), get_eq_name())),
m_lift_needed(lift_needed),
m_fixed_fun(true),
m_heq_based(heq_based) {}
m_heq_result(heq_based) {}
ext_congr_lemma::ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed):
m_R(R),
m_congr_lemma(H),
m_rel_names(rel_names),
m_lift_needed(lift_needed),
m_fixed_fun(true),
m_heq_based(false) {}
m_heq_result(false) {}
/* We use the following cache for user-defined lemmas and automatically generated ones. */
typedef std::unordered_map<congr_lemma_key, optional<ext_congr_lemma>, congr_lemma_key_hash_fn, congr_lemma_key_eq_fn> congr_cache;
@ -381,8 +381,8 @@ static optional<ext_congr_lemma> mk_ext_specialized_congr_lemma(name const & R,
if (R == get_eq_name())
return optional<ext_congr_lemma>(res1);
bool lift_needed = true;
bool heq_based = false;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_based);
bool heq_result = false;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_result);
}
/* Automatically generated congruence lemma based on heterogeneous equality. */
@ -391,23 +391,26 @@ static optional<ext_congr_lemma> mk_hcongr_lemma(name const & R, expr const & fn
if (!eq_congr)
return optional<ext_congr_lemma>();
ext_congr_lemma res1(*eq_congr);
/* If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. */
if (eq_congr->all_eq_kind())
res1.m_fixed_fun = false;
if (R == get_eq_name() || R == get_heq_name())
return optional<ext_congr_lemma>(res1);
/* If R is not equality (=) nor heterogeneous equality (==),
we try to lift, but we can only lift if the congruence lemma produces an equality. */
expr type = eq_congr->get_type();
while (is_pi(type)) type = binding_body(type);
/* If all arguments are Eq kind, then we can use generic congr axiom and consider equality for the function. */
if (!is_heq(type) && eq_congr->all_eq_kind())
res1.m_fixed_fun = false;
lean_assert(is_eq(type) || is_heq(type));
if (R == get_eq_name() || R == get_heq_name()) {
if (is_heq(type))
res1.m_heq_result = true;
return optional<ext_congr_lemma>(res1);
}
/* If R is not equality (=) nor heterogeneous equality (==),
we try to lift, but we can only lift if the congruence lemma produces an equality. */
if (is_heq(type)) {
/* We cannot lift heterogeneous equality. */
return optional<ext_congr_lemma>();
} else {
bool heq_based = !eq_congr->all_eq_kind() || is_heq(type);
bool heq_result = false;
bool lift_needed = true;
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_based);
return optional<ext_congr_lemma>(R, *eq_congr, lift_needed, heq_result);
}
}
@ -684,7 +687,7 @@ void congruence_closure::add_congruence_table(ext_congr_lemma const & lemma, exp
new_entry.m_cg_root = old_k->m_expr;
m_entries.insert(k, new_entry);
// 2. Put new equivalence in the TODO queue
bool heq_proof = false; // TODO(Leo): fix this
bool heq_proof = lemma.m_heq_result;
push_todo(lemma.m_R, e, old_k->m_expr, *g_congr_mark, heq_proof);
} else {
m_congruences.insert(k);
@ -1147,7 +1150,7 @@ bool congruence_closure::is_eqv(name const & R, expr const & e1, expr const & e2
return n1->m_root == n2->m_root;
}
expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs) const {
expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const {
app_builder & b = get_app_builder();
buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
@ -1163,7 +1166,10 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
lean_assert(*it1 && *it2);
switch (head(*it2)) {
case congr_arg_kind::HEq:
lean_unreachable();
lemma_args.push_back(lhs_args[i]);
lemma_args.push_back(rhs_args[i]);
lemma_args.push_back(*get_eqv_proof(get_heq_name(), lhs_args[i], rhs_args[i]));
break;
case congr_arg_kind::Eq:
lean_assert(head(*it1));
lemma_args.push_back(lhs_args[i]);
@ -1184,8 +1190,14 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
it2 = &(tail(*it2));
}
expr r = mk_app(lemma->m_congr_lemma.get_proof(), lemma_args);
if (lemma->m_lift_needed)
if (lemma->m_lift_needed) {
lean_assert(!lemma->m_heq_result);
r = b.lift_from_eq(R, r);
}
if (lemma->m_heq_result && !heq_proofs)
r = b.mk_eq_of_heq(r);
else if (!lemma->m_heq_result && heq_proofs)
r = b.mk_heq_of_eq(r);
return r;
} else {
optional<expr> r;
@ -1222,7 +1234,7 @@ expr congruence_closure::mk_congr_proof_core(name const & R, expr const & lhs, e
}
}
expr congruence_closure::mk_congr_proof(name const & R, expr const & e1, expr const & e2) const {
expr congruence_closure::mk_congr_proof(name const & R, expr const & e1, expr const & e2, bool heq_proofs) const {
name R1; expr lhs1, rhs1;
if (is_equivalence_relation_app(e1, R1, lhs1, rhs1)) {
name R2; expr lhs2, rhs2;
@ -1243,16 +1255,16 @@ expr congruence_closure::mk_congr_proof(name const & R, expr const & e1, expr co
if (R != get_eq_name())
e1_eqv_new_e1 = b.lift_from_eq(R, e1_eqv_new_e1);
}
return b.mk_trans(R, e1_eqv_new_e1, mk_congr_proof_core(R, new_e1, e2));
return b.mk_trans(R, e1_eqv_new_e1, mk_congr_proof_core(R, new_e1, e2, heq_proofs));
}
}
}
return mk_congr_proof_core(R, e1, e2);
return mk_congr_proof_core(R, e1, e2, heq_proofs);
}
expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const {
expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const {
if (H == *g_congr_mark) {
return mk_congr_proof(R, lhs, rhs);
return mk_congr_proof(R, lhs, rhs, heq_proofs);
} else if (H == *g_iff_true_mark) {
bool flip;
name R1; expr a, b;
@ -1275,13 +1287,13 @@ expr congruence_closure::mk_proof(name const & R, expr const & lhs, expr const &
}
}
static expr flip_proof(name const & R, expr const & H, bool flipped, bool has_heq_proofs) {
static expr flip_proof(name const & R, expr const & H, bool flipped, bool heq_proofs) {
if (H == *g_congr_mark || H == *g_iff_true_mark || H == *g_lift_mark) {
return H;
} else {
auto & b = get_app_builder();
expr new_H = H;
if (has_heq_proofs && is_eq(relaxed_whnf(infer_type(new_H)))) {
if (heq_proofs && is_eq(relaxed_whnf(infer_type(new_H)))) {
new_H = b.mk_heq_of_eq(new_H);
}
if (!flipped) {
@ -1361,13 +1373,13 @@ optional<expr> congruence_closure::get_eqv_proof(name const & R, expr const & e1
optional<expr> pr;
expr lhs = e1;
for (unsigned i = 0; i < path1.size(); i++) {
pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path1[i], Hs1[i]));
pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path1[i], Hs1[i], heq_proofs));
lhs = path1[i];
}
unsigned i = Hs2.size();
while (i > 0) {
--i;
pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path2[i], Hs2[i]));
pr = mk_trans(R_trans, pr, mk_proof(R, lhs, path2[i], Hs2[i], heq_proofs));
lhs = path2[i];
}
lean_assert(pr);

View file

@ -146,9 +146,9 @@ class congruence_closure {
void add_eqv_core(name const & R, expr const & lhs, expr const & rhs, expr const & H, optional<expr> const & added_prop, bool heq_proof);
void propagate_no_confusion_eq(expr const & e1, expr const & e2);
expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs) const;
expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs) const;
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H) const;
expr mk_congr_proof_core(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const;
expr mk_congr_proof(name const & R, expr const & lhs, expr const & rhs, bool heq_proofs) const;
expr mk_proof(name const & R, expr const & lhs, expr const & rhs, expr const & H, bool heq_proofs) const;
bool has_heq_proofs(expr const & root) const;
@ -261,10 +261,10 @@ struct ext_congr_lemma {
/* If m_fixed_fun is false, then we build equivalences for functions, and use generic congr lemma, and ignore m_congr_lemma.
That is, even the function can be treated as an Eq argument. */
unsigned m_fixed_fun:1;
/* If m_uses_heq is true, then lemma is based on heterogeneous equality. */
unsigned m_heq_based:1;
/* If m_heq_result is true, then lemma is based on heterogeneous equality and the conclusion is a heterogeneous equality. */
unsigned m_heq_result:1;
ext_congr_lemma(congr_lemma const & H);
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_based);
ext_congr_lemma(name const & R, congr_lemma const & H, bool lift_needed, bool heq_result);
ext_congr_lemma(name const & R, congr_lemma const & H, list<optional<name>> const & rel_names, bool lift_needed);
name const & get_relation() const { return m_R; }

View file

@ -0,0 +1,17 @@
set_option blast.strategy "cc"
set_option blast.cc.heq true -- make sure heterogeneous congruence lemmas are enabled
axiom vector.{l} : Type.{l} → nat → Type.{l}
axiom app : Π {A : Type} {n m : nat}, vector A m → vector A n → vector A (m+n)
example (n1 n2 n3 : nat) (v1 w1 : vector nat n1) (w1' : vector nat n3) (v2 w2 : vector nat n2) :
n1 = n3 → v1 = w1 → w1 == w1' → v2 = w2 → app v1 v2 == app w1' w2 :=
by blast
example (n1 n2 n3 : nat) (v1 w1 : vector nat n1) (w1' : vector nat n3) (v2 w2 : vector nat n2) :
n1 == n3 → v1 = w1 → w1 == w1' → v2 == w2 → app v1 v2 == app w1' w2 :=
by blast
example (n1 n2 n3 : nat) (v1 w1 v : vector nat n1) (w1' : vector nat n3) (v2 w2 w : vector nat n2) :
n1 == n3 → v1 = w1 → w1 == w1' → v2 == w2 → app w1' w2 == app v w → app v1 v2 = app v w :=
by blast

View file

@ -0,0 +1,37 @@
universes l1 l2 l3 l4 l5 l6
constants (A : Type.{l1}) (B : A → Type.{l2}) (C : ∀ (a : A) (ba : B a), Type.{l3})
(D : ∀ (a : A) (ba : B a) (cba : C a ba), Type.{l4})
(E : ∀ (a : A) (ba : B a) (cba : C a ba) (dcba : D a ba cba), Type.{l5})
(F : ∀ (a : A) (ba : B a) (cba : C a ba) (dcba : D a ba cba) (edcba : E a ba cba dcba), Type.{l6})
(C_ss : ∀ a ba, subsingleton (C a ba))
(a1 a2 a3 : A)
(mk_B1 mk_B2 : ∀ a, B a)
(mk_C1 mk_C2 : ∀ {a} ba, C a ba)
(tr_B : ∀ {a}, B a → B a)
(x y z : A → A)
(f f' : ∀ {a : A} {ba : B a} (cba : C a ba), D a ba cba)
(g : ∀ {a : A} {ba : B a} {cba : C a ba} (dcba : D a ba cba), E a ba cba dcba)
(h : ∀ {a : A} {ba : B a} {cba : C a ba} {dcba : D a ba cba} (edcba : E a ba cba dcba), F a ba cba dcba edcba)
attribute C_ss [instance]
set_option blast.strategy "cc"
set_option blast.cc.heq true
example : ∀ {a a' : A}, a == a' → mk_B1 a == mk_B1 a' :=
by blast
example : ∀ {a a' : A}, a == a' → mk_B2 a == mk_B2 a' :=
by blast
example : a1 == y a2 → mk_B1 a1 == mk_B1 (y a2) :=
by blast
example : a1 == x a2 → a2 == y a1 → mk_B1 (x (y a1)) == mk_B1 (x (y (x a2))) :=
by blast
-- The following one needs subsingleton support
-- example : a1 == y a2 → mk_B1 a1 == mk_B2 (y a2) → f (mk_C1 (mk_B2 a1)) == f (mk_C2 (mk_B1 (y a2))) :=
-- by blast