fix(library/blast/unit/unit_propagate): memory access violation

This commit is contained in:
Leonardo de Moura 2016-01-27 15:22:34 -08:00
parent fb95b71a5e
commit 684995640a
2 changed files with 68 additions and 15 deletions

View file

@ -86,9 +86,9 @@ struct unit_branch_extension : public branch_extension {
} }
virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx hidx) override { virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx hidx) override {
if (is_lemma(h.get_type())) { if (is_lemma(h.get_type())) {
list<expr> const * facts = find_facts_watching_lemma(hidx); list<expr> facts = find_facts_watching_lemma(hidx);
if (facts) { if (facts) {
for_each(*facts, [&](expr const & fact) { for_each(facts, [&](expr const & fact) {
unwatch(hidx, fact); unwatch(hidx, fact);
}); });
} }
@ -106,11 +106,11 @@ struct unit_branch_extension : public branch_extension {
} }
public: public:
list<hypothesis_idx> const * find_lemmas_watching_fact(expr const & fact_type) { list<hypothesis_idx> find_lemmas_watching_fact(expr const & fact_type) {
return m_facts_to_lemmas.find(fact_type); return ptr_to_list(m_facts_to_lemmas.find(fact_type));
} }
list<expr> const * find_facts_watching_lemma(hypothesis_idx lemma_hidx) { list<expr> find_facts_watching_lemma(hypothesis_idx lemma_hidx) {
return m_lemmas_to_facts.find(lemma_hidx); return ptr_to_list(m_lemmas_to_facts.find(lemma_hidx));
} }
void unwatch(hypothesis_idx lemma_hidx, expr const & fact_type) { void unwatch(hypothesis_idx lemma_hidx, expr const & fact_type) {
m_lemmas_to_facts.filter(lemma_hidx, [&](expr const & fact_type2) { m_lemmas_to_facts.filter(lemma_hidx, [&](expr const & fact_type2) {
@ -125,8 +125,8 @@ public:
m_lemmas_to_facts.insert(lemma_hidx, fact_type); m_lemmas_to_facts.insert(lemma_hidx, fact_type);
m_facts_to_lemmas.insert(fact_type, lemma_hidx); m_facts_to_lemmas.insert(fact_type, lemma_hidx);
} }
list<hypothesis_idx> const * find_dep_lemmas_watching_fact(expr const & fact_type) { list<hypothesis_idx> find_dep_lemmas_watching_fact(expr const & fact_type) {
return m_facts_to_dep_lemmas.find(fact_type); return ptr_to_list(m_facts_to_dep_lemmas.find(fact_type));
} }
}; };
@ -210,10 +210,10 @@ static action_result unit_lemma(hypothesis_idx hidx, expr const & _type, expr co
unit_branch_extension & ext = get_extension(); unit_branch_extension & ext = get_extension();
/* (1) Find the facts that are watching this lemma and clear them. */ /* (1) Find the facts that are watching this lemma and clear them. */
list<expr> const * watching = ext.find_facts_watching_lemma(hidx); list<expr> watching = ext.find_facts_watching_lemma(hidx);
if (watching) { if (watching) {
lean_assert(length(*watching) == 2); lean_assert(length(watching) == 2);
for_each(*watching, [&](expr const & fact) { ext.unwatch(hidx, fact); }); for_each(watching, [&](expr const & fact) { ext.unwatch(hidx, fact); });
} }
/* (2) Check if we can propagate */ /* (2) Check if we can propagate */
@ -328,8 +328,8 @@ static action_result unit_fact(expr const & type) {
unit_branch_extension & ext = get_extension(); unit_branch_extension & ext = get_extension();
bool success = false; bool success = false;
/* non dependent lemmas */ /* non dependent lemmas */
if (list<hypothesis_idx> const * lemmas = ext.find_lemmas_watching_fact(type)) { if (list<hypothesis_idx> lemmas = ext.find_lemmas_watching_fact(type)) {
for_each(*lemmas, [&](hypothesis_idx const & hidx) { for_each(lemmas, [&](hypothesis_idx const & hidx) {
hypothesis const & h = curr_state().get_hypothesis_decl(hidx); hypothesis const & h = curr_state().get_hypothesis_decl(hidx);
// TODO(Leo): it is not clear to me why we need whnf in the following statement. // TODO(Leo): it is not clear to me why we need whnf in the following statement.
action_result r = unit_lemma(hidx, whnf(h.get_type()), h.get_self()); action_result r = unit_lemma(hidx, whnf(h.get_type()), h.get_self());
@ -337,8 +337,8 @@ static action_result unit_fact(expr const & type) {
}); });
} }
/* dependent lemmas */ /* dependent lemmas */
if (list<hypothesis_idx> const * lemmas = ext.find_dep_lemmas_watching_fact(type)) { if (list<hypothesis_idx> lemmas = ext.find_dep_lemmas_watching_fact(type)) {
for_each(*lemmas, [&](hypothesis_idx const & hidx) { for_each(lemmas, [&](hypothesis_idx const & hidx) {
hypothesis const & h = curr_state().get_hypothesis_decl(hidx); hypothesis const & h = curr_state().get_hypothesis_decl(hidx);
action_result r = unit_dep_lemma(hidx, whnf(h.get_type()), h.get_self()); action_result r = unit_dep_lemma(hidx, whnf(h.get_type()), h.get_self());
success = success || (r.get_kind() == action_result::NewBranch); success = success || (r.get_kind() == action_result::NewBranch);

View file

@ -0,0 +1,53 @@
import data.real
open real
namespace safe
definition pos (x : ) := x > 0
definition nzero (x : ) := x ≠ 0
constants (exp : )
constants (safe_log : Π (x : ), pos x → )
constants (safe_inv : Π (x : ), nzero x → )
notation `log`:max x:max := (@safe_log x (by grind))
notation [priority 100000] x:max ⁻¹ := (@safe_inv x (by grind))
definition sq (x : ) := x * x
notation a `²` := sq a
lemma nzero_of_pos [intro!] {x : } : pos x → nzero x := sorry
lemma pos_bit1 [intro!] {x : } : pos x → pos (: bit1 x :) := sorry
lemma pos_bit0 [intro!] {x : } : pos x → pos (: bit0 x :) := sorry
lemma pos_inv [forward] {x : } : pos x → (: pos (inv x) :) := sorry
lemma pos_1 [intro!] : pos (1:) := sorry
lemma pos_add [intro!] [forward] {x y : } : pos x → pos y → pos (: x + y :) := sorry
lemma pos_mul [intro!] [forward] {x y : } : pos x → pos y → pos (: x * y :) := sorry
lemma pos_sq [intro!] {x : } : pos x → pos (: sq x :) := sorry
lemma inv_pos [intro!] {x : } : pos x → pos (: x⁻¹ :) := sorry
lemma exp_pos [intro!] (x : ) : pos (: exp x :) := sorry
lemma log_mul [forward] : ∀ (x y : ), pos x → pos y → (: log (x * y) :) = (: log x + log y :) := sorry
lemma log_sq [forward] : ∀ (x : ), pos x → (: log (sq x) :) = (: 2 * log x :) := sorry
lemma log_inv [forward] : ∀ (x : ), pos x → (: log (x⁻¹) :) = (: - log x :) := sorry
lemma inv_mul [forward] : ∀ (x y : ), pos x → pos y → (: (x * y)⁻¹ :) = (: x⁻¹ * y⁻¹ :) := sorry
lemma exp_add [forward] : ∀ (x y : ), (: exp (x + y) :) = (: exp x * exp y :) := sorry
lemma pair_prod [forward] : ∀ (x y : ), (: sq x + 2 * x * y + sq y :) = sq (x + y) := sorry
lemma mul_two_sum [forward] : ∀ (x : ), (: 2 * x :) = (: x + x :) := sorry
lemma sub_def [forward] : ∀ (x y : ), (x - y) = x + -y := sorry
lemma mul_div_cancel [forward] : ∀ (x y : ), (y * x) / y = x := sorry
lemma div_neg [forward] : ∀ (x y : ), x / -y = - (x / y) := sorry
attribute right_distrib [forward]
attribute left_distrib [forward]
set_option blast.strategy "ematch"
example (x y z w : ) : pos x → pos y → pos z → pos w → x * y = w + exp z →
log ((w² + 2 * w * exp z + (exp z)²)) / -2 = log (x⁻¹) - log y :=
by blast
end safe