feat(library/blast): branch extensions

This commit is contained in:
Leonardo de Moura 2015-11-18 15:30:59 -08:00
parent bf5fe8d180
commit 6823a34935
2 changed files with 188 additions and 2 deletions

View file

@ -29,6 +29,115 @@ bool metavar_decl::restrict_context_using(metavar_decl const & other) {
return !to_erase.empty();
}
class extension_manager {
std::vector<branch_extension *> m_exts;
public:
~extension_manager() {
for (auto ext : m_exts)
ext->dec_ref();
m_exts.clear();
}
unsigned register_extension(branch_extension * ext) {
ext->inc_ref();
unsigned r = m_exts.size();
m_exts.push_back(ext);
return r;
}
bool has_ext(unsigned extid) const {
return extid < m_exts.size();
}
branch_extension * get_initial(unsigned extid) {
return m_exts[extid];
}
unsigned get_num_extensions() const {
return m_exts.size();
}
};
static extension_manager * g_extension_manager = nullptr;
static extension_manager & get_extension_manager() {
return *g_extension_manager;
}
branch::branch() {
unsigned n = get_extension_manager().get_num_extensions();
m_extensions = new branch_extension*[n];
for (unsigned i = 0; i < n; i++)
m_extensions[i] = nullptr;
}
branch::~branch() {
unsigned n = get_extension_manager().get_num_extensions();
for (unsigned i = 0; i < n; i++) {
if (m_extensions[i])
m_extensions[i]->dec_ref();
}
delete m_extensions;
}
branch::branch(branch const & b):
m_hyp_decls(b.m_hyp_decls),
m_assumption(b.m_assumption),
m_active(b.m_active),
m_todo_queue(b.m_todo_queue),
m_rec_queue(b.m_rec_queue),
m_head_to_hyps(b.m_head_to_hyps),
m_forward_deps(b.m_forward_deps),
m_target(b.m_target),
m_target_deps(b.m_target_deps),
m_simp_rule_sets(b.m_simp_rule_sets) {
unsigned n = get_extension_manager().get_num_extensions();
m_extensions = new branch_extension*[n];
for (unsigned i = 0; i < n; i++) {
m_extensions[i] = b.m_extensions[i];
if (m_extensions[i])
m_extensions[i]->inc_ref();
}
}
branch::branch(branch && b):
m_hyp_decls(std::move(b.m_hyp_decls)),
m_assumption(std::move(b.m_assumption)),
m_active(std::move(b.m_active)),
m_todo_queue(std::move(b.m_todo_queue)),
m_rec_queue(std::move(b.m_rec_queue)),
m_head_to_hyps(std::move(b.m_head_to_hyps)),
m_forward_deps(std::move(b.m_forward_deps)),
m_target(std::move(b.m_target)),
m_target_deps(std::move(b.m_target_deps)),
m_simp_rule_sets(std::move(b.m_simp_rule_sets)) {
unsigned n = get_extension_manager().get_num_extensions();
m_extensions = new branch_extension*[n];
for (unsigned i = 0; i < n; i++) {
m_extensions[i] = b.m_extensions[i];
b.m_extensions[i] = nullptr;
}
}
void branch::swap(branch & b) {
std::swap(m_hyp_decls, b.m_hyp_decls);
std::swap(m_assumption, b.m_assumption);
std::swap(m_active, b.m_active);
std::swap(m_todo_queue, b.m_todo_queue);
std::swap(m_rec_queue, b.m_rec_queue);
std::swap(m_head_to_hyps, b.m_head_to_hyps);
std::swap(m_forward_deps, b.m_forward_deps);
std::swap(m_target, b.m_target);
std::swap(m_target_deps, b.m_target_deps);
std::swap(m_simp_rule_sets, b.m_simp_rule_sets);
std::swap(m_extensions, b.m_extensions);
}
branch & branch::operator=(branch s) {
swap(s);
return *this;
}
state::state() {}
expr state::mk_metavar(hypothesis_idx_set const & c, expr const & type) {
@ -581,16 +690,59 @@ list<hypothesis_idx> state::get_head_related() const {
return list<hypothesis_idx>();
}
branch_extension * state::get_extension_core(unsigned i) {
branch_extension * ext = m_branch.m_extensions[i];
if (ext && ext->get_rc() > 1) {
branch_extension * new_ext = ext->clone();
new_ext->inc_ref();
ext->dec_ref();
m_branch.m_extensions[i] = new_ext;
return new_ext;
}
return ext;
}
branch_extension & state::get_extension(unsigned extid) {
lean_assert(extid < get_extension_manager().get_num_extensions());
if (!m_branch.m_extensions[extid]) {
/* lazy initialization */
branch_extension * ext = get_extension_manager().get_initial(extid)->clone();;
ext->inc_ref();
m_branch.m_extensions[extid] = ext;
lean_assert(ext.get_rc() == 1);
m_branch.m_active.for_each([&](hypothesis_idx hidx) {
hypothesis const * h = get_hypothesis_decl(hidx);
lean_assert(h);
ext->hypothesis_activated(*h, hidx);
});
return *ext;
} else {
branch_extension * ext = get_extension_core(extid);
lean_assert(ext);
return *ext;
}
}
void state::update_indices(hypothesis_idx hidx) {
hypothesis const * h = get_hypothesis_decl(hidx);
lean_assert(h);
/* update m_head_to_hyps */
if (auto i = to_head_index(*h))
m_branch.m_head_to_hyps.insert(*i, hidx);
unsigned n = get_extension_manager().get_num_extensions();
for (unsigned i = 0; i < n; i++) {
branch_extension * ext = get_extension_core(i);
if (ext) ext->hypothesis_activated(*h, hidx);
}
/* TODO(Leo): update congruence closure indices */
}
void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
unsigned n = get_extension_manager().get_num_extensions();
for (unsigned i = 0; i < n; i++) {
branch_extension * ext = get_extension_core(i);
if (ext) ext->hypothesis_deleted(h, hidx);
}
if (auto i = to_head_index(h))
m_branch.m_head_to_hyps.erase(*i, hidx);
}
@ -710,6 +862,7 @@ expr state::mk_pi(list<expr> const & hrefs, expr const & b) const {
}
void initialize_state() {
g_extension_manager = new extension_manager();
g_prefix = new name(name::mk_internal_unique_name());
g_H = new name("H");
}
@ -717,5 +870,6 @@ void initialize_state() {
void finalize_state() {
delete g_prefix;
delete g_H;
delete g_extension_manager;
}
}}

View file

@ -96,10 +96,25 @@ public:
}
};
/** \brief Actions that require additional indexing data-structures may store them
at a branch_extension */
class branch_extension {
MK_LEAN_RC();
void dealloc() { delete this; }
public:
virtual ~branch_extension() {}
virtual branch_extension * clone() = 0;
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) = 0;
virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx hidx) = 0;
};
unsigned register_branch_extension(branch_extension * initial);
/** \brief Information associated with the current branch of the proof state.
This is essentially a mechanism for creating snapshots of the current branch. */
class branch {
friend class state;
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
struct inv_double_cmp {
@ -108,6 +123,7 @@ class branch {
typedef rb_map<double, hypothesis_idx, inv_double_cmp> priority_queue;
typedef priority_queue todo_queue;
typedef priority_queue rec_queue;
// Hypothesis/facts in the current state
hypothesis_decls m_hyp_decls;
// We break the set of hypotheses in m_hyp_decls in 4 sets that are not necessarily disjoint:
@ -134,6 +150,14 @@ class branch {
expr m_target;
hypothesis_idx_set m_target_deps;
simp_rule_sets m_simp_rule_sets;
branch_extension ** m_extensions;
public:
branch();
branch(branch const & b);
branch(branch && b);
~branch();
void swap(branch & b);
branch & operator=(branch s);
};
/** \brief Proof state for the blast tactic */
@ -177,6 +201,8 @@ class state {
void display_active(output_channel & out) const;
branch_extension * get_extension_core(unsigned i);
#ifdef LEAN_DEBUG
bool check_hypothesis(expr const & e, hypothesis_idx hidx, hypothesis const & h) const;
bool check_hypothesis(hypothesis_idx hidx, hypothesis const & h) const;
@ -425,6 +451,12 @@ public:
return m_branch.m_simp_rule_sets;
}
/************************
Branch extensions
*************************/
branch_extension & get_extension(unsigned extid);
/************************
Debugging support
*************************/