diff --git a/src/library/blast/hypothesis.h b/src/library/blast/hypothesis.h index f1b35b566..03867390d 100644 --- a/src/library/blast/hypothesis.h +++ b/src/library/blast/hypothesis.h @@ -47,4 +47,29 @@ public: bool depends_on(expr const & h) const { return m_deps.contains(href_index(h)); } bool is_assumption() const { return !m_value || is_local_non_href(*m_value); } }; + +class hypothesis_idx_buffer_set { + friend class state; + hypothesis_idx_buffer m_buffer; + hypothesis_idx_set m_set; +public: + hypothesis_idx_buffer_set() {} + hypothesis_idx_buffer_set(hypothesis_idx_buffer const & b) { + for (auto hidx : b) + insert(hidx); + } + + void insert(hypothesis_idx h) { + if (!m_set.contains(h)) { + m_set.insert(h); + m_buffer.push_back(h); + } + } + hypothesis_idx_buffer const & as_buffer() const { + return m_buffer; + } + hypothesis_idx_set const & as_set() const { + return m_set; + } +}; }} diff --git a/src/library/blast/revert.cpp b/src/library/blast/revert.cpp index 7444a5acf..db2954c40 100644 --- a/src/library/blast/revert.cpp +++ b/src/library/blast/revert.cpp @@ -23,26 +23,25 @@ struct revert_proof_step_cell : public proof_step_cell { virtual bool is_silent() const override { return true; } }; -unsigned revert_action(buffer & hidxs, hypothesis_idx_set & hidxs_set) { - lean_assert(hidxs.size() == hidxs_set.size()); +unsigned revert_action(hypothesis_idx_buffer_set & hidxs) { state & s = curr_state(); - unsigned hidxs_size = hidxs.size(); + unsigned hidxs_size = hidxs.as_buffer().size(); for (unsigned i = 0; i < hidxs_size; i++) { - s.collect_forward_deps(hidxs[i], hidxs, hidxs_set); + s.collect_forward_deps(hidxs.as_buffer()[i], hidxs); } s.sort_hypotheses(hidxs); buffer hs; - s.to_hrefs(hidxs, hs); + s.to_hrefs(hidxs.as_buffer(), hs); expr target = s.get_target(); expr new_target = s.mk_pi(hs, target); s.set_target(new_target); s.push_proof_step(new revert_proof_step_cell(to_list(hs))); - lean_verify(s.del_hypotheses(hidxs)); - return hidxs.size(); + lean_verify(s.del_hypotheses(hidxs.as_buffer())); + return hidxs.as_buffer().size(); } unsigned revert_action(buffer & hidxs) { - hypothesis_idx_set hidxs_set(hidxs); - return revert_action(hidxs, hidxs_set); + hypothesis_idx_buffer_set _hidxs(hidxs); + return revert_action(_hidxs); } }} diff --git a/src/library/blast/revert.h b/src/library/blast/revert.h index 5440f09d4..bf778019a 100644 --- a/src/library/blast/revert.h +++ b/src/library/blast/revert.h @@ -12,8 +12,5 @@ namespace blast { /** \brief Revert the given hypotheses and their dependencies. Return the total number of hypotheses reverted. */ unsigned revert_action(buffer & hidxs); - -/** \brief Lower-level version of previous procedure. - \pre hidxs and hidxs_set contain the same elements. */ -unsigned revert_action(buffer & hidxs, hypothesis_idx_set & hidxs_set); +unsigned revert_action(hypothesis_idx_buffer_set & hidxs); }} diff --git a/src/library/blast/state.cpp b/src/library/blast/state.cpp index 15f9f4d44..8700c1d5c 100644 --- a/src/library/blast/state.cpp +++ b/src/library/blast/state.cpp @@ -363,6 +363,10 @@ void state::sort_hypotheses(hypothesis_idx_buffer & r) const { std::sort(r.begin(), r.end(), hypothesis_dep_depth_lt(*this)); } +void state::sort_hypotheses(hypothesis_idx_buffer_set & r) const { + std::sort(r.m_buffer.begin(), r.m_buffer.end(), hypothesis_dep_depth_lt(*this)); +} + void state::to_hrefs(hypothesis_idx_buffer const & hidxs, buffer & r) const { for (hypothesis_idx hidx : hidxs) r.push_back(get_hypothesis_decl(hidx)->get_self()); @@ -485,19 +489,15 @@ void state::del_hypotheses(buffer const & to_delete, hypothesis_ } } -void state::collect_forward_deps(hypothesis_idx hidx, buffer & result, hypothesis_idx_set & already_found) { - unsigned qhead = result.size(); +void state::collect_forward_deps(hypothesis_idx hidx, hypothesis_idx_buffer_set & result) { + hypothesis_idx_buffer const & b = result.as_buffer(); + unsigned qhead = b.size(); while (true) { - hypothesis_idx_set s = get_forward_deps(hidx); - s.for_each([&](hypothesis_idx h_dep) { - if (already_found.contains(h_dep)) - return; - already_found.insert(h_dep); - result.push_back(h_dep); - }); - if (qhead == result.size()) + hypothesis_idx_set s = get_direct_forward_deps(hidx); + s.for_each([&](hypothesis_idx h_dep) { result.insert(h_dep); }); + if (qhead == b.size()) return; - hidx = result[qhead]; + hidx = b[qhead]; qhead++; } } @@ -514,38 +514,29 @@ bool state::safe_to_delete(buffer const & to_delete) { return true; } -void state::collect_forward_deps(hypothesis_idx hidx, buffer & result) { - hypothesis_idx_set found; - collect_forward_deps(hidx, result, found); -} - bool state::del_hypotheses(buffer const & hs) { - hypothesis_idx_set found; - buffer to_delete; + hypothesis_idx_buffer_set to_delete; for (hypothesis_idx hidx : hs) { - to_delete.push_back(hidx); - found.insert(hidx); - collect_forward_deps(hidx, to_delete, found); + to_delete.insert(hidx); + collect_forward_deps(hidx, to_delete); } - if (!safe_to_delete(to_delete)) + if (!safe_to_delete(to_delete.as_buffer())) return false; - del_hypotheses(to_delete, found); + del_hypotheses(to_delete.as_buffer(), to_delete.as_set()); return true; } bool state::del_hypothesis(hypothesis_idx hidx) { - hypothesis_idx_set found; - buffer to_delete; - to_delete.push_back(hidx); - found.insert(hidx); - collect_forward_deps(hidx, to_delete, found); - if (!safe_to_delete(to_delete)) + hypothesis_idx_buffer_set to_delete; + to_delete.insert(hidx); + collect_forward_deps(hidx, to_delete); + if (!safe_to_delete(to_delete.as_buffer())) return false; - del_hypotheses(to_delete, found); + del_hypotheses(to_delete.as_buffer(), to_delete.as_set()); return true; } -hypothesis_idx_set state::get_forward_deps(hypothesis_idx hidx) const { +hypothesis_idx_set state::get_direct_forward_deps(hypothesis_idx hidx) const { if (auto r = m_branch.m_forward_deps.find(hidx)) return *r; else diff --git a/src/library/blast/state.h b/src/library/blast/state.h index 59ff065fb..5c89864e4 100644 --- a/src/library/blast/state.h +++ b/src/library/blast/state.h @@ -219,9 +219,22 @@ public: bool del_hypothesis(hypothesis_idx hidx); bool del_hypotheses(buffer const & hs); + /** \brief Return the set of hypotheses that (directly) depend on the given one */ + hypothesis_idx_set get_direct_forward_deps(hypothesis_idx hidx) const; + /** \brief Collect in \c result the hypotheses that (directly) depend on \c hidx and satisfy \c pred. */ + template + void collect_direct_forward_deps(hypothesis_idx hidx, hypothesis_idx_buffer_set & result, P && pred) { + get_direct_forward_deps(hidx).for_each([&](hypothesis_idx d) { + if (pred(d)) result.insert(d); + }); + } + /** \brief Collect in \c result the hypotheses that (directly) depend on \c hidx and satisfy \c pred. */ + void collect_direct_forward_deps(hypothesis_idx hidx, hypothesis_idx_buffer_set & result) { + return collect_direct_forward_deps(hidx, result, [](hypothesis_idx) { return true; }); + } + /** \brief Collect all hypothesis in \c result that depend directly or indirectly on hidx */ - void collect_forward_deps(hypothesis_idx hidx, buffer & result); - void collect_forward_deps(hypothesis_idx hidx, buffer & result, hypothesis_idx_set & already_found); + void collect_forward_deps(hypothesis_idx hidx, hypothesis_idx_buffer_set & result); /** \brief Return true iff the hypothesis with index \c hidx_user depends on the hypothesis with index \c hidx_provider. */ @@ -244,6 +257,7 @@ public: void get_sorted_hypotheses(hypothesis_idx_buffer & r) const; /** \brief Sort hypotheses in r */ void sort_hypotheses(hypothesis_idx_buffer & r) const; + void sort_hypotheses(hypothesis_idx_buffer_set & r) const; /** \brief Convert hypotheses indices into hrefs */ void to_hrefs(hypothesis_idx_buffer const & hidxs, buffer & r) const; @@ -261,13 +275,6 @@ public: /** \brief Return (active) hypotheses whose head symbol is equal to target or it is the negation of */ list get_head_related() const; - /** \brief Return the set of hypotheses that (directly) depend on the given one */ - hypothesis_idx_set get_forward_deps(hypothesis_idx hidx) const; - template - void for_each_forward_dep(hypothesis_idx hidx, F && f) const { - get_forward_deps(hidx).for_each(f); - } - /************************ Abstracting hypotheses *************************/ diff --git a/src/library/blast/subst.cpp b/src/library/blast/subst.cpp index dc8d0edff..f829f8b14 100644 --- a/src/library/blast/subst.cpp +++ b/src/library/blast/subst.cpp @@ -52,13 +52,11 @@ bool subst_core(hypothesis_idx hidx) { lean_verify(is_eq(type, lhs, rhs)); lean_assert(is_href(rhs)); try { - hypothesis_idx_buffer to_revert; - s.for_each_forward_dep(href_index(rhs), - [&](hypothesis_idx d) { - if (d != hidx) to_revert.push_back(d); - }); - s.for_each_forward_dep(hidx, - [&](hypothesis_idx d) { to_revert.push_back(d); }); + hypothesis_idx_buffer_set to_revert; + s.collect_direct_forward_deps(href_index(rhs), + to_revert, + [&](hypothesis_idx d) { return d != hidx; }); + s.collect_direct_forward_deps(hidx, to_revert); unsigned num = revert_action(to_revert); expr target = s.get_target(); expr new_target = abstract(target, h->get_self()); @@ -90,7 +88,7 @@ bool subst_action(hypothesis_idx hidx) { if (is_href(rhs)) { return subst_core(hidx); } else if (is_href(lhs)) { - if (!s.get_forward_deps(href_index(lhs)).empty()) { + if (!s.get_direct_forward_deps(href_index(lhs)).empty()) { // TODO(Leo): we don't handle this case yet. // Other hypotheses depend on this equality. return false; diff --git a/tests/lean/run/blast6.lean b/tests/lean/run/blast6.lean new file mode 100644 index 000000000..432c28ab7 --- /dev/null +++ b/tests/lean/run/blast6.lean @@ -0,0 +1,7 @@ +set_option blast.init_depth 10 + +lemma lemma1 (bv : nat → Type) (n m : nat) (H : n = m) (b1 : bv n) (b2 : bv m) (H2 : eq.rec_on H b1 = b2) : b1 = eq.rec_on (eq.symm H) b2 := +by blast + +reveal lemma1 +print lemma1