feat(library/blast): add discrimination trees

This commit is contained in:
Leonardo de Moura 2015-12-21 06:20:15 -08:00
parent b09fdc8c61
commit 43e1292f22
6 changed files with 389 additions and 2 deletions

View file

@ -1,4 +1,4 @@
add_library(blast OBJECT hypothesis.cpp state.cpp blast.cpp blast_tactic.cpp
init_module.cpp proof_expr.cpp options.cpp choice_point.cpp util.cpp
gexpr.cpp revert.cpp strategy.cpp congruence_closure.cpp trace.cpp
imp_extension.cpp)
imp_extension.cpp discr_tree.cpp)

View file

@ -0,0 +1,307 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "util/rb_map.h"
#include "util/memory_pool.h"
#include "library/blast/trace.h"
#include "library/blast/blast.h"
#include "library/blast/discr_tree.h"
namespace lean {
namespace blast {
/* Auxiliary expr used to implement insert/erase operations.
When adding the children of an application into the todo stack,
we use g_delimiter to indicate where the arguments end.
For example, suppose the current stack is S, and we want
to add the children of (f a b). Then, the new stack will
be [S, *g_delimiter, b, a]
\remark g_delimiter is an unique expression. */
static expr * g_delimiter = nullptr;
struct discr_tree::node_cmp {
int operator()(node const & n1, node const & n2) const;
};
struct discr_tree::edge {
edge_kind m_kind;
bool m_fn; // TODO(Leo): set this field
name m_name; // only relevant for Local/Const
edge():m_kind(edge_kind::Unsupported), m_fn(false) {}
edge(edge_kind k, bool fn = false):m_kind(k), m_fn(fn) {}
edge(edge_kind k, name const & n, bool fn = false):m_kind(k), m_fn(fn), m_name(n) {}
};
struct discr_tree::edge_cmp {
int operator()(edge const & e1, edge const & e2) const {
if (e1.m_kind != e2.m_kind)
return static_cast<int>(e1.m_kind) - static_cast<int>(e2.m_kind);
return quick_cmp(e1.m_name, e2.m_name);
}
};
struct discr_tree::node_cell {
MK_LEAN_RC();
/* Unique id. We use it to implement node_cmp */
unsigned m_id;
/* We use a map based tree to map edges to nodes, we should investigate whether we really need a tree here.
We suspect the set of edges is usually quite small. So, an assoc-list may be enough.
We should also investigate whether a small array + hash code based on the edge is not enough.
Note that we may even ignore collisions since this is an imperfect discrimination tree anyway. */
rb_map<edge, node, edge_cmp> m_children;
node m_star_child;
/* The skip set is needed when searching for the set of terms stored in the discrimination tree
that may match an input term containing metavariables. In the literature, they are called
skip set/list. */
rb_tree<node, node_cmp> m_skip;
rb_tree<expr, expr_quick_cmp> m_values;
void dealloc();
node_cell();
node_cell(node_cell const & s);
};
DEF_THREAD_MEMORY_POOL(get_allocator, sizeof(discr_tree::node_cell));
LEAN_THREAD_VALUE(unsigned, g_next_id, 0);
MK_THREAD_LOCAL_GET_DEF(std::vector<unsigned>, get_recycled_ids);
static unsigned mk_id() {
auto & ids = get_recycled_ids();
unsigned r;
if (ids.empty()) {
r = g_next_id;
g_next_id++;
} else {
r = ids.back();
ids.pop_back();
}
return r;
}
discr_tree::node_cell::node_cell():m_rc(0), m_id(mk_id()) {
}
discr_tree::node_cell::node_cell(node_cell const & s):
m_rc(0), m_id(mk_id()),
m_children(s.m_children),
m_star_child(s.m_star_child),
m_values(s.m_values) {
}
void discr_tree::node_cell::dealloc() {
this->~node_cell();
get_recycled_ids().push_back(m_id);
get_allocator().recycle(this);
}
auto discr_tree::ensure_unshared(node && n) -> node {
if (!n.m_ptr)
return node(new (get_allocator().allocate()) node_cell());
else if (n.is_shared())
return node(new (get_allocator().allocate()) node_cell(*n.m_ptr));
else
return n;
}
discr_tree::node::node(node_cell * ptr):m_ptr(ptr) { if (m_ptr) ptr->inc_ref(); }
discr_tree::node::node(node const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); }
discr_tree::node::node(node && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; }
discr_tree::node::~node() { if (m_ptr) m_ptr->dec_ref(); }
discr_tree::node & discr_tree::node::operator=(node const & n) { LEAN_COPY_REF(n); }
discr_tree::node & discr_tree::node::operator=(node&& n) { LEAN_MOVE_REF(n); }
bool discr_tree::node::is_shared() const { return m_ptr && m_ptr->get_rc() > 1; }
int discr_tree::node_cmp::operator()(node const & n1, node const & n2) const {
if (n1.m_ptr) {
return n2.m_ptr ? unsigned_cmp()(n1.m_ptr->m_id, n2.m_ptr->m_id) : 1;
} else {
return n2.m_ptr ? -1 : 0;
}
}
auto discr_tree::insert_atom(node && n, edge const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
node new_n = ensure_unshared(n.steal());
if (auto child = new_n.m_ptr->m_children.find(e)) {
node new_child(*child);
new_n.m_ptr->m_children.erase(e);
new_child = insert(new_child.steal(), false, todo, v, skip);
new_n.m_ptr->m_children.insert(e, new_child);
return new_n;
} else {
node new_child = insert(node(), false, todo, v, skip);
new_n.m_ptr->m_children.insert(e, new_child);
return new_n;
}
}
auto discr_tree::insert_star(node && n, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
node new_n = ensure_unshared(n.steal());
new_n.m_ptr->m_star_child = insert(new_n.m_ptr->m_star_child.steal(), false, todo, v, skip);
return new_n;
}
auto discr_tree::insert_app(node && n, bool is_root, expr const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
lean_assert(is_app(e));
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (is_constant(fn) || is_local(fn)) {
if (!is_root)
todo.push_back(*g_delimiter);
fun_info info = get_fun_info(fn);
buffer<param_info> pinfos;
to_buffer(info.get_params_info(), pinfos);
lean_assert(pinfos.size() == args.size());
unsigned i = args.size();
while (i > 0) {
--i;
if (pinfos[i].is_prop() || pinfos[i].is_inst_implicit() || pinfos[i].is_implicit())
continue; // we ignore propositions, implicit and inst-implict arguments
todo.push_back(args[i]);
}
todo.push_back(fn);
node new_n = insert(std::move(n), false, todo, v, skip);
if (!is_root) {
lean_assert(!skip.empty());
pair<node, node> const & p = skip.back();
new_n.m_ptr->m_skip.erase(p.first); // remove old skip node
new_n.m_ptr->m_skip.insert(p.second); // insert new skip node
skip.pop_back();
}
return new_n;
} else if (is_meta(fn)) {
return insert_star(std::move(n), todo, v, skip);
} else {
return insert_atom(std::move(n), edge(), todo, v, skip);
}
}
auto discr_tree::insert(node && n, bool is_root, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
if (todo.empty()) {
node new_n = ensure_unshared(n.steal());
new_n.m_ptr->m_values.insert(v);
return new_n;
}
expr e = todo.back();
todo.pop_back();
if (is_eqp(e, *g_delimiter)) {
node old_n(n);
node new_n = insert(std::move(n), false, todo, v, skip);
skip.emplace_back(old_n, new_n);
return new_n;
}
switch (e.kind()) {
case expr_kind::Constant:
return insert_atom(std::move(n), edge(edge_kind::Constant, const_name(e)), todo, v, skip);
case expr_kind::Local:
return insert_atom(std::move(n), edge(edge_kind::Local, mlocal_name(e)), todo, v, skip);
case expr_kind::Meta:
return insert_star(std::move(n), todo, v, skip);
case expr_kind::App:
return insert_app(std::move(n), is_root, e, todo, v, skip);
case expr_kind::Var:
lean_unreachable();
case expr_kind::Sort: case expr_kind::Lambda:
case expr_kind::Pi: case expr_kind::Macro:
// unsupported
return insert_atom(std::move(n), edge(), todo, v, skip);
}
lean_unreachable();
}
void discr_tree::insert(expr const & k, expr const & v) {
buffer<expr> todo;
buffer<pair<node, node>> skip;
todo.push_back(k);
m_root = insert(m_root.steal(), true, todo, v, skip);
lean_trace("discr_tree", tout() << "\n"; trace(););
}
static void indent(unsigned depth) {
for (unsigned i = 0; i < depth; i++) tout() << " ";
}
void discr_tree::node::trace(optional<edge> const & e, unsigned depth, bool disj) const {
if (!m_ptr) {
tout() << "[null]\n";
return;
}
indent(depth);
if (disj)
tout() << "| ";
else if (depth > 0)
tout() << " ";
if (e) {
switch (e->m_kind) {
case edge_kind::Constant:
tout() << e->m_name;
break;
case edge_kind::Local:
if (e->m_name.is_numeral()) {
// This is a hack for getting nicer traces.
unsigned hidx = e->m_name.get_numeral();
if (hypothesis const * h = curr_state().find_hypothesis_decl(hidx)) {
tout() << h->get_name();
break;
}
}
tout() << e->m_name;
break;
case edge_kind::Star:
tout() << "*";
break;
case edge_kind::Unsupported:
tout() << "#";
break;
}
tout() << " -> ";
}
tout() << "[" << m_ptr->m_id << "] {";
bool first = true;
m_ptr->m_skip.for_each([&](node const & s) {
if (first) first = false; else tout() << ", ";
tout() << s.m_ptr->m_id;
});
tout() << "}";
if (!m_ptr->m_values.empty()) {
tout() << " {";
first = true;
m_ptr->m_values.for_each([&](expr const & v) {
if (first) first = false; else tout() << ", ";
tout() << ppb(v);
});
tout() << "}";
}
tout() << "\n";
unsigned new_depth = depth;
unsigned num_children = m_ptr->m_children.size();
if (m_ptr->m_star_child)
num_children++;
if (num_children > 1)
new_depth++;
m_ptr->m_children.for_each([&](edge const & e, node const & n) {
n.trace(optional<edge>(e), new_depth, num_children > 1);
});
if (m_ptr->m_star_child) {
m_ptr->m_star_child.trace(optional<edge>(edge_kind::Star), new_depth, num_children > 1);
}
}
void discr_tree::trace() const {
m_root.trace(optional<edge>(), 0, false);
}
void initialize_discr_tree() {
register_trace_class(name{"discr_tree"});
g_delimiter = new expr(mk_constant(name::mk_internal_unique_name()));
}
void finalize_discr_tree() {
delete g_delimiter;
}
}}

View file

@ -0,0 +1,74 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/expr.h"
#include "library/expr_lt.h"
namespace lean {
namespace blast {
/**
\brief (Imperfect) discrimination trees.
The edges are labeled with:
1- Constant names (the universes are ignored)
2- Local names (e.g., hypotheses)
3- Star/Wildcard (we use them to encode metavariables). We use the same symbol
for all metavariables. Remark: in the discrimination tree literature, our
metavariables are called variables.
4- Unsupported. We use them to encode nested lambda's, Pi's, Sort's
Anything that is not an application, constant or local.
When indexing terms, we ignore propositions and instance implicit
arguments. We use blast get_fun_info procedure for retrieving
this information. Thus, this data-structure should only be used
inside of the blast module. */
class discr_tree {
public:
struct node_cell;
private:
enum class edge_kind { Local, Constant, Star, Unsupported };
struct edge;
struct edge_cmp;
struct node_cmp;
struct node {
node_cell * m_ptr;
node():m_ptr(nullptr) {}
node(node_cell * ptr);
node(node const & s);
node(node && s);
~node();
node & operator=(node const & n);
node & operator=(node&& n);
operator bool() const { return m_ptr != nullptr; }
bool is_shared() const;
node steal() { node r; swap(r, *this); return r; }
void trace(optional<edge> const & e, unsigned depth, bool disj) const;
friend void swap(node & n1, node & n2) { std::swap(n1.m_ptr, n2.m_ptr); }
};
static node ensure_unshared(node && n);
static node insert_atom(node && n, edge const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
static node insert_star(node && n, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
static node insert_app(node && n, bool is_root, expr const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
static node insert(node && n, bool is_root, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
node m_root;
public:
void insert(expr const & k, expr const & v);
void insert(expr const & k) { insert(k, k); }
void erase(expr const & k, expr const & v);
void erase(expr const & k) { erase(k, k); }
void find(expr const & e, std::function<bool(expr const &)> const & fn) const;
void collect(expr const & e, buffer<expr> & r) const;
void trace() const;
};
void initialize_discr_tree();
void finalize_discr_tree();
}}

View file

@ -9,6 +9,7 @@ Author: Leonardo de Moura
#include "library/blast/blast_tactic.h"
#include "library/blast/options.h"
#include "library/blast/congruence_closure.h"
#include "library/blast/discr_tree.h"
#include "library/blast/simplifier/init_module.h"
#include "library/blast/backward/init_module.h"
#include "library/blast/forward/init_module.h"
@ -21,6 +22,7 @@ void initialize_blast_module() {
blast::initialize_options();
blast::initialize_state();
initialize_blast();
blast::initialize_discr_tree();
blast::initialize_simplifier_module();
blast::initialize_backward_module();
blast::initialize_forward_module();
@ -39,6 +41,7 @@ void finalize_blast_module() {
blast::finalize_forward_module();
blast::finalize_backward_module();
blast::finalize_simplifier_module();
blast::finalize_discr_tree();
finalize_blast();
blast::finalize_state();
blast::finalize_options();

View file

@ -824,7 +824,7 @@ void state::update_indices(hypothesis_idx hidx) {
branch_extension * ext = get_extension_core(i);
if (ext) ext->hypothesis_activated(h, hidx);
}
/* TODO(Leo): update congruence closure indices */
m_branch.m_hyp_index.insert(h.get_type(), h.get_self());
}
void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {

View file

@ -9,6 +9,7 @@ Author: Leonardo de Moura
#include "kernel/expr.h"
#include "library/head_map.h"
#include "library/tactic/goal.h"
#include "library/blast/discr_tree.h"
#include "library/blast/action_result.h"
#include "library/blast/hypothesis.h"
@ -155,6 +156,7 @@ class branch {
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
expr m_target;
hypothesis_idx_set m_target_deps;
discr_tree m_hyp_index;
branch_extension ** m_extensions;
public:
branch();
@ -287,6 +289,7 @@ public:
hypothesis const & get_hypothesis_decl(hypothesis_idx hidx) const { auto h = m_branch.m_hyp_decls.find(hidx); lean_assert(h); return *h; }
hypothesis const & get_hypothesis_decl(expr const & h) const;
hypothesis const * find_hypothesis_decl(hypothesis_idx hidx) const { return m_branch.m_hyp_decls.find(hidx); }
void for_each_hypothesis(std::function<void(hypothesis_idx, hypothesis const &)> const & fn) const { m_branch.m_hyp_decls.for_each(fn); }
optional<hypothesis_idx> find_active_hypothesis(std::function<bool(hypothesis_idx, hypothesis const &)> const & fn) const { // NOLINT