From 43e1292f221f5ce171e9691be99e078cbbc2df7a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 21 Dec 2015 06:20:15 -0800 Subject: [PATCH] feat(library/blast): add discrimination trees --- src/library/blast/CMakeLists.txt | 2 +- src/library/blast/discr_tree.cpp | 307 ++++++++++++++++++++++++++++++ src/library/blast/discr_tree.h | 74 +++++++ src/library/blast/init_module.cpp | 3 + src/library/blast/state.cpp | 2 +- src/library/blast/state.h | 3 + 6 files changed, 389 insertions(+), 2 deletions(-) create mode 100644 src/library/blast/discr_tree.cpp create mode 100644 src/library/blast/discr_tree.h diff --git a/src/library/blast/CMakeLists.txt b/src/library/blast/CMakeLists.txt index 0edb3b9fa..949d06655 100644 --- a/src/library/blast/CMakeLists.txt +++ b/src/library/blast/CMakeLists.txt @@ -1,4 +1,4 @@ add_library(blast OBJECT hypothesis.cpp state.cpp blast.cpp blast_tactic.cpp init_module.cpp proof_expr.cpp options.cpp choice_point.cpp util.cpp gexpr.cpp revert.cpp strategy.cpp congruence_closure.cpp trace.cpp - imp_extension.cpp) + imp_extension.cpp discr_tree.cpp) diff --git a/src/library/blast/discr_tree.cpp b/src/library/blast/discr_tree.cpp new file mode 100644 index 000000000..f6c59fe26 --- /dev/null +++ b/src/library/blast/discr_tree.cpp @@ -0,0 +1,307 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "util/rb_map.h" +#include "util/memory_pool.h" +#include "library/blast/trace.h" +#include "library/blast/blast.h" +#include "library/blast/discr_tree.h" + +namespace lean { +namespace blast { +/* Auxiliary expr used to implement insert/erase operations. + When adding the children of an application into the todo stack, + we use g_delimiter to indicate where the arguments end. + For example, suppose the current stack is S, and we want + to add the children of (f a b). Then, the new stack will + be [S, *g_delimiter, b, a] + + \remark g_delimiter is an unique expression. */ +static expr * g_delimiter = nullptr; + +struct discr_tree::node_cmp { + int operator()(node const & n1, node const & n2) const; +}; + +struct discr_tree::edge { + edge_kind m_kind; + bool m_fn; // TODO(Leo): set this field + name m_name; // only relevant for Local/Const + edge():m_kind(edge_kind::Unsupported), m_fn(false) {} + edge(edge_kind k, bool fn = false):m_kind(k), m_fn(fn) {} + edge(edge_kind k, name const & n, bool fn = false):m_kind(k), m_fn(fn), m_name(n) {} +}; + +struct discr_tree::edge_cmp { + int operator()(edge const & e1, edge const & e2) const { + if (e1.m_kind != e2.m_kind) + return static_cast(e1.m_kind) - static_cast(e2.m_kind); + return quick_cmp(e1.m_name, e2.m_name); + } +}; + +struct discr_tree::node_cell { + MK_LEAN_RC(); + /* Unique id. We use it to implement node_cmp */ + unsigned m_id; + /* We use a map based tree to map edges to nodes, we should investigate whether we really need a tree here. + We suspect the set of edges is usually quite small. So, an assoc-list may be enough. + We should also investigate whether a small array + hash code based on the edge is not enough. + Note that we may even ignore collisions since this is an imperfect discrimination tree anyway. */ + rb_map m_children; + node m_star_child; + /* The skip set is needed when searching for the set of terms stored in the discrimination tree + that may match an input term containing metavariables. In the literature, they are called + skip set/list. */ + rb_tree m_skip; + rb_tree m_values; + void dealloc(); + node_cell(); + node_cell(node_cell const & s); +}; + +DEF_THREAD_MEMORY_POOL(get_allocator, sizeof(discr_tree::node_cell)); +LEAN_THREAD_VALUE(unsigned, g_next_id, 0); +MK_THREAD_LOCAL_GET_DEF(std::vector, get_recycled_ids); + +static unsigned mk_id() { + auto & ids = get_recycled_ids(); + unsigned r; + if (ids.empty()) { + r = g_next_id; + g_next_id++; + } else { + r = ids.back(); + ids.pop_back(); + } + return r; +} + +discr_tree::node_cell::node_cell():m_rc(0), m_id(mk_id()) { +} + +discr_tree::node_cell::node_cell(node_cell const & s): + m_rc(0), m_id(mk_id()), + m_children(s.m_children), + m_star_child(s.m_star_child), + m_values(s.m_values) { +} + +void discr_tree::node_cell::dealloc() { + this->~node_cell(); + get_recycled_ids().push_back(m_id); + get_allocator().recycle(this); +} + +auto discr_tree::ensure_unshared(node && n) -> node { + if (!n.m_ptr) + return node(new (get_allocator().allocate()) node_cell()); + else if (n.is_shared()) + return node(new (get_allocator().allocate()) node_cell(*n.m_ptr)); + else + return n; +} + +discr_tree::node::node(node_cell * ptr):m_ptr(ptr) { if (m_ptr) ptr->inc_ref(); } +discr_tree::node::node(node const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } +discr_tree::node::node(node && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; } +discr_tree::node::~node() { if (m_ptr) m_ptr->dec_ref(); } + +discr_tree::node & discr_tree::node::operator=(node const & n) { LEAN_COPY_REF(n); } +discr_tree::node & discr_tree::node::operator=(node&& n) { LEAN_MOVE_REF(n); } +bool discr_tree::node::is_shared() const { return m_ptr && m_ptr->get_rc() > 1; } + +int discr_tree::node_cmp::operator()(node const & n1, node const & n2) const { + if (n1.m_ptr) { + return n2.m_ptr ? unsigned_cmp()(n1.m_ptr->m_id, n2.m_ptr->m_id) : 1; + } else { + return n2.m_ptr ? -1 : 0; + } +} + +auto discr_tree::insert_atom(node && n, edge const & e, buffer & todo, expr const & v, buffer> & skip) -> node { + node new_n = ensure_unshared(n.steal()); + if (auto child = new_n.m_ptr->m_children.find(e)) { + node new_child(*child); + new_n.m_ptr->m_children.erase(e); + new_child = insert(new_child.steal(), false, todo, v, skip); + new_n.m_ptr->m_children.insert(e, new_child); + return new_n; + } else { + node new_child = insert(node(), false, todo, v, skip); + new_n.m_ptr->m_children.insert(e, new_child); + return new_n; + } +} + +auto discr_tree::insert_star(node && n, buffer & todo, expr const & v, buffer> & skip) -> node { + node new_n = ensure_unshared(n.steal()); + new_n.m_ptr->m_star_child = insert(new_n.m_ptr->m_star_child.steal(), false, todo, v, skip); + return new_n; +} + +auto discr_tree::insert_app(node && n, bool is_root, expr const & e, buffer & todo, expr const & v, buffer> & skip) -> node { + lean_assert(is_app(e)); + buffer args; + expr const & fn = get_app_args(e, args); + if (is_constant(fn) || is_local(fn)) { + if (!is_root) + todo.push_back(*g_delimiter); + fun_info info = get_fun_info(fn); + buffer pinfos; + to_buffer(info.get_params_info(), pinfos); + lean_assert(pinfos.size() == args.size()); + unsigned i = args.size(); + while (i > 0) { + --i; + if (pinfos[i].is_prop() || pinfos[i].is_inst_implicit() || pinfos[i].is_implicit()) + continue; // we ignore propositions, implicit and inst-implict arguments + todo.push_back(args[i]); + } + todo.push_back(fn); + node new_n = insert(std::move(n), false, todo, v, skip); + if (!is_root) { + lean_assert(!skip.empty()); + pair const & p = skip.back(); + new_n.m_ptr->m_skip.erase(p.first); // remove old skip node + new_n.m_ptr->m_skip.insert(p.second); // insert new skip node + skip.pop_back(); + } + return new_n; + } else if (is_meta(fn)) { + return insert_star(std::move(n), todo, v, skip); + } else { + return insert_atom(std::move(n), edge(), todo, v, skip); + } +} + +auto discr_tree::insert(node && n, bool is_root, buffer & todo, expr const & v, buffer> & skip) -> node { + if (todo.empty()) { + node new_n = ensure_unshared(n.steal()); + new_n.m_ptr->m_values.insert(v); + return new_n; + } + + expr e = todo.back(); + todo.pop_back(); + + if (is_eqp(e, *g_delimiter)) { + node old_n(n); + node new_n = insert(std::move(n), false, todo, v, skip); + skip.emplace_back(old_n, new_n); + return new_n; + } + + switch (e.kind()) { + case expr_kind::Constant: + return insert_atom(std::move(n), edge(edge_kind::Constant, const_name(e)), todo, v, skip); + case expr_kind::Local: + return insert_atom(std::move(n), edge(edge_kind::Local, mlocal_name(e)), todo, v, skip); + case expr_kind::Meta: + return insert_star(std::move(n), todo, v, skip); + case expr_kind::App: + return insert_app(std::move(n), is_root, e, todo, v, skip); + case expr_kind::Var: + lean_unreachable(); + case expr_kind::Sort: case expr_kind::Lambda: + case expr_kind::Pi: case expr_kind::Macro: + // unsupported + return insert_atom(std::move(n), edge(), todo, v, skip); + } + lean_unreachable(); +} + +void discr_tree::insert(expr const & k, expr const & v) { + buffer todo; + buffer> skip; + todo.push_back(k); + m_root = insert(m_root.steal(), true, todo, v, skip); + lean_trace("discr_tree", tout() << "\n"; trace();); +} + +static void indent(unsigned depth) { + for (unsigned i = 0; i < depth; i++) tout() << " "; +} + +void discr_tree::node::trace(optional const & e, unsigned depth, bool disj) const { + if (!m_ptr) { + tout() << "[null]\n"; + return; + } + indent(depth); + if (disj) + tout() << "| "; + else if (depth > 0) + tout() << " "; + if (e) { + switch (e->m_kind) { + case edge_kind::Constant: + tout() << e->m_name; + break; + case edge_kind::Local: + if (e->m_name.is_numeral()) { + // This is a hack for getting nicer traces. + unsigned hidx = e->m_name.get_numeral(); + if (hypothesis const * h = curr_state().find_hypothesis_decl(hidx)) { + tout() << h->get_name(); + break; + } + } + tout() << e->m_name; + break; + case edge_kind::Star: + tout() << "*"; + break; + case edge_kind::Unsupported: + tout() << "#"; + break; + } + tout() << " -> "; + } + tout() << "[" << m_ptr->m_id << "] {"; + bool first = true; + m_ptr->m_skip.for_each([&](node const & s) { + if (first) first = false; else tout() << ", "; + tout() << s.m_ptr->m_id; + }); + tout() << "}"; + if (!m_ptr->m_values.empty()) { + tout() << " {"; + first = true; + m_ptr->m_values.for_each([&](expr const & v) { + if (first) first = false; else tout() << ", "; + tout() << ppb(v); + }); + tout() << "}"; + } + tout() << "\n"; + unsigned new_depth = depth; + unsigned num_children = m_ptr->m_children.size(); + if (m_ptr->m_star_child) + num_children++; + if (num_children > 1) + new_depth++; + m_ptr->m_children.for_each([&](edge const & e, node const & n) { + n.trace(optional(e), new_depth, num_children > 1); + }); + if (m_ptr->m_star_child) { + m_ptr->m_star_child.trace(optional(edge_kind::Star), new_depth, num_children > 1); + } +} + +void discr_tree::trace() const { + m_root.trace(optional(), 0, false); +} + +void initialize_discr_tree() { + register_trace_class(name{"discr_tree"}); + g_delimiter = new expr(mk_constant(name::mk_internal_unique_name())); +} + +void finalize_discr_tree() { + delete g_delimiter; +} +}} diff --git a/src/library/blast/discr_tree.h b/src/library/blast/discr_tree.h new file mode 100644 index 000000000..4c8d3578e --- /dev/null +++ b/src/library/blast/discr_tree.h @@ -0,0 +1,74 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "kernel/expr.h" +#include "library/expr_lt.h" + +namespace lean { +namespace blast { +/** + \brief (Imperfect) discrimination trees. + + The edges are labeled with: + 1- Constant names (the universes are ignored) + 2- Local names (e.g., hypotheses) + 3- Star/Wildcard (we use them to encode metavariables). We use the same symbol + for all metavariables. Remark: in the discrimination tree literature, our + metavariables are called variables. + 4- Unsupported. We use them to encode nested lambda's, Pi's, Sort's + Anything that is not an application, constant or local. + + When indexing terms, we ignore propositions and instance implicit + arguments. We use blast get_fun_info procedure for retrieving + this information. Thus, this data-structure should only be used + inside of the blast module. */ +class discr_tree { +public: + struct node_cell; +private: + enum class edge_kind { Local, Constant, Star, Unsupported }; + struct edge; + struct edge_cmp; + struct node_cmp; + struct node { + node_cell * m_ptr; + node():m_ptr(nullptr) {} + node(node_cell * ptr); + node(node const & s); + node(node && s); + + ~node(); + node & operator=(node const & n); + node & operator=(node&& n); + operator bool() const { return m_ptr != nullptr; } + bool is_shared() const; + node steal() { node r; swap(r, *this); return r; } + void trace(optional const & e, unsigned depth, bool disj) const; + friend void swap(node & n1, node & n2) { std::swap(n1.m_ptr, n2.m_ptr); } + }; + + static node ensure_unshared(node && n); + static node insert_atom(node && n, edge const & e, buffer & todo, expr const & v, buffer> & skip); + static node insert_star(node && n, buffer & todo, expr const & v, buffer> & skip); + static node insert_app(node && n, bool is_root, expr const & e, buffer & todo, expr const & v, buffer> & skip); + static node insert(node && n, bool is_root, buffer & todo, expr const & v, buffer> & skip); + + node m_root; +public: + void insert(expr const & k, expr const & v); + void insert(expr const & k) { insert(k, k); } + void erase(expr const & k, expr const & v); + void erase(expr const & k) { erase(k, k); } + + void find(expr const & e, std::function const & fn) const; + void collect(expr const & e, buffer & r) const; + + void trace() const; +}; +void initialize_discr_tree(); +void finalize_discr_tree(); +}} diff --git a/src/library/blast/init_module.cpp b/src/library/blast/init_module.cpp index 4466873b0..6dfce0b51 100644 --- a/src/library/blast/init_module.cpp +++ b/src/library/blast/init_module.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include "library/blast/blast_tactic.h" #include "library/blast/options.h" #include "library/blast/congruence_closure.h" +#include "library/blast/discr_tree.h" #include "library/blast/simplifier/init_module.h" #include "library/blast/backward/init_module.h" #include "library/blast/forward/init_module.h" @@ -21,6 +22,7 @@ void initialize_blast_module() { blast::initialize_options(); blast::initialize_state(); initialize_blast(); + blast::initialize_discr_tree(); blast::initialize_simplifier_module(); blast::initialize_backward_module(); blast::initialize_forward_module(); @@ -39,6 +41,7 @@ void finalize_blast_module() { blast::finalize_forward_module(); blast::finalize_backward_module(); blast::finalize_simplifier_module(); + blast::finalize_discr_tree(); finalize_blast(); blast::finalize_state(); blast::finalize_options(); diff --git a/src/library/blast/state.cpp b/src/library/blast/state.cpp index a89cf8a9f..92d45bf39 100644 --- a/src/library/blast/state.cpp +++ b/src/library/blast/state.cpp @@ -824,7 +824,7 @@ void state::update_indices(hypothesis_idx hidx) { branch_extension * ext = get_extension_core(i); if (ext) ext->hypothesis_activated(h, hidx); } - /* TODO(Leo): update congruence closure indices */ + m_branch.m_hyp_index.insert(h.get_type(), h.get_self()); } void state::remove_from_indices(hypothesis const & h, hypothesis_idx hidx) { diff --git a/src/library/blast/state.h b/src/library/blast/state.h index 72269a85e..fdce4ab33 100644 --- a/src/library/blast/state.h +++ b/src/library/blast/state.h @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include "kernel/expr.h" #include "library/head_map.h" #include "library/tactic/goal.h" +#include "library/blast/discr_tree.h" #include "library/blast/action_result.h" #include "library/blast/hypothesis.h" @@ -155,6 +156,7 @@ class branch { forward_deps m_forward_deps; // given an entry (h -> {h_1, ..., h_n}), we have that each h_i uses h. expr m_target; hypothesis_idx_set m_target_deps; + discr_tree m_hyp_index; branch_extension ** m_extensions; public: branch(); @@ -287,6 +289,7 @@ public: hypothesis const & get_hypothesis_decl(hypothesis_idx hidx) const { auto h = m_branch.m_hyp_decls.find(hidx); lean_assert(h); return *h; } hypothesis const & get_hypothesis_decl(expr const & h) const; + hypothesis const * find_hypothesis_decl(hypothesis_idx hidx) const { return m_branch.m_hyp_decls.find(hidx); } void for_each_hypothesis(std::function const & fn) const { m_branch.m_hyp_decls.for_each(fn); } optional find_active_hypothesis(std::function const & fn) const { // NOLINT