refactor(library/blast): use branch_extension for recursor_action

This commit is contained in:
Leonardo de Moura 2015-11-18 15:49:02 -08:00
parent 6823a34935
commit 223e77c4a9
6 changed files with 52 additions and 35 deletions

View file

@ -22,6 +22,13 @@ typedef buffer<unsigned> hypothesis_idx_buffer;
template<typename T> template<typename T>
using hypothesis_idx_map = typename lean::rb_map<unsigned, T, unsigned_cmp>; using hypothesis_idx_map = typename lean::rb_map<unsigned, T, unsigned_cmp>;
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
struct inv_double_cmp {
int operator()(double const & d1, double const & d2) const { return d1 > d2 ? -1 : (d1 < d2 ? 1 : 0); }
};
typedef rb_map<double, hypothesis_idx, inv_double_cmp> hypothesis_priority_queue;
class hypothesis { class hypothesis {
friend class state; friend class state;
name m_name; // for pretty printing name m_name; // for pretty printing

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
#include "library/blast/blast_tactic.h" #include "library/blast/blast_tactic.h"
#include "library/blast/simplifier.h" #include "library/blast/simplifier.h"
#include "library/blast/options.h" #include "library/blast/options.h"
#include "library/blast/recursor_action.h"
namespace lean { namespace lean {
void initialize_blast_module() { void initialize_blast_module() {
@ -19,8 +20,10 @@ void initialize_blast_module() {
initialize_blast(); initialize_blast();
blast::initialize_simplifier(); blast::initialize_simplifier();
initialize_blast_tactic(); initialize_blast_tactic();
blast::initialize_recursor_action();
} }
void finalize_blast_module() { void finalize_blast_module() {
blast::finalize_recursor_action();
finalize_blast_tactic(); finalize_blast_tactic();
blast::finalize_simplifier(); blast::finalize_simplifier();
finalize_blast(); finalize_blast();

View file

@ -13,6 +13,28 @@ Author: Leonardo de Moura
namespace lean { namespace lean {
namespace blast { namespace blast {
static unsigned g_ext_id = 0;
struct recursor_branch_extension : public branch_extension {
hypothesis_priority_queue m_rec_queue;
recursor_branch_extension() {}
recursor_branch_extension(recursor_branch_extension const & b):m_rec_queue(b.m_rec_queue) {}
virtual ~recursor_branch_extension() {}
virtual branch_extension * clone() { return new recursor_branch_extension(*this); }
virtual void hypothesis_activated(hypothesis const &, hypothesis_idx) {}
virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) {}
};
void initialize_recursor_action() {
g_ext_id = register_branch_extension(new recursor_branch_extension());
}
void finalize_recursor_action() {
}
static recursor_branch_extension & get_extension() {
return static_cast<recursor_branch_extension&>(curr_state().get_extension(g_ext_id));
}
optional<name> is_recursor_action_target(hypothesis_idx hidx) { optional<name> is_recursor_action_target(hypothesis_idx hidx) {
state & s = curr_state(); state & s = curr_state();
hypothesis const * h = s.get_hypothesis_decl(hidx); hypothesis const * h = s.get_hypothesis_decl(hidx);
@ -312,7 +334,8 @@ action_result recursor_preprocess_action(hypothesis_idx hidx) {
if (!is_recursive_recursor(*R)) { if (!is_recursive_recursor(*R)) {
// TODO(Leo): we need a better strategy for handling recursive recursors... // TODO(Leo): we need a better strategy for handling recursive recursors...
w += static_cast<double>(num_minor); w += static_cast<double>(num_minor);
curr_state().add_to_rec_queue(hidx, w); recursor_branch_extension & ext = get_extension();
ext.m_rec_queue.insert(w, hidx);
return action_result::new_branch(); return action_result::new_branch();
} }
} }
@ -321,11 +344,17 @@ action_result recursor_preprocess_action(hypothesis_idx hidx) {
} }
action_result recursor_action() { action_result recursor_action() {
while (auto hidx = curr_state().select_rec_hypothesis()) { recursor_branch_extension & ext = get_extension();
if (optional<name> R = is_recursor_action_target(*hidx)) { while (true) {
Try(recursor_action(*hidx, *R)); if (ext.m_rec_queue.empty())
return action_result::failed();
unsigned hidx = ext.m_rec_queue.erase_min();
hypothesis const * h_decl = curr_state().get_hypothesis_decl(hidx);
if (h_decl->is_dead())
continue;
if (optional<name> R = is_recursor_action_target(hidx)) {
Try(recursor_action(hidx, *R));
} }
} }
return action_result::failed();
} }
}} }}

View file

@ -11,4 +11,7 @@ namespace lean {
namespace blast { namespace blast {
action_result recursor_preprocess_action(hypothesis_idx hidx); action_result recursor_preprocess_action(hypothesis_idx hidx);
action_result recursor_action(); action_result recursor_action();
void initialize_recursor_action();
void finalize_recursor_action();
}} }}

View file

@ -64,6 +64,10 @@ static extension_manager & get_extension_manager() {
return *g_extension_manager; return *g_extension_manager;
} }
unsigned register_branch_extension(branch_extension * initial) {
return get_extension_manager().register_extension(initial);
}
branch::branch() { branch::branch() {
unsigned n = get_extension_manager().get_num_extensions(); unsigned n = get_extension_manager().get_num_extensions();
m_extensions = new branch_extension*[n]; m_extensions = new branch_extension*[n];
@ -85,7 +89,6 @@ branch::branch(branch const & b):
m_assumption(b.m_assumption), m_assumption(b.m_assumption),
m_active(b.m_active), m_active(b.m_active),
m_todo_queue(b.m_todo_queue), m_todo_queue(b.m_todo_queue),
m_rec_queue(b.m_rec_queue),
m_head_to_hyps(b.m_head_to_hyps), m_head_to_hyps(b.m_head_to_hyps),
m_forward_deps(b.m_forward_deps), m_forward_deps(b.m_forward_deps),
m_target(b.m_target), m_target(b.m_target),
@ -105,7 +108,6 @@ branch::branch(branch && b):
m_assumption(std::move(b.m_assumption)), m_assumption(std::move(b.m_assumption)),
m_active(std::move(b.m_active)), m_active(std::move(b.m_active)),
m_todo_queue(std::move(b.m_todo_queue)), m_todo_queue(std::move(b.m_todo_queue)),
m_rec_queue(std::move(b.m_rec_queue)),
m_head_to_hyps(std::move(b.m_head_to_hyps)), m_head_to_hyps(std::move(b.m_head_to_hyps)),
m_forward_deps(std::move(b.m_forward_deps)), m_forward_deps(std::move(b.m_forward_deps)),
m_target(std::move(b.m_target)), m_target(std::move(b.m_target)),
@ -124,7 +126,6 @@ void branch::swap(branch & b) {
std::swap(m_assumption, b.m_assumption); std::swap(m_assumption, b.m_assumption);
std::swap(m_active, b.m_active); std::swap(m_active, b.m_active);
std::swap(m_todo_queue, b.m_todo_queue); std::swap(m_todo_queue, b.m_todo_queue);
std::swap(m_rec_queue, b.m_rec_queue);
std::swap(m_head_to_hyps, b.m_head_to_hyps); std::swap(m_head_to_hyps, b.m_head_to_hyps);
std::swap(m_forward_deps, b.m_forward_deps); std::swap(m_forward_deps, b.m_forward_deps);
std::swap(m_target, b.m_target); std::swap(m_target, b.m_target);
@ -761,21 +762,6 @@ optional<unsigned> state::activate_hypothesis() {
} }
} }
optional<unsigned> state::select_rec_hypothesis() {
while (true) {
if (m_branch.m_rec_queue.empty())
return optional<unsigned>();
unsigned hidx = m_branch.m_rec_queue.erase_min();
hypothesis const * h_decl = get_hypothesis_decl(hidx);
if (!h_decl->is_dead())
return optional<unsigned>(hidx);
}
}
void state::add_to_rec_queue(hypothesis_idx hidx, double w) {
m_branch.m_rec_queue.insert(w, hidx);
}
bool state::hidx_depends_on(unsigned hidx_user, unsigned hidx_provider) const { bool state::hidx_depends_on(unsigned hidx_user, unsigned hidx_provider) const {
if (auto s = m_branch.m_forward_deps.find(hidx_provider)) { if (auto s = m_branch.m_forward_deps.find(hidx_provider)) {
return s->contains(hidx_user); return s->contains(hidx_user);

View file

@ -116,13 +116,7 @@ class branch {
friend class state; friend class state;
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps; typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */ typedef hypothesis_priority_queue todo_queue;
struct inv_double_cmp {
int operator()(double const & d1, double const & d2) const { return d1 > d2 ? -1 : (d1 < d2 ? 1 : 0); }
};
typedef rb_map<double, hypothesis_idx, inv_double_cmp> priority_queue;
typedef priority_queue todo_queue;
typedef priority_queue rec_queue;
// Hypothesis/facts in the current state // Hypothesis/facts in the current state
hypothesis_decls m_hyp_decls; hypothesis_decls m_hyp_decls;
@ -144,7 +138,6 @@ class branch {
hypothesis_idx_set m_assumption; hypothesis_idx_set m_assumption;
hypothesis_idx_set m_active; hypothesis_idx_set m_active;
todo_queue m_todo_queue; todo_queue m_todo_queue;
rec_queue m_rec_queue; // priority queue for hypothesis we want to eliminate/recurse
head_map<hypothesis_idx> m_head_to_hyps; head_map<hypothesis_idx> m_head_to_hyps;
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h. forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
expr m_target; expr m_target;
@ -289,10 +282,6 @@ public:
/** \brief Activate the next hypothesis in the TODO queue, return none if the TODO queue is empty. */ /** \brief Activate the next hypothesis in the TODO queue, return none if the TODO queue is empty. */
optional<hypothesis_idx> activate_hypothesis(); optional<hypothesis_idx> activate_hypothesis();
/** \brief Pick next hypothesis from the rec queue */
optional<hypothesis_idx> select_rec_hypothesis();
void add_to_rec_queue(hypothesis_idx hidx, double w);
/** \brief Store in \c r the hypotheses in this branch sorted by dependency depth */ /** \brief Store in \c r the hypotheses in this branch sorted by dependency depth */
void get_sorted_hypotheses(hypothesis_idx_buffer & r) const; void get_sorted_hypotheses(hypothesis_idx_buffer & r) const;
/** \brief Sort hypotheses in r */ /** \brief Sort hypotheses in r */