diff --git a/src/library/blast/hypothesis.h b/src/library/blast/hypothesis.h index 336569c84..d14336291 100644 --- a/src/library/blast/hypothesis.h +++ b/src/library/blast/hypothesis.h @@ -22,6 +22,13 @@ typedef buffer hypothesis_idx_buffer; template using hypothesis_idx_map = typename lean::rb_map; +/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */ +struct inv_double_cmp { + int operator()(double const & d1, double const & d2) const { return d1 > d2 ? -1 : (d1 < d2 ? 1 : 0); } +}; + +typedef rb_map hypothesis_priority_queue; + class hypothesis { friend class state; name m_name; // for pretty printing diff --git a/src/library/blast/init_module.cpp b/src/library/blast/init_module.cpp index 4513f23c2..2605a072d 100644 --- a/src/library/blast/init_module.cpp +++ b/src/library/blast/init_module.cpp @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "library/blast/blast_tactic.h" #include "library/blast/simplifier.h" #include "library/blast/options.h" +#include "library/blast/recursor_action.h" namespace lean { void initialize_blast_module() { @@ -19,8 +20,10 @@ void initialize_blast_module() { initialize_blast(); blast::initialize_simplifier(); initialize_blast_tactic(); + blast::initialize_recursor_action(); } void finalize_blast_module() { + blast::finalize_recursor_action(); finalize_blast_tactic(); blast::finalize_simplifier(); finalize_blast(); diff --git a/src/library/blast/recursor_action.cpp b/src/library/blast/recursor_action.cpp index 21956d33b..385a0a7e7 100644 --- a/src/library/blast/recursor_action.cpp +++ b/src/library/blast/recursor_action.cpp @@ -13,6 +13,28 @@ Author: Leonardo de Moura namespace lean { namespace blast { +static unsigned g_ext_id = 0; +struct recursor_branch_extension : public branch_extension { + hypothesis_priority_queue m_rec_queue; + recursor_branch_extension() {} + recursor_branch_extension(recursor_branch_extension const & b):m_rec_queue(b.m_rec_queue) {} + virtual ~recursor_branch_extension() {} + virtual branch_extension * clone() { return new recursor_branch_extension(*this); } + virtual void hypothesis_activated(hypothesis const &, hypothesis_idx) {} + virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) {} +}; + +void initialize_recursor_action() { + g_ext_id = register_branch_extension(new recursor_branch_extension()); +} + +void finalize_recursor_action() { +} + +static recursor_branch_extension & get_extension() { + return static_cast(curr_state().get_extension(g_ext_id)); +} + optional is_recursor_action_target(hypothesis_idx hidx) { state & s = curr_state(); hypothesis const * h = s.get_hypothesis_decl(hidx); @@ -312,7 +334,8 @@ action_result recursor_preprocess_action(hypothesis_idx hidx) { if (!is_recursive_recursor(*R)) { // TODO(Leo): we need a better strategy for handling recursive recursors... w += static_cast(num_minor); - curr_state().add_to_rec_queue(hidx, w); + recursor_branch_extension & ext = get_extension(); + ext.m_rec_queue.insert(w, hidx); return action_result::new_branch(); } } @@ -321,11 +344,17 @@ action_result recursor_preprocess_action(hypothesis_idx hidx) { } action_result recursor_action() { - while (auto hidx = curr_state().select_rec_hypothesis()) { - if (optional R = is_recursor_action_target(*hidx)) { - Try(recursor_action(*hidx, *R)); + recursor_branch_extension & ext = get_extension(); + while (true) { + if (ext.m_rec_queue.empty()) + return action_result::failed(); + unsigned hidx = ext.m_rec_queue.erase_min(); + hypothesis const * h_decl = curr_state().get_hypothesis_decl(hidx); + if (h_decl->is_dead()) + continue; + if (optional R = is_recursor_action_target(hidx)) { + Try(recursor_action(hidx, *R)); } } - return action_result::failed(); } }} diff --git a/src/library/blast/recursor_action.h b/src/library/blast/recursor_action.h index f10f84dd4..f0764f3fc 100644 --- a/src/library/blast/recursor_action.h +++ b/src/library/blast/recursor_action.h @@ -11,4 +11,7 @@ namespace lean { namespace blast { action_result recursor_preprocess_action(hypothesis_idx hidx); action_result recursor_action(); + +void initialize_recursor_action(); +void finalize_recursor_action(); }} diff --git a/src/library/blast/state.cpp b/src/library/blast/state.cpp index 5d848db3e..5a09a2b40 100644 --- a/src/library/blast/state.cpp +++ b/src/library/blast/state.cpp @@ -64,6 +64,10 @@ static extension_manager & get_extension_manager() { return *g_extension_manager; } +unsigned register_branch_extension(branch_extension * initial) { + return get_extension_manager().register_extension(initial); +} + branch::branch() { unsigned n = get_extension_manager().get_num_extensions(); m_extensions = new branch_extension*[n]; @@ -85,7 +89,6 @@ branch::branch(branch const & b): m_assumption(b.m_assumption), m_active(b.m_active), m_todo_queue(b.m_todo_queue), - m_rec_queue(b.m_rec_queue), m_head_to_hyps(b.m_head_to_hyps), m_forward_deps(b.m_forward_deps), m_target(b.m_target), @@ -105,7 +108,6 @@ branch::branch(branch && b): m_assumption(std::move(b.m_assumption)), m_active(std::move(b.m_active)), m_todo_queue(std::move(b.m_todo_queue)), - m_rec_queue(std::move(b.m_rec_queue)), m_head_to_hyps(std::move(b.m_head_to_hyps)), m_forward_deps(std::move(b.m_forward_deps)), m_target(std::move(b.m_target)), @@ -124,7 +126,6 @@ void branch::swap(branch & b) { std::swap(m_assumption, b.m_assumption); std::swap(m_active, b.m_active); std::swap(m_todo_queue, b.m_todo_queue); - std::swap(m_rec_queue, b.m_rec_queue); std::swap(m_head_to_hyps, b.m_head_to_hyps); std::swap(m_forward_deps, b.m_forward_deps); std::swap(m_target, b.m_target); @@ -761,21 +762,6 @@ optional state::activate_hypothesis() { } } -optional state::select_rec_hypothesis() { - while (true) { - if (m_branch.m_rec_queue.empty()) - return optional(); - unsigned hidx = m_branch.m_rec_queue.erase_min(); - hypothesis const * h_decl = get_hypothesis_decl(hidx); - if (!h_decl->is_dead()) - return optional(hidx); - } -} - -void state::add_to_rec_queue(hypothesis_idx hidx, double w) { - m_branch.m_rec_queue.insert(w, hidx); -} - bool state::hidx_depends_on(unsigned hidx_user, unsigned hidx_provider) const { if (auto s = m_branch.m_forward_deps.find(hidx_provider)) { return s->contains(hidx_user); diff --git a/src/library/blast/state.h b/src/library/blast/state.h index 87ec9753f..8bae1a751 100644 --- a/src/library/blast/state.h +++ b/src/library/blast/state.h @@ -116,13 +116,7 @@ class branch { friend class state; typedef hypothesis_idx_map forward_deps; - /* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */ - struct inv_double_cmp { - int operator()(double const & d1, double const & d2) const { return d1 > d2 ? -1 : (d1 < d2 ? 1 : 0); } - }; - typedef rb_map priority_queue; - typedef priority_queue todo_queue; - typedef priority_queue rec_queue; + typedef hypothesis_priority_queue todo_queue; // Hypothesis/facts in the current state hypothesis_decls m_hyp_decls; @@ -144,7 +138,6 @@ class branch { hypothesis_idx_set m_assumption; hypothesis_idx_set m_active; todo_queue m_todo_queue; - rec_queue m_rec_queue; // priority queue for hypothesis we want to eliminate/recurse head_map m_head_to_hyps; forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h. expr m_target; @@ -289,10 +282,6 @@ public: /** \brief Activate the next hypothesis in the TODO queue, return none if the TODO queue is empty. */ optional activate_hypothesis(); - /** \brief Pick next hypothesis from the rec queue */ - optional select_rec_hypothesis(); - void add_to_rec_queue(hypothesis_idx hidx, double w); - /** \brief Store in \c r the hypotheses in this branch sorted by dependency depth */ void get_sorted_hypotheses(hypothesis_idx_buffer & r) const; /** \brief Sort hypotheses in r */