feat(library/blast/forward): inst_simp should use the left-hand-side as a pattern (if none is provided by the user)

The motivation is to reduce the number of instances generated by ematching.

For example, given

   inv_inv:  forall a, (a⁻¹)⁻¹ = a

the new heuristic uses ((a⁻¹)⁻¹) as the pattern.
This matches the intuition that inv_inv should be used a simplification
rule.

The default pattern inference procedure would use (a⁻¹). This is bad
because it generates an infinite chain of instances whenever there is a
term (a⁻¹) in the proof state.
By using (a⁻¹), we get
   (a⁻¹)⁻¹ = a
Now that we have (a⁻¹)⁻¹, we can match again and generate
   ((a⁻¹)⁻¹)⁻¹ = a⁻¹
and so on
This commit is contained in:
Leonardo de Moura 2015-12-31 20:20:39 -08:00
parent 03f9e9acb0
commit 54f2c0f254
13 changed files with 105 additions and 62 deletions

View file

@ -84,7 +84,8 @@ section division_ring
suppose a = 0,
have 0 = (1:A), by inst_simp,
absurd this zero_ne_one,
by inst_simp
assert b = (1 / a) * a * b, by inst_simp,
show b = 1 / a, by inst_simp
theorem eq_one_div_of_mul_eq_one_left (H : b * a = 1) : b = 1 / a :=
assert a ≠ 0, from

View file

@ -9,7 +9,6 @@ Various multiplicative and additive structures. Partially modeled on Isabelle's
import logic.eq data.unit data.sigma data.prod
import algebra.binary algebra.priority
open eq eq.ops -- note: ⁻¹ will be overloaded
open binary
variable {A : Type}
@ -22,7 +21,9 @@ attribute neg [light 3]
structure semigroup [class] (A : Type) extends has_mul A :=
(mul_assoc : ∀a b c, mul (mul a b) c = mul a (mul b c))
theorem mul.assoc [simp] [semigroup A] (a b c : A) : a * b * c = a * (b * c) :=
-- We add pattern hints to the following lemma because we want it to be used in both directions
-- at inst_simp strategy.
theorem mul.assoc [simp] [semigroup A] (a b c : A) : (: a * b * c :) = (: a * (b * c) :) :=
!semigroup.mul_assoc
structure comm_semigroup [class] (A : Type) extends semigroup A :=
@ -58,7 +59,7 @@ abbreviation eq_of_mul_eq_mul_right' := @mul.right_cancel
structure add_semigroup [class] (A : Type) extends has_add A :=
(add_assoc : ∀a b c, add (add a b) c = add a (add b c))
theorem add.assoc [simp] [add_semigroup A] (a b c : A) : a + b + c = a + (b + c) :=
theorem add.assoc [simp] [add_semigroup A] (a b c : A) : (: a + b + c :) = (: a + (b + c) :) :=
!add_semigroup.add_assoc
structure add_comm_semigroup [class] (A : Type) extends add_semigroup A :=
@ -155,6 +156,7 @@ section group
by simp
theorem inv_eq_of_mul_eq_one {a b : A} (H : a * b = 1) : a⁻¹ = b :=
assert a⁻¹ * 1 = b, by inst_simp,
by inst_simp
theorem one_inv [simp] : 1⁻¹ = (1 : A) :=
@ -164,13 +166,15 @@ section group
inv_eq_of_mul_eq_one (mul.left_inv a)
theorem inv.inj {a b : A} (H : a⁻¹ = b⁻¹) : a = b :=
assert a = a⁻¹⁻¹, by simp_nohyps,
by inst_simp
theorem inv_eq_inv_iff_eq (a b : A) : a⁻¹ = b⁻¹ ↔ a = b :=
iff.intro (assume H, inv.inj H) (by simp)
theorem inv_eq_one_iff_eq_one (a : A) : a⁻¹ = 1 ↔ a = 1 :=
one_inv ▸ inv_eq_inv_iff_eq a 1
assert a⁻¹ = 1⁻¹ ↔ a = 1, from inv_eq_inv_iff_eq a 1,
by simp
theorem eq_one_of_inv_eq_one (a : A) : a⁻¹ = 1 → a = 1 :=
iff.mp !inv_eq_one_iff_eq_one
@ -182,9 +186,11 @@ section group
iff.intro !eq_inv_of_eq_inv !eq_inv_of_eq_inv
theorem eq_inv_of_mul_eq_one {a b : A} (H : a * b = 1) : a = b⁻¹ :=
begin apply eq_inv_of_eq_inv, symmetry, exact inv_eq_of_mul_eq_one H end
assert a⁻¹ = b, from inv_eq_of_mul_eq_one H,
by inst_simp
theorem mul.right_inv [simp] (a : A) : a * a⁻¹ = 1 :=
assert a = a⁻¹⁻¹, by simp,
by inst_simp
theorem mul_inv_cancel_left [simp] (a b : A) : a * (a⁻¹ * b) = b :=
@ -194,9 +200,11 @@ section group
by inst_simp
theorem mul_inv [simp] (a b : A) : (a * b)⁻¹ = b⁻¹ * a⁻¹ :=
assert a * a⁻¹ = 1, by inst_simp, -- why do we need it?
inv_eq_of_mul_eq_one (by inst_simp)
theorem eq_of_mul_inv_eq_one {a b : A} (H : a * b⁻¹ = 1) : a = b :=
assert a⁻¹ * 1 = a⁻¹, by inst_simp,
by inst_simp
theorem eq_mul_inv_of_mul_eq {a b c : A} (H : a * c = b) : a = b * c⁻¹ :=
@ -230,13 +238,15 @@ section group
iff.intro eq_mul_inv_of_mul_eq mul_eq_of_eq_mul_inv
theorem mul_left_cancel {a b c : A} (H : a * b = a * c) : b = c :=
assert a⁻¹ * (a * b) = b, by inst_simp,
by inst_simp
theorem mul_right_cancel {a b c : A} (H : a * b = c * b) : a = c :=
assert a * b * b⁻¹ = a, by inst_simp,
by inst_simp
theorem mul_eq_one_of_mul_eq_one {a b : A} (H : b * a = 1) : a * b = 1 :=
by inst_simp
by rewrite [-inv_eq_of_mul_eq_one H, mul.left_inv]
theorem mul_eq_one_iff_mul_eq_one (a b : A) : a * b = 1 ↔ b * a = 1 :=
iff.intro !mul_eq_one_of_mul_eq_one !mul_eq_one_of_mul_eq_one
@ -256,6 +266,7 @@ section group
by inst_simp
lemma conj_one [simp] (g : A) : g ∘c 1 = 1 :=
assert g * g⁻¹ = 1, by inst_simp, -- why do we need it?
by inst_simp
lemma conj_inv_cancel [simp] (g : A) : ∀ a, g⁻¹ ∘c g ∘c a = a :=
@ -315,6 +326,7 @@ section add_group
by simp
theorem neg_eq_of_add_eq_zero {a b : A} (H : a + b = 0) : -a = b :=
assert -a + 0 = b, by inst_simp,
by inst_simp
theorem neg_zero [simp] : -0 = (0 : A) := neg_eq_of_add_eq_zero (zero_add 0)
@ -322,9 +334,11 @@ section add_group
theorem neg_neg [simp] (a : A) : -(-a) = a := neg_eq_of_add_eq_zero (add.left_inv a)
theorem eq_neg_of_add_eq_zero {a b : A} (H : a + b = 0) : a = -b :=
assert -a = b, from neg_eq_of_add_eq_zero H,
by inst_simp
theorem neg.inj {a b : A} (H : -a = -b) : a = b :=
assert a = -(-a), by simp_nohyps,
by inst_simp
theorem neg_eq_neg_iff_eq (a b : A) : -a = -b ↔ a = b :=
@ -334,7 +348,8 @@ section add_group
iff.mp !neg_eq_neg_iff_eq
theorem neg_eq_zero_iff_eq_zero (a : A) : -a = 0 ↔ a = 0 :=
neg_zero ▸ !neg_eq_neg_iff_eq
assert -a = -0 ↔ a = 0, from neg_eq_neg_iff_eq a 0,
by simp
theorem eq_zero_of_neg_eq_zero {a : A} : -a = 0 → a = 0 :=
iff.mp !neg_eq_zero_iff_eq_zero
@ -346,6 +361,7 @@ section add_group
iff.intro !eq_neg_of_eq_neg !eq_neg_of_eq_neg
theorem add.right_inv [simp] (a : A) : a + -a = 0 :=
assert a = -(-a), by simp,
by inst_simp
theorem add_neg_cancel_left [simp] (a b : A) : a + (-a + b) = b :=
@ -389,9 +405,11 @@ section add_group
iff.intro eq_add_neg_of_add_eq add_eq_of_eq_add_neg
theorem add_left_cancel {a b c : A} (H : a + b = a + c) : b = c :=
assert -a + (a + b) = b, by inst_simp,
by inst_simp
theorem add_right_cancel {a b c : A} (H : a + b = c + b) : a = c :=
assert a + b + -b = a, by inst_simp,
by inst_simp
definition add_group.to_left_cancel_semigroup [trans_instance] [reducible] :
@ -424,10 +442,11 @@ section add_group
theorem add_sub_cancel (a b : A) : a + b - b = a := !add_neg_cancel_right
theorem eq_of_sub_eq_zero {a b : A} (H : a - b = 0) : a = b :=
assert -a + 0 = -a, by inst_simp,
by inst_simp
theorem eq_iff_sub_eq_zero (a b : A) : a = b ↔ a - b = 0 :=
iff.intro (assume H, H ▸ !sub_self) (assume H, eq_of_sub_eq_zero H)
iff.intro (assume H, eq.subst H !sub_self) (assume H, eq_of_sub_eq_zero H)
theorem zero_sub (a : A) : 0 - a = -a := !zero_add
@ -440,7 +459,8 @@ section add_group
theorem neg_sub (a b : A) : -(a - b) = b - a :=
neg_eq_of_add_eq_zero (by inst_simp)
theorem add_sub (a b c : A) : a + (b - c) = a + b - c := !add.assoc⁻¹
theorem add_sub (a b c : A) : a + (b - c) = a + b - c :=
by simp
theorem sub_add_eq_sub_sub_swap (a b c : A) : a - (b + c) = a - c - b :=
by inst_simp

View file

@ -69,13 +69,13 @@ rfl
theorem length_append [simp] : ∀ (s t : list T), length (s ++ t) = length s + length t :=
by rec_inst_simp
theorem eq_nil_of_length_eq_zero : ∀ {l : list T}, length l = 0 → l = [] :=
by rec_inst_simp
theorem eq_nil_of_length_eq_zero : ∀ {l : list T}, length l = 0 → l = []
| [] H := rfl
| (a::s) H := by contradiction
theorem ne_nil_of_length_eq_succ : ∀ {l : list T} {n : nat}, length l = succ n → l ≠ [] :=
by rec_inst_simp
-- add_rewrite length_nil length_cons
theorem ne_nil_of_length_eq_succ : ∀ {l : list T} {n : nat}, length l = succ n → l ≠ []
| [] n h := by contradiction
| (a::l) n h := by contradiction
/- concat -/
@ -93,7 +93,7 @@ theorem concat_eq_append [simp] (a : T) : ∀ (l : list T), concat a l = l ++ [a
by rec_inst_simp
theorem concat_ne_nil [simp] (a : T) : ∀ (l : list T), concat a l ≠ [] :=
by rec_inst_simp
by intro l; induction l; repeat contradiction
theorem length_concat [simp] (a : T) : ∀ (l : list T), length (concat a l) = length l + 1 :=
by rec_inst_simp

View file

@ -44,7 +44,10 @@ nat.induction_on x
/- successor and predecessor -/
theorem succ_ne_zero (n : ) : succ n ≠ 0 :=
theorem succ_ne_zero [simp] (n : ) : succ n ≠ 0 :=
by contradiction
theorem add_one_ne_zero [simp] (n : ) : n + 1 ≠ 0 :=
by contradiction
-- add_rewrite succ_ne_zero
@ -141,14 +144,24 @@ protected theorem add_right_comm : Π (n m k : ), n + m + k = n + k + m :=
right_comm nat.add_comm nat.add_assoc
protected theorem add_left_cancel {n m k : } : n + m = n + k → m = k :=
nat.induction_on n (by simp) (by inst_simp)
nat.induction_on n
(by simp)
(take a iH,
-- TODO(Leo): replace with forward reasoning after we add strategies for it.
assert succ (a + m) = succ (a + k) → a + m = a + k, from !succ.inj,
by inst_simp)
protected theorem add_right_cancel {n m k : } (H : n + m = k + m) : n = k :=
have H2 : m + n = m + k, by simp,
nat.add_left_cancel H2
theorem eq_zero_of_add_eq_zero_right {n m : } : n + m = 0 → n = 0 :=
nat.induction_on n (by simp) (by inst_simp)
nat.induction_on n
(by simp)
(take k iH, assume H : succ k + m = 0,
absurd
(show succ (k + m) = 0, by simp)
!succ_ne_zero)
theorem eq_zero_of_add_eq_zero_left {n m : } (H : n + m = 0) : m = 0 :=
eq_zero_of_add_eq_zero_right (!nat.add_comm ⬝ H)

View file

@ -41,7 +41,7 @@ nat.induction_on m (by simp) (by simp)
protected theorem add_sub_cancel_left [simp] (n m : ) : n + m - n = m :=
nat.induction_on n (by simp) (by simp)
protected theorem sub_sub [simp] (n m k : ) : n - m - k = n - (m + k) :=
protected theorem sub_sub [simp] (n m k : ) : (: n - m - k :) = (: n - (m + k) :) :=
nat.induction_on k (by simp) (by simp)
theorem succ_sub_sub_succ [simp] (n m k : ) : succ n - m - succ k = n - m - k :=

View file

@ -69,7 +69,8 @@
"apply" "fapply" "eapply" "rename" "intro" "intros" "all_goals" "fold" "focus" "focus_at"
"generalize" "generalizes" "clear" "clears" "revert" "reverts" "back" "beta" "done" "exact" "rexact"
"refine" "repeat" "whnf" "rotate" "rotate_left" "rotate_right" "inversion" "cases" "rewrite"
"xrewrite" "krewrite" "blast" "simp" "simp_nohyps" "simp_topdown" "esimp" "unfold" "change" "check_expr" "contradiction"
"xrewrite" "krewrite" "blast" "rec_simp" "rec_inst_simp"
"inst_simp" "simp" "simp_nohyps" "simp_topdown" "esimp" "unfold" "change" "check_expr" "contradiction"
"exfalso" "split" "existsi" "constructor" "fconstructor" "left" "right" "injection" "congruence" "reflexivity"
"symmetry" "transitivity" "state" "induction" "induction_using" "fail" "append"
"substvars" "now" "with_options" "with_attributes" "with_attrs" "note")

View file

@ -6,7 +6,6 @@ Author: Leonardo de Moura
*/
#include "library/replace_visitor.h"
#include "library/attribute_manager.h"
#include "library/blast/forward/pattern.h"
#include "frontends/lean/decl_attributes.h"
#include "frontends/lean/parser.h"
#include "frontends/lean/tokens.h"
@ -108,10 +107,6 @@ void decl_attributes::parse(parser & p) {
environment decl_attributes::apply(environment env, io_state const & ios, name const & d, name const & ns) const {
buffer<entry> entries;
to_buffer(m_entries, entries);
if (has_pattern_hints(env.get(d).get_type())) {
// turn on [forward] if patterns hints have been used in the type.
entries.push_back(entry("forward"));
}
unsigned i = entries.size();
while (i > 0) {
--i;

View file

@ -207,7 +207,7 @@ struct ematch_simp_branch_extension : public ematch_branch_extension_core {
get_simp_lemmas(env(), simp_lemmas);
for (name const & n : simp_lemmas) {
try {
m_new_lemmas.insert(mk_hi_lemma(n, get_simp_lemma_priority(env(), n)));
m_new_lemmas.insert(mk_hi_simp_lemma(n, get_simp_lemma_priority(env(), n)));
} catch (exception & ex) {
lean_trace_event_ematch(tout() << "ematcher discarding [simp] '" << n << "', " << ex.what() << "\n";);
}

View file

@ -21,6 +21,7 @@ Author: Leonardo de Moura
#include "library/attribute_manager.h"
#include "library/idx_metavar.h"
#include "library/blast/options.h"
#include "library/blast/trace.h"
#include "library/blast/blast.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/forward_lemmas.h"
@ -288,6 +289,9 @@ struct mk_hi_lemma_fn {
unsigned m_num_uvars;
unsigned m_priority;
unsigned m_max_steps;
/* If m_simp is true, the pattern inference procedure assumes the given lemma is a [simp] lemma.
That is, the conclusion is of the form (t ~ s), and it will try to use t as a pattern. */
bool m_simp;
buffer<expr> m_mvars;
idx_metavar_set m_trackable;
@ -295,9 +299,10 @@ struct mk_hi_lemma_fn {
unsigned m_num_steps;
mk_hi_lemma_fn(tmp_type_context & ctx, expr const & H,
unsigned num_uvars, unsigned prio, unsigned max_steps):
unsigned num_uvars, unsigned prio, unsigned max_steps, bool simp):
m_ctx(ctx), m_no_patterns(no_pattern_ext::get_state(ctx.env())),
m_H(H), m_num_uvars(num_uvars), m_priority(prio), m_max_steps(max_steps) {}
m_H(H), m_num_uvars(num_uvars), m_priority(prio), m_max_steps(max_steps),
m_simp(simp) {}
struct candidate {
expr m_expr;
@ -410,7 +415,14 @@ struct mk_hi_lemma_fn {
candidate_set collect(expr const & a) {
m_candidates = candidate_set();
if (m_simp) {
name R; expr lhs, rhs;
if (is_relation_app(a, R, lhs, rhs)) {
m_candidates.insert(candidate(lhs));
}
} else {
save_candidates(collect_core(a));
}
return m_candidates;
}
@ -571,7 +583,7 @@ struct mk_hi_lemma_fn {
candidate_set B_candidates = collect(B);
if (auto r1 = mk_multi_patterns_using(B_candidates, true)) {
mps = r1;
} else {
} else if (!m_simp) {
candidate_set residue_candidates;
for (expr const & r : residue_locals) {
residue_candidates.merge(collect(m_ctx.infer(r)));
@ -604,15 +616,15 @@ struct mk_hi_lemma_fn {
};
hi_lemma mk_hi_lemma_core(tmp_type_context & ctx, expr const & H, unsigned num_uvars,
unsigned priority, unsigned max_steps) {
unsigned priority, unsigned max_steps, bool simp) {
try {
bool erase_hints = false;
return mk_hi_lemma_fn(ctx, H, num_uvars, priority, max_steps)(erase_hints);
return mk_hi_lemma_fn(ctx, H, num_uvars, priority, max_steps, simp)(erase_hints);
} catch (mk_hi_lemma_fn::try_again_without_hints &) {
ctx.clear();
try {
bool erase_hints = true;
return mk_hi_lemma_fn(ctx, H, num_uvars, priority, max_steps)(erase_hints);
return mk_hi_lemma_fn(ctx, H, num_uvars, priority, max_steps, simp)(erase_hints);
} catch (mk_hi_lemma_fn::try_again_without_hints &) {
lean_unreachable();
}
@ -622,10 +634,11 @@ hi_lemma mk_hi_lemma_core(tmp_type_context & ctx, expr const & H, unsigned num_u
hi_lemma mk_hi_lemma(expr const & H) {
blast_tmp_type_context ctx;
unsigned max_steps = get_config().m_pattern_max_steps;
return mk_hi_lemma_core(*ctx, H, 0, LEAN_DEFAULT_PRIORITY, max_steps);
bool simp = false;
return mk_hi_lemma_core(*ctx, H, 0, LEAN_DEFAULT_PRIORITY, max_steps, simp);
}
hi_lemma mk_hi_lemma(name const & c, unsigned priority) {
hi_lemma mk_hi_lemma(name const & c, unsigned priority, bool simp) {
blast_tmp_type_context ctx;
unsigned max_steps = get_config().m_pattern_max_steps;
declaration const & d = env().get(c);
@ -634,7 +647,17 @@ hi_lemma mk_hi_lemma(name const & c, unsigned priority) {
for (unsigned i = 0; i < num_us; i++)
us.push_back(ctx->mk_uvar());
expr H = mk_constant(c, to_list(us));
return mk_hi_lemma_core(*ctx, H, num_us, priority, max_steps);
return mk_hi_lemma_core(*ctx, H, num_us, priority, max_steps, simp);
}
hi_lemma mk_hi_lemma(name const & c, unsigned priority) {
bool simp = false;
return mk_hi_lemma(c, priority, simp);
}
hi_lemma mk_hi_simp_lemma(name const & c, unsigned priority) {
bool simp = true;
return mk_hi_lemma(c, priority, simp);
}
}

View file

@ -58,6 +58,10 @@ namespace blast {
The maximum number of steps is extracted from the blast config object. */
hi_lemma mk_hi_lemma(expr const & H);
hi_lemma mk_hi_lemma(name const & n, unsigned prio);
/** \brief Similar to \c mk_hi_lemma, but uses a different pattern inference procedure.
It assumes the given lemma has a conclusion of the form t ~ s, and uses \c t as the pattern.
\remark This procedure is used to automatically convert [simp] lemmas into [forward] lemmas. */
hi_lemma mk_hi_simp_lemma(name const & n, unsigned prio);
}
void initialize_pattern();

View file

@ -4,8 +4,8 @@ import logic data.nat.basic
open nat eq.ops algebra
theorem tst (a b c : nat) : a + b + c = a + c + b :=
calc a + b + c = a + (b + c) : !add.assoc
... = a + (c + b) : {!add.comm}
... = a + c + b : (!add.assoc)⁻¹
calc a + b + c = a + (b + c) : by rewrite add.assoc
... = a + (c + b) : by rewrite (add.comm b c)
... = a + c + b : by rewrite add.assoc
WAIT
INFO 7

View file

@ -35,29 +35,15 @@ c
b
-- ACK
-- SYMBOL|7|31
(
by
-- ACK
-- SYMBOL|7|32
!
-- SYMBOL|7|34
rewrite
-- ACK
-- TYPE|7|33
a + c + b = a + (c + b)
-- TYPE|7|42
∀ (a_1 b_1 c_1 : ?A), (:a_1 + b_1 + c_1:) = (:a_1 + (b_1 + c_1):)
-- ACK
-- IDENTIFIER|7|33
-- IDENTIFIER|7|42
add.assoc
-- ACK
-- TYPE|7|43
a + c + b = a + (c + b) → a + (c + b) = a + c + b
-- ACK
-- OVERLOAD|7|43
eq.symm #0
--
inv #0
-- ACK
-- SYMBOL|7|43
⁻¹
-- ACK
-- IDENTIFIER|7|43
eq.symm
-- ACK
-- ENDINFO

View file

@ -1,5 +1,5 @@
constants {A : Type.{1}} (P : A → Prop) (Q : A → Prop)
definition H : ∀ a, (: P a :) → Exists Q := sorry
definition H [forward] : ∀ a, (: P a :) → Exists Q := sorry
set_option blast.strategy "ematch"