feat(library/kernel_bindings): add mk_choice_cnstr to Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-23 13:39:21 -07:00
parent 6d14de76f3
commit 60c60c6cf5
2 changed files with 120 additions and 3 deletions

View file

@ -11,6 +11,7 @@ Author: Leonardo de Moura
#include "util/lua_list.h"
#include "util/lua_pair.h"
#include "util/lua_named_param.h"
#include "util/lazy_list_fn.h"
#include "util/luaref.h"
#include "kernel/abstract.h"
#include "kernel/for_each_fn.h"
@ -1622,9 +1623,87 @@ static int constraint_tostring(lua_State * L) {
out << to_constraint(L, 1);
return push_string(L, out.str().c_str());
}
static int mk_eq_cnstr(lua_State * L) { return push_constraint(L, mk_eq_cnstr(to_expr(L, 1), to_expr(L, 2), to_justification(L, 3))); }
static int mk_level_eq_cnstr(lua_State * L) { return push_constraint(L, mk_level_eq_cnstr(to_level_ext(L, 1), to_level_ext(L, 2),
to_justification(L, 3))); }
static int mk_eq_cnstr(lua_State * L) {
int nargs = lua_gettop(L);
return push_constraint(L, mk_eq_cnstr(to_expr(L, 1), to_expr(L, 2), nargs == 3 ? to_justification(L, 3) : justification()));
}
static int mk_level_eq_cnstr(lua_State * L) {
int nargs = lua_gettop(L);
return push_constraint(L, mk_level_eq_cnstr(to_level_ext(L, 1), to_level_ext(L, 2),
nargs == 3 ? to_justification(L, 3) : justification()));
}
static choice_fn to_choice_fn(lua_State * L, int idx) {
luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun
luaref f(L, idx);
return choice_fn([=](expr const & e, substitution const & s, name_generator const & ngen) {
lua_State * L = f.get_state();
f.push();
push_expr(L, e);
push_substitution(L, s);
push_name_generator(L, ngen);
pcall(L, 3, 1, 0);
buffer<choice_fn_result> r;
if (lua_isnil(L, -1)) {
// do nothing
} else if (lua_istable(L, -1)) {
int num = objlen(L, -1);
// each entry is an alternative
for (int i = 1; i <= num; i++) {
lua_rawgeti(L, -1, i);
if (is_expr(L, -1)) {
r.push_back(choice_fn_result(to_expr(L, -1), justification(), constraints()));
} else if (lua_istable(L, -1) && objlen(L, -1) == 3) {
lua_rawgeti(L, -1, 1);
expr c = to_expr(L, -1);
lua_pop(L, 1);
lua_rawgeti(L, -1, 2);
justification j = to_justification(L, -1);
lua_pop(L, 1);
lua_rawgeti(L, -1, 3);
buffer<constraint> cs;
if (lua_isnil(L, -1)) {
// do nothing
} else if (lua_istable(L, -1)) {
int num_cs = objlen(L, -1);
for (int i = 1; i <= num_cs; i++) {
lua_rawgeti(L, -1, i);
cs.push_back(to_constraint(L, -1));
lua_pop(L, 1);
}
} else {
throw exception("invalid choice function, result must be an array of triples, "
"where the third element of each triple is an array of constraints");
}
lua_pop(L, 1);
r.push_back(choice_fn_result(c, j, to_list(cs.begin(), cs.end())));
} else {
throw exception("invalid choice function, result must be an array of triples");
}
lua_pop(L, 1);
}
} else {
throw exception("invalid choice function, result must be an array of triples");
}
lua_pop(L, 1);
return to_lazy(to_list(r.begin(), r.end()));
});
}
static int mk_choice_cnstr(lua_State * L) {
int nargs = lua_gettop(L);
expr m = to_expr(L, 1);
choice_fn fn = to_choice_fn(L, 2);
if (nargs == 2)
return push_constraint(L, mk_choice_cnstr(m, fn, false, justification()));
else if (nargs == 3 && is_justification(L, 3))
return push_constraint(L, mk_choice_cnstr(m, fn, false, to_justification(L, 3)));
else if (nargs == 3)
return push_constraint(L, mk_choice_cnstr(m, fn, lua_toboolean(L, 3), justification()));
else
return push_constraint(L, mk_choice_cnstr(m, fn, lua_toboolean(L, 3), to_justification(L, 4)));
}
static const struct luaL_Reg constraint_m[] = {
{"__gc", constraint_gc}, // never throws
{"__tostring", safe_function<constraint_tostring>},
@ -1648,6 +1727,7 @@ static void open_constraint(lua_State * L) {
SET_GLOBAL_FUN(constraint_pred, "is_constraint");
SET_GLOBAL_FUN(mk_eq_cnstr, "mk_eq_cnstr");
SET_GLOBAL_FUN(mk_level_eq_cnstr, "mk_level_eq_cnstr");
SET_GLOBAL_FUN(mk_choice_cnstr, "mk_choice_cnstr");
lua_newtable(L);
SET_ENUM("Eq", constraint_kind::Eq);

37
tests/lua/unify5.lua Normal file
View file

@ -0,0 +1,37 @@
local env = environment()
env = add_decl(env, mk_var_decl("N", Type))
local N = Const("N")
env = add_decl(env, mk_var_decl("f", mk_arrow(N, N, N)))
env = add_decl(env, mk_var_decl("a", N))
env = add_decl(env, mk_var_decl("b", N))
local f = Const("f")
local a = Const("a")
local b = Const("b")
local m1 = mk_metavar("m1", N)
local m2 = mk_metavar("m2", N)
local m3 = mk_metavar("m3", N)
local m4 = mk_metavar("m4", N)
local o = options({"unifier", "use_exceptions"}, false)
function display_solutions(m, ss)
local n = 0
for s in ss do
print("solution: " .. tostring(s:instantiate(m)))
s:for_each_expr(function(n, v, j)
print(" " .. tostring(n) .. " := " .. tostring(v))
end)
s:for_each_level(function(n, v, j)
print(" " .. tostring(n) .. " := " .. tostring(v))
end)
n = n + 1
end
end
cs = { mk_eq_cnstr(m1, f(m2, f(m3, m4))),
mk_choice_cnstr(m2, function(e, s, ngen) return {{a, justification(), {}}, {f(a, a), justification(), {}}} end),
mk_choice_cnstr(m3, function(e, s, ngen) return {{b, justification(), {}}, {f(b, b), justification(), {}}} end),
mk_choice_cnstr(m4, function(e, s, ngen) return {a, b} end)
}
display_solutions(m1, unify(env, cs, o))