feat(library/blast): branch extensions
This commit is contained in:
parent
bf5fe8d180
commit
6823a34935
2 changed files with 188 additions and 2 deletions
|
@ -29,6 +29,115 @@ bool metavar_decl::restrict_context_using(metavar_decl const & other) {
|
||||||
return !to_erase.empty();
|
return !to_erase.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class extension_manager {
|
||||||
|
std::vector<branch_extension *> m_exts;
|
||||||
|
public:
|
||||||
|
~extension_manager() {
|
||||||
|
for (auto ext : m_exts)
|
||||||
|
ext->dec_ref();
|
||||||
|
m_exts.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned register_extension(branch_extension * ext) {
|
||||||
|
ext->inc_ref();
|
||||||
|
unsigned r = m_exts.size();
|
||||||
|
m_exts.push_back(ext);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_ext(unsigned extid) const {
|
||||||
|
return extid < m_exts.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
branch_extension * get_initial(unsigned extid) {
|
||||||
|
return m_exts[extid];
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned get_num_extensions() const {
|
||||||
|
return m_exts.size();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static extension_manager * g_extension_manager = nullptr;
|
||||||
|
|
||||||
|
static extension_manager & get_extension_manager() {
|
||||||
|
return *g_extension_manager;
|
||||||
|
}
|
||||||
|
|
||||||
|
branch::branch() {
|
||||||
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
|
m_extensions = new branch_extension*[n];
|
||||||
|
for (unsigned i = 0; i < n; i++)
|
||||||
|
m_extensions[i] = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
branch::~branch() {
|
||||||
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
|
for (unsigned i = 0; i < n; i++) {
|
||||||
|
if (m_extensions[i])
|
||||||
|
m_extensions[i]->dec_ref();
|
||||||
|
}
|
||||||
|
delete m_extensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
branch::branch(branch const & b):
|
||||||
|
m_hyp_decls(b.m_hyp_decls),
|
||||||
|
m_assumption(b.m_assumption),
|
||||||
|
m_active(b.m_active),
|
||||||
|
m_todo_queue(b.m_todo_queue),
|
||||||
|
m_rec_queue(b.m_rec_queue),
|
||||||
|
m_head_to_hyps(b.m_head_to_hyps),
|
||||||
|
m_forward_deps(b.m_forward_deps),
|
||||||
|
m_target(b.m_target),
|
||||||
|
m_target_deps(b.m_target_deps),
|
||||||
|
m_simp_rule_sets(b.m_simp_rule_sets) {
|
||||||
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
|
m_extensions = new branch_extension*[n];
|
||||||
|
for (unsigned i = 0; i < n; i++) {
|
||||||
|
m_extensions[i] = b.m_extensions[i];
|
||||||
|
if (m_extensions[i])
|
||||||
|
m_extensions[i]->inc_ref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
branch::branch(branch && b):
|
||||||
|
m_hyp_decls(std::move(b.m_hyp_decls)),
|
||||||
|
m_assumption(std::move(b.m_assumption)),
|
||||||
|
m_active(std::move(b.m_active)),
|
||||||
|
m_todo_queue(std::move(b.m_todo_queue)),
|
||||||
|
m_rec_queue(std::move(b.m_rec_queue)),
|
||||||
|
m_head_to_hyps(std::move(b.m_head_to_hyps)),
|
||||||
|
m_forward_deps(std::move(b.m_forward_deps)),
|
||||||
|
m_target(std::move(b.m_target)),
|
||||||
|
m_target_deps(std::move(b.m_target_deps)),
|
||||||
|
m_simp_rule_sets(std::move(b.m_simp_rule_sets)) {
|
||||||
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
|
m_extensions = new branch_extension*[n];
|
||||||
|
for (unsigned i = 0; i < n; i++) {
|
||||||
|
m_extensions[i] = b.m_extensions[i];
|
||||||
|
b.m_extensions[i] = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void branch::swap(branch & b) {
|
||||||
|
std::swap(m_hyp_decls, b.m_hyp_decls);
|
||||||
|
std::swap(m_assumption, b.m_assumption);
|
||||||
|
std::swap(m_active, b.m_active);
|
||||||
|
std::swap(m_todo_queue, b.m_todo_queue);
|
||||||
|
std::swap(m_rec_queue, b.m_rec_queue);
|
||||||
|
std::swap(m_head_to_hyps, b.m_head_to_hyps);
|
||||||
|
std::swap(m_forward_deps, b.m_forward_deps);
|
||||||
|
std::swap(m_target, b.m_target);
|
||||||
|
std::swap(m_target_deps, b.m_target_deps);
|
||||||
|
std::swap(m_simp_rule_sets, b.m_simp_rule_sets);
|
||||||
|
std::swap(m_extensions, b.m_extensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
branch & branch::operator=(branch s) {
|
||||||
|
swap(s);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
state::state() {}
|
state::state() {}
|
||||||
|
|
||||||
expr state::mk_metavar(hypothesis_idx_set const & c, expr const & type) {
|
expr state::mk_metavar(hypothesis_idx_set const & c, expr const & type) {
|
||||||
|
@ -581,16 +690,59 @@ list<hypothesis_idx> state::get_head_related() const {
|
||||||
return list<hypothesis_idx>();
|
return list<hypothesis_idx>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
branch_extension * state::get_extension_core(unsigned i) {
|
||||||
|
branch_extension * ext = m_branch.m_extensions[i];
|
||||||
|
if (ext && ext->get_rc() > 1) {
|
||||||
|
branch_extension * new_ext = ext->clone();
|
||||||
|
new_ext->inc_ref();
|
||||||
|
ext->dec_ref();
|
||||||
|
m_branch.m_extensions[i] = new_ext;
|
||||||
|
return new_ext;
|
||||||
|
}
|
||||||
|
return ext;
|
||||||
|
}
|
||||||
|
|
||||||
|
branch_extension & state::get_extension(unsigned extid) {
|
||||||
|
lean_assert(extid < get_extension_manager().get_num_extensions());
|
||||||
|
if (!m_branch.m_extensions[extid]) {
|
||||||
|
/* lazy initialization */
|
||||||
|
branch_extension * ext = get_extension_manager().get_initial(extid)->clone();;
|
||||||
|
ext->inc_ref();
|
||||||
|
m_branch.m_extensions[extid] = ext;
|
||||||
|
lean_assert(ext.get_rc() == 1);
|
||||||
|
m_branch.m_active.for_each([&](hypothesis_idx hidx) {
|
||||||
|
hypothesis const * h = get_hypothesis_decl(hidx);
|
||||||
|
lean_assert(h);
|
||||||
|
ext->hypothesis_activated(*h, hidx);
|
||||||
|
});
|
||||||
|
return *ext;
|
||||||
|
} else {
|
||||||
|
branch_extension * ext = get_extension_core(extid);
|
||||||
|
lean_assert(ext);
|
||||||
|
return *ext;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void state::update_indices(hypothesis_idx hidx) {
|
void state::update_indices(hypothesis_idx hidx) {
|
||||||
hypothesis const * h = get_hypothesis_decl(hidx);
|
hypothesis const * h = get_hypothesis_decl(hidx);
|
||||||
lean_assert(h);
|
lean_assert(h);
|
||||||
/* update m_head_to_hyps */
|
/* update m_head_to_hyps */
|
||||||
if (auto i = to_head_index(*h))
|
if (auto i = to_head_index(*h))
|
||||||
m_branch.m_head_to_hyps.insert(*i, hidx);
|
m_branch.m_head_to_hyps.insert(*i, hidx);
|
||||||
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
|
for (unsigned i = 0; i < n; i++) {
|
||||||
|
branch_extension * ext = get_extension_core(i);
|
||||||
|
if (ext) ext->hypothesis_activated(*h, hidx);
|
||||||
|
}
|
||||||
/* TODO(Leo): update congruence closure indices */
|
/* TODO(Leo): update congruence closure indices */
|
||||||
}
|
}
|
||||||
|
|
||||||
void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
|
void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
|
||||||
|
unsigned n = get_extension_manager().get_num_extensions();
|
||||||
|
for (unsigned i = 0; i < n; i++) {
|
||||||
|
branch_extension * ext = get_extension_core(i);
|
||||||
|
if (ext) ext->hypothesis_deleted(h, hidx);
|
||||||
|
}
|
||||||
if (auto i = to_head_index(h))
|
if (auto i = to_head_index(h))
|
||||||
m_branch.m_head_to_hyps.erase(*i, hidx);
|
m_branch.m_head_to_hyps.erase(*i, hidx);
|
||||||
}
|
}
|
||||||
|
@ -710,12 +862,14 @@ expr state::mk_pi(list<expr> const & hrefs, expr const & b) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
void initialize_state() {
|
void initialize_state() {
|
||||||
g_prefix = new name(name::mk_internal_unique_name());
|
g_extension_manager = new extension_manager();
|
||||||
g_H = new name("H");
|
g_prefix = new name(name::mk_internal_unique_name());
|
||||||
|
g_H = new name("H");
|
||||||
}
|
}
|
||||||
|
|
||||||
void finalize_state() {
|
void finalize_state() {
|
||||||
delete g_prefix;
|
delete g_prefix;
|
||||||
delete g_H;
|
delete g_H;
|
||||||
|
delete g_extension_manager;
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
|
|
@ -96,10 +96,25 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** \brief Actions that require additional indexing data-structures may store them
|
||||||
|
at a branch_extension */
|
||||||
|
class branch_extension {
|
||||||
|
MK_LEAN_RC();
|
||||||
|
void dealloc() { delete this; }
|
||||||
|
public:
|
||||||
|
virtual ~branch_extension() {}
|
||||||
|
virtual branch_extension * clone() = 0;
|
||||||
|
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) = 0;
|
||||||
|
virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx hidx) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
unsigned register_branch_extension(branch_extension * initial);
|
||||||
|
|
||||||
/** \brief Information associated with the current branch of the proof state.
|
/** \brief Information associated with the current branch of the proof state.
|
||||||
This is essentially a mechanism for creating snapshots of the current branch. */
|
This is essentially a mechanism for creating snapshots of the current branch. */
|
||||||
class branch {
|
class branch {
|
||||||
friend class state;
|
friend class state;
|
||||||
|
|
||||||
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
|
typedef hypothesis_idx_map<hypothesis_idx_set> forward_deps;
|
||||||
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
|
/* trick to make sure the rb_map::erase_min removes the hypothesis with biggest weight */
|
||||||
struct inv_double_cmp {
|
struct inv_double_cmp {
|
||||||
|
@ -108,6 +123,7 @@ class branch {
|
||||||
typedef rb_map<double, hypothesis_idx, inv_double_cmp> priority_queue;
|
typedef rb_map<double, hypothesis_idx, inv_double_cmp> priority_queue;
|
||||||
typedef priority_queue todo_queue;
|
typedef priority_queue todo_queue;
|
||||||
typedef priority_queue rec_queue;
|
typedef priority_queue rec_queue;
|
||||||
|
|
||||||
// Hypothesis/facts in the current state
|
// Hypothesis/facts in the current state
|
||||||
hypothesis_decls m_hyp_decls;
|
hypothesis_decls m_hyp_decls;
|
||||||
// We break the set of hypotheses in m_hyp_decls in 4 sets that are not necessarily disjoint:
|
// We break the set of hypotheses in m_hyp_decls in 4 sets that are not necessarily disjoint:
|
||||||
|
@ -134,6 +150,14 @@ class branch {
|
||||||
expr m_target;
|
expr m_target;
|
||||||
hypothesis_idx_set m_target_deps;
|
hypothesis_idx_set m_target_deps;
|
||||||
simp_rule_sets m_simp_rule_sets;
|
simp_rule_sets m_simp_rule_sets;
|
||||||
|
branch_extension ** m_extensions;
|
||||||
|
public:
|
||||||
|
branch();
|
||||||
|
branch(branch const & b);
|
||||||
|
branch(branch && b);
|
||||||
|
~branch();
|
||||||
|
void swap(branch & b);
|
||||||
|
branch & operator=(branch s);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** \brief Proof state for the blast tactic */
|
/** \brief Proof state for the blast tactic */
|
||||||
|
@ -177,6 +201,8 @@ class state {
|
||||||
|
|
||||||
void display_active(output_channel & out) const;
|
void display_active(output_channel & out) const;
|
||||||
|
|
||||||
|
branch_extension * get_extension_core(unsigned i);
|
||||||
|
|
||||||
#ifdef LEAN_DEBUG
|
#ifdef LEAN_DEBUG
|
||||||
bool check_hypothesis(expr const & e, hypothesis_idx hidx, hypothesis const & h) const;
|
bool check_hypothesis(expr const & e, hypothesis_idx hidx, hypothesis const & h) const;
|
||||||
bool check_hypothesis(hypothesis_idx hidx, hypothesis const & h) const;
|
bool check_hypothesis(hypothesis_idx hidx, hypothesis const & h) const;
|
||||||
|
@ -425,6 +451,12 @@ public:
|
||||||
return m_branch.m_simp_rule_sets;
|
return m_branch.m_simp_rule_sets;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/************************
|
||||||
|
Branch extensions
|
||||||
|
*************************/
|
||||||
|
|
||||||
|
branch_extension & get_extension(unsigned extid);
|
||||||
|
|
||||||
/************************
|
/************************
|
||||||
Debugging support
|
Debugging support
|
||||||
*************************/
|
*************************/
|
||||||
|
|
Loading…
Reference in a new issue