diff --git a/src/library/blast/state.cpp b/src/library/blast/state.cpp index 231c47acc..5d848db3e 100644 --- a/src/library/blast/state.cpp +++ b/src/library/blast/state.cpp @@ -29,6 +29,115 @@ bool metavar_decl::restrict_context_using(metavar_decl const & other) { return !to_erase.empty(); } +class extension_manager { + std::vector m_exts; +public: + ~extension_manager() { + for (auto ext : m_exts) + ext->dec_ref(); + m_exts.clear(); + } + + unsigned register_extension(branch_extension * ext) { + ext->inc_ref(); + unsigned r = m_exts.size(); + m_exts.push_back(ext); + return r; + } + + bool has_ext(unsigned extid) const { + return extid < m_exts.size(); + } + + branch_extension * get_initial(unsigned extid) { + return m_exts[extid]; + } + + unsigned get_num_extensions() const { + return m_exts.size(); + } +}; + +static extension_manager * g_extension_manager = nullptr; + +static extension_manager & get_extension_manager() { + return *g_extension_manager; +} + +branch::branch() { + unsigned n = get_extension_manager().get_num_extensions(); + m_extensions = new branch_extension*[n]; + for (unsigned i = 0; i < n; i++) + m_extensions[i] = nullptr; +} + +branch::~branch() { + unsigned n = get_extension_manager().get_num_extensions(); + for (unsigned i = 0; i < n; i++) { + if (m_extensions[i]) + m_extensions[i]->dec_ref(); + } + delete m_extensions; +} + +branch::branch(branch const & b): + m_hyp_decls(b.m_hyp_decls), + m_assumption(b.m_assumption), + m_active(b.m_active), + m_todo_queue(b.m_todo_queue), + m_rec_queue(b.m_rec_queue), + m_head_to_hyps(b.m_head_to_hyps), + m_forward_deps(b.m_forward_deps), + m_target(b.m_target), + m_target_deps(b.m_target_deps), + m_simp_rule_sets(b.m_simp_rule_sets) { + unsigned n = get_extension_manager().get_num_extensions(); + m_extensions = new branch_extension*[n]; + for (unsigned i = 0; i < n; i++) { + m_extensions[i] = b.m_extensions[i]; + if (m_extensions[i]) + m_extensions[i]->inc_ref(); + } +} + +branch::branch(branch && b): + m_hyp_decls(std::move(b.m_hyp_decls)), + m_assumption(std::move(b.m_assumption)), + m_active(std::move(b.m_active)), + m_todo_queue(std::move(b.m_todo_queue)), + m_rec_queue(std::move(b.m_rec_queue)), + m_head_to_hyps(std::move(b.m_head_to_hyps)), + m_forward_deps(std::move(b.m_forward_deps)), + m_target(std::move(b.m_target)), + m_target_deps(std::move(b.m_target_deps)), + m_simp_rule_sets(std::move(b.m_simp_rule_sets)) { + unsigned n = get_extension_manager().get_num_extensions(); + m_extensions = new branch_extension*[n]; + for (unsigned i = 0; i < n; i++) { + m_extensions[i] = b.m_extensions[i]; + b.m_extensions[i] = nullptr; + } +} + +void branch::swap(branch & b) { + std::swap(m_hyp_decls, b.m_hyp_decls); + std::swap(m_assumption, b.m_assumption); + std::swap(m_active, b.m_active); + std::swap(m_todo_queue, b.m_todo_queue); + std::swap(m_rec_queue, b.m_rec_queue); + std::swap(m_head_to_hyps, b.m_head_to_hyps); + std::swap(m_forward_deps, b.m_forward_deps); + std::swap(m_target, b.m_target); + std::swap(m_target_deps, b.m_target_deps); + std::swap(m_simp_rule_sets, b.m_simp_rule_sets); + std::swap(m_extensions, b.m_extensions); +} + +branch & branch::operator=(branch s) { + swap(s); + return *this; +} + state::state() {} expr state::mk_metavar(hypothesis_idx_set const & c, expr const & type) { @@ -581,16 +690,59 @@ list state::get_head_related() const { return list(); } +branch_extension * state::get_extension_core(unsigned i) { + branch_extension * ext = m_branch.m_extensions[i]; + if (ext && ext->get_rc() > 1) { + branch_extension * new_ext = ext->clone(); + new_ext->inc_ref(); + ext->dec_ref(); + m_branch.m_extensions[i] = new_ext; + return new_ext; + } + return ext; +} + +branch_extension & state::get_extension(unsigned extid) { + lean_assert(extid < get_extension_manager().get_num_extensions()); + if (!m_branch.m_extensions[extid]) { + /* lazy initialization */ + branch_extension * ext = get_extension_manager().get_initial(extid)->clone();; + ext->inc_ref(); + m_branch.m_extensions[extid] = ext; + lean_assert(ext.get_rc() == 1); + m_branch.m_active.for_each([&](hypothesis_idx hidx) { + hypothesis const * h = get_hypothesis_decl(hidx); + lean_assert(h); + ext->hypothesis_activated(*h, hidx); + }); + return *ext; + } else { + branch_extension * ext = get_extension_core(extid); + lean_assert(ext); + return *ext; + } +} + void state::update_indices(hypothesis_idx hidx) { hypothesis const * h = get_hypothesis_decl(hidx); lean_assert(h); /* update m_head_to_hyps */ if (auto i = to_head_index(*h)) m_branch.m_head_to_hyps.insert(*i, hidx); + unsigned n = get_extension_manager().get_num_extensions(); + for (unsigned i = 0; i < n; i++) { + branch_extension * ext = get_extension_core(i); + if (ext) ext->hypothesis_activated(*h, hidx); + } /* TODO(Leo): update congruence closure indices */ } void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) { + unsigned n = get_extension_manager().get_num_extensions(); + for (unsigned i = 0; i < n; i++) { + branch_extension * ext = get_extension_core(i); + if (ext) ext->hypothesis_deleted(h, hidx); + } if (auto i = to_head_index(h)) m_branch.m_head_to_hyps.erase(*i, hidx); } @@ -710,12 +862,14 @@ expr state::mk_pi(list const & hrefs, expr const & b) const { } void initialize_state() { - g_prefix = new name(name::mk_internal_unique_name()); - g_H = new name("H"); + g_extension_manager = new extension_manager(); + g_prefix = new name(name::mk_internal_unique_name()); + g_H = new name("H"); } void finalize_state() { delete g_prefix; delete g_H; + delete g_extension_manager; } }} diff --git a/src/library/blast/state.h b/src/library/blast/state.h index ae2d2d289..87ec9753f 100644 --- a/src/library/blast/state.h +++ b/src/library/blast/state.h @@ -96,10 +96,25 @@ public: } }; +/** \brief Actions that require additional indexing data-structures may store them + at a branch_extension */ +class branch_extension { + MK_LEAN_RC(); + void dealloc() { delete this; } +public: + virtual ~branch_extension() {} + virtual branch_extension * clone() = 0; + virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) = 0; + virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx hidx) = 0; +}; + +unsigned register_branch_extension(branch_extension * initial); + /** \brief Information associated with the current branch of the proof state. This is essentially a mechanism for creating snapshots of the current branch. */ class branch { friend class state; + typedef hypothesis_idx_map forward_deps; /* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */ struct inv_double_cmp { @@ -108,6 +123,7 @@ class branch { typedef rb_map priority_queue; typedef priority_queue todo_queue; typedef priority_queue rec_queue; + // Hypothesis/facts in the current state hypothesis_decls m_hyp_decls; // We break the set of hypotheses in m_hyp_decls in 4 sets that are not necessarily disjoint: @@ -134,6 +150,14 @@ class branch { expr m_target; hypothesis_idx_set m_target_deps; simp_rule_sets m_simp_rule_sets; + branch_extension ** m_extensions; +public: + branch(); + branch(branch const & b); + branch(branch && b); + ~branch(); + void swap(branch & b); + branch & operator=(branch s); }; /** \brief Proof state for the blast tactic */ @@ -177,6 +201,8 @@ class state { void display_active(output_channel & out) const; + branch_extension * get_extension_core(unsigned i); + #ifdef LEAN_DEBUG bool check_hypothesis(expr const & e, hypothesis_idx hidx, hypothesis const & h) const; bool check_hypothesis(hypothesis_idx hidx, hypothesis const & h) const; @@ -425,6 +451,12 @@ public: return m_branch.m_simp_rule_sets; } + /************************ + Branch extensions + *************************/ + + branch_extension & get_extension(unsigned extid); + /************************ Debugging support *************************/