diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 6be2f2d82..a750d1617 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "util/lua_list.h" #include "util/lua_pair.h" #include "util/lua_named_param.h" +#include "util/lazy_list_fn.h" #include "util/luaref.h" #include "kernel/abstract.h" #include "kernel/for_each_fn.h" @@ -1622,9 +1623,87 @@ static int constraint_tostring(lua_State * L) { out << to_constraint(L, 1); return push_string(L, out.str().c_str()); } -static int mk_eq_cnstr(lua_State * L) { return push_constraint(L, mk_eq_cnstr(to_expr(L, 1), to_expr(L, 2), to_justification(L, 3))); } -static int mk_level_eq_cnstr(lua_State * L) { return push_constraint(L, mk_level_eq_cnstr(to_level_ext(L, 1), to_level_ext(L, 2), - to_justification(L, 3))); } +static int mk_eq_cnstr(lua_State * L) { + int nargs = lua_gettop(L); + return push_constraint(L, mk_eq_cnstr(to_expr(L, 1), to_expr(L, 2), nargs == 3 ? to_justification(L, 3) : justification())); +} +static int mk_level_eq_cnstr(lua_State * L) { + int nargs = lua_gettop(L); + return push_constraint(L, mk_level_eq_cnstr(to_level_ext(L, 1), to_level_ext(L, 2), + nargs == 3 ? to_justification(L, 3) : justification())); +} + +static choice_fn to_choice_fn(lua_State * L, int idx) { + luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun + luaref f(L, idx); + return choice_fn([=](expr const & e, substitution const & s, name_generator const & ngen) { + lua_State * L = f.get_state(); + f.push(); + push_expr(L, e); + push_substitution(L, s); + push_name_generator(L, ngen); + pcall(L, 3, 1, 0); + buffer r; + if (lua_isnil(L, -1)) { + // do nothing + } else if (lua_istable(L, -1)) { + int num = objlen(L, -1); + // each entry is an alternative + for (int i = 1; i <= num; i++) { + lua_rawgeti(L, -1, i); + if (is_expr(L, -1)) { + r.push_back(choice_fn_result(to_expr(L, -1), justification(), constraints())); + } else if (lua_istable(L, -1) && objlen(L, -1) == 3) { + lua_rawgeti(L, -1, 1); + expr c = to_expr(L, -1); + lua_pop(L, 1); + lua_rawgeti(L, -1, 2); + justification j = to_justification(L, -1); + lua_pop(L, 1); + lua_rawgeti(L, -1, 3); + buffer cs; + if (lua_isnil(L, -1)) { + // do nothing + } else if (lua_istable(L, -1)) { + int num_cs = objlen(L, -1); + for (int i = 1; i <= num_cs; i++) { + lua_rawgeti(L, -1, i); + cs.push_back(to_constraint(L, -1)); + lua_pop(L, 1); + } + } else { + throw exception("invalid choice function, result must be an array of triples, " + "where the third element of each triple is an array of constraints"); + } + lua_pop(L, 1); + r.push_back(choice_fn_result(c, j, to_list(cs.begin(), cs.end()))); + } else { + throw exception("invalid choice function, result must be an array of triples"); + } + lua_pop(L, 1); + } + } else { + throw exception("invalid choice function, result must be an array of triples"); + } + lua_pop(L, 1); + return to_lazy(to_list(r.begin(), r.end())); + }); +} + +static int mk_choice_cnstr(lua_State * L) { + int nargs = lua_gettop(L); + expr m = to_expr(L, 1); + choice_fn fn = to_choice_fn(L, 2); + if (nargs == 2) + return push_constraint(L, mk_choice_cnstr(m, fn, false, justification())); + else if (nargs == 3 && is_justification(L, 3)) + return push_constraint(L, mk_choice_cnstr(m, fn, false, to_justification(L, 3))); + else if (nargs == 3) + return push_constraint(L, mk_choice_cnstr(m, fn, lua_toboolean(L, 3), justification())); + else + return push_constraint(L, mk_choice_cnstr(m, fn, lua_toboolean(L, 3), to_justification(L, 4))); +} + static const struct luaL_Reg constraint_m[] = { {"__gc", constraint_gc}, // never throws {"__tostring", safe_function}, @@ -1648,6 +1727,7 @@ static void open_constraint(lua_State * L) { SET_GLOBAL_FUN(constraint_pred, "is_constraint"); SET_GLOBAL_FUN(mk_eq_cnstr, "mk_eq_cnstr"); SET_GLOBAL_FUN(mk_level_eq_cnstr, "mk_level_eq_cnstr"); + SET_GLOBAL_FUN(mk_choice_cnstr, "mk_choice_cnstr"); lua_newtable(L); SET_ENUM("Eq", constraint_kind::Eq); diff --git a/tests/lua/unify5.lua b/tests/lua/unify5.lua new file mode 100644 index 000000000..5e96f23da --- /dev/null +++ b/tests/lua/unify5.lua @@ -0,0 +1,37 @@ +local env = environment() +env = add_decl(env, mk_var_decl("N", Type)) +local N = Const("N") +env = add_decl(env, mk_var_decl("f", mk_arrow(N, N, N))) +env = add_decl(env, mk_var_decl("a", N)) +env = add_decl(env, mk_var_decl("b", N)) +local f = Const("f") +local a = Const("a") +local b = Const("b") +local m1 = mk_metavar("m1", N) +local m2 = mk_metavar("m2", N) +local m3 = mk_metavar("m3", N) +local m4 = mk_metavar("m4", N) + +local o = options({"unifier", "use_exceptions"}, false) + +function display_solutions(m, ss) + local n = 0 + for s in ss do + print("solution: " .. tostring(s:instantiate(m))) + s:for_each_expr(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + s:for_each_level(function(n, v, j) + print(" " .. tostring(n) .. " := " .. tostring(v)) + end) + n = n + 1 + end +end + +cs = { mk_eq_cnstr(m1, f(m2, f(m3, m4))), + mk_choice_cnstr(m2, function(e, s, ngen) return {{a, justification(), {}}, {f(a, a), justification(), {}}} end), + mk_choice_cnstr(m3, function(e, s, ngen) return {{b, justification(), {}}, {f(b, b), justification(), {}}} end), + mk_choice_cnstr(m4, function(e, s, ngen) return {a, b} end) + } + +display_solutions(m1, unify(env, cs, o))