feat(library/tc_multigraph): finish transitive closed multigraph
This commit is contained in:
parent
3626bd83bf
commit
b8243934de
2 changed files with 227 additions and 5 deletions
|
@ -4,18 +4,206 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include "util/sstream.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "library/tc_multigraph.h"
|
||||
#include "library/composition_manager.h"
|
||||
#include "library/util.h"
|
||||
|
||||
namespace lean {
|
||||
pair<environment, list<name>> tc_multigraph::add(environment const & env, name const & /* e */, unsigned /* num_args */) {
|
||||
// TODO(Leo)
|
||||
return mk_pair(env, list<name>());
|
||||
static name * g_fun_sink = nullptr;
|
||||
static name * g_sort_sink = nullptr;
|
||||
struct add_edge_fn {
|
||||
environment m_env;
|
||||
type_checker_ptr m_tc;
|
||||
tc_multigraph & m_graph;
|
||||
|
||||
add_edge_fn(environment const & env, tc_multigraph & g):
|
||||
m_env(env), m_tc(new type_checker(env)), m_graph(g) {}
|
||||
|
||||
// Return true iff the types of constants c1 and c2 are equal.
|
||||
bool is_def_eq(name const & c1, name const & c2) {
|
||||
if (c1 == c2)
|
||||
return true;
|
||||
declaration const & d1 = m_env.get(c1);
|
||||
declaration const & d2 = m_env.get(c2);
|
||||
if (d1.get_num_univ_params() != d2.get_num_univ_params())
|
||||
return false;
|
||||
return m_tc->is_def_eq(d1.get_type(), d2.get_type()).first;
|
||||
}
|
||||
|
||||
// Erase edges that are definitionally equal to edge
|
||||
void erase_def_eqs(name const & src, name const & edge, name const & tgt) {
|
||||
buffer<name> to_delete;
|
||||
for (auto const & p : m_graph.get_successors(src)) {
|
||||
if (p.second == tgt) {
|
||||
if (is_def_eq(p.first, edge))
|
||||
to_delete.push_back(p.first);
|
||||
}
|
||||
}
|
||||
for (name const & e : to_delete)
|
||||
m_graph.erase(e);
|
||||
}
|
||||
|
||||
template<typename Val>
|
||||
static void insert_maplist(name_map<list<Val>> & m, name const & k, Val const & v) {
|
||||
if (auto it = m.find(k)) {
|
||||
m.insert(k, cons(v, filter(*it, [&](Val const & v2) { return v2 != v; })));
|
||||
} else {
|
||||
m.insert(k, list<Val>(v));
|
||||
}
|
||||
}
|
||||
|
||||
void add_core(name const & src, name const & edge, name const & tgt) {
|
||||
erase_def_eqs(src, edge, tgt);
|
||||
insert_maplist(m_graph.m_successors, src, mk_pair(edge, tgt));
|
||||
insert_maplist(m_graph.m_predecessors, tgt, src);
|
||||
m_graph.m_edges.insert(edge, src);
|
||||
}
|
||||
|
||||
name compose(name const & src, name const & e1, name const & e2, name const & tgt) {
|
||||
name n = src + name("to") + tgt;
|
||||
pair<environment, name> env_e = ::lean::compose(*m_tc, e2, e1, optional<name>(n));
|
||||
m_env = env_e.first;
|
||||
return env_e.second;
|
||||
}
|
||||
|
||||
pair<environment, list<name>> operator()(name const & src, name const & edge, name const & tgt) {
|
||||
buffer<std::tuple<name, name, name>> new_edges;
|
||||
if (auto preds = m_graph.m_predecessors.find(src)) {
|
||||
for (name const & pred : *preds) {
|
||||
if (pred == tgt)
|
||||
continue; // avoid loops
|
||||
if (auto pred_succ = m_graph.m_successors.find(pred)) {
|
||||
for (pair<name, name> const & p : *pred_succ) {
|
||||
if (p.second != src)
|
||||
continue;
|
||||
name new_e = compose(pred, p.first, edge, tgt);
|
||||
new_edges.emplace_back(pred, new_e, tgt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
m_tc.reset(new type_checker(m_env)); // update to reflect new constants in the environment
|
||||
buffer<std::tuple<name, name, name>> new_back_edges;
|
||||
new_back_edges.append(new_edges);
|
||||
if (auto succs = m_graph.m_successors.find(tgt)) {
|
||||
for (pair<name, name> const & p : *succs) {
|
||||
if (src == p.second)
|
||||
continue; // avoid loops
|
||||
name new_e = compose(src, edge, p.first, p.second);
|
||||
new_edges.emplace_back(src, new_e, p.second);
|
||||
for (auto const & back_edge : new_back_edges) {
|
||||
name bsrc, bedge, btgt;
|
||||
std::tie(bsrc, bedge, btgt) = back_edge;
|
||||
if (bsrc != p.second)
|
||||
continue;
|
||||
name new_e = compose(bsrc, bedge, p.first, p.second);
|
||||
new_edges.emplace_back(bsrc, new_e, p.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
buffer<name> new_cnsts;
|
||||
add_core(src, edge, tgt);
|
||||
for (auto const & new_edge : new_edges) {
|
||||
name nsrc, nedge, ntgt;
|
||||
std::tie(nsrc, nedge, ntgt) = new_edge;
|
||||
add_core(nsrc, nedge, ntgt);
|
||||
new_cnsts.push_back(nedge);
|
||||
}
|
||||
return mk_pair(m_env, to_list(new_cnsts));
|
||||
}
|
||||
};
|
||||
|
||||
/** \brief Return true iff args contains Var(0), Var(1), ..., Var(args.size() - 1) */
|
||||
static bool check_var_args(buffer<expr> const & args) {
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (!is_var(args[i]) || var_idx(args[i]) != i)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** \brief Return true iff param_id(levels[i]) == level_params[i] */
|
||||
static bool check_levels(levels ls, level_param_names ps) {
|
||||
while (!is_nil(ls) && !is_nil(ps)) {
|
||||
if (!is_param(head(ls)))
|
||||
return false;
|
||||
if (param_id(head(ls)) != head(ps))
|
||||
return false;
|
||||
ls = tail(ls);
|
||||
ps = tail(ps);
|
||||
}
|
||||
return is_nil(ls) && is_nil(ps);
|
||||
}
|
||||
|
||||
static void throw_ex(name const & k, name const & e) {
|
||||
throw exception(sstream() << "invalid " << k << ", type of '" << e
|
||||
<< "' does not match any of the allowed expected types for " << k << "s\n"
|
||||
<< " Pi (x_1 : A_1) ... (x_n : A_n) (y: C x_1 ... x_n), D t_1 ... t_m\n"
|
||||
<< " Pi (x_1 : A_1) ... (x_n : A_n) (y: C x_1 ... x_n), Type\n"
|
||||
<< " Pi (x_1 : A_1) ... (x_n : A_n) (y: C x_1 ... x_n), A -> B");
|
||||
}
|
||||
|
||||
pair<name, name> tc_multigraph::validate(environment const & env, name const & e, unsigned num_args) {
|
||||
declaration const & d = env.get(e);
|
||||
expr type = d.get_type();
|
||||
for (unsigned i = 0; i < num_args; i++) {
|
||||
if (!is_pi(type))
|
||||
throw_ex(m_kind, e);
|
||||
type = binding_body(type);
|
||||
}
|
||||
if (!is_pi(type))
|
||||
throw_ex(m_kind, e);
|
||||
buffer<expr> args;
|
||||
expr const & C = get_app_args(binding_domain(type), args);
|
||||
if (!is_constant(C) || !check_levels(const_levels(C), d.get_univ_params()) || !check_var_args(args))
|
||||
throw_ex(m_kind, e);
|
||||
name src = const_name(C);
|
||||
type = binding_body(type);
|
||||
name tgt;
|
||||
if (is_sort(type)) {
|
||||
tgt = *g_sort_sink;
|
||||
} else if (is_pi(type)) {
|
||||
tgt = *g_fun_sink;
|
||||
} else {
|
||||
expr const & D = get_app_fn(type);
|
||||
if (is_constant(D))
|
||||
tgt = const_name(D);
|
||||
else
|
||||
throw_ex(m_kind, e);
|
||||
}
|
||||
if (src == tgt)
|
||||
throw_ex(m_kind, e);
|
||||
return mk_pair(src, tgt);
|
||||
}
|
||||
|
||||
pair<environment, list<name>> tc_multigraph::add(environment const & env, name const & e, unsigned num_args) {
|
||||
auto src_tgt = validate(env, e, num_args);
|
||||
return add_edge_fn(env, *this)(src_tgt.first, e, src_tgt.second);
|
||||
}
|
||||
|
||||
pair<environment, list<name>> tc_multigraph::add(environment const & env, name const & e) {
|
||||
declaration const & d = env.get(e);
|
||||
return add(env, e, get_arity(d.get_type()));
|
||||
unsigned n = get_arity(d.get_type());
|
||||
if (n == 0)
|
||||
throw_ex(m_kind, e);
|
||||
return add(env, e, n-1);
|
||||
}
|
||||
|
||||
void tc_multigraph::add1(environment const & env, name const & e, unsigned num_args) {
|
||||
auto src_tgt = validate(env, e, num_args);
|
||||
return add_edge_fn(env, *this).add_core(src_tgt.first, e, src_tgt.second);
|
||||
}
|
||||
|
||||
void tc_multigraph::add1(environment const & env, name const & e) {
|
||||
declaration const & d = env.get(e);
|
||||
unsigned n = get_arity(d.get_type());
|
||||
if (n == 0)
|
||||
throw_ex(m_kind, e);
|
||||
return add1(env, e, n-1);
|
||||
}
|
||||
|
||||
void tc_multigraph::erase(name const & e) {
|
||||
auto src = m_edges.find(e);
|
||||
if (!src)
|
||||
|
@ -41,26 +229,52 @@ void tc_multigraph::erase(name const & e) {
|
|||
auto pred_lst = m_predecessors.find(tgt);
|
||||
lean_assert(pred_lst);
|
||||
list<name> new_pred_lst = filter(*pred_lst, [&](name const & n) { return n != *src; });
|
||||
m_predecessors.insert(tgt, new_pred_lst);
|
||||
if (new_pred_lst)
|
||||
m_predecessors.insert(tgt, new_pred_lst);
|
||||
else
|
||||
m_predecessors.erase(tgt);
|
||||
}
|
||||
m_edges.erase(e);
|
||||
}
|
||||
|
||||
bool tc_multigraph::is_edge(name const & e) const {
|
||||
return m_edges.contains(e);
|
||||
}
|
||||
|
||||
bool tc_multigraph::is_node(name const & c) const {
|
||||
return m_successors.contains(c) || m_predecessors.contains(c);
|
||||
}
|
||||
|
||||
list<pair<name, name>> tc_multigraph::get_successors(name const & c) const {
|
||||
if (auto r = m_successors.find(c))
|
||||
return *r;
|
||||
else
|
||||
return list<pair<name, name>>();
|
||||
}
|
||||
|
||||
list<name> tc_multigraph::get_predecessors(name const & c) const {
|
||||
if (auto r = m_predecessors.find(c))
|
||||
return *r;
|
||||
else
|
||||
return list<name>();
|
||||
}
|
||||
|
||||
bool tc_multigraph::is_fun_sink(name const & c) {
|
||||
return c == *g_fun_sink;
|
||||
}
|
||||
|
||||
bool tc_multigraph::is_sort_sink(name const & c) {
|
||||
return c == *g_sort_sink;
|
||||
}
|
||||
|
||||
void initialize_tc_multigraph() {
|
||||
name p = name::mk_internal_unique_name();
|
||||
g_fun_sink = new name(p, "Fun");
|
||||
g_sort_sink = new name(p, "Sort");
|
||||
}
|
||||
|
||||
void finalize_tc_multigraph() {
|
||||
delete g_fun_sink;
|
||||
delete g_sort_sink;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,12 +9,18 @@ Author: Leonardo de Moura
|
|||
namespace lean {
|
||||
/** \brief Transitive closed multigraph */
|
||||
class tc_multigraph {
|
||||
name m_kind;
|
||||
name_map<list<pair<name, name>>> m_successors;
|
||||
name_map<list<name>> m_predecessors;
|
||||
name_map<name> m_edges;
|
||||
pair<name, name> validate(environment const & env, name const & e, unsigned num_args);
|
||||
friend struct add_edge_fn;
|
||||
public:
|
||||
tc_multigraph(name const & kind):m_kind(kind) {}
|
||||
pair<environment, list<name>> add(environment const & env, name const & e, unsigned num_args);
|
||||
pair<environment, list<name>> add(environment const & env, name const & e);
|
||||
void add1(environment const & env, name const & e, unsigned num_args);
|
||||
void add1(environment const & env, name const & e);
|
||||
void erase(name const & e);
|
||||
bool is_edge(name const & e) const;
|
||||
bool is_node(name const & c) const;
|
||||
|
@ -23,4 +29,6 @@ public:
|
|||
static bool is_fun_sink(name const & c);
|
||||
static bool is_sort_sink(name const & c);
|
||||
};
|
||||
void initialize_tc_multigraph();
|
||||
void finalize_tc_multigraph();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue