feat(library/blast): add discrimination trees
This commit is contained in:
parent
b09fdc8c61
commit
43e1292f22
6 changed files with 389 additions and 2 deletions
|
@ -1,4 +1,4 @@
|
|||
add_library(blast OBJECT hypothesis.cpp state.cpp blast.cpp blast_tactic.cpp
|
||||
init_module.cpp proof_expr.cpp options.cpp choice_point.cpp util.cpp
|
||||
gexpr.cpp revert.cpp strategy.cpp congruence_closure.cpp trace.cpp
|
||||
imp_extension.cpp)
|
||||
imp_extension.cpp discr_tree.cpp)
|
||||
|
|
307
src/library/blast/discr_tree.cpp
Normal file
307
src/library/blast/discr_tree.cpp
Normal file
|
@ -0,0 +1,307 @@
|
|||
/*
|
||||
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include "util/rb_map.h"
|
||||
#include "util/memory_pool.h"
|
||||
#include "library/blast/trace.h"
|
||||
#include "library/blast/blast.h"
|
||||
#include "library/blast/discr_tree.h"
|
||||
|
||||
namespace lean {
|
||||
namespace blast {
|
||||
/* Auxiliary expr used to implement insert/erase operations.
|
||||
When adding the children of an application into the todo stack,
|
||||
we use g_delimiter to indicate where the arguments end.
|
||||
For example, suppose the current stack is S, and we want
|
||||
to add the children of (f a b). Then, the new stack will
|
||||
be [S, *g_delimiter, b, a]
|
||||
|
||||
\remark g_delimiter is an unique expression. */
|
||||
static expr * g_delimiter = nullptr;
|
||||
|
||||
struct discr_tree::node_cmp {
|
||||
int operator()(node const & n1, node const & n2) const;
|
||||
};
|
||||
|
||||
struct discr_tree::edge {
|
||||
edge_kind m_kind;
|
||||
bool m_fn; // TODO(Leo): set this field
|
||||
name m_name; // only relevant for Local/Const
|
||||
edge():m_kind(edge_kind::Unsupported), m_fn(false) {}
|
||||
edge(edge_kind k, bool fn = false):m_kind(k), m_fn(fn) {}
|
||||
edge(edge_kind k, name const & n, bool fn = false):m_kind(k), m_fn(fn), m_name(n) {}
|
||||
};
|
||||
|
||||
struct discr_tree::edge_cmp {
|
||||
int operator()(edge const & e1, edge const & e2) const {
|
||||
if (e1.m_kind != e2.m_kind)
|
||||
return static_cast<int>(e1.m_kind) - static_cast<int>(e2.m_kind);
|
||||
return quick_cmp(e1.m_name, e2.m_name);
|
||||
}
|
||||
};
|
||||
|
||||
struct discr_tree::node_cell {
|
||||
MK_LEAN_RC();
|
||||
/* Unique id. We use it to implement node_cmp */
|
||||
unsigned m_id;
|
||||
/* We use a map based tree to map edges to nodes, we should investigate whether we really need a tree here.
|
||||
We suspect the set of edges is usually quite small. So, an assoc-list may be enough.
|
||||
We should also investigate whether a small array + hash code based on the edge is not enough.
|
||||
Note that we may even ignore collisions since this is an imperfect discrimination tree anyway. */
|
||||
rb_map<edge, node, edge_cmp> m_children;
|
||||
node m_star_child;
|
||||
/* The skip set is needed when searching for the set of terms stored in the discrimination tree
|
||||
that may match an input term containing metavariables. In the literature, they are called
|
||||
skip set/list. */
|
||||
rb_tree<node, node_cmp> m_skip;
|
||||
rb_tree<expr, expr_quick_cmp> m_values;
|
||||
void dealloc();
|
||||
node_cell();
|
||||
node_cell(node_cell const & s);
|
||||
};
|
||||
|
||||
DEF_THREAD_MEMORY_POOL(get_allocator, sizeof(discr_tree::node_cell));
|
||||
LEAN_THREAD_VALUE(unsigned, g_next_id, 0);
|
||||
MK_THREAD_LOCAL_GET_DEF(std::vector<unsigned>, get_recycled_ids);
|
||||
|
||||
static unsigned mk_id() {
|
||||
auto & ids = get_recycled_ids();
|
||||
unsigned r;
|
||||
if (ids.empty()) {
|
||||
r = g_next_id;
|
||||
g_next_id++;
|
||||
} else {
|
||||
r = ids.back();
|
||||
ids.pop_back();
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
discr_tree::node_cell::node_cell():m_rc(0), m_id(mk_id()) {
|
||||
}
|
||||
|
||||
discr_tree::node_cell::node_cell(node_cell const & s):
|
||||
m_rc(0), m_id(mk_id()),
|
||||
m_children(s.m_children),
|
||||
m_star_child(s.m_star_child),
|
||||
m_values(s.m_values) {
|
||||
}
|
||||
|
||||
void discr_tree::node_cell::dealloc() {
|
||||
this->~node_cell();
|
||||
get_recycled_ids().push_back(m_id);
|
||||
get_allocator().recycle(this);
|
||||
}
|
||||
|
||||
auto discr_tree::ensure_unshared(node && n) -> node {
|
||||
if (!n.m_ptr)
|
||||
return node(new (get_allocator().allocate()) node_cell());
|
||||
else if (n.is_shared())
|
||||
return node(new (get_allocator().allocate()) node_cell(*n.m_ptr));
|
||||
else
|
||||
return n;
|
||||
}
|
||||
|
||||
discr_tree::node::node(node_cell * ptr):m_ptr(ptr) { if (m_ptr) ptr->inc_ref(); }
|
||||
discr_tree::node::node(node const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); }
|
||||
discr_tree::node::node(node && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; }
|
||||
discr_tree::node::~node() { if (m_ptr) m_ptr->dec_ref(); }
|
||||
|
||||
discr_tree::node & discr_tree::node::operator=(node const & n) { LEAN_COPY_REF(n); }
|
||||
discr_tree::node & discr_tree::node::operator=(node&& n) { LEAN_MOVE_REF(n); }
|
||||
bool discr_tree::node::is_shared() const { return m_ptr && m_ptr->get_rc() > 1; }
|
||||
|
||||
int discr_tree::node_cmp::operator()(node const & n1, node const & n2) const {
|
||||
if (n1.m_ptr) {
|
||||
return n2.m_ptr ? unsigned_cmp()(n1.m_ptr->m_id, n2.m_ptr->m_id) : 1;
|
||||
} else {
|
||||
return n2.m_ptr ? -1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
auto discr_tree::insert_atom(node && n, edge const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
|
||||
node new_n = ensure_unshared(n.steal());
|
||||
if (auto child = new_n.m_ptr->m_children.find(e)) {
|
||||
node new_child(*child);
|
||||
new_n.m_ptr->m_children.erase(e);
|
||||
new_child = insert(new_child.steal(), false, todo, v, skip);
|
||||
new_n.m_ptr->m_children.insert(e, new_child);
|
||||
return new_n;
|
||||
} else {
|
||||
node new_child = insert(node(), false, todo, v, skip);
|
||||
new_n.m_ptr->m_children.insert(e, new_child);
|
||||
return new_n;
|
||||
}
|
||||
}
|
||||
|
||||
auto discr_tree::insert_star(node && n, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
|
||||
node new_n = ensure_unshared(n.steal());
|
||||
new_n.m_ptr->m_star_child = insert(new_n.m_ptr->m_star_child.steal(), false, todo, v, skip);
|
||||
return new_n;
|
||||
}
|
||||
|
||||
auto discr_tree::insert_app(node && n, bool is_root, expr const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
|
||||
lean_assert(is_app(e));
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (is_constant(fn) || is_local(fn)) {
|
||||
if (!is_root)
|
||||
todo.push_back(*g_delimiter);
|
||||
fun_info info = get_fun_info(fn);
|
||||
buffer<param_info> pinfos;
|
||||
to_buffer(info.get_params_info(), pinfos);
|
||||
lean_assert(pinfos.size() == args.size());
|
||||
unsigned i = args.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
if (pinfos[i].is_prop() || pinfos[i].is_inst_implicit() || pinfos[i].is_implicit())
|
||||
continue; // we ignore propositions, implicit and inst-implict arguments
|
||||
todo.push_back(args[i]);
|
||||
}
|
||||
todo.push_back(fn);
|
||||
node new_n = insert(std::move(n), false, todo, v, skip);
|
||||
if (!is_root) {
|
||||
lean_assert(!skip.empty());
|
||||
pair<node, node> const & p = skip.back();
|
||||
new_n.m_ptr->m_skip.erase(p.first); // remove old skip node
|
||||
new_n.m_ptr->m_skip.insert(p.second); // insert new skip node
|
||||
skip.pop_back();
|
||||
}
|
||||
return new_n;
|
||||
} else if (is_meta(fn)) {
|
||||
return insert_star(std::move(n), todo, v, skip);
|
||||
} else {
|
||||
return insert_atom(std::move(n), edge(), todo, v, skip);
|
||||
}
|
||||
}
|
||||
|
||||
auto discr_tree::insert(node && n, bool is_root, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip) -> node {
|
||||
if (todo.empty()) {
|
||||
node new_n = ensure_unshared(n.steal());
|
||||
new_n.m_ptr->m_values.insert(v);
|
||||
return new_n;
|
||||
}
|
||||
|
||||
expr e = todo.back();
|
||||
todo.pop_back();
|
||||
|
||||
if (is_eqp(e, *g_delimiter)) {
|
||||
node old_n(n);
|
||||
node new_n = insert(std::move(n), false, todo, v, skip);
|
||||
skip.emplace_back(old_n, new_n);
|
||||
return new_n;
|
||||
}
|
||||
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Constant:
|
||||
return insert_atom(std::move(n), edge(edge_kind::Constant, const_name(e)), todo, v, skip);
|
||||
case expr_kind::Local:
|
||||
return insert_atom(std::move(n), edge(edge_kind::Local, mlocal_name(e)), todo, v, skip);
|
||||
case expr_kind::Meta:
|
||||
return insert_star(std::move(n), todo, v, skip);
|
||||
case expr_kind::App:
|
||||
return insert_app(std::move(n), is_root, e, todo, v, skip);
|
||||
case expr_kind::Var:
|
||||
lean_unreachable();
|
||||
case expr_kind::Sort: case expr_kind::Lambda:
|
||||
case expr_kind::Pi: case expr_kind::Macro:
|
||||
// unsupported
|
||||
return insert_atom(std::move(n), edge(), todo, v, skip);
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
void discr_tree::insert(expr const & k, expr const & v) {
|
||||
buffer<expr> todo;
|
||||
buffer<pair<node, node>> skip;
|
||||
todo.push_back(k);
|
||||
m_root = insert(m_root.steal(), true, todo, v, skip);
|
||||
lean_trace("discr_tree", tout() << "\n"; trace(););
|
||||
}
|
||||
|
||||
static void indent(unsigned depth) {
|
||||
for (unsigned i = 0; i < depth; i++) tout() << " ";
|
||||
}
|
||||
|
||||
void discr_tree::node::trace(optional<edge> const & e, unsigned depth, bool disj) const {
|
||||
if (!m_ptr) {
|
||||
tout() << "[null]\n";
|
||||
return;
|
||||
}
|
||||
indent(depth);
|
||||
if (disj)
|
||||
tout() << "| ";
|
||||
else if (depth > 0)
|
||||
tout() << " ";
|
||||
if (e) {
|
||||
switch (e->m_kind) {
|
||||
case edge_kind::Constant:
|
||||
tout() << e->m_name;
|
||||
break;
|
||||
case edge_kind::Local:
|
||||
if (e->m_name.is_numeral()) {
|
||||
// This is a hack for getting nicer traces.
|
||||
unsigned hidx = e->m_name.get_numeral();
|
||||
if (hypothesis const * h = curr_state().find_hypothesis_decl(hidx)) {
|
||||
tout() << h->get_name();
|
||||
break;
|
||||
}
|
||||
}
|
||||
tout() << e->m_name;
|
||||
break;
|
||||
case edge_kind::Star:
|
||||
tout() << "*";
|
||||
break;
|
||||
case edge_kind::Unsupported:
|
||||
tout() << "#";
|
||||
break;
|
||||
}
|
||||
tout() << " -> ";
|
||||
}
|
||||
tout() << "[" << m_ptr->m_id << "] {";
|
||||
bool first = true;
|
||||
m_ptr->m_skip.for_each([&](node const & s) {
|
||||
if (first) first = false; else tout() << ", ";
|
||||
tout() << s.m_ptr->m_id;
|
||||
});
|
||||
tout() << "}";
|
||||
if (!m_ptr->m_values.empty()) {
|
||||
tout() << " {";
|
||||
first = true;
|
||||
m_ptr->m_values.for_each([&](expr const & v) {
|
||||
if (first) first = false; else tout() << ", ";
|
||||
tout() << ppb(v);
|
||||
});
|
||||
tout() << "}";
|
||||
}
|
||||
tout() << "\n";
|
||||
unsigned new_depth = depth;
|
||||
unsigned num_children = m_ptr->m_children.size();
|
||||
if (m_ptr->m_star_child)
|
||||
num_children++;
|
||||
if (num_children > 1)
|
||||
new_depth++;
|
||||
m_ptr->m_children.for_each([&](edge const & e, node const & n) {
|
||||
n.trace(optional<edge>(e), new_depth, num_children > 1);
|
||||
});
|
||||
if (m_ptr->m_star_child) {
|
||||
m_ptr->m_star_child.trace(optional<edge>(edge_kind::Star), new_depth, num_children > 1);
|
||||
}
|
||||
}
|
||||
|
||||
void discr_tree::trace() const {
|
||||
m_root.trace(optional<edge>(), 0, false);
|
||||
}
|
||||
|
||||
void initialize_discr_tree() {
|
||||
register_trace_class(name{"discr_tree"});
|
||||
g_delimiter = new expr(mk_constant(name::mk_internal_unique_name()));
|
||||
}
|
||||
|
||||
void finalize_discr_tree() {
|
||||
delete g_delimiter;
|
||||
}
|
||||
}}
|
74
src/library/blast/discr_tree.h
Normal file
74
src/library/blast/discr_tree.h
Normal file
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "kernel/expr.h"
|
||||
#include "library/expr_lt.h"
|
||||
|
||||
namespace lean {
|
||||
namespace blast {
|
||||
/**
|
||||
\brief (Imperfect) discrimination trees.
|
||||
|
||||
The edges are labeled with:
|
||||
1- Constant names (the universes are ignored)
|
||||
2- Local names (e.g., hypotheses)
|
||||
3- Star/Wildcard (we use them to encode metavariables). We use the same symbol
|
||||
for all metavariables. Remark: in the discrimination tree literature, our
|
||||
metavariables are called variables.
|
||||
4- Unsupported. We use them to encode nested lambda's, Pi's, Sort's
|
||||
Anything that is not an application, constant or local.
|
||||
|
||||
When indexing terms, we ignore propositions and instance implicit
|
||||
arguments. We use blast get_fun_info procedure for retrieving
|
||||
this information. Thus, this data-structure should only be used
|
||||
inside of the blast module. */
|
||||
class discr_tree {
|
||||
public:
|
||||
struct node_cell;
|
||||
private:
|
||||
enum class edge_kind { Local, Constant, Star, Unsupported };
|
||||
struct edge;
|
||||
struct edge_cmp;
|
||||
struct node_cmp;
|
||||
struct node {
|
||||
node_cell * m_ptr;
|
||||
node():m_ptr(nullptr) {}
|
||||
node(node_cell * ptr);
|
||||
node(node const & s);
|
||||
node(node && s);
|
||||
|
||||
~node();
|
||||
node & operator=(node const & n);
|
||||
node & operator=(node&& n);
|
||||
operator bool() const { return m_ptr != nullptr; }
|
||||
bool is_shared() const;
|
||||
node steal() { node r; swap(r, *this); return r; }
|
||||
void trace(optional<edge> const & e, unsigned depth, bool disj) const;
|
||||
friend void swap(node & n1, node & n2) { std::swap(n1.m_ptr, n2.m_ptr); }
|
||||
};
|
||||
|
||||
static node ensure_unshared(node && n);
|
||||
static node insert_atom(node && n, edge const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
|
||||
static node insert_star(node && n, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
|
||||
static node insert_app(node && n, bool is_root, expr const & e, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
|
||||
static node insert(node && n, bool is_root, buffer<expr> & todo, expr const & v, buffer<pair<node, node>> & skip);
|
||||
|
||||
node m_root;
|
||||
public:
|
||||
void insert(expr const & k, expr const & v);
|
||||
void insert(expr const & k) { insert(k, k); }
|
||||
void erase(expr const & k, expr const & v);
|
||||
void erase(expr const & k) { erase(k, k); }
|
||||
|
||||
void find(expr const & e, std::function<bool(expr const &)> const & fn) const;
|
||||
void collect(expr const & e, buffer<expr> & r) const;
|
||||
|
||||
void trace() const;
|
||||
};
|
||||
void initialize_discr_tree();
|
||||
void finalize_discr_tree();
|
||||
}}
|
|
@ -9,6 +9,7 @@ Author: Leonardo de Moura
|
|||
#include "library/blast/blast_tactic.h"
|
||||
#include "library/blast/options.h"
|
||||
#include "library/blast/congruence_closure.h"
|
||||
#include "library/blast/discr_tree.h"
|
||||
#include "library/blast/simplifier/init_module.h"
|
||||
#include "library/blast/backward/init_module.h"
|
||||
#include "library/blast/forward/init_module.h"
|
||||
|
@ -21,6 +22,7 @@ void initialize_blast_module() {
|
|||
blast::initialize_options();
|
||||
blast::initialize_state();
|
||||
initialize_blast();
|
||||
blast::initialize_discr_tree();
|
||||
blast::initialize_simplifier_module();
|
||||
blast::initialize_backward_module();
|
||||
blast::initialize_forward_module();
|
||||
|
@ -39,6 +41,7 @@ void finalize_blast_module() {
|
|||
blast::finalize_forward_module();
|
||||
blast::finalize_backward_module();
|
||||
blast::finalize_simplifier_module();
|
||||
blast::finalize_discr_tree();
|
||||
finalize_blast();
|
||||
blast::finalize_state();
|
||||
blast::finalize_options();
|
||||
|
|
|
@ -824,7 +824,7 @@ void state::update_indices(hypothesis_idx hidx) {
|
|||
branch_extension * ext = get_extension_core(i);
|
||||
if (ext) ext->hypothesis_activated(h, hidx);
|
||||
}
|
||||
/* TODO(Leo): update congruence closure indices */
|
||||
m_branch.m_hyp_index.insert(h.get_type(), h.get_self());
|
||||
}
|
||||
|
||||
void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) {
|
||||
|
|
|
@ -9,6 +9,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/expr.h"
|
||||
#include "library/head_map.h"
|
||||
#include "library/tactic/goal.h"
|
||||
#include "library/blast/discr_tree.h"
|
||||
#include "library/blast/action_result.h"
|
||||
#include "library/blast/hypothesis.h"
|
||||
|
||||
|
@ -155,6 +156,7 @@ class branch {
|
|||
forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h.
|
||||
expr m_target;
|
||||
hypothesis_idx_set m_target_deps;
|
||||
discr_tree m_hyp_index;
|
||||
branch_extension ** m_extensions;
|
||||
public:
|
||||
branch();
|
||||
|
@ -287,6 +289,7 @@ public:
|
|||
|
||||
hypothesis const & get_hypothesis_decl(hypothesis_idx hidx) const { auto h = m_branch.m_hyp_decls.find(hidx); lean_assert(h); return *h; }
|
||||
hypothesis const & get_hypothesis_decl(expr const & h) const;
|
||||
hypothesis const * find_hypothesis_decl(hypothesis_idx hidx) const { return m_branch.m_hyp_decls.find(hidx); }
|
||||
|
||||
void for_each_hypothesis(std::function<void(hypothesis_idx, hypothesis const &)> const & fn) const { m_branch.m_hyp_decls.for_each(fn); }
|
||||
optional<hypothesis_idx> find_active_hypothesis(std::function<bool(hypothesis_idx, hypothesis const &)> const & fn) const { // NOLINT
|
||||
|
|
Loading…
Reference in a new issue