feat(library/blast): branch extensions
This commit is contained in:
parent
bf5fe8d180
commit
6823a34935
2 changed files with 188 additions and 2 deletions
|
@ -29,6 +29,115 @@ bool metavar_decl::restrict_context_using(metavar_decl const & other) {
|
|||
return !to_erase.empty();
|
||||
}
|
||||
|
||||
class extension_manager {
|
||||
std::vector<branch_extension *> m_exts;
|
||||
public:
|
||||
~extension_manager() {
|
||||
for (auto ext : m_exts)
|
||||
ext->dec_ref();
|
||||
m_exts.clear();
|
||||
}
|
||||
|
||||
unsigned register_extension(branch_extension * ext) {
|
||||
ext->inc_ref();
|
||||
unsigned r = m_exts.size();
|
||||
m_exts.push_back(ext);
|
||||
return r;
|
||||
}
|
||||
|
||||
bool has_ext(unsigned extid) const {
|
||||
return extid < m_exts.size();
|
||||
}
|
||||
|
||||
branch_extension * get_initial(unsigned extid) {
|
||||
return m_exts[extid];
|
||||
}
|
||||
|
||||
unsigned get_num_extensions() const {
|
||||
return m_exts.size();
|
||||
}
|
||||
};
|
||||
|
||||
static extension_manager * g_extension_manager = nullptr;
|
||||
|
||||
static extension_manager & get_extension_manager() {
|
||||
return *g_extension_manager;
|
||||
}
|
||||
|
||||
branch::branch() {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
m_extensions = new branch_extension*[n];
|
||||
for (unsigned i = 0; i < n; i++)
|
||||
m_extensions[i] = nullptr;
|
||||
}
|
||||
|
||||
branch::~branch() {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
if (m_extensions[i])
|
||||
m_extensions[i]->dec_ref();
|
||||
}
|
||||
delete m_extensions;
|
||||
}
|
||||
|
||||
branch::branch(branch const & b):
|
||||
m_hyp_decls(b.m_hyp_decls),
|
||||
m_assumption(b.m_assumption),
|
||||
m_active(b.m_active),
|
||||
m_todo_queue(b.m_todo_queue),
|
||||
m_rec_queue(b.m_rec_queue),
|
||||
m_head_to_hyps(b.m_head_to_hyps),
|
||||
m_forward_deps(b.m_forward_deps),
|
||||
m_target(b.m_target),
|
||||
m_target_deps(b.m_target_deps),
|
||||
m_simp_rule_sets(b.m_simp_rule_sets) {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
m_extensions = new branch_extension*[n];
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
m_extensions[i] = b.m_extensions[i];
|
||||
if (m_extensions[i])
|
||||
m_extensions[i]->inc_ref();
|
||||
}
|
||||
}
|
||||
|
||||
branch::branch(branch && b):
|
||||
m_hyp_decls(std::move(b.m_hyp_decls)),
|
||||
m_assumption(std::move(b.m_assumption)),
|
||||
m_active(std::move(b.m_active)),
|
||||
m_todo_queue(std::move(b.m_todo_queue)),
|
||||
m_rec_queue(std::move(b.m_rec_queue)),
|
||||
m_head_to_hyps(std::move(b.m_head_to_hyps)),
|
||||
m_forward_deps(std::move(b.m_forward_deps)),
|
||||
m_target(std::move(b.m_target)),
|
||||
m_target_deps(std::move(b.m_target_deps)),
|
||||
m_simp_rule_sets(std::move(b.m_simp_rule_sets)) {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
m_extensions = new branch_extension*[n];
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
m_extensions[i] = b.m_extensions[i];
|
||||
b.m_extensions[i] = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void branch::swap(branch & b) {
|
||||
std::swap(m_hyp_decls, b.m_hyp_decls);
|
||||
std::swap(m_assumption, b.m_assumption);
|
||||
std::swap(m_active, b.m_active);
|
||||
std::swap(m_todo_queue, b.m_todo_queue);
|
||||
std::swap(m_rec_queue, b.m_rec_queue);
|
||||
std::swap(m_head_to_hyps, b.m_head_to_hyps);
|
||||
std::swap(m_forward_deps, b.m_forward_deps);
|
||||
std::swap(m_target, b.m_target);
|
||||
std::swap(m_target_deps, b.m_target_deps);
|
||||
std::swap(m_simp_rule_sets, b.m_simp_rule_sets);
|
||||
std::swap(m_extensions, b.m_extensions);
|
||||
}
|
||||
|
||||
branch & branch::operator=(branch s) {
|
||||
swap(s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
state::state() {}
|
||||
|
||||
expr state::mk_metavar(hypothesis_idx_set const & c, expr const & type) {
|
||||
|
@ -581,16 +690,59 @@ list<hypothesis_idx> state::get_head_related() const {
|
|||
return list<hypothesis_idx>();
|
||||
}
|
||||
|
||||
branch_extension * state::get_extension_core(unsigned i) {
|
||||
branch_extension * ext = m_branch.m_extensions[i];
|
||||
if (ext && ext->get_rc() > 1) {
|
||||
branch_extension * new_ext = ext->clone();
|
||||
new_ext->inc_ref();
|
||||
ext->dec_ref();
|
||||
m_branch.m_extensions[i] = new_ext;
|
||||
return new_ext;
|
||||
}
|
||||
return ext;
|
||||
}
|
||||
|
||||
branch_extension & state::get_extension(unsigned extid) {
|
||||
lean_assert(extid < get_extension_manager().get_num_extensions());
|
||||
if (!m_branch.m_extensions[extid]) {
|
||||
/* lazy initialization */
|
||||
branch_extension * ext = get_extension_manager().get_initial(extid)->clone();;
|
||||
ext->inc_ref();
|
||||
m_branch.m_extensions[extid] = ext;
|
||||
lean_assert(ext.get_rc() == 1);
|
||||
m_branch.m_active.for_each([&](hypothesis_idx hidx) {
|
||||
hypothesis const * h = get_hypothesis_decl(hidx);
|
||||
lean_assert(h);
|
||||
ext->hypothesis_activated(*h, hidx);
|
||||
});
|
||||
return *ext;
|
||||
} else {
|
||||
branch_extension * ext = get_extension_core(extid);
|
||||
lean_assert(ext);
|
||||
return *ext;
|
||||
}
|
||||
}
|
||||
|
||||
void state::update_indices(hypothesis_idx hidx) {
|
||||
hypothesis const * h = get_hypothesis_decl(hidx);
|
||||
lean_assert(h);
|
||||
/* update m_head_to_hyps */
|
||||
if (auto i = to_head_index(*h))
|
||||
m_branch.m_head_to_hyps.insert(*i, hidx);
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
branch_extension * ext = get_extension_core(i);
|
||||
if (ext) ext->hypothesis_activated(*h, hidx);
|
||||
}
|
||||
/* TODO(Leo): update congruence closure indices */
|
||||
}
|
||||
|
||||
void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
branch_extension * ext = get_extension_core(i);
|
||||
if (ext) ext->hypothesis_deleted(h, hidx);
|
||||
}
|
||||
if (auto i = to_head_index(h))
|
||||
m_branch.m_head_to_hyps.erase(*i, hidx);
|
||||
}
|
||||
|
@ -710,12 +862,14 @@ expr state::mk_pi(list<expr> const & hrefs, expr const & b) const {
|
|||
}
|
||||
|
||||
void initialize_state() {
|
||||
g_prefix = new name(name::mk_internal_unique_name());
|
||||
g_H = new name("H");
|
||||
g_extension_manager = new extension_manager();
|
||||
g_prefix = new name(name::mk_internal_unique_name());
|
||||
g_H = new name("H");
|
||||
}
|
||||
|
||||
void finalize_state() {
|
||||
delete g_prefix;
|
||||
delete g_H;
|
||||
delete g_extension_manager;
|
||||
}
|
||||
}}
|
||||
|
|
|
@ -96,10 +96,25 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/** \brief Actions that require additional indexing data-structures may store them
|
||||
at a branch_extension */
|
||||
class branch_extension {
|
||||
MK_LEAN_RC();
|
||||
void dealloc() { delete this; }
|
||||
public:
|
||||
virtual ~branch_extension() {}
|
||||
virtual branch_extension * clone() = 0;
|
||||
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) = 0;
|
||||
virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx hidx) = 0;
|
||||
};
|
||||
|
||||
unsigned register_branch_extension(branch_extension * initial);
|
||||
|
||||
/** \brief Information associated with the current branch of the proof state.
|
||||
This is essentially a mechanism for creating snapshots of the current branch. */
|
||||
class branch {
|
||||
friend class state;
|
||||
|
||||
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
|
||||
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
|
||||
struct inv_double_cmp {
|
||||
|
@ -108,6 +123,7 @@ class branch {
|
|||
typedef rb_map<double, hypothesis_idx, inv_double_cmp> priority_queue;
|
||||
typedef priority_queue todo_queue;
|
||||
typedef priority_queue rec_queue;
|
||||
|
||||
// Hypothesis/facts in the current state
|
||||
hypothesis_decls m_hyp_decls;
|
||||
// We break the set of hypotheses in m_hyp_decls in 4 sets that are not necessarily disjoint:
|
||||
|
@ -134,6 +150,14 @@ class branch {
|
|||
expr m_target;
|
||||
hypothesis_idx_set m_target_deps;
|
||||
simp_rule_sets m_simp_rule_sets;
|
||||
branch_extension ** m_extensions;
|
||||
public:
|
||||
branch();
|
||||
branch(branch const & b);
|
||||
branch(branch && b);
|
||||
~branch();
|
||||
void swap(branch & b);
|
||||
branch & operator=(branch s);
|
||||
};
|
||||
|
||||
/** \brief Proof state for the blast tactic */
|
||||
|
@ -177,6 +201,8 @@ class state {
|
|||
|
||||
void display_active(output_channel & out) const;
|
||||
|
||||
branch_extension * get_extension_core(unsigned i);
|
||||
|
||||
#ifdef LEAN_DEBUG
|
||||
bool check_hypothesis(expr const & e, hypothesis_idx hidx, hypothesis const & h) const;
|
||||
bool check_hypothesis(hypothesis_idx hidx, hypothesis const & h) const;
|
||||
|
@ -425,6 +451,12 @@ public:
|
|||
return m_branch.m_simp_rule_sets;
|
||||
}
|
||||
|
||||
/************************
|
||||
Branch extensions
|
||||
*************************/
|
||||
|
||||
branch_extension & get_extension(unsigned extid);
|
||||
|
||||
/************************
|
||||
Debugging support
|
||||
*************************/
|
||||
|
|
Loading…
Reference in a new issue