feat(library/blast): add 'ematch_simp' strategy for blast and msimp shortcut for it.

This strategy is based on ematching and congruence closure, but it uses
the [simp] lemmas instead of [forward] lemmas.
This commit is contained in:
Leonardo de Moura 2015-12-29 20:04:31 -08:00
parent 8c87f90a29
commit 0148bb08fd
8 changed files with 171 additions and 127 deletions

View file

@ -155,9 +155,7 @@ section group
by simp
theorem inv_eq_of_mul_eq_one {a b : A} (H : a * b = 1) : a⁻¹ = b :=
calc a⁻¹ = a⁻¹ * 1 : by simp
... = a⁻¹ * (a * b) : by simp
... = b : by simp_nohyps
by msimp
theorem one_inv [simp] : 1⁻¹ = (1 : A) :=
inv_eq_of_mul_eq_one (one_mul 1)
@ -166,7 +164,7 @@ section group
inv_eq_of_mul_eq_one (mul.left_inv a)
theorem inv.inj {a b : A} (H : a⁻¹ = b⁻¹) : a = b :=
by rewrite [-inv_inv a, H, inv_inv b]
by msimp
theorem inv_eq_inv_iff_eq (a b : A) : a⁻¹ = b⁻¹ ↔ a = b :=
iff.intro (assume H, inv.inj H) (by simp)
@ -187,33 +185,19 @@ section group
begin apply eq_inv_of_eq_inv, symmetry, exact inv_eq_of_mul_eq_one H end
theorem mul.right_inv [simp] (a : A) : a * a⁻¹ = 1 :=
calc
a * a⁻¹ = (a⁻¹)⁻¹ * a⁻¹ : by simp
... = 1 : mul.left_inv
by msimp
theorem mul_inv_cancel_left [simp] (a b : A) : a * (a⁻¹ * b) = b :=
calc
a * (a⁻¹ * b) = a * a⁻¹ * b : by rewrite mul.assoc
... = 1 * b : by simp
... = b : by simp
by msimp
theorem mul_inv_cancel_right [simp] (a b : A) : a * b * b⁻¹ = a :=
calc
a * b * b⁻¹ = a * (b * b⁻¹) : by simp
... = a * 1 : by simp
... = a : by simp
by msimp
theorem mul_inv [simp] (a b : A) : (a * b)⁻¹ = b⁻¹ * a⁻¹ :=
inv_eq_of_mul_eq_one
(calc
a * b * (b⁻¹ * a⁻¹) = a * (b * (b⁻¹ * a⁻¹)) : by simp
... = a * a⁻¹ : by simp
... = 1 : by simp)
inv_eq_of_mul_eq_one (by msimp)
theorem eq_of_mul_inv_eq_one {a b : A} (H : a * b⁻¹ = 1) : a = b :=
calc
a = a * b⁻¹ * b : by simp_nohyps
... = b : by simp
by msimp
theorem eq_mul_inv_of_mul_eq {a b c : A} (H : a * c = b) : a = b * c⁻¹ :=
by simp
@ -246,13 +230,13 @@ section group
iff.intro eq_mul_inv_of_mul_eq mul_eq_of_eq_mul_inv
theorem mul_left_cancel {a b c : A} (H : a * b = a * c) : b = c :=
by rewrite [-inv_mul_cancel_left a b, H, inv_mul_cancel_left]
by msimp
theorem mul_right_cancel {a b c : A} (H : a * b = c * b) : a = c :=
by rewrite [-mul_inv_cancel_right a b, H, mul_inv_cancel_right]
by msimp
theorem mul_eq_one_of_mul_eq_one {a b : A} (H : b * a = 1) : a * b = 1 :=
by rewrite [-inv_eq_of_mul_eq_one H, mul.left_inv]
by msimp
theorem mul_eq_one_iff_mul_eq_one (a b : A) : a * b = 1 ↔ b * a = 1 :=
iff.intro !mul_eq_one_of_mul_eq_one !mul_eq_one_of_mul_eq_one
@ -263,29 +247,22 @@ section group
local infixl ` ~ ` := is_conjugate
local infixr ` ∘c `:55 := conj_by
local attribute conj_by [reducible]
lemma conj_compose [simp] (f g a : A) : f ∘c g ∘c a = f*g ∘c a :=
calc f ∘c g ∘c a = f * (g * a * g⁻¹) * f⁻¹ : rfl
... = f * g * a * (f * g)⁻¹ : by simp
by msimp
lemma conj_id [simp] (a : A) : 1 ∘c a = a :=
calc 1 * a * 1⁻¹ = a * 1⁻¹ : by simp
... = a : by simp
by msimp
lemma conj_one [simp] (g : A) : g ∘c 1 = 1 :=
calc g * 1 * g⁻¹ = g * g⁻¹ : by simp
... = 1 : by simp
by msimp
lemma conj_inv_cancel [simp] (g : A) : ∀ a, g⁻¹ ∘c g ∘c a = a :=
assume a, calc
g⁻¹ ∘c g ∘c a = g⁻¹*g ∘c a : by simp
... = a : by simp
by msimp
lemma conj_inv [simp] (g : A) : ∀ a, (g ∘c a)⁻¹ = g ∘c a⁻¹ :=
take a, calc
(g * a * g⁻¹)⁻¹ = g⁻¹⁻¹ * (g * a)⁻¹ : by simp
... = g⁻¹⁻¹ * (a⁻¹ * g⁻¹) : by simp
... = g⁻¹⁻¹ * a⁻¹ * g⁻¹ : by simp
... = g * a⁻¹ * g⁻¹ : by simp
by msimp
lemma is_conj.refl (a : A) : a ~ a := exists.intro 1 (conj_id a)
@ -298,10 +275,8 @@ section group
assume Pab, assume Pbc,
obtain x (Px : x ∘c b = a), from Pab,
obtain y (Py : y ∘c c = b), from Pbc,
exists.intro (x*y) (calc
x*y ∘c c = x ∘c y ∘c c : by simp
... = x ∘c b : Py
... = a : Px)
exists.intro (x*y) (by msimp)
end group
definition group.to_left_cancel_semigroup [trans_instance] [reducible] [s : group A] :
@ -340,14 +315,14 @@ section add_group
by simp
theorem neg_eq_of_add_eq_zero {a b : A} (H : a + b = 0) : -a = b :=
by rewrite [-add_zero, -H, neg_add_cancel_left]
by msimp
theorem neg_zero [simp] : -0 = (0 : A) := neg_eq_of_add_eq_zero (zero_add 0)
theorem neg_neg [simp] (a : A) : -(-a) = a := neg_eq_of_add_eq_zero (add.left_inv a)
theorem eq_neg_of_add_eq_zero {a b : A} (H : a + b = 0) : a = -b :=
by rewrite [-neg_eq_of_add_eq_zero H, neg_neg]
by msimp
theorem neg.inj {a b : A} (H : -a = -b) : a = b :=
calc
@ -373,13 +348,10 @@ section add_group
iff.intro !eq_neg_of_eq_neg !eq_neg_of_eq_neg
theorem add.right_inv [simp] (a : A) : a + -a = 0 :=
calc
a + -a = -(-a) + -a : by simp
... = 0 : add.left_inv
by msimp
theorem add_neg_cancel_left [simp] (a b : A) : a + (-a + b) = b :=
calc a + (-a + b) = (a + -a) + b : by rewrite add.assoc
... = b : by simp
by msimp
theorem add_neg_cancel_right [simp] (a b : A) : a + b + -b = a :=
by simp
@ -419,14 +391,10 @@ section add_group
iff.intro eq_add_neg_of_add_eq add_eq_of_eq_add_neg
theorem add_left_cancel {a b c : A} (H : a + b = a + c) : b = c :=
calc b = -a + (a + b) : by simp_nohyps
... = -a + (a + c) : by simp
... = c : by simp
by msimp
theorem add_right_cancel {a b c : A} (H : a + b = c + b) : a = c :=
calc a = (a + b) + -b : by simp_nohyps
... = (c + b) + -b : by simp
... = c : by simp
by msimp
definition add_group.to_left_cancel_semigroup [trans_instance] [reducible] :
add_left_cancel_semigroup A :=
@ -458,10 +426,7 @@ section add_group
theorem add_sub_cancel (a b : A) : a + b - b = a := !add_neg_cancel_right
theorem eq_of_sub_eq_zero {a b : A} (H : a - b = 0) : a = b :=
calc
a = (a - b) + b : by simp_nohyps
... = 0 + b : H
... = b : by simp
by msimp
theorem eq_iff_sub_eq_zero (a b : A) : a = b ↔ a - b = 0 :=
iff.intro (assume H, H ▸ !sub_self) (assume H, eq_of_sub_eq_zero H)
@ -475,18 +440,12 @@ section add_group
by simp
theorem neg_sub (a b : A) : -(a - b) = b - a :=
neg_eq_of_add_eq_zero
(calc
a - b + (b - a) = a - b + b - a : by simp
... = a - a : by simp
... = 0 : by simp)
neg_eq_of_add_eq_zero (by msimp)
theorem add_sub (a b c : A) : a + (b - c) = a + b - c := !add.assoc⁻¹
theorem sub_add_eq_sub_sub_swap (a b c : A) : a - (b + c) = a - c - b :=
calc
a - (b + c) = a + (-c - b) : by simp
... = a - c - b : by simp
by msimp
theorem sub_eq_iff_eq_add (a b c : A) : a - b = c ↔ a = c + b :=
iff.intro (assume H, eq_add_of_add_neg_eq H) (assume H, add_neg_eq_of_eq_add H)
@ -616,9 +575,7 @@ by simp
theorem bit1_add_bit1_helper [add_comm_semigroup A] [has_one A] (a b t s: A)
(H : (a + b) = t) (H2 : add1 t = s) : bit1 a + bit1 b = bit0 s :=
calc bit1 a + bit1 b = bit0 (add1 (a + b)) : by simp_nohyps
... = bit0 (add1 t) : by simp
... = bit0 s : by simp
by msimp
theorem bin_add_zero [add_monoid A] (a : A) : a + zero = a :=
by simp
@ -637,14 +594,14 @@ rfl
theorem bit1_add_one_helper [has_add A] [has_one A] (a t : A) (H : add1 (bit1 a) = t) :
bit1 a + one = t :=
by rewrite -H
by msimp
theorem one_add_bit1 [add_comm_semigroup A] [has_one A] (a : A) : one + bit1 a = add1 (bit1 a) :=
by simp
theorem one_add_bit1_helper [add_comm_semigroup A] [has_one A] (a t : A)
(H : add1 (bit1 a) = t) : one + bit1 a = t :=
by rewrite -H; simp
by msimp
theorem add1_bit0 [has_add A] [has_one A] (a : A) : add1 (bit0 a) = bit1 a :=
rfl
@ -655,7 +612,7 @@ by simp
theorem add1_bit1_helper [add_comm_semigroup A] [has_one A] (a t : A) (H : add1 a = t) :
add1 (bit1 a) = bit0 t :=
by rewrite -H; simp
by msimp
theorem add1_one [has_add A] [has_one A] : add1 (one : A) = bit0 one :=
rfl

View file

@ -110,6 +110,7 @@ definition with_attributes_tac (o : expr) (n : identifier_list) (t : tactic) : t
definition simp : tactic := #tactic with_options [blast.strategy "simp"] blast
definition simp_nohyps : tactic := #tactic with_options [blast.strategy "simp_nohyps"] blast
definition topdown_simp : tactic := #tactic with_options [blast.strategy "simp", simplify.top_down true] blast
definition msimp : tactic := #tactic with_options [blast.strategy "ematch_simp"] blast
definition cases (h : expr) (ids : opt_identifier_list) : tactic := builtin

View file

@ -351,8 +351,6 @@ add_subdirectory(library/blast/backward)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:backward>)
add_subdirectory(library/blast/unit)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:unit>)
add_subdirectory(library/blast/forward)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:forward>)
add_subdirectory(library/blast/actions)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:blast_actions>)
add_subdirectory(library/blast/strategies)
@ -361,6 +359,8 @@ add_subdirectory(library/blast/grinder)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:blast_grinder>)
add_subdirectory(library/blast/simplifier)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:simplifier>)
add_subdirectory(library/blast/forward)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:forward>)
add_subdirectory(compiler)
set(LEAN_OBJS ${LEAN_OBJS} $<TARGET_OBJECTS:compiler>)
add_subdirectory(frontends/lean)

View file

@ -15,42 +15,24 @@ Author: Leonardo de Moura
#include "library/blast/congruence_closure.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/forward_lemmas.h"
#include "library/blast/simplifier/simp_lemmas.h"
namespace lean {
namespace blast {
/*
When a hypothesis hidx is activated:
1- Traverse its type and for each f-application.
If it is the first f-application found, and f is a constant then
retrieve lemmas which contain a multi-pattern starting with f.
#define lean_trace_ematch(Code) lean_trace(name({"blast", "ematch"}), Code)
2- If hypothesis is a proposition and a quantifier,
try to create a hi-lemma for it, and add it to
set of recently activated hi_lemmas
E-match round action
1- For each active hi-lemma L, and mulit-pattern P,
If L has been recently activated, then we ematch ignoring
gmt.
If L has been processed before, we try to ematch starting
at each each element of the multi-pattern.
We only consider the head f-applications that have a mt
equal to gmt
*/
typedef rb_tree<expr, expr_quick_cmp> expr_set;
typedef rb_tree<hi_lemma, hi_lemma_cmp> hi_lemma_set;
static unsigned g_ext_id = 0;
struct ematch_branch_extension : public branch_extension {
hi_lemma_set m_lemmas;
hi_lemma_set m_new_lemmas;
/* Branch extension for supporting heuristic instantiations methods.
It contains
1- mapping functions to its applications.
2- set of lemmas that have been instantiated. */
struct instances_branch_extension : public branch_extension {
rb_map<head_index, expr_set, head_index::cmp> m_apps;
expr_set m_instances;
ematch_branch_extension() {}
ematch_branch_extension(ematch_branch_extension const &) {}
instances_branch_extension() {}
instances_branch_extension(instances_branch_extension const & src):
m_apps(src.m_apps), m_instances(src.m_instances) {}
void collect_apps(expr const & e) {
switch (e.kind()) {
@ -86,6 +68,52 @@ struct ematch_branch_extension : public branch_extension {
}}
}
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override {
collect_apps(h.get_type());
}
virtual void target_updated() override { collect_apps(curr_state().get_target()); }
virtual branch_extension * clone() override { return new instances_branch_extension(*this); }
};
static unsigned g_inst_ext_id = 0;
instances_branch_extension & get_inst_ext() {
return static_cast<instances_branch_extension&>(curr_state().get_extension(g_inst_ext_id));
}
/*
When a hypothesis hidx is activated:
1- Traverse its type and for each f-application.
If it is the first f-application found, and f is a constant then
retrieve lemmas which contain a multi-pattern starting with f.
2- If hypothesis is a proposition and a quantifier,
try to create a hi-lemma for it, and add it to
set of recently activated hi_lemmas
E-match round action
1- For each active hi-lemma L, and mulit-pattern P,
If L has been recently activated, then we ematch ignoring
gmt.
If L has been processed before, we try to ematch starting
at each each element of the multi-pattern.
We only consider the head f-applications that have a mt
equal to gmt
*/
typedef rb_tree<hi_lemma, hi_lemma_cmp> hi_lemma_set;
struct ematch_branch_extension_core : public branch_extension {
hi_lemma_set m_lemmas;
hi_lemma_set m_new_lemmas;
ematch_branch_extension_core() {}
ematch_branch_extension_core(ematch_branch_extension_core const & src):
m_lemmas(src.m_lemmas), m_new_lemmas(src.m_new_lemmas) {}
virtual ~ematch_branch_extension_core() {}
void register_lemma(hypothesis const & h) {
if (is_pi(h.get_type()) && !is_arrow(h.get_type())) {
try {
@ -94,24 +122,52 @@ struct ematch_branch_extension : public branch_extension {
}
}
virtual ~ematch_branch_extension() {}
virtual branch_extension * clone() override { return new ematch_branch_extension(*this); }
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override {
register_lemma(h);
}
};
/* Extension that populates initial lemma set using [forward] lemmas */
struct ematch_branch_extension : public ematch_branch_extension_core {
virtual void initialized() override {
forward_lemmas s = get_forward_lemmas(env());
s.for_each([&](name const & n, unsigned prio) {
try {
m_new_lemmas.insert(mk_hi_lemma(n, prio));
} catch (exception &) {}
} catch (exception & ex) {
lean_trace_ematch(tout() << "discarding [forward] '" << n << "', " << ex.what() << "\n";);
}
});
}
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override {
collect_apps(h.get_type());
register_lemma(h);
}
virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) override {}
virtual void target_updated() override { collect_apps(curr_state().get_target()); }
virtual branch_extension * clone() override { return new ematch_branch_extension(*this); }
};
static unsigned g_ematch_ext_id = 0;
ematch_branch_extension_core & get_ematch_ext() {
return static_cast<ematch_branch_extension&>(curr_state().get_extension(g_ematch_ext_id));
}
/* Extension that populates initial lemma set using [simp] lemmas */
struct ematch_simp_branch_extension : public ematch_branch_extension_core {
virtual void initialized() override {
buffer<name> simp_lemmas;
get_simp_lemmas(env(), simp_lemmas);
for (name const & n : simp_lemmas) {
try {
m_new_lemmas.insert(mk_hi_lemma(n, get_simp_lemma_priority(env(), n)));
} catch (exception & ex) {
lean_trace_ematch(tout() << "discarding [simp] '" << n << "', " << ex.what() << "\n";);
}
}
}
virtual branch_extension * clone() override { return new ematch_simp_branch_extension(*this); }
};
static unsigned g_ematch_simp_ext_id = 0;
ematch_branch_extension_core & get_ematch_simp_ext() {
return static_cast<ematch_simp_branch_extension&>(curr_state().get_extension(g_ematch_simp_ext_id));
}
/* Auxiliary proof step used to bump proof depth */
struct noop_proof_step_cell : public proof_step_cell {
virtual ~noop_proof_step_cell() {}
@ -121,14 +177,17 @@ struct noop_proof_step_cell : public proof_step_cell {
};
void initialize_ematch() {
g_ext_id = register_branch_extension(new ematch_branch_extension());
g_inst_ext_id = register_branch_extension(new instances_branch_extension());
g_ematch_ext_id = register_branch_extension(new ematch_branch_extension());
g_ematch_simp_ext_id = register_branch_extension(new ematch_simp_branch_extension());
register_trace_class(name{"blast", "ematch"});
}
void finalize_ematch() {}
struct ematch_fn {
ematch_branch_extension & m_ext;
ematch_branch_extension_core & m_ext;
instances_branch_extension & m_inst_ext;
blast_tmp_type_context m_ctx;
congruence_closure & m_cc;
@ -143,8 +202,10 @@ struct ematch_fn {
bool m_new_instances;
ematch_fn():
m_ext(static_cast<ematch_branch_extension&>(curr_state().get_extension(g_ext_id))),
/** If fwd == true, then use [forward] lemmas, otherwise use [simp] lemmas */
ematch_fn(bool fwd = true):
m_ext(fwd ? get_ematch_ext() : get_ematch_simp_ext()),
m_inst_ext(get_inst_ext()),
m_cc(get_cc()),
m_new_instances(false) {
}
@ -226,7 +287,7 @@ struct ematch_fn {
buffer<expr> p_args;
expr const & f = get_app_args(p, p_args);
buffer<state> new_states;
if (auto s = m_ext.m_apps.find(head_index(f))) {
if (auto s = m_inst_ext.m_apps.find(head_index(f))) {
s->for_each([&](expr const & t) {
if (m_cc.is_congr_root(R, t)) {
state new_state = m_state;
@ -299,15 +360,15 @@ struct ematch_fn {
expr new_inst = normalize(m_ctx->instantiate_uvars_mvars(lemma.m_prop));
if (has_idx_metavar(new_inst))
return; // result contains temporary metavariables
if (m_ext.m_instances.contains(new_inst))
if (m_inst_ext.m_instances.contains(new_inst))
return; // already added this instance
if (!m_new_instances) {
trace_action("ematch");
}
lean_trace(name({"blast", "ematch"}), tout() << "instance: " << ppb(new_inst) << "\n";);
lean_trace_ematch(tout() << "instance: " << ppb(new_inst) << "\n";);
m_new_instances = true;
expr new_proof = m_ctx->instantiate_uvars_mvars(lemma.m_proof);
m_ext.m_instances.insert(new_inst);
m_inst_ext.m_instances.insert(new_inst);
curr_state().mk_hypothesis(new_inst, new_proof);
}
@ -332,7 +393,7 @@ struct ematch_fn {
expr const & f = get_app_args(p0, p0_args);
name const & R = is_prop(p0) ? get_iff_name() : get_eq_name();
unsigned gmt = m_cc.get_gmt();
if (auto s = m_ext.m_apps.find(head_index(f))) {
if (auto s = m_inst_ext.m_apps.find(head_index(f))) {
s->for_each([&](expr const & t) {
if (m_cc.is_congr_root(R, t) && (!filter || m_cc.get_mt(R, t) == gmt)) {
m_ctx->clear();
@ -398,7 +459,14 @@ struct ematch_fn {
action_result ematch_action() {
if (get_config().m_ematch)
return ematch_fn()();
return ematch_fn(true)();
else
return action_result::failed();
}
action_result ematch_simp_action() {
if (get_config().m_ematch)
return ematch_fn(false)();
else
return action_result::failed();
}

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
namespace lean {
namespace blast {
action_result ematch_action();
action_result ematch_simp_action();
void initialize_ematch();
void finalize_ematch();
}}

View file

@ -84,6 +84,13 @@ bool is_simp_lemma(environment const & env, name const & c) {
return simp_ext::get_state(env).m_simp_lemmas.contains(c);
}
unsigned get_simp_lemma_priority(environment const & env, name const & n) {
if (auto r = simp_ext::get_state(env).m_simp_lemmas.get_prio(n))
return *r;
else
return LEAN_DEFAULT_PRIORITY;
}
bool is_congr_lemma(environment const & env, name const & c) {
return simp_ext::get_state(env).m_congr_lemmas.contains(c);
}

View file

@ -14,6 +14,7 @@ Author: Leonardo de Moura
namespace lean {
environment add_simp_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent);
environment add_congr_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent);
unsigned get_simp_lemma_priority(environment const & env, name const & n);
bool is_simp_lemma(environment const & env, name const & n);
bool is_congr_lemma(environment const & env, name const & n);
void get_simp_lemmas(environment const & env, buffer<name> & r);

View file

@ -49,6 +49,13 @@ static optional<expr> apply_ematch() {
ematch_action)();
}
static optional<expr> apply_ematch_simp() {
flet<bool> set(get_config().m_ematch, true);
return mk_debug_action_strategy(assert_cc_action,
unit_propagate,
ematch_simp_action)();
}
static optional<expr> apply_constructor() {
return mk_debug_action_strategy([]() { return constructor_action(); })();
}
@ -92,6 +99,8 @@ optional<expr> apply_strategy() {
return apply_core_grind();
} else if (s_name == "ematch") {
return apply_ematch();
} else if (s_name == "ematch_simp") {
return apply_ematch_simp();
} else if (s_name == "constructor") {
return apply_constructor();
} else if (s_name == "unit") {