feat(library/blast): use discrimination trees instead of head_map for indexing hypotheses

This commit is contained in:
Leonardo de Moura 2015-12-26 15:13:27 -05:00
parent 1f1fafd535
commit 93b912ec89
5 changed files with 169 additions and 99 deletions

View file

@ -15,14 +15,18 @@ namespace blast {
action_result assumption_action() {
state const & s = curr_state();
expr const & target = s.get_target();
for (hypothesis_idx hidx : s.get_head_related()) {
hypothesis const & h = s.get_hypothesis_decl(hidx);
if (is_def_eq(h.get_type(), target)) {
lean_trace_action(tout() << "assumption " << h << "\n";);
return action_result(h.get_self());
}
}
return action_result::failed();
action_result r = action_result::failed();
s.find_hypotheses(target, [&](hypothesis_idx hidx) {
hypothesis const & h = s.get_hypothesis_decl(hidx);
if (is_def_eq(h.get_type(), target)) {
lean_trace_action(tout() << "assumption " << h << "\n";);
r = action_result(h.get_self());
return false; // stop search
} else {
return true; // continue
}
});
return r;
}
/* Close branch IF h is of the form (H : not a ~ a) where ~ is a reflexive relation */
@ -76,27 +80,31 @@ action_result assumption_contradiction_actions(hypothesis_idx hidx) {
expr p1 = type;
unsigned num_not1 = consume_not(p1);
/* try to find complement */
for (hypothesis_idx hidx2 : s.get_head_related(hidx)) {
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
expr p2 = h2.get_type();
unsigned num_not2 = consume_not(p2);
if ((num_not1 % 2) != (num_not2 % 2)) {
if (is_def_eq(p1, p2)) {
lean_trace_action(tout() << "contradiction " << h << " " << h2 << "\n";);
expr pr1 = h.get_self();
expr pr2 = h2.get_self();
reduce_nots(pr1, num_not1);
reduce_nots(pr2, num_not2);
if (num_not1 > num_not2) {
return action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr2, pr1}));
} else {
lean_assert(num_not1 < num_not2);
return action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr1, pr2}));
action_result r = action_result::failed();
s.find_hypotheses(type, [&](hypothesis_idx hidx2) {
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
expr p2 = h2.get_type();
unsigned num_not2 = consume_not(p2);
if ((num_not1 % 2) != (num_not2 % 2)) {
if (is_def_eq(p1, p2)) {
lean_trace_action(tout() << "contradiction " << h << " " << h2 << "\n";);
expr pr1 = h.get_self();
expr pr2 = h2.get_self();
reduce_nots(pr1, num_not1);
reduce_nots(pr2, num_not2);
if (num_not1 > num_not2) {
r = action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr2, pr1}));
return false; // stop search
} else {
lean_assert(num_not1 < num_not2);
r = action_result(b.mk_app(get_absurd_name(), {s.get_target(), pr1, pr2}));
return false; // stop search
}
}
}
}
}
return action_result::failed();
return true; // continue search
});
return r;
}
action_result trivial_action() {
@ -144,15 +152,20 @@ bool discard(hypothesis_idx hidx) {
if (is_relation_app(type, rop, lhs, rhs) && is_def_eq(lhs, rhs) && is_reflexive(rop))
return true;
// 3- We already have an equivalent hypothesis
for (hypothesis_idx hidx2 : s.get_head_related(hidx)) {
if (hidx == hidx2)
continue;
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
expr type2 = h2.get_type();
if (is_def_eq(type, type2))
return true;
}
return false;
bool r = false;
s.find_hypotheses(type, [&](hypothesis_idx hidx2) {
if (hidx == hidx2)
return true; // continue
hypothesis const & h2 = s.get_hypothesis_decl(hidx2);
expr type2 = h2.get_type();
if (is_def_eq(type, type2)) {
r = true;
return false; // stop search
} else {
return true; // continue
}
});
return r;
}
action_result discard_action(hypothesis_idx hidx) {

View file

@ -245,6 +245,97 @@ void discr_tree::insert_erase(expr const & k, expr const & v, bool ins) {
lean_trace("discr_tree", tout() << "\n"; trace(););
}
bool discr_tree::find_atom(node const & n, edge const & e, list<expr> todo, std::function<bool(expr const &)> const & fn) {
if (auto child = n.m_ptr->m_children.find(e)) {
return find(*child, todo, fn);
} else {
return true; // continue
}
}
bool discr_tree::find_star(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn) {
bool cont = true;
n.m_ptr->m_skip.for_each([&](node const & skip_child) {
if (cont && !find(skip_child, todo, fn))
cont = false;
});
if (!cont)
return false;
// we also have to traverse children whose edge is an atom.
n.m_ptr->m_children.for_each([&](edge const & e, node const & child) {
if (cont && !e.m_fn && !find(child, todo, fn))
cont = false;
});
return cont;
}
bool discr_tree::find_app(node const & n, expr const & e, list<expr> todo, std::function<bool(expr const &)> const & fn) {
lean_assert(is_app(e));
buffer<expr> args;
expr const & f = get_app_args(e, args);
if (is_constant(f) || is_local(f)) {
fun_info info = get_fun_info(f);
buffer<param_info> pinfos;
to_buffer(info.get_params_info(), pinfos);
lean_assert(pinfos.size() == args.size());
unsigned i = args.size();
list<expr> new_todo = todo;
while (i > 0) {
--i;
if (pinfos[i].is_prop() || pinfos[i].is_inst_implicit() || pinfos[i].is_implicit())
continue; // We ignore propositions, implicit and inst-implict arguments
new_todo = cons(args[i], new_todo);
}
new_todo = cons(f, new_todo);
return find(n, new_todo, fn);
} else if (is_meta(f)) {
return find_star(n, todo, fn);
} else {
return find_atom(n, edge(edge_kind::Unsupported), todo, fn);
}
}
bool discr_tree::find(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn) {
if (!todo) {
bool cont = true;
n.m_ptr->m_values.for_each([&](expr const & v) {
if (cont && !fn(v))
cont = false;
});
return cont;
}
if (n.m_ptr->m_star_child && !find(n.m_ptr->m_star_child, tail(todo), fn))
return false; // stop search
expr const & e = head(todo);
switch (e.kind()) {
case expr_kind::Constant: case expr_kind::Local:
return find_atom(n, edge(e), tail(todo), fn);
case expr_kind::Meta:
return find_star(n, tail(todo), fn);
case expr_kind::App:
return find_app(n, e, tail(todo), fn);
case expr_kind::Var:
lean_unreachable();
case expr_kind::Sort: case expr_kind::Lambda:
case expr_kind::Pi: case expr_kind::Macro:
// unsupported
return find_atom(n, edge(edge_kind::Unsupported), tail(todo), fn);
}
lean_unreachable();
}
void discr_tree::find(expr const & e, std::function<bool(expr const &)> const & fn) const {
if (m_root)
find(m_root, to_list(e), fn);
}
void discr_tree::collect(expr const & e, buffer<expr> & r) const {
find(e, [&](expr const & v) { r.push_back(v); return true; });
}
static void indent(unsigned depth) {
for (unsigned i = 0; i < depth; i++) tout() << " ";
}

View file

@ -58,6 +58,11 @@ private:
static node insert_erase(node && n, bool is_root, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip, bool ins);
void insert_erase(expr const & k, expr const & v, bool ins);
static bool find_atom(node const & n, edge const & e, list<expr> todo, std::function<bool(expr const &)> const & fn);
static bool find_star(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn);
static bool find_app(node const & n, expr const & e, list<expr> todo, std::function<bool(expr const &)> const & fn);
static bool find(node const & n, list<expr> todo, std::function<bool(expr const &)> const & fn);
node m_root;
public:
void insert(expr const & k, expr const & v) { insert_erase(k, v, true); }

View file

@ -97,10 +97,10 @@ branch::branch(branch const & b):
m_assumption(b.m_assumption),
m_active(b.m_active),
m_todo_queue(b.m_todo_queue),
m_head_to_hyps(b.m_head_to_hyps),
m_forward_deps(b.m_forward_deps),
m_target(b.m_target),
m_target_deps(b.m_target_deps) {
m_target_deps(b.m_target_deps),
m_hyp_index(b.m_hyp_index) {
unsigned n = get_extension_manager().get_num_extensions();
m_extensions = new branch_extension*[n];
for (unsigned i = 0; i < n; i++) {
@ -115,10 +115,10 @@ branch::branch(branch && b):
m_assumption(std::move(b.m_assumption)),
m_active(std::move(b.m_active)),
m_todo_queue(std::move(b.m_todo_queue)),
m_head_to_hyps(std::move(b.m_head_to_hyps)),
m_forward_deps(std::move(b.m_forward_deps)),
m_target(std::move(b.m_target)),
m_target_deps(std::move(b.m_target_deps)) {
m_target_deps(std::move(b.m_target_deps)),
m_hyp_index(std::move(b.m_hyp_index)) {
unsigned n = get_extension_manager().get_num_extensions();
m_extensions = new branch_extension*[n];
for (unsigned i = 0; i < n; i++) {
@ -132,7 +132,7 @@ void branch::swap(branch & b) {
std::swap(m_assumption, b.m_assumption);
std::swap(m_active, b.m_active);
std::swap(m_todo_queue, b.m_todo_queue);
std::swap(m_head_to_hyps, b.m_head_to_hyps);
std::swap(m_hyp_index, b.m_hyp_index);
std::swap(m_forward_deps, b.m_forward_deps);
std::swap(m_target, b.m_target);
std::swap(m_target_deps, b.m_target_deps);
@ -720,49 +720,18 @@ hypothesis_idx_set state::get_direct_forward_deps(hypothesis_idx hidx) const {
return hypothesis_idx_set();
}
static optional<head_index> to_head_index(expr type) {
while (is_not(type, type)) {}
expr const & f = get_app_fn(type);
if (is_constant(f) || is_local(f))
return optional<head_index>(head_index(f));
else
return optional<head_index>();
}
static optional<head_index> to_head_index(hypothesis const & h) {
return to_head_index(h.get_type());
}
list<hypothesis_idx> state::get_occurrences_of(head_index const & h) const {
if (auto r = m_branch.m_head_to_hyps.find(h))
return *r;
else
return list<hypothesis_idx>();
}
list<hypothesis_idx> state::get_head_related(hypothesis_idx hidx) const {
hypothesis const & h = get_hypothesis_decl(hidx);
/* update m_head_to_hyps */
if (auto i = to_head_index(h))
return get_occurrences_of(*i);
else
return list<hypothesis_idx>();
}
list<hypothesis_idx> state::get_head_related() const {
if (auto i = to_head_index(m_branch.m_target))
return get_occurrences_of(*i);
else
return list<hypothesis_idx>();
}
optional<hypothesis_idx> state::contains_hypothesis(expr const & type) const {
for (auto hidx : get_occurrences_of(head_index(type))) {
hypothesis const & h = get_hypothesis_decl(hidx);
if (h.get_type() == type)
return optional<hypothesis_idx>(hidx);
}
return optional<hypothesis_idx>();
optional<hypothesis_idx> r;
find_hypotheses(type, [&](hypothesis_idx hidx) {
hypothesis const & h = get_hypothesis_decl(hidx);
if (h.get_type() == type) {
r = hidx;
return false; // stop search
} else {
return true; // continue search
}
});
return r;
}
branch_extension * state::get_extension_core(unsigned i) {
@ -800,7 +769,7 @@ branch_extension & state::get_extension(unsigned extid) {
}
void state::deactivate_all() {
m_branch.m_head_to_hyps = head_map<hypothesis_idx>();
m_branch.m_hyp_index = discr_tree();
unsigned n = get_extension_manager().get_num_extensions();
for (unsigned i = 0; i < n; i++) {
if (m_branch.m_extensions[i]) {
@ -821,9 +790,6 @@ static expr get_key_for(expr type) {
void state::update_indices(hypothesis_idx hidx) {
hypothesis const & h = get_hypothesis_decl(hidx);
/* update m_head_to_hyps */
if (auto i = to_head_index(h))
m_branch.m_head_to_hyps.insert(*i, hidx);
unsigned n = get_extension_manager().get_num_extensions();
for (unsigned i = 0; i < n; i++) {
branch_extension * ext = get_extension_core(i);
@ -838,11 +804,13 @@ void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
branch_extension * ext = get_extension_core(i);
if (ext) ext->hypothesis_deleted(h, hidx);
}
if (auto i = to_head_index(h))
m_branch.m_head_to_hyps.erase(*i, hidx);
m_branch.m_hyp_index.erase(get_key_for(h.get_type()), h.get_self());
}
void state::find_hypotheses(expr const & e, std::function<bool(hypothesis_idx)> const & fn) const {
m_branch.m_hyp_index.find(get_key_for(e), [&](expr const & h) { return fn(href_index(h)); });
}
optional<unsigned> state::select_hypothesis_to_activate() {
while (true) {
if (m_branch.m_todo_queue.empty())

View file

@ -152,7 +152,6 @@ class branch {
hypothesis_idx_set m_assumption;
hypothesis_idx_set m_active;
todo_queue m_todo_queue;
head_map<hypothesis_idx> m_head_to_hyps;
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
expr m_target;
hypothesis_idx_set m_target_deps;
@ -318,18 +317,12 @@ public:
hypothesis_idx_set get_assumptions() const { return m_branch.m_assumption; }
/** \brief Return (active) hypotheses whose head symbol is h or (not h) */
list<hypothesis_idx> get_occurrences_of(head_index const & h) const;
/** \brief Return (active) hypotheses whose head symbol is equal to the of hidx or it is the negation of */
list<hypothesis_idx> get_head_related(hypothesis_idx hidx) const;
/** \brief Return (active) hypotheses whose head symbol is equal to target or it is the negation of */
list<hypothesis_idx> get_head_related() const;
/** \brief If there is an hypothesis with the given type (return it), otherwise return none */
optional<hypothesis_idx> contains_hypothesis(expr const & type) const;
/** \brief Find hypotheses whose type may unify with \c e or its negation */
void find_hypotheses(expr const & e, std::function<bool(hypothesis_idx)> const & fn) const;
/************************
Abstracting hypotheses
*************************/