feat(library/kernel_bindings): substitution Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-05-01 15:30:30 -07:00
parent 0bc101ac28
commit 340c0e0945
3 changed files with 315 additions and 241 deletions

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "kernel/for_each_fn.h"
#include "kernel/free_vars.h"
#include "kernel/instantiate.h"
#include "kernel/metavar.h"
#include "library/occurs.h"
#include "library/io_state_stream.h"
#include "library/expr_lt.h"
@ -29,6 +30,8 @@ io_state * get_io_state(lua_State * L);
DECL_UDATA(level)
DEFINE_LUA_LIST(level, push_level, to_level)
int push_optional_level(lua_State * L, optional<level> const & l) { return l ? push_level(L, *l) : pushnil(L); }
static int level_tostring(lua_State * L) {
std::ostringstream out;
options opts = get_global_options(L);
@ -827,6 +830,279 @@ io_state * get_io_state(lua_State * L) {
return nullptr;
}
DECL_UDATA(justification)
int push_optional_justification(lua_State * L, optional<justification> const & j) {
if (j)
push_justification(L, *j);
else
lua_pushnil(L);
return 1;
}
#if 0
static int justification_tostring(lua_State * L) {
std::ostringstream out;
justification & jst = to_justification(L, 1);
if (jst) {
formatter fmt = get_global_formatter(L);
options opts = get_global_options(L);
out << mk_pair(jst.pp(fmt, opts), opts);
} else {
out << "<null-justification>";
}
lua_pushstring(L, out.str().c_str());
return 1;
}
static int justification_has_children(lua_State * L) {
lua_pushboolean(L, to_justification(L, 1).has_children());
return 1;
}
static int justification_is_null(lua_State * L) {
lua_pushboolean(L, !to_justification(L, 1));
return 1;
}
/**
\brief Iterator (closure base function) for justification children. See \c justification_children
*/
static int justification_next_child(lua_State * L) {
unsigned i = lua_tointeger(L, lua_upvalueindex(2));
unsigned num = objlen(L, lua_upvalueindex(1));
if (i > num) {
lua_pushnil(L);
} else {
lua_pushinteger(L, i + 1);
lua_replace(L, lua_upvalueindex(2)); // update i
lua_rawgeti(L, lua_upvalueindex(1), i); // read children[i]
}
return 1;
}
static int justification_children(lua_State * L) {
buffer<justification_cell*> children;
to_justification(L, 1).get_children(children);
lua_newtable(L);
int i = 1;
for (auto jcell : children) {
push_justification(L, justification(jcell));
lua_rawseti(L, -2, i);
i = i + 1;
}
lua_pushinteger(L, 1);
lua_pushcclosure(L, &safe_function<justification_next_child>, 2); // create closure with 2 upvalues
return 1;
}
static int justification_get_main_expr(lua_State * L) {
optional<expr> r = to_justification(L, 1).get_main_expr();
if (r)
push_expr(L, *r);
else
lua_pushnil(L);
return 1;
}
static int justification_pp(lua_State * L) {
int nargs = lua_gettop(L);
justification & jst = to_justification(L, 1);
formatter fmt = get_global_formatter(L);
options opts = get_global_options(L);
bool display_children = true;
if (nargs == 2) {
if (lua_isboolean(L, 2)) {
display_children = lua_toboolean(L, 2);
} else {
luaL_checktype(L, 2, LUA_TTABLE);
lua_pushstring(L, "formatter");
lua_gettable(L, 2);
if (is_formatter(L, -1))
fmt = to_formatter(L, -1);
lua_pop(L, 1);
lua_pushstring(L, "options");
lua_gettable(L, 2);
if (is_options(L, -1))
opts = to_options(L, -1);
lua_pop(L, 1);
lua_pushstring(L, "display_children");
lua_gettable(L, 2);
if (lua_isboolean(L, -1))
display_children = lua_toboolean(L, -1);
lua_pop(L, 1);
}
}
return push_format(L, jst.pp(fmt, opts, nullptr, display_children));
}
static int justification_depends_on(lua_State * L) {
lua_pushboolean(L, depends_on(to_justification(L, 1), to_justification(L, 2)));
return 1;
}
static int mk_assumption_justification(lua_State * L) {
return push_justification(L, mk_assumption_justification(luaL_checkinteger(L, 1)));
}
#endif
static const struct luaL_Reg justification_m[] = {
{"__gc", justification_gc}, // never throws
// {"__tostring", safe_function<justification_tostring>},
// {"is_null", safe_function<justification_is_null>},
// {"has_children", safe_function<justification_has_children>},
// {"children", safe_function<justification_children>},
// {"get_main_expr", safe_function<justification_get_main_expr>},
// {"pp", safe_function<justification_pp>},
// {"depends_on", safe_function<justification_depends_on>},
{0, 0}
};
static void open_justification(lua_State * L) {
luaL_newmetatable(L, justification_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, justification_m, 0);
// SET_GLOBAL_FUN(mk_assumption_justification, "mk_assumption_justification");
SET_GLOBAL_FUN(justification_pred, "is_justification");
}
// Substitution
DECL_UDATA(substitution)
static int mk_substitution(lua_State * L) { return push_substitution(L, substitution()); }
static int subst_get_expr(lua_State * L) {
if (is_expr(L, 2))
return push_optional_expr(L, to_substitution(L, 1).get_expr(to_expr(L, 2)));
else
return push_optional_expr(L, to_substitution(L, 1).get_expr(to_name_ext(L, 2)));
}
static int subst_get_level(lua_State * L) {
if (is_level(L, 2))
return push_optional_level(L, to_substitution(L, 1).get_level(to_level(L, 2)));
else
return push_optional_level(L, to_substitution(L, 1).get_level(to_name_ext(L, 2)));
}
static int subst_assign(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs == 3) {
if (is_expr(L, 3)) {
if (is_expr(L, 2))
return push_substitution(L, to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3)));
else
return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3)));
} else {
if (is_level(L, 2))
return push_substitution(L, to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3)));
else
return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3)));
}
} else {
if (is_expr(L, 3)) {
if (is_expr(L, 2))
return push_substitution(L, to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3), to_justification(L, 4)));
else
return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3), to_justification(L, 4)));
} else {
if (is_level(L, 2))
return push_substitution(L, to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3), to_justification(L, 4)));
else
return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3), to_justification(L, 4)));
}
}
}
static int subst_is_assigned(lua_State * L) {
if (is_expr(L, 2))
return pushboolean(L, to_substitution(L, 1).is_assigned(to_expr(L, 2)));
else
return pushboolean(L, to_substitution(L, 1).is_assigned(to_level(L, 2)));
}
static int subst_is_expr_assigned(lua_State * L) { return pushboolean(L, to_substitution(L, 1).is_expr_assigned(to_name_ext(L, 2))); }
static int subst_is_level_assigned(lua_State * L) { return pushboolean(L, to_substitution(L, 1).is_level_assigned(to_name_ext(L, 2))); }
static int subst_occurs(lua_State * L) { return pushboolean(L, to_substitution(L, 1).occurs(to_expr(L, 2), to_expr(L, 3))); }
static int subst_occurs_expr(lua_State * L) { return pushboolean(L, to_substitution(L, 1).occurs_expr(to_name_ext(L, 2), to_expr(L, 3))); }
static int subst_get_expr_assignment(lua_State * L) {
auto r = to_substitution(L, 1).get_expr_assignment(to_name_ext(L, 2));
if (r) {
push_expr(L, r->first);
push_justification(L, r->second);
} else {
pushnil(L); pushnil(L);
}
return 2;
}
static int subst_get_level_assignment(lua_State * L) {
auto r = to_substitution(L, 1).get_level_assignment(to_name_ext(L, 2));
if (r) {
push_level(L, r->first);
push_justification(L, r->second);
} else {
pushnil(L); pushnil(L);
}
return 2;
}
static int subst_get_assignment(lua_State * L) {
if (is_expr(L, 2)) {
auto r = to_substitution(L, 1).get_assignment(to_expr(L, 2));
if (r) {
push_expr(L, r->first);
push_justification(L, r->second);
} else {
pushnil(L); pushnil(L);
}
} else {
auto r = to_substitution(L, 1).get_assignment(to_level(L, 2));
if (r) {
push_level(L, r->first);
push_justification(L, r->second);
} else {
pushnil(L); pushnil(L);
}
}
return 2;
}
static int subst_instantiate(lua_State * L) {
if (is_expr(L, 2)) {
auto r = to_substitution(L, 1).instantiate_metavars(to_expr(L, 2));
push_expr(L, r.first); push_justification(L, r.second);
} else {
auto r = to_substitution(L, 1).instantiate_metavars(to_level(L, 2));
push_level(L, r.first); push_justification(L, r.second);
}
return 2;
}
static const struct luaL_Reg substitution_m[] = {
{"__gc", substitution_gc},
{"get_expr", safe_function<subst_get_expr>},
{"get_level", safe_function<subst_get_level>},
{"assign", safe_function<subst_assign>},
{"is_assigned", safe_function<subst_is_assigned>},
{"is_expr_assigned", safe_function<subst_is_expr_assigned>},
{"is_level_assigned", safe_function<subst_is_level_assigned>},
{"occurs", safe_function<subst_occurs>},
{"occurs_expr", safe_function<subst_occurs_expr>},
{"get_expr_assignment", safe_function<subst_get_expr_assignment>},
{"get_level_assignment", safe_function<subst_get_level_assignment>},
{"get_assignment", safe_function<subst_get_assignment>},
{"instantiate", safe_function<subst_instantiate>},
{0, 0}
};
static void open_substitution(lua_State * L) {
luaL_newmetatable(L, substitution_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, substitution_m, 0);
SET_GLOBAL_FUN(mk_substitution, "substitution");
SET_GLOBAL_FUN(substitution_pred, "is_substitution");
}
void open_kernel_module(lua_State * L) {
// TODO(Leo)
open_level(L);
@ -837,113 +1113,14 @@ void open_kernel_module(lua_State * L) {
open_formatter(L);
open_environment(L);
open_io_state(L);
open_justification(L);
open_substitution(L);
}
}
#if 0
namespace lean {
static const struct luaL_Reg expr_m[] = {
{"__gc", expr_gc}, // never throws
{"__tostring", safe_function<expr_tostring>},
{"__eq", safe_function<expr_eq>},
{"__lt", safe_function<expr_lt>},
{"__call", safe_function<expr_mk_app>},
{"kind", safe_function<expr_get_kind>},
{"is_var", safe_function<expr_is_var>},
{"is_constant", safe_function<expr_is_constant>},
{"is_app", safe_function<expr_is_app>},
{"is_lambda", safe_function<expr_is_lambda>},
{"is_pi", safe_function<expr_is_pi>},
{"is_abstraction", safe_function<expr_is_abstraction>},
{"is_let", safe_function<expr_is_let>},
{"is_value", safe_function<expr_is_value>},
{"is_metavar", safe_function<expr_is_metavar>},
{"fields", safe_function<expr_fields>},
{"data", safe_function<expr_fields>},
{"args", safe_function<expr_args>},
{"num_args", safe_function<expr_num_args>},
{"depth", safe_function<expr_depth>},
{"arg", safe_function<expr_arg>},
{"abst_name", safe_function<expr_abst_name>},
{"abst_domain", safe_function<expr_abst_domain>},
{"abst_body", safe_function<expr_abst_body>},
{"for_each", safe_function<expr_for_each>},
{"has_free_vars", safe_function<expr_has_free_vars>},
{"closed", safe_function<expr_closed>},
{"has_free_var", safe_function<expr_has_free_var>},
{"lift_free_vars", safe_function<expr_lift_free_vars>},
{"lower_free_vars", safe_function<expr_lower_free_vars>},
{"instantiate", safe_function<expr_instantiate>},
{"beta_reduce", safe_function<expr_beta_reduce>},
{"head_beta_reduce", safe_function<expr_head_beta_reduce>},
{"abstract", safe_function<expr_abstract>},
{"occurs", safe_function<expr_occurs>},
{"has_metavar", safe_function<expr_has_metavar>},
{"is_eqp", safe_function<expr_is_eqp>},
{"is_lt", safe_function<expr_is_lt>},
{"hash", safe_function<expr_hash>},
{"is_not", safe_function<expr_is_not>},
{"is_and", safe_function<expr_is_and>},
{"is_or", safe_function<expr_is_or>},
{"is_implies", safe_function<expr_is_implies>},
{"is_exists", safe_function<expr_is_exists>},
{"is_eq", safe_function<expr_is_eq>},
{0, 0}
};
static void expr_migrate(lua_State * src, int i, lua_State * tgt) {
push_expr(tgt, to_expr(src, i));
}
static void open_expr(lua_State * L) {
luaL_newmetatable(L, expr_mt);
set_migrate_fn_field(L, -1, expr_migrate);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, expr_m, 0);
SET_GLOBAL_FUN(expr_mk_constant, "mk_constant");
SET_GLOBAL_FUN(expr_mk_constant, "Const");
SET_GLOBAL_FUN(expr_mk_var, "mk_var");
SET_GLOBAL_FUN(expr_mk_var, "Var");
SET_GLOBAL_FUN(expr_mk_app, "mk_app");
SET_GLOBAL_FUN(expr_mk_lambda, "mk_lambda");
SET_GLOBAL_FUN(expr_mk_pi, "mk_pi");
SET_GLOBAL_FUN(expr_mk_arrow, "mk_arrow");
SET_GLOBAL_FUN(expr_mk_let, "mk_let");
SET_GLOBAL_FUN(expr_fun, "fun");
SET_GLOBAL_FUN(expr_fun, "Fun");
SET_GLOBAL_FUN(expr_pi, "Pi");
SET_GLOBAL_FUN(expr_let, "Let");
SET_GLOBAL_FUN(expr_type, "mk_type");
SET_GLOBAL_FUN(expr_mk_eq, "mk_eq");
SET_GLOBAL_FUN(expr_type, "Type");
SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar");
SET_GLOBAL_FUN(expr_pred, "is_expr");
lua_newtable(L);
SET_ENUM("Var", expr_kind::Var);
SET_ENUM("Constant", expr_kind::Constant);
SET_ENUM("Type", expr_kind::Type);
SET_ENUM("Value", expr_kind::Value);
SET_ENUM("Pair", expr_kind::Pair);
SET_ENUM("Proj", expr_kind::Proj);
SET_ENUM("App", expr_kind::App);
SET_ENUM("Sigma", expr_kind::Sigma);
SET_ENUM("Lambda", expr_kind::Lambda);
SET_ENUM("Pi", expr_kind::Pi);
SET_ENUM("Let", expr_kind::Let);
SET_ENUM("HEq", expr_kind::HEq);
SET_ENUM("MetaVar", expr_kind::MetaVar);
lua_setglobal(L, "expr_kind");
}
DECL_UDATA(object)
int push_optional_object(lua_State * L, optional<object> const & o) {
@ -1083,145 +1260,6 @@ static void open_object(lua_State * L) {
SET_GLOBAL_FUN(object_pred, "is_kernel_object");
}
DECL_UDATA(justification)
int push_optional_justification(lua_State * L, optional<justification> const & j) {
if (j)
push_justification(L, *j);
else
lua_pushnil(L);
return 1;
}
static int justification_tostring(lua_State * L) {
std::ostringstream out;
justification & jst = to_justification(L, 1);
if (jst) {
formatter fmt = get_global_formatter(L);
options opts = get_global_options(L);
out << mk_pair(jst.pp(fmt, opts), opts);
} else {
out << "<null-justification>";
}
lua_pushstring(L, out.str().c_str());
return 1;
}
static int justification_has_children(lua_State * L) {
lua_pushboolean(L, to_justification(L, 1).has_children());
return 1;
}
static int justification_is_null(lua_State * L) {
lua_pushboolean(L, !to_justification(L, 1));
return 1;
}
/**
\brief Iterator (closure base function) for justification children. See \c justification_children
*/
static int justification_next_child(lua_State * L) {
unsigned i = lua_tointeger(L, lua_upvalueindex(2));
unsigned num = objlen(L, lua_upvalueindex(1));
if (i > num) {
lua_pushnil(L);
} else {
lua_pushinteger(L, i + 1);
lua_replace(L, lua_upvalueindex(2)); // update i
lua_rawgeti(L, lua_upvalueindex(1), i); // read children[i]
}
return 1;
}
static int justification_children(lua_State * L) {
buffer<justification_cell*> children;
to_justification(L, 1).get_children(children);
lua_newtable(L);
int i = 1;
for (auto jcell : children) {
push_justification(L, justification(jcell));
lua_rawseti(L, -2, i);
i = i + 1;
}
lua_pushinteger(L, 1);
lua_pushcclosure(L, &safe_function<justification_next_child>, 2); // create closure with 2 upvalues
return 1;
}
static int justification_get_main_expr(lua_State * L) {
optional<expr> r = to_justification(L, 1).get_main_expr();
if (r)
push_expr(L, *r);
else
lua_pushnil(L);
return 1;
}
static int justification_pp(lua_State * L) {
int nargs = lua_gettop(L);
justification & jst = to_justification(L, 1);
formatter fmt = get_global_formatter(L);
options opts = get_global_options(L);
bool display_children = true;
if (nargs == 2) {
if (lua_isboolean(L, 2)) {
display_children = lua_toboolean(L, 2);
} else {
luaL_checktype(L, 2, LUA_TTABLE);
lua_pushstring(L, "formatter");
lua_gettable(L, 2);
if (is_formatter(L, -1))
fmt = to_formatter(L, -1);
lua_pop(L, 1);
lua_pushstring(L, "options");
lua_gettable(L, 2);
if (is_options(L, -1))
opts = to_options(L, -1);
lua_pop(L, 1);
lua_pushstring(L, "display_children");
lua_gettable(L, 2);
if (lua_isboolean(L, -1))
display_children = lua_toboolean(L, -1);
lua_pop(L, 1);
}
}
return push_format(L, jst.pp(fmt, opts, nullptr, display_children));
}
static int justification_depends_on(lua_State * L) {
lua_pushboolean(L, depends_on(to_justification(L, 1), to_justification(L, 2)));
return 1;
}
static int mk_assumption_justification(lua_State * L) {
return push_justification(L, mk_assumption_justification(luaL_checkinteger(L, 1)));
}
static const struct luaL_Reg justification_m[] = {
{"__gc", justification_gc}, // never throws
{"__tostring", safe_function<justification_tostring>},
{"is_null", safe_function<justification_is_null>},
{"has_children", safe_function<justification_has_children>},
{"children", safe_function<justification_children>},
{"get_main_expr", safe_function<justification_get_main_expr>},
{"pp", safe_function<justification_pp>},
{"depends_on", safe_function<justification_depends_on>},
{0, 0}
};
static void open_justification(lua_State * L) {
luaL_newmetatable(L, justification_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, justification_m, 0);
SET_GLOBAL_FUN(mk_assumption_justification, "mk_assumption_justification");
SET_GLOBAL_FUN(justification_pred, "is_justification");
}
DECL_UDATA(metavar_env)

View file

@ -11,10 +11,11 @@ Author: Leonardo de Moura
namespace lean {
void open_kernel_module(lua_State * L);
UDATA_DEFS(level)
UDATA_DEFS(expr);
UDATA_DEFS(expr)
UDATA_DEFS(formatter)
UDATA_DEFS(definition)
UDATA_DEFS(environment)
UDATA_DEFS(substitution)
UDATA_DEFS(justification)
UDATA_DEFS(constraint)
UDATA_DEFS(substitution)

35
tests/lua/subst1.lua Normal file
View file

@ -0,0 +1,35 @@
local m = mk_metavar("m", Bool)
local s = substitution()
assert(not s:is_assigned(m))
assert(not s:is_expr_assigned("m"))
assert(not s:is_level_assigned("m"))
local f = Const("f")
local g = Const("g")
local a = Const("a")
local t = f(f(a))
s = s:assign(m, t)
assert(s:is_assigned(m))
assert(s:is_expr_assigned("m"))
assert(not s:is_level_assigned("m"))
assert(s:instantiate(g(m)) == g(t))
s = s:assign("m", a)
assert(s:instantiate(g(m)) == g(a))
local l = mk_level_one()
local u = mk_meta_univ("u")
s = s:assign(u, l)
assert(s:is_assigned(u))
assert(s:is_level_assigned("u"))
assert(not s:is_expr_assigned("u"))
assert(s:get_expr("m") == a)
local m2 = mk_metavar("m2", Bool)
s = s:assign(m2, f(m))
print(s:get_expr("m2"))
assert(s:occurs(m, f(m2)))
assert(s:occurs_expr("m", f(m2)))
print(s:get_level("u"))
print(s:instantiate(mk_sort(u)))
assert(s:instantiate(mk_sort(u)) == mk_sort(l))
assert(s:get_assignment(m) == a)
assert(s:get_assignment(u) == l)
assert(s:get_expr_assignment("m") == a)
assert(s:get_level_assignment("u") == l)