feat(library/blast): use discrimination trees instead of head_map for indexing hypotheses
This commit is contained in:
parent
1f1fafd535
commit
93b912ec89
5 changed files with 169 additions and 99 deletions
|
@ -15,14 +15,18 @@ namespace blast {
|
|||
action_result assumption_action() {
|
||||
state const & s = curr_state();
|
||||
expr const & target = s.get_target();
|
||||
for (hypothesis_idx hidx : s.get_head_related()) {
|
||||
hypothesis const & h = s.get_hypothesis_decl(hidx);
|
||||
if (is_def_eq(h.get_type(), target)) {
|
||||
lean_trace_action(tout() << "assumption " << h << "\n";);
|
||||
return action_result(h.get_self());
|
||||
}
|
||||
}
|
||||
return action_result::failed();
|
||||
action_result r = action_result::failed();
|
||||
s.find_hypotheses(target, [&](hypothesis_idx hidx) {
|
||||
hypothesis const & h = s.get_hypothesis_decl(hidx);
|
||||
if (is_def_eq(h.get_type(), target)) {
|
||||
lean_trace_action(tout() << "assumption " << h << "\n";);
|
||||
r = action_result(h.get_self());
|
||||
return false; // stop search
|
||||
} else {
|
||||
return true; // continue
|
||||
}
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
/* Close branch IF h is of the form (H : not a ~ a) where ~ is a reflexive relation */
|
||||
|
@ -76,27 +80,31 @@ action_result assumption_contradiction_actions(hypothesis_idx hidx) {
|
|||
expr p1 = type;
|
||||
unsigned num_not1 = consume_not(p1);
|
||||
/* try to find complement */
|
||||
for (hypothesis_idx hidx2 : s.get_head_related(hidx)) {
|
||||
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
|
||||
expr p2 = h2.get_type();
|
||||
unsigned num_not2 = consume_not(p2);
|
||||
if ((num_not1 % 2) != (num_not2 % 2)) {
|
||||
if (is_def_eq(p1, p2)) {
|
||||
lean_trace_action(tout() << "contradiction " << h << " " << h2 << "\n";);
|
||||
expr pr1 = h.get_self();
|
||||
expr pr2 = h2.get_self();
|
||||
reduce_nots(pr1, num_not1);
|
||||
reduce_nots(pr2, num_not2);
|
||||
if (num_not1 > num_not2) {
|
||||
return action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr2, pr1}));
|
||||
} else {
|
||||
lean_assert(num_not1 < num_not2);
|
||||
return action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr1, pr2}));
|
||||
action_result r = action_result::failed();
|
||||
s.find_hypotheses(type, [&](hypothesis_idx hidx2) {
|
||||
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
|
||||
expr p2 = h2.get_type();
|
||||
unsigned num_not2 = consume_not(p2);
|
||||
if ((num_not1 % 2) != (num_not2 % 2)) {
|
||||
if (is_def_eq(p1, p2)) {
|
||||
lean_trace_action(tout() << "contradiction " << h << " " << h2 << "\n";);
|
||||
expr pr1 = h.get_self();
|
||||
expr pr2 = h2.get_self();
|
||||
reduce_nots(pr1, num_not1);
|
||||
reduce_nots(pr2, num_not2);
|
||||
if (num_not1 > num_not2) {
|
||||
r = action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr2, pr1}));
|
||||
return false; // stop search
|
||||
} else {
|
||||
lean_assert(num_not1 < num_not2);
|
||||
r = action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr1, pr2}));
|
||||
return false; // stop search
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return action_result::failed();
|
||||
return true; // continue search
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
action_result trivial_action() {
|
||||
|
@ -144,15 +152,20 @@ bool discard(hypothesis_idx hidx) {
|
|||
if (is_relation_app(type, rop, lhs, rhs) && is_def_eq(lhs, rhs) && is_reflexive(rop))
|
||||
return true;
|
||||
// 3- We already have an equivalent hypothesis
|
||||
for (hypothesis_idx hidx2 : s.get_head_related(hidx)) {
|
||||
if (hidx == hidx2)
|
||||
continue;
|
||||
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
|
||||
expr type2 = h2.get_type();
|
||||
if (is_def_eq(type, type2))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
bool r = false;
|
||||
s.find_hypotheses(type, [&](hypothesis_idx hidx2) {
|
||||
if (hidx == hidx2)
|
||||
return true; // continue
|
||||
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
|
||||
expr type2 = h2.get_type();
|
||||
if (is_def_eq(type, type2)) {
|
||||
r = true;
|
||||
return false; // stop search
|
||||
} else {
|
||||
return true; // continue
|
||||
}
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
action_result discard_action(hypothesis_idx hidx) {
|
||||
|
|
|
@ -245,6 +245,97 @@ void discr_tree::insert_erase(expr const & k, expr const & v, bool ins) {
|
|||
lean_trace("discr_tree", tout() << "\n"; trace(););
|
||||
}
|
||||
|
||||
bool discr_tree::find_atom(node const & n, edge const & e, list<expr> todo, std::function<bool(expr const &)> const & fn) {
|
||||
if (auto child = n.m_ptr->m_children.find(e)) {
|
||||
return find(*child, todo, fn);
|
||||
} else {
|
||||
return true; // continue
|
||||
}
|
||||
}
|
||||
|
||||
bool discr_tree::find_star(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn) {
|
||||
bool cont = true;
|
||||
n.m_ptr->m_skip.for_each([&](node const & skip_child) {
|
||||
if (cont && !find(skip_child, todo, fn))
|
||||
cont = false;
|
||||
});
|
||||
if (!cont)
|
||||
return false;
|
||||
// we also have to traverse children whose edge is an atom.
|
||||
n.m_ptr->m_children.for_each([&](edge const & e, node const & child) {
|
||||
if (cont && !e.m_fn && !find(child, todo, fn))
|
||||
cont = false;
|
||||
});
|
||||
return cont;
|
||||
}
|
||||
|
||||
bool discr_tree::find_app(node const & n, expr const & e, list<expr> todo, std::function<bool(expr const &)> const & fn) {
|
||||
lean_assert(is_app(e));
|
||||
buffer<expr> args;
|
||||
expr const & f = get_app_args(e, args);
|
||||
if (is_constant(f) || is_local(f)) {
|
||||
fun_info info = get_fun_info(f);
|
||||
buffer<param_info> pinfos;
|
||||
to_buffer(info.get_params_info(), pinfos);
|
||||
lean_assert(pinfos.size() == args.size());
|
||||
unsigned i = args.size();
|
||||
list<expr> new_todo = todo;
|
||||
while (i > 0) {
|
||||
--i;
|
||||
if (pinfos[i].is_prop() || pinfos[i].is_inst_implicit() || pinfos[i].is_implicit())
|
||||
continue; // We ignore propositions, implicit and inst-implict arguments
|
||||
new_todo = cons(args[i], new_todo);
|
||||
}
|
||||
new_todo = cons(f, new_todo);
|
||||
return find(n, new_todo, fn);
|
||||
} else if (is_meta(f)) {
|
||||
return find_star(n, todo, fn);
|
||||
} else {
|
||||
return find_atom(n, edge(edge_kind::Unsupported), todo, fn);
|
||||
}
|
||||
}
|
||||
|
||||
bool discr_tree::find(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn) {
|
||||
if (!todo) {
|
||||
bool cont = true;
|
||||
n.m_ptr->m_values.for_each([&](expr const & v) {
|
||||
if (cont && !fn(v))
|
||||
cont = false;
|
||||
});
|
||||
return cont;
|
||||
}
|
||||
|
||||
if (n.m_ptr->m_star_child && !find(n.m_ptr->m_star_child, tail(todo), fn))
|
||||
return false; // stop search
|
||||
|
||||
expr const & e = head(todo);
|
||||
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Constant: case expr_kind::Local:
|
||||
return find_atom(n, edge(e), tail(todo), fn);
|
||||
case expr_kind::Meta:
|
||||
return find_star(n, tail(todo), fn);
|
||||
case expr_kind::App:
|
||||
return find_app(n, e, tail(todo), fn);
|
||||
case expr_kind::Var:
|
||||
lean_unreachable();
|
||||
case expr_kind::Sort: case expr_kind::Lambda:
|
||||
case expr_kind::Pi: case expr_kind::Macro:
|
||||
// unsupported
|
||||
return find_atom(n, edge(edge_kind::Unsupported), tail(todo), fn);
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
void discr_tree::find(expr const & e, std::function<bool(expr const &)> const & fn) const {
|
||||
if (m_root)
|
||||
find(m_root, to_list(e), fn);
|
||||
}
|
||||
|
||||
void discr_tree::collect(expr const & e, buffer<expr> & r) const {
|
||||
find(e, [&](expr const & v) { r.push_back(v); return true; });
|
||||
}
|
||||
|
||||
static void indent(unsigned depth) {
|
||||
for (unsigned i = 0; i < depth; i++) tout() << " ";
|
||||
}
|
||||
|
|
|
@ -58,6 +58,11 @@ private:
|
|||
static node insert_erase(node && n, bool is_root, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip, bool ins);
|
||||
void insert_erase(expr const & k, expr const & v, bool ins);
|
||||
|
||||
static bool find_atom(node const & n, edge const & e, list<expr> todo, std::function<bool(expr const &)> const & fn);
|
||||
static bool find_star(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn);
|
||||
static bool find_app(node const & n, expr const & e, list<expr> todo, std::function<bool(expr const &)> const & fn);
|
||||
static bool find(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn);
|
||||
|
||||
node m_root;
|
||||
public:
|
||||
void insert(expr const & k, expr const & v) { insert_erase(k, v, true); }
|
||||
|
|
|
@ -97,10 +97,10 @@ branch::branch(branch const & b):
|
|||
m_assumption(b.m_assumption),
|
||||
m_active(b.m_active),
|
||||
m_todo_queue(b.m_todo_queue),
|
||||
m_head_to_hyps(b.m_head_to_hyps),
|
||||
m_forward_deps(b.m_forward_deps),
|
||||
m_target(b.m_target),
|
||||
m_target_deps(b.m_target_deps) {
|
||||
m_target_deps(b.m_target_deps),
|
||||
m_hyp_index(b.m_hyp_index) {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
m_extensions = new branch_extension*[n];
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
|
@ -115,10 +115,10 @@ branch::branch(branch && b):
|
|||
m_assumption(std::move(b.m_assumption)),
|
||||
m_active(std::move(b.m_active)),
|
||||
m_todo_queue(std::move(b.m_todo_queue)),
|
||||
m_head_to_hyps(std::move(b.m_head_to_hyps)),
|
||||
m_forward_deps(std::move(b.m_forward_deps)),
|
||||
m_target(std::move(b.m_target)),
|
||||
m_target_deps(std::move(b.m_target_deps)) {
|
||||
m_target_deps(std::move(b.m_target_deps)),
|
||||
m_hyp_index(std::move(b.m_hyp_index)) {
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
m_extensions = new branch_extension*[n];
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
|
@ -132,7 +132,7 @@ void branch::swap(branch & b) {
|
|||
std::swap(m_assumption, b.m_assumption);
|
||||
std::swap(m_active, b.m_active);
|
||||
std::swap(m_todo_queue, b.m_todo_queue);
|
||||
std::swap(m_head_to_hyps, b.m_head_to_hyps);
|
||||
std::swap(m_hyp_index, b.m_hyp_index);
|
||||
std::swap(m_forward_deps, b.m_forward_deps);
|
||||
std::swap(m_target, b.m_target);
|
||||
std::swap(m_target_deps, b.m_target_deps);
|
||||
|
@ -720,49 +720,18 @@ hypothesis_idx_set state::get_direct_forward_deps(hypothesis_idx hidx) const {
|
|||
return hypothesis_idx_set();
|
||||
}
|
||||
|
||||
static optional<head_index> to_head_index(expr type) {
|
||||
while (is_not(type, type)) {}
|
||||
expr const & f = get_app_fn(type);
|
||||
if (is_constant(f) || is_local(f))
|
||||
return optional<head_index>(head_index(f));
|
||||
else
|
||||
return optional<head_index>();
|
||||
}
|
||||
|
||||
static optional<head_index> to_head_index(hypothesis const & h) {
|
||||
return to_head_index(h.get_type());
|
||||
}
|
||||
|
||||
list<hypothesis_idx> state::get_occurrences_of(head_index const & h) const {
|
||||
if (auto r = m_branch.m_head_to_hyps.find(h))
|
||||
return *r;
|
||||
else
|
||||
return list<hypothesis_idx>();
|
||||
}
|
||||
|
||||
list<hypothesis_idx> state::get_head_related(hypothesis_idx hidx) const {
|
||||
hypothesis const & h = get_hypothesis_decl(hidx);
|
||||
/* update m_head_to_hyps */
|
||||
if (auto i = to_head_index(h))
|
||||
return get_occurrences_of(*i);
|
||||
else
|
||||
return list<hypothesis_idx>();
|
||||
}
|
||||
|
||||
list<hypothesis_idx> state::get_head_related() const {
|
||||
if (auto i = to_head_index(m_branch.m_target))
|
||||
return get_occurrences_of(*i);
|
||||
else
|
||||
return list<hypothesis_idx>();
|
||||
}
|
||||
|
||||
optional<hypothesis_idx> state::contains_hypothesis(expr const & type) const {
|
||||
for (auto hidx : get_occurrences_of(head_index(type))) {
|
||||
hypothesis const & h = get_hypothesis_decl(hidx);
|
||||
if (h.get_type() == type)
|
||||
return optional<hypothesis_idx>(hidx);
|
||||
}
|
||||
return optional<hypothesis_idx>();
|
||||
optional<hypothesis_idx> r;
|
||||
find_hypotheses(type, [&](hypothesis_idx hidx) {
|
||||
hypothesis const & h = get_hypothesis_decl(hidx);
|
||||
if (h.get_type() == type) {
|
||||
r = hidx;
|
||||
return false; // stop search
|
||||
} else {
|
||||
return true; // continue search
|
||||
}
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
branch_extension * state::get_extension_core(unsigned i) {
|
||||
|
@ -800,7 +769,7 @@ branch_extension & state::get_extension(unsigned extid) {
|
|||
}
|
||||
|
||||
void state::deactivate_all() {
|
||||
m_branch.m_head_to_hyps = head_map<hypothesis_idx>();
|
||||
m_branch.m_hyp_index = discr_tree();
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
if (m_branch.m_extensions[i]) {
|
||||
|
@ -821,9 +790,6 @@ static expr get_key_for(expr type) {
|
|||
|
||||
void state::update_indices(hypothesis_idx hidx) {
|
||||
hypothesis const & h = get_hypothesis_decl(hidx);
|
||||
/* update m_head_to_hyps */
|
||||
if (auto i = to_head_index(h))
|
||||
m_branch.m_head_to_hyps.insert(*i, hidx);
|
||||
unsigned n = get_extension_manager().get_num_extensions();
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
branch_extension * ext = get_extension_core(i);
|
||||
|
@ -838,11 +804,13 @@ void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
|
|||
branch_extension * ext = get_extension_core(i);
|
||||
if (ext) ext->hypothesis_deleted(h, hidx);
|
||||
}
|
||||
if (auto i = to_head_index(h))
|
||||
m_branch.m_head_to_hyps.erase(*i, hidx);
|
||||
m_branch.m_hyp_index.erase(get_key_for(h.get_type()), h.get_self());
|
||||
}
|
||||
|
||||
void state::find_hypotheses(expr const & e, std::function<bool(hypothesis_idx)> const & fn) const {
|
||||
m_branch.m_hyp_index.find(get_key_for(e), [&](expr const & h) { return fn(href_index(h)); });
|
||||
}
|
||||
|
||||
optional<unsigned> state::select_hypothesis_to_activate() {
|
||||
while (true) {
|
||||
if (m_branch.m_todo_queue.empty())
|
||||
|
|
|
@ -152,7 +152,6 @@ class branch {
|
|||
hypothesis_idx_set m_assumption;
|
||||
hypothesis_idx_set m_active;
|
||||
todo_queue m_todo_queue;
|
||||
head_map<hypothesis_idx> m_head_to_hyps;
|
||||
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
|
||||
expr m_target;
|
||||
hypothesis_idx_set m_target_deps;
|
||||
|
@ -318,18 +317,12 @@ public:
|
|||
|
||||
hypothesis_idx_set get_assumptions() const { return m_branch.m_assumption; }
|
||||
|
||||
/** \brief Return (active) hypotheses whose head symbol is h or (not h) */
|
||||
list<hypothesis_idx> get_occurrences_of(head_index const & h) const;
|
||||
|
||||
/** \brief Return (active) hypotheses whose head symbol is equal to the of hidx or it is the negation of */
|
||||
list<hypothesis_idx> get_head_related(hypothesis_idx hidx) const;
|
||||
|
||||
/** \brief Return (active) hypotheses whose head symbol is equal to target or it is the negation of */
|
||||
list<hypothesis_idx> get_head_related() const;
|
||||
|
||||
/** \brief If there is an hypothesis with the given type (return it), otherwise return none */
|
||||
optional<hypothesis_idx> contains_hypothesis(expr const & type) const;
|
||||
|
||||
/** \brief Find hypotheses whose type may unify with \c e or its negation */
|
||||
void find_hypotheses(expr const & e, std::function<bool(hypothesis_idx)> const & fn) const;
|
||||
|
||||
/************************
|
||||
Abstracting hypotheses
|
||||
*************************/
|
||||
|
|
Loading…
Reference in a new issue