refactor(library/blast): use branch_extension for recursor_action
This commit is contained in:
parent
6823a34935
commit
223e77c4a9
6 changed files with 52 additions and 35 deletions
|
@ -22,6 +22,13 @@ typedef buffer<unsigned> hypothesis_idx_buffer;
|
||||||
template<typename T>
|
template<typename T>
|
||||||
using hypothesis_idx_map = typename lean::rb_map<unsigned, T, unsigned_cmp>;
|
using hypothesis_idx_map = typename lean::rb_map<unsigned, T, unsigned_cmp>;
|
||||||
|
|
||||||
|
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
|
||||||
|
struct inv_double_cmp {
|
||||||
|
int operator()(double const & d1, double const & d2) const { return d1 > d2 ? -1 : (d1 < d2 ? 1 : 0); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef rb_map<double, hypothesis_idx, inv_double_cmp> hypothesis_priority_queue;
|
||||||
|
|
||||||
class hypothesis {
|
class hypothesis {
|
||||||
friend class state;
|
friend class state;
|
||||||
name m_name; // for pretty printing
|
name m_name; // for pretty printing
|
||||||
|
|
|
@ -10,6 +10,7 @@ Author: Leonardo de Moura
|
||||||
#include "library/blast/blast_tactic.h"
|
#include "library/blast/blast_tactic.h"
|
||||||
#include "library/blast/simplifier.h"
|
#include "library/blast/simplifier.h"
|
||||||
#include "library/blast/options.h"
|
#include "library/blast/options.h"
|
||||||
|
#include "library/blast/recursor_action.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
void initialize_blast_module() {
|
void initialize_blast_module() {
|
||||||
|
@ -19,8 +20,10 @@ void initialize_blast_module() {
|
||||||
initialize_blast();
|
initialize_blast();
|
||||||
blast::initialize_simplifier();
|
blast::initialize_simplifier();
|
||||||
initialize_blast_tactic();
|
initialize_blast_tactic();
|
||||||
|
blast::initialize_recursor_action();
|
||||||
}
|
}
|
||||||
void finalize_blast_module() {
|
void finalize_blast_module() {
|
||||||
|
blast::finalize_recursor_action();
|
||||||
finalize_blast_tactic();
|
finalize_blast_tactic();
|
||||||
blast::finalize_simplifier();
|
blast::finalize_simplifier();
|
||||||
finalize_blast();
|
finalize_blast();
|
||||||
|
|
|
@ -13,6 +13,28 @@ Author: Leonardo de Moura
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
namespace blast {
|
namespace blast {
|
||||||
|
static unsigned g_ext_id = 0;
|
||||||
|
struct recursor_branch_extension : public branch_extension {
|
||||||
|
hypothesis_priority_queue m_rec_queue;
|
||||||
|
recursor_branch_extension() {}
|
||||||
|
recursor_branch_extension(recursor_branch_extension const & b):m_rec_queue(b.m_rec_queue) {}
|
||||||
|
virtual ~recursor_branch_extension() {}
|
||||||
|
virtual branch_extension * clone() { return new recursor_branch_extension(*this); }
|
||||||
|
virtual void hypothesis_activated(hypothesis const &, hypothesis_idx) {}
|
||||||
|
virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
void initialize_recursor_action() {
|
||||||
|
g_ext_id = register_branch_extension(new recursor_branch_extension());
|
||||||
|
}
|
||||||
|
|
||||||
|
void finalize_recursor_action() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static recursor_branch_extension & get_extension() {
|
||||||
|
return static_cast<recursor_branch_extension&>(curr_state().get_extension(g_ext_id));
|
||||||
|
}
|
||||||
|
|
||||||
optional<name> is_recursor_action_target(hypothesis_idx hidx) {
|
optional<name> is_recursor_action_target(hypothesis_idx hidx) {
|
||||||
state & s = curr_state();
|
state & s = curr_state();
|
||||||
hypothesis const * h = s.get_hypothesis_decl(hidx);
|
hypothesis const * h = s.get_hypothesis_decl(hidx);
|
||||||
|
@ -312,7 +334,8 @@ action_result recursor_preprocess_action(hypothesis_idx hidx) {
|
||||||
if (!is_recursive_recursor(*R)) {
|
if (!is_recursive_recursor(*R)) {
|
||||||
// TODO(Leo): we need a better strategy for handling recursive recursors...
|
// TODO(Leo): we need a better strategy for handling recursive recursors...
|
||||||
w += static_cast<double>(num_minor);
|
w += static_cast<double>(num_minor);
|
||||||
curr_state().add_to_rec_queue(hidx, w);
|
recursor_branch_extension & ext = get_extension();
|
||||||
|
ext.m_rec_queue.insert(w, hidx);
|
||||||
return action_result::new_branch();
|
return action_result::new_branch();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -321,11 +344,17 @@ action_result recursor_preprocess_action(hypothesis_idx hidx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
action_result recursor_action() {
|
action_result recursor_action() {
|
||||||
while (auto hidx = curr_state().select_rec_hypothesis()) {
|
recursor_branch_extension & ext = get_extension();
|
||||||
if (optional<name> R = is_recursor_action_target(*hidx)) {
|
while (true) {
|
||||||
Try(recursor_action(*hidx, *R));
|
if (ext.m_rec_queue.empty())
|
||||||
|
return action_result::failed();
|
||||||
|
unsigned hidx = ext.m_rec_queue.erase_min();
|
||||||
|
hypothesis const * h_decl = curr_state().get_hypothesis_decl(hidx);
|
||||||
|
if (h_decl->is_dead())
|
||||||
|
continue;
|
||||||
|
if (optional<name> R = is_recursor_action_target(hidx)) {
|
||||||
|
Try(recursor_action(hidx, *R));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return action_result::failed();
|
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
|
|
@ -11,4 +11,7 @@ namespace lean {
|
||||||
namespace blast {
|
namespace blast {
|
||||||
action_result recursor_preprocess_action(hypothesis_idx hidx);
|
action_result recursor_preprocess_action(hypothesis_idx hidx);
|
||||||
action_result recursor_action();
|
action_result recursor_action();
|
||||||
|
|
||||||
|
void initialize_recursor_action();
|
||||||
|
void finalize_recursor_action();
|
||||||
}}
|
}}
|
||||||
|
|
|
@ -64,6 +64,10 @@ static extension_manager & get_extension_manager() {
|
||||||
return *g_extension_manager;
|
return *g_extension_manager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned register_branch_extension(branch_extension * initial) {
|
||||||
|
return get_extension_manager().register_extension(initial);
|
||||||
|
}
|
||||||
|
|
||||||
branch::branch() {
|
branch::branch() {
|
||||||
unsigned n = get_extension_manager().get_num_extensions();
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
m_extensions = new branch_extension*[n];
|
m_extensions = new branch_extension*[n];
|
||||||
|
@ -85,7 +89,6 @@ branch::branch(branch const & b):
|
||||||
m_assumption(b.m_assumption),
|
m_assumption(b.m_assumption),
|
||||||
m_active(b.m_active),
|
m_active(b.m_active),
|
||||||
m_todo_queue(b.m_todo_queue),
|
m_todo_queue(b.m_todo_queue),
|
||||||
m_rec_queue(b.m_rec_queue),
|
|
||||||
m_head_to_hyps(b.m_head_to_hyps),
|
m_head_to_hyps(b.m_head_to_hyps),
|
||||||
m_forward_deps(b.m_forward_deps),
|
m_forward_deps(b.m_forward_deps),
|
||||||
m_target(b.m_target),
|
m_target(b.m_target),
|
||||||
|
@ -105,7 +108,6 @@ branch::branch(branch && b):
|
||||||
m_assumption(std::move(b.m_assumption)),
|
m_assumption(std::move(b.m_assumption)),
|
||||||
m_active(std::move(b.m_active)),
|
m_active(std::move(b.m_active)),
|
||||||
m_todo_queue(std::move(b.m_todo_queue)),
|
m_todo_queue(std::move(b.m_todo_queue)),
|
||||||
m_rec_queue(std::move(b.m_rec_queue)),
|
|
||||||
m_head_to_hyps(std::move(b.m_head_to_hyps)),
|
m_head_to_hyps(std::move(b.m_head_to_hyps)),
|
||||||
m_forward_deps(std::move(b.m_forward_deps)),
|
m_forward_deps(std::move(b.m_forward_deps)),
|
||||||
m_target(std::move(b.m_target)),
|
m_target(std::move(b.m_target)),
|
||||||
|
@ -124,7 +126,6 @@ void branch::swap(branch & b) {
|
||||||
std::swap(m_assumption, b.m_assumption);
|
std::swap(m_assumption, b.m_assumption);
|
||||||
std::swap(m_active, b.m_active);
|
std::swap(m_active, b.m_active);
|
||||||
std::swap(m_todo_queue, b.m_todo_queue);
|
std::swap(m_todo_queue, b.m_todo_queue);
|
||||||
std::swap(m_rec_queue, b.m_rec_queue);
|
|
||||||
std::swap(m_head_to_hyps, b.m_head_to_hyps);
|
std::swap(m_head_to_hyps, b.m_head_to_hyps);
|
||||||
std::swap(m_forward_deps, b.m_forward_deps);
|
std::swap(m_forward_deps, b.m_forward_deps);
|
||||||
std::swap(m_target, b.m_target);
|
std::swap(m_target, b.m_target);
|
||||||
|
@ -761,21 +762,6 @@ optional<unsigned> state::activate_hypothesis() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<unsigned> state::select_rec_hypothesis() {
|
|
||||||
while (true) {
|
|
||||||
if (m_branch.m_rec_queue.empty())
|
|
||||||
return optional<unsigned>();
|
|
||||||
unsigned hidx = m_branch.m_rec_queue.erase_min();
|
|
||||||
hypothesis const * h_decl = get_hypothesis_decl(hidx);
|
|
||||||
if (!h_decl->is_dead())
|
|
||||||
return optional<unsigned>(hidx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void state::add_to_rec_queue(hypothesis_idx hidx, double w) {
|
|
||||||
m_branch.m_rec_queue.insert(w, hidx);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool state::hidx_depends_on(unsigned hidx_user, unsigned hidx_provider) const {
|
bool state::hidx_depends_on(unsigned hidx_user, unsigned hidx_provider) const {
|
||||||
if (auto s = m_branch.m_forward_deps.find(hidx_provider)) {
|
if (auto s = m_branch.m_forward_deps.find(hidx_provider)) {
|
||||||
return s->contains(hidx_user);
|
return s->contains(hidx_user);
|
||||||
|
|
|
@ -116,13 +116,7 @@ class branch {
|
||||||
friend class state;
|
friend class state;
|
||||||
|
|
||||||
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
|
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
|
||||||
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
|
typedef hypothesis_priority_queue todo_queue;
|
||||||
struct inv_double_cmp {
|
|
||||||
int operator()(double const & d1, double const & d2) const { return d1 > d2 ? -1 : (d1 < d2 ? 1 : 0); }
|
|
||||||
};
|
|
||||||
typedef rb_map<double, hypothesis_idx, inv_double_cmp> priority_queue;
|
|
||||||
typedef priority_queue todo_queue;
|
|
||||||
typedef priority_queue rec_queue;
|
|
||||||
|
|
||||||
// Hypothesis/facts in the current state
|
// Hypothesis/facts in the current state
|
||||||
hypothesis_decls m_hyp_decls;
|
hypothesis_decls m_hyp_decls;
|
||||||
|
@ -144,7 +138,6 @@ class branch {
|
||||||
hypothesis_idx_set m_assumption;
|
hypothesis_idx_set m_assumption;
|
||||||
hypothesis_idx_set m_active;
|
hypothesis_idx_set m_active;
|
||||||
todo_queue m_todo_queue;
|
todo_queue m_todo_queue;
|
||||||
rec_queue m_rec_queue; // priority queue for hypothesis we want to eliminate/recurse
|
|
||||||
head_map<hypothesis_idx> m_head_to_hyps;
|
head_map<hypothesis_idx> m_head_to_hyps;
|
||||||
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
|
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
|
||||||
expr m_target;
|
expr m_target;
|
||||||
|
@ -289,10 +282,6 @@ public:
|
||||||
/** \brief Activate the next hypothesis in the TODO queue, return none if the TODO queue is empty. */
|
/** \brief Activate the next hypothesis in the TODO queue, return none if the TODO queue is empty. */
|
||||||
optional<hypothesis_idx> activate_hypothesis();
|
optional<hypothesis_idx> activate_hypothesis();
|
||||||
|
|
||||||
/** \brief Pick next hypothesis from the rec queue */
|
|
||||||
optional<hypothesis_idx> select_rec_hypothesis();
|
|
||||||
void add_to_rec_queue(hypothesis_idx hidx, double w);
|
|
||||||
|
|
||||||
/** \brief Store in \c r the hypotheses in this branch sorted by dependency depth */
|
/** \brief Store in \c r the hypotheses in this branch sorted by dependency depth */
|
||||||
void get_sorted_hypotheses(hypothesis_idx_buffer & r) const;
|
void get_sorted_hypotheses(hypothesis_idx_buffer & r) const;
|
||||||
/** \brief Sort hypotheses in r */
|
/** \brief Sort hypotheses in r */
|
||||||
|
|
Loading…
Reference in a new issue