/* Copyright (c) 2015 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "util/sstream.h" #include "kernel/type_checker.h" #include "library/tc_multigraph.h" #include "library/composition_manager.h" #include "library/util.h" namespace lean { static name * g_fun_sink = nullptr; static name * g_sort_sink = nullptr; struct add_edge_fn { environment m_env; type_checker_ptr m_tc; tc_multigraph & m_graph; add_edge_fn(environment const & env, tc_multigraph & g): m_env(env), m_tc(new type_checker(env)), m_graph(g) {} // Return true iff the types of constants c1 and c2 are equal. bool is_def_eq(name const & c1, name const & c2) { if (c1 == c2) return true; declaration const & d1 = m_env.get(c1); declaration const & d2 = m_env.get(c2); if (d1.get_num_univ_params() != d2.get_num_univ_params()) return false; return m_tc->is_def_eq(d1.get_type(), d2.get_type()).first; } // Erase edges that are definitionally equal to edge void erase_def_eqs(name const & src, name const & edge, name const & tgt) { buffer to_delete; for (auto const & p : m_graph.get_successors(src)) { if (p.second == tgt) { if (is_def_eq(p.first, edge)) to_delete.push_back(p.first); } } for (name const & e : to_delete) m_graph.erase(e); } template static void insert_maplist(name_map> & m, name const & k, Val const & v) { if (auto it = m.find(k)) { m.insert(k, cons(v, filter(*it, [&](Val const & v2) { return v2 != v; }))); } else { m.insert(k, list(v)); } } void add_core(name const & src, name const & edge, name const & tgt) { erase_def_eqs(src, edge, tgt); insert_maplist(m_graph.m_successors, src, mk_pair(edge, tgt)); insert_maplist(m_graph.m_predecessors, tgt, src); m_graph.m_edges.insert(edge, src); } name compose(name const & src, name const & e1, name const & e2, name const & tgt) { name n = src + name("to") + tgt; pair env_e = ::lean::compose(*m_tc, e2, e1, optional(n)); m_env = env_e.first; return env_e.second; } pair> operator()(name const & src, name const & edge, name const & tgt) { buffer> new_edges; if (auto preds = m_graph.m_predecessors.find(src)) { for (name const & pred : *preds) { if (pred == tgt) continue; // avoid loops if (auto pred_succ = m_graph.m_successors.find(pred)) { for (pair const & p : *pred_succ) { if (p.second != src) continue; name new_e = compose(pred, p.first, edge, tgt); new_edges.emplace_back(pred, new_e, tgt); } } } } m_tc.reset(new type_checker(m_env)); // update to reflect new constants in the environment buffer> new_back_edges; new_back_edges.append(new_edges); if (auto succs = m_graph.m_successors.find(tgt)) { for (pair const & p : *succs) { if (src == p.second) continue; // avoid loops name new_e = compose(src, edge, p.first, p.second); new_edges.emplace_back(src, new_e, p.second); for (auto const & back_edge : new_back_edges) { name bsrc, bedge, btgt; std::tie(bsrc, bedge, btgt) = back_edge; if (bsrc != p.second) continue; name new_e = compose(bsrc, bedge, p.first, p.second); new_edges.emplace_back(bsrc, new_e, p.second); } } } buffer new_cnsts; add_core(src, edge, tgt); for (auto const & new_edge : new_edges) { name nsrc, nedge, ntgt; std::tie(nsrc, nedge, ntgt) = new_edge; add_core(nsrc, nedge, ntgt); new_cnsts.push_back(nedge); } return mk_pair(m_env, to_list(new_cnsts)); } }; /** \brief Return true iff args contains Var(0), Var(1), ..., Var(args.size() - 1) */ static bool check_var_args(buffer const & args) { for (unsigned i = 0; i < args.size(); i++) { if (!is_var(args[i]) || var_idx(args[i]) != i) return false; } return true; } /** \brief Return true iff param_id(levels[i]) == level_params[i] */ static bool check_levels(levels ls, level_param_names ps) { while (!is_nil(ls) && !is_nil(ps)) { if (!is_param(head(ls))) return false; if (param_id(head(ls)) != head(ps)) return false; ls = tail(ls); ps = tail(ps); } return is_nil(ls) && is_nil(ps); } static void throw_ex(name const & k, name const & e) { throw exception(sstream() << "invalid " << k << ", type of '" << e << "' does not match any of the allowed expected types for " << k << "s\n" << " Pi (x_1 : A_1) ... (x_n : A_n) (y: C x_1 ... x_n), D t_1 ... t_m\n" << " Pi (x_1 : A_1) ... (x_n : A_n) (y: C x_1 ... x_n), Type\n" << " Pi (x_1 : A_1) ... (x_n : A_n) (y: C x_1 ... x_n), A -> B"); } pair tc_multigraph::validate(environment const & env, name const & e, unsigned num_args) { declaration const & d = env.get(e); expr type = d.get_type(); for (unsigned i = 0; i < num_args; i++) { if (!is_pi(type)) throw_ex(m_kind, e); type = binding_body(type); } if (!is_pi(type)) throw_ex(m_kind, e); buffer args; expr const & C = get_app_args(binding_domain(type), args); if (!is_constant(C) || !check_levels(const_levels(C), d.get_univ_params()) || !check_var_args(args)) throw_ex(m_kind, e); name src = const_name(C); type = binding_body(type); name tgt; if (is_sort(type)) { tgt = *g_sort_sink; } else if (is_pi(type)) { tgt = *g_fun_sink; } else { expr const & D = get_app_fn(type); if (is_constant(D)) tgt = const_name(D); else throw_ex(m_kind, e); } if (src == tgt) throw_ex(m_kind, e); return mk_pair(src, tgt); } pair> tc_multigraph::add(environment const & env, name const & e, unsigned num_args) { auto src_tgt = validate(env, e, num_args); return add_edge_fn(env, *this)(src_tgt.first, e, src_tgt.second); } pair> tc_multigraph::add(environment const & env, name const & e) { declaration const & d = env.get(e); unsigned n = get_arity(d.get_type()); if (n == 0) throw_ex(m_kind, e); return add(env, e, n-1); } void tc_multigraph::add1(environment const & env, name const & e, unsigned num_args) { auto src_tgt = validate(env, e, num_args); return add_edge_fn(env, *this).add_core(src_tgt.first, e, src_tgt.second); } void tc_multigraph::add1(environment const & env, name const & e) { declaration const & d = env.get(e); unsigned n = get_arity(d.get_type()); if (n == 0) throw_ex(m_kind, e); return add1(env, e, n-1); } void tc_multigraph::erase(name const & e) { auto src = m_edges.find(e); if (!src) return; auto succ_lst = m_successors.find(*src); lean_assert(succ_lst); name tgt; list> new_succ_lst = filter(*succ_lst, [&](pair const & p) { if (p.first == e) { lean_assert(tgt.is_anonymous()); tgt = p.second; return false; } else { return true; } }); lean_assert(!tgt.is_anonymous()); m_successors.insert(*src, new_succ_lst); if (std::all_of(new_succ_lst.begin(), new_succ_lst.end(), [&](pair const & p) { return p.second != tgt; })) { // e is the last edge from src to tgt auto pred_lst = m_predecessors.find(tgt); lean_assert(pred_lst); list new_pred_lst = filter(*pred_lst, [&](name const & n) { return n != *src; }); if (new_pred_lst) m_predecessors.insert(tgt, new_pred_lst); else m_predecessors.erase(tgt); } m_edges.erase(e); } bool tc_multigraph::is_edge(name const & e) const { return m_edges.contains(e); } bool tc_multigraph::is_node(name const & c) const { return m_successors.contains(c) || m_predecessors.contains(c); } list> tc_multigraph::get_successors(name const & c) const { if (auto r = m_successors.find(c)) return *r; else return list>(); } list tc_multigraph::get_predecessors(name const & c) const { if (auto r = m_predecessors.find(c)) return *r; else return list(); } bool tc_multigraph::is_fun_sink(name const & c) { return c == *g_fun_sink; } bool tc_multigraph::is_sort_sink(name const & c) { return c == *g_sort_sink; } void initialize_tc_multigraph() { name p = name::mk_internal_unique_name(); g_fun_sink = new name(p, "Fun"); g_sort_sink = new name(p, "Sort"); } void finalize_tc_multigraph() { delete g_fun_sink; delete g_sort_sink; } }