diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt index 8f916bc25..46179eb56 100644 --- a/src/library/CMakeLists.txt +++ b/src/library/CMakeLists.txt @@ -3,7 +3,8 @@ add_library(library deep_copy.cpp expr_lt.cpp io_state.cpp occurs.cpp resolve_macro.cpp kernel_serializer.cpp max_sharing.cpp normalize.cpp shared_environment.cpp module.cpp coercion.cpp private.cpp placeholder.cpp aliases.cpp level_names.cpp - update_declaration.cpp choice.cpp scoped_ext.cpp locals.cpp) + update_declaration.cpp choice.cpp scoped_ext.cpp locals.cpp + unifier.cpp) # hop_match.cpp) target_link_libraries(library ${LEAN_LIBS}) diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp new file mode 100644 index 000000000..a502f4903 --- /dev/null +++ b/src/library/unifier.cpp @@ -0,0 +1,346 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include "util/luaref.h" +#include "util/lazy_list_fn.h" +#include "kernel/for_each_fn.h" +#include "kernel/abstract.h" +#include "kernel/type_checker.h" +#include "library/unifier.h" +#include "library/kernel_bindings.h" + +namespace lean { +static std::pair unify_simple_core(substitution const & s, expr const & lhs, expr const & rhs, + justification const & j) { + lean_assert(is_meta(lhs)); + buffer args; + expr const & m = get_app_args(lhs, args); + lean_assert(is_metavar(m)); + for (auto it = args.begin(); it != args.end(); it++) { + if (!is_local(*it) || std::find(args.begin(), it, *it) != it) + return mk_pair(unify_status::Unsupported, s); + } + if (is_meta(rhs) && get_app_fn(rhs) == m) + return mk_pair(unify_status::Unsupported, s); + bool failed = false; + for_each(rhs, [&](expr const & e, unsigned) { + if (failed) + return false; + if (is_local(e) && std::find(args.begin(), args.end(), e) == args.end()) { + // right-hand-side contains variable that is not in the scope + // of metavariable. + failed = true; + return false; + } + if (is_metavar(e) && e == m) { + // occurs-check failed + failed = true; + return false; + } + // we only need to continue exploring e if it contains + // metavariables and/or local constants. + return has_metavar(e) || has_local(e); + }); + if (failed) + return mk_pair(unify_status::Failed, s); + expr v = abstract_locals(rhs, args.size(), args.data()); + unsigned i = args.size(); + while (i > 0) { + --i; + v = mk_lambda(local_pp_name(args[i]), mlocal_type(args[i]), v); + } + return mk_pair(unify_status::Solved, s.assign(mlocal_name(m), v, j)); +} + +std::pair unify_simple(substitution const & s, expr const & lhs, expr const & rhs, justification const & j) { + if (lhs == rhs) + return mk_pair(unify_status::Solved, s); + else if (!has_metavar(lhs) && !has_metavar(rhs)) + return mk_pair(unify_status::Failed, s); + else if (is_meta(lhs)) + return unify_simple_core(s, lhs, rhs, j); + else if (is_meta(rhs)) + return unify_simple_core(s, rhs, lhs, j); + else + return mk_pair(unify_status::Unsupported, s); +} + +std::pair unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) { + lean_assert(is_meta(lhs)); + bool contains = false; + for_each(rhs, [&](level const & l) { + if (contains) + return false; + if (l == lhs) { + // occurs-check failed + contains = true; + return false; + } + return true; + }); + if (contains) { + if (is_succ(rhs)) + return mk_pair(unify_status::Failed, s); + else + return mk_pair(unify_status::Unsupported, s); + } + return mk_pair(unify_status::Solved, s.assign(meta_id(lhs), rhs, j)); +} + +std::pair unify_simple(substitution const & s, level const & lhs, level const & rhs, justification const & j) { + if (lhs == rhs) + return mk_pair(unify_status::Solved, s); + else if (!has_meta(lhs) && !has_meta(rhs)) + return mk_pair(unify_status::Failed, s); + else if (is_meta(lhs)) + return unify_simple_core(s, lhs, rhs, j); + else if (is_meta(rhs)) + return unify_simple_core(s, rhs, lhs, j); + else if (is_succ(lhs) && is_succ(rhs)) + return unify_simple(s, succ_of(lhs), succ_of(rhs), j); + else + return mk_pair(unify_status::Unsupported, s); +} + +std::pair unify_simple(substitution const & s, constraint const & c) { + if (is_eq_cnstr(c)) + return unify_simple(s, cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification()); + else if (is_level_cnstr(c)) + return unify_simple(s, cnstr_lhs_level(c), cnstr_rhs_level(c), c.get_justification()); + else + return mk_pair(unify_status::Unsupported, s); +} + +struct unifier_fn { + environment m_env; + name_generator m_ngen; + substitution m_subst; + unifier_plugin m_plugin; + bool m_use_exception; + + unifier_fn(environment const & env, unsigned /* num_cs */, constraint const * /* cs */, + name_generator const & ngen, substitution const & s, unifier_plugin const & p, + bool use_exception): + m_env(env), m_ngen(ngen), m_subst(s), m_plugin(p), m_use_exception(use_exception) { + } + + optional next() { + // TODO(Leo) + return optional(); + } +}; + +lazy_list unify(std::shared_ptr const & u) { + return mk_lazy_list([=]() { + auto s = u->next(); + if (s) + return some(mk_pair(*s, unify(u))); + else + return lazy_list::maybe_pair(); + }); +} + +lazy_list unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, + unifier_plugin const & p, bool use_exception) { + return unify(std::make_shared(env, num_cs, cs, ngen, substitution(), p, use_exception)); +} + +lazy_list unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, + bool use_exception) { + return unify(env, num_cs, cs, ngen, [](constraint const &, name_generator const &) { return lazy_list(); }, use_exception); +} + +lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, unifier_plugin const & p) { + substitution s; + buffer cs; + name_generator new_ngen(ngen); + bool failed = false; + type_checker tc(env, new_ngen.mk_child(), [&](constraint const & c) { + if (!failed) { + auto r = unify_simple(s, c); + switch (r.first) { + case unify_status::Solved: + s = r.second; break; + case unify_status::Failed: + failed = true; break; + case unify_status::Unsupported: + cs.push_back(c); break; + } + } + }); + if (!tc.is_def_eq(lhs, rhs) || failed) { + return lazy_list(); + } else if (cs.empty()) { + return lazy_list(s); + } else { + return unify(std::make_shared(env, cs.size(), cs.data(), ngen, s, p, false)); + } +} + +lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen) { + return unify(env, lhs, rhs, ngen, [](constraint const &, name_generator const &) { return lazy_list(); }); +} + +static int unify_simple(lua_State * L) { + int nargs = lua_gettop(L); + std::pair r; + if (nargs == 2) + r = unify_simple(to_substitution(L, 1), to_constraint(L, 2)); + else if (nargs == 3 && is_expr(L, 2)) + r = unify_simple(to_substitution(L, 1), to_expr(L, 2), to_expr(L, 3), justification()); + else if (nargs == 3 && is_level(L, 2)) + r = unify_simple(to_substitution(L, 1), to_level(L, 2), to_level(L, 3), justification()); + else if (is_expr(L, 2)) + r = unify_simple(to_substitution(L, 1), to_expr(L, 2), to_expr(L, 3), to_justification(L, 4)); + else + r = unify_simple(to_substitution(L, 1), to_level(L, 2), to_level(L, 3), to_justification(L, 4)); + push_integer(L, static_cast(r.first)); + push_substitution(L, r.second); + return 2; +} + +typedef lazy_list substitution_seq; +DECL_UDATA(substitution_seq) + +static const struct luaL_Reg substitution_seq_m[] = { + {"__gc", substitution_seq_gc}, + {0, 0} +}; + +static int substitution_seq_next(lua_State * L) { + substitution_seq seq = to_substitution_seq(L, lua_upvalueindex(1)); + substitution_seq::maybe_pair p; + p = seq.pull(); + if (p) { + push_substitution_seq(L, p->second); + lua_replace(L, lua_upvalueindex(1)); + push_substitution(L, p->first); + } else { + lua_pushnil(L); + } + return 1; +} + +static int push_substitution_seq_it(lua_State * L, substitution_seq const & seq) { + push_substitution_seq(L, seq); + lua_pushcclosure(L, &safe_function, 1); // create closure with 1 upvalue + return 1; +} + +static void to_constraint_buffer(lua_State * L, int idx, buffer & cs) { + luaL_checktype(L, idx, LUA_TTABLE); + lua_pushvalue(L, idx); // put table on top of the stack + int n = objlen(L, idx); + for (int i = 1; i <= n; i++) { + lua_rawgeti(L, -1, i); + cs.push_back(to_constraint(L, -1)); + lua_pop(L, 1); + } + lua_pop(L, 1); +} + +static constraints to_constraints(lua_State * L, int idx) { + buffer cs; + to_constraint_buffer(L, idx, cs); + return to_list(cs.begin(), cs.end()); +} + +static unifier_plugin to_unifier_plugin(lua_State * L, int idx) { + luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun + luaref f(L, idx); + return unifier_plugin([=](constraint const & c, name_generator const & ngen) { + lua_State * L = f.get_state(); + f.push(); + push_constraint(L, c); + push_name_generator(L, ngen); + pcall(L, 2, 1, 0); + lazy_list r; + if (is_constraint(L, -1)) { + // single constraint + r = lazy_list(constraints(to_constraint(L, -1))); + } else if (lua_istable(L, -1)) { + int num = objlen(L, -1); + if (num == 0) { + // empty table + r = lazy_list(); + } else { + lua_rawgeti(L, -1, 1); + if (is_constraint(L, -1)) { + // array of constraints case + lua_pop(L, 1); + r = lazy_list(to_constraints(L, -1)); + } else { + lua_pop(L, 1); + buffer css; + // array of array of constraints + for (int i = 1; i <= num; i++) { + lua_rawgeti(L, -1, i); + css.push_back(to_constraints(L, -1)); + lua_pop(L, 1); + } + r = to_lazy(to_list(css.begin(), css.end())); + } + } + } else if (lua_isnil(L, -1)) { + // nil case + r = lazy_list(); + } else { + throw exception("invalid unifier plugin, the result value must be a constrant, " + "nil, an array of constraints, or an array of arrays of constraints"); + } + lua_pop(L, 1); + return r; + }); +} + +static name g_tmp_prefix = name::mk_internal_unique_name(); + +static int unify(lua_State * L) { + int nargs = lua_gettop(L); + lazy_list r; + environment const & env = to_environment(L, 1); + if (is_expr(L, 2)) { + if (nargs == 3) + r = unify(env, to_expr(L, 2), to_expr(L, 3), name_generator(g_tmp_prefix)); + else if (nargs == 4 && is_name_generator(L, 4)) + r = unify(env, to_expr(L, 2), to_expr(L, 3), to_name_generator(L, 4)); + else if (nargs == 4) + r = unify(env, to_expr(L, 2), to_expr(L, 3), name_generator(g_tmp_prefix), to_unifier_plugin(L, 4)); + else + r = unify(env, to_expr(L, 2), to_expr(L, 3), to_name_generator(L, 4), to_unifier_plugin(L, 5)); + } else { + buffer cs; + to_constraint_buffer(L, 2, cs); + if (nargs == 2) + r = unify(env, cs.size(), cs.data(), name_generator(g_tmp_prefix)); + else if (nargs == 3 && is_name_generator(L, 3)) + r = unify(env, cs.size(), cs.data(), to_name_generator(L, 3)); + else if (nargs == 3) + r = unify(env, cs.size(), cs.data(), name_generator(g_tmp_prefix), to_unifier_plugin(L, 3)); + else + r = unify(env, cs.size(), cs.data(), to_name_generator(L, 3), to_unifier_plugin(L, 4)); + } + return push_substitution_seq_it(L, r); +} + +void open_unifier(lua_State * L) { + luaL_newmetatable(L, substitution_seq_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, substitution_seq_m, 0); + SET_GLOBAL_FUN(substitution_seq_pred, "is_substitution_seq"); + + SET_GLOBAL_FUN(unify_simple, "unify_simple"); + SET_GLOBAL_FUN(unify, "unify"); + + lua_newtable(L); + SET_ENUM("Solved", unify_status::Solved); + SET_ENUM("Failed", unify_status::Failed); + SET_ENUM("Unsupported", unify_status::Unsupported); + lua_setglobal(L, "unify_status"); +} +} diff --git a/src/library/unifier.h b/src/library/unifier.h new file mode 100644 index 000000000..cb9993e21 --- /dev/null +++ b/src/library/unifier.h @@ -0,0 +1,44 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include +#include "util/lua.h" +#include "util/lazy_list.h" +#include "util/name_generator.h" +#include "kernel/constraint.h" +#include "kernel/environment.h" +#include "kernel/metavar.h" + +namespace lean { +enum class unify_status { Solved, Failed, Unsupported }; +/** + \brief Handle the easy-cases: first-order unification, higher-order patterns, identical terms, and terms without metavariables. + + This function assumes that all assigned metavariables have been substituted. +*/ +std::pair unify_simple(substitution const & s, expr const & lhs, expr const & rhs, justification const & j); +std::pair unify_simple(substitution const & s, level const & lhs, level const & rhs, justification const & j); +std::pair unify_simple(substitution const & s, constraint const & c); + +/** + \brief A unifier_plugin provides a simple way to extend the \c unify procedures. + Whenever, the default implementation does not know how to solve a constraint, it invokes the plugin. + The plugin return a lazy_list (stream) of possible solutions. Each "solution" is represented as + a new list of constraints. +*/ +typedef std::function(constraint const &, name_generator const &)> unifier_plugin; + +lazy_list unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, + unifier_plugin const & p, bool use_exception = true); +lazy_list unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, + bool use_exception = true); +lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, unifier_plugin const & p); +lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen); + +void open_unifier(lua_State * L); +} diff --git a/tests/lua/unify1.lua b/tests/lua/unify1.lua new file mode 100644 index 000000000..14aa0019b --- /dev/null +++ b/tests/lua/unify1.lua @@ -0,0 +1,55 @@ +function test_unify_simple(lhs, rhs, expected) + print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(expected)) + r, s = unify_simple(substitution(), lhs, rhs, justification()) + if r == unify_status.Solved then + s:for_each_expr(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + s:for_each_level(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + end + assert(r == expected) +end + +local f = Const("f") +local a = Const("a") +local l1 = mk_local("l1", "x", Bool) +local l2 = mk_local("l2", "y", Bool) +local l3 = mk_local("l3", "z", Bool) +local m = mk_metavar("m", Bool) + +test_unify_simple(m(l1, l2), f(f(a, l1), l1), unify_status.Solved) +test_unify_simple(m(l1, l2), m(l1, l2), unify_status.Solved) +test_unify_simple(m(l1, l2), m(l2, l2), unify_status.Unsupported) +test_unify_simple(m(l1, l2), f(f(l2, l1), l1), unify_status.Solved) +test_unify_simple(m(l1, a), f(f(a), l1), unify_status.Unsupported) +test_unify_simple(f(m, a), f(f(a), l1), unify_status.Unsupported) +test_unify_simple(m(l1, l2), f(f(m), l1), unify_status.Failed) +test_unify_simple(m(l1, l2), f(f(l3), l1), unify_status.Failed) +test_unify_simple(m, f(a), unify_status.Solved) +test_unify_simple(f(a), f(a), unify_status.Solved) +test_unify_simple(f(a), f(f(a)), unify_status.Failed) +test_unify_simple(f(f(a, l1), l1), m(l1, l2), unify_status.Solved) +test_unify_simple(m(l1, l1), f(f(l2, l1), l1), unify_status.Unsupported) +test_unify_simple(m(l1, l2, l1), f(f(l2, l1), l1), unify_status.Unsupported) +test_unify_simple(m(l1, l2), Fun(l3, f(f(a, l1), l3)), unify_status.Solved) + +local zero = level() +local one = zero+1 +local l = mk_param_univ("l") +local u = mk_global_univ("u") +local m = mk_meta_univ("m") +test_unify_simple(m+1, u+1, unify_status.Solved) +test_unify_simple(u+1, u+1, unify_status.Solved) +test_unify_simple(u+1, m+1, unify_status.Solved) +test_unify_simple(m, u+1, unify_status.Solved) +test_unify_simple(m, max_univ(u, l), unify_status.Solved) +test_unify_simple(max_univ(u, l), max_univ(u, l), unify_status.Solved) +test_unify_simple(m+1, m+1, unify_status.Solved) +test_unify_simple(l, l+1, unify_status.Failed) +test_unify_simple(m, m+1, unify_status.Failed) +test_unify_simple(m, max_univ(m, u), unify_status.Unsupported) +test_unify_simple(m, max_univ(m, u)+1, unify_status.Failed) +test_unify_simple(m+2, m+1, unify_status.Failed) +test_unify_simple(u+2, m+1, unify_status.Solved) diff --git a/tests/lua/unify2.lua b/tests/lua/unify2.lua new file mode 100644 index 000000000..99bfc625f --- /dev/null +++ b/tests/lua/unify2.lua @@ -0,0 +1,31 @@ +function test_unify(env, lhs, rhs, num_s) + print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s)) + local ss = unify(env, lhs, rhs) + local n = 0 + for s in ss do + print("solution: ") + s:for_each_expr(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + s:for_each_level(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + n = n + 1 + end + assert(num_s == n) +end + +local env = environment() +env = add_decl(env, mk_var_decl("N", Type)) +local N = Const("N") +env = add_decl(env, mk_var_decl("f", mk_arrow(N, N, N))) +env = add_decl(env, mk_var_decl("a", N)) +local f = Const("f") +local a = Const("a") +local l1 = mk_local("l1", "x", N) +local l2 = mk_local("l2", "y", N) +local l3 = mk_local("l3", "z", N) +local m = mk_metavar("m", mk_arrow(N, N, N)) +test_unify(env, m(l1, l2), f(f(a, l1), l1), 1) +test_unify(env, f(m(l1, l2), l1), f(f(a, l1), l1), 1) +test_unify(env, f(m(l1, l2), a), f(f(a, l1), l1), 0)