From 340c0e094569ea55a634c75a56b0e436c58555e5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 1 May 2014 15:30:30 -0700 Subject: [PATCH] feat(library/kernel_bindings): substitution Lua API Signed-off-by: Leonardo de Moura --- src/library/kernel_bindings.cpp | 518 +++++++++++++++++--------------- src/library/kernel_bindings.h | 3 +- tests/lua/subst1.lua | 35 +++ 3 files changed, 315 insertions(+), 241 deletions(-) create mode 100644 tests/lua/subst1.lua diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 6fcdfc01f..87c6443ed 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -13,6 +13,7 @@ Author: Leonardo de Moura #include "kernel/for_each_fn.h" #include "kernel/free_vars.h" #include "kernel/instantiate.h" +#include "kernel/metavar.h" #include "library/occurs.h" #include "library/io_state_stream.h" #include "library/expr_lt.h" @@ -29,6 +30,8 @@ io_state * get_io_state(lua_State * L); DECL_UDATA(level) DEFINE_LUA_LIST(level, push_level, to_level) +int push_optional_level(lua_State * L, optional const & l) { return l ? push_level(L, *l) : pushnil(L); } + static int level_tostring(lua_State * L) { std::ostringstream out; options opts = get_global_options(L); @@ -827,6 +830,279 @@ io_state * get_io_state(lua_State * L) { return nullptr; } +DECL_UDATA(justification) + +int push_optional_justification(lua_State * L, optional const & j) { + if (j) + push_justification(L, *j); + else + lua_pushnil(L); + return 1; +} + +#if 0 +static int justification_tostring(lua_State * L) { + std::ostringstream out; + justification & jst = to_justification(L, 1); + if (jst) { + formatter fmt = get_global_formatter(L); + options opts = get_global_options(L); + out << mk_pair(jst.pp(fmt, opts), opts); + } else { + out << ""; + } + lua_pushstring(L, out.str().c_str()); + return 1; +} + +static int justification_has_children(lua_State * L) { + lua_pushboolean(L, to_justification(L, 1).has_children()); + return 1; +} + +static int justification_is_null(lua_State * L) { + lua_pushboolean(L, !to_justification(L, 1)); + return 1; +} + +/** + \brief Iterator (closure base function) for justification children. See \c justification_children +*/ +static int justification_next_child(lua_State * L) { + unsigned i = lua_tointeger(L, lua_upvalueindex(2)); + unsigned num = objlen(L, lua_upvalueindex(1)); + if (i > num) { + lua_pushnil(L); + } else { + lua_pushinteger(L, i + 1); + lua_replace(L, lua_upvalueindex(2)); // update i + lua_rawgeti(L, lua_upvalueindex(1), i); // read children[i] + } + return 1; +} + +static int justification_children(lua_State * L) { + buffer children; + to_justification(L, 1).get_children(children); + lua_newtable(L); + int i = 1; + for (auto jcell : children) { + push_justification(L, justification(jcell)); + lua_rawseti(L, -2, i); + i = i + 1; + } + lua_pushinteger(L, 1); + lua_pushcclosure(L, &safe_function, 2); // create closure with 2 upvalues + return 1; +} + +static int justification_get_main_expr(lua_State * L) { + optional r = to_justification(L, 1).get_main_expr(); + if (r) + push_expr(L, *r); + else + lua_pushnil(L); + return 1; +} + +static int justification_pp(lua_State * L) { + int nargs = lua_gettop(L); + justification & jst = to_justification(L, 1); + formatter fmt = get_global_formatter(L); + options opts = get_global_options(L); + bool display_children = true; + + if (nargs == 2) { + if (lua_isboolean(L, 2)) { + display_children = lua_toboolean(L, 2); + } else { + luaL_checktype(L, 2, LUA_TTABLE); + + lua_pushstring(L, "formatter"); + lua_gettable(L, 2); + if (is_formatter(L, -1)) + fmt = to_formatter(L, -1); + lua_pop(L, 1); + + lua_pushstring(L, "options"); + lua_gettable(L, 2); + if (is_options(L, -1)) + opts = to_options(L, -1); + lua_pop(L, 1); + + lua_pushstring(L, "display_children"); + lua_gettable(L, 2); + if (lua_isboolean(L, -1)) + display_children = lua_toboolean(L, -1); + lua_pop(L, 1); + } + } + return push_format(L, jst.pp(fmt, opts, nullptr, display_children)); +} + +static int justification_depends_on(lua_State * L) { + lua_pushboolean(L, depends_on(to_justification(L, 1), to_justification(L, 2))); + return 1; +} + +static int mk_assumption_justification(lua_State * L) { + return push_justification(L, mk_assumption_justification(luaL_checkinteger(L, 1))); +} +#endif + +static const struct luaL_Reg justification_m[] = { + {"__gc", justification_gc}, // never throws + // {"__tostring", safe_function}, + // {"is_null", safe_function}, + // {"has_children", safe_function}, + // {"children", safe_function}, + // {"get_main_expr", safe_function}, + // {"pp", safe_function}, + // {"depends_on", safe_function}, + {0, 0} +}; + +static void open_justification(lua_State * L) { + luaL_newmetatable(L, justification_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, justification_m, 0); + + // SET_GLOBAL_FUN(mk_assumption_justification, "mk_assumption_justification"); + SET_GLOBAL_FUN(justification_pred, "is_justification"); +} + +// Substitution +DECL_UDATA(substitution) +static int mk_substitution(lua_State * L) { return push_substitution(L, substitution()); } +static int subst_get_expr(lua_State * L) { + if (is_expr(L, 2)) + return push_optional_expr(L, to_substitution(L, 1).get_expr(to_expr(L, 2))); + else + return push_optional_expr(L, to_substitution(L, 1).get_expr(to_name_ext(L, 2))); +} +static int subst_get_level(lua_State * L) { + if (is_level(L, 2)) + return push_optional_level(L, to_substitution(L, 1).get_level(to_level(L, 2))); + else + return push_optional_level(L, to_substitution(L, 1).get_level(to_name_ext(L, 2))); +} +static int subst_assign(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 3) { + if (is_expr(L, 3)) { + if (is_expr(L, 2)) + return push_substitution(L, to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3))); + else + return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3))); + } else { + if (is_level(L, 2)) + return push_substitution(L, to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3))); + else + return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3))); + } + } else { + if (is_expr(L, 3)) { + if (is_expr(L, 2)) + return push_substitution(L, to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3), to_justification(L, 4))); + else + return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3), to_justification(L, 4))); + } else { + if (is_level(L, 2)) + return push_substitution(L, to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3), to_justification(L, 4))); + else + return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3), to_justification(L, 4))); + } + } +} +static int subst_is_assigned(lua_State * L) { + if (is_expr(L, 2)) + return pushboolean(L, to_substitution(L, 1).is_assigned(to_expr(L, 2))); + else + return pushboolean(L, to_substitution(L, 1).is_assigned(to_level(L, 2))); +} +static int subst_is_expr_assigned(lua_State * L) { return pushboolean(L, to_substitution(L, 1).is_expr_assigned(to_name_ext(L, 2))); } +static int subst_is_level_assigned(lua_State * L) { return pushboolean(L, to_substitution(L, 1).is_level_assigned(to_name_ext(L, 2))); } +static int subst_occurs(lua_State * L) { return pushboolean(L, to_substitution(L, 1).occurs(to_expr(L, 2), to_expr(L, 3))); } +static int subst_occurs_expr(lua_State * L) { return pushboolean(L, to_substitution(L, 1).occurs_expr(to_name_ext(L, 2), to_expr(L, 3))); } +static int subst_get_expr_assignment(lua_State * L) { + auto r = to_substitution(L, 1).get_expr_assignment(to_name_ext(L, 2)); + if (r) { + push_expr(L, r->first); + push_justification(L, r->second); + } else { + pushnil(L); pushnil(L); + } + return 2; +} +static int subst_get_level_assignment(lua_State * L) { + auto r = to_substitution(L, 1).get_level_assignment(to_name_ext(L, 2)); + if (r) { + push_level(L, r->first); + push_justification(L, r->second); + } else { + pushnil(L); pushnil(L); + } + return 2; +} +static int subst_get_assignment(lua_State * L) { + if (is_expr(L, 2)) { + auto r = to_substitution(L, 1).get_assignment(to_expr(L, 2)); + if (r) { + push_expr(L, r->first); + push_justification(L, r->second); + } else { + pushnil(L); pushnil(L); + } + } else { + auto r = to_substitution(L, 1).get_assignment(to_level(L, 2)); + if (r) { + push_level(L, r->first); + push_justification(L, r->second); + } else { + pushnil(L); pushnil(L); + } + } + return 2; +} +static int subst_instantiate(lua_State * L) { + if (is_expr(L, 2)) { + auto r = to_substitution(L, 1).instantiate_metavars(to_expr(L, 2)); + push_expr(L, r.first); push_justification(L, r.second); + } else { + auto r = to_substitution(L, 1).instantiate_metavars(to_level(L, 2)); + push_level(L, r.first); push_justification(L, r.second); + } + return 2; +} + +static const struct luaL_Reg substitution_m[] = { + {"__gc", substitution_gc}, + {"get_expr", safe_function}, + {"get_level", safe_function}, + {"assign", safe_function}, + {"is_assigned", safe_function}, + {"is_expr_assigned", safe_function}, + {"is_level_assigned", safe_function}, + {"occurs", safe_function}, + {"occurs_expr", safe_function}, + {"get_expr_assignment", safe_function}, + {"get_level_assignment", safe_function}, + {"get_assignment", safe_function}, + {"instantiate", safe_function}, + {0, 0} +}; + +static void open_substitution(lua_State * L) { + luaL_newmetatable(L, substitution_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, substitution_m, 0); + + SET_GLOBAL_FUN(mk_substitution, "substitution"); + SET_GLOBAL_FUN(substitution_pred, "is_substitution"); +} + void open_kernel_module(lua_State * L) { // TODO(Leo) open_level(L); @@ -837,113 +1113,14 @@ void open_kernel_module(lua_State * L) { open_formatter(L); open_environment(L); open_io_state(L); + open_justification(L); + open_substitution(L); } } #if 0 namespace lean { - - - - - -static const struct luaL_Reg expr_m[] = { - {"__gc", expr_gc}, // never throws - {"__tostring", safe_function}, - {"__eq", safe_function}, - {"__lt", safe_function}, - {"__call", safe_function}, - {"kind", safe_function}, - {"is_var", safe_function}, - {"is_constant", safe_function}, - {"is_app", safe_function}, - {"is_lambda", safe_function}, - {"is_pi", safe_function}, - {"is_abstraction", safe_function}, - {"is_let", safe_function}, - {"is_value", safe_function}, - {"is_metavar", safe_function}, - {"fields", safe_function}, - {"data", safe_function}, - {"args", safe_function}, - {"num_args", safe_function}, - {"depth", safe_function}, - {"arg", safe_function}, - {"abst_name", safe_function}, - {"abst_domain", safe_function}, - {"abst_body", safe_function}, - {"for_each", safe_function}, - {"has_free_vars", safe_function}, - {"closed", safe_function}, - {"has_free_var", safe_function}, - {"lift_free_vars", safe_function}, - {"lower_free_vars", safe_function}, - {"instantiate", safe_function}, - {"beta_reduce", safe_function}, - {"head_beta_reduce", safe_function}, - {"abstract", safe_function}, - {"occurs", safe_function}, - {"has_metavar", safe_function}, - {"is_eqp", safe_function}, - {"is_lt", safe_function}, - {"hash", safe_function}, - {"is_not", safe_function}, - {"is_and", safe_function}, - {"is_or", safe_function}, - {"is_implies", safe_function}, - {"is_exists", safe_function}, - {"is_eq", safe_function}, - {0, 0} -}; - -static void expr_migrate(lua_State * src, int i, lua_State * tgt) { - push_expr(tgt, to_expr(src, i)); -} - -static void open_expr(lua_State * L) { - luaL_newmetatable(L, expr_mt); - set_migrate_fn_field(L, -1, expr_migrate); - lua_pushvalue(L, -1); - lua_setfield(L, -2, "__index"); - setfuncs(L, expr_m, 0); - - SET_GLOBAL_FUN(expr_mk_constant, "mk_constant"); - SET_GLOBAL_FUN(expr_mk_constant, "Const"); - SET_GLOBAL_FUN(expr_mk_var, "mk_var"); - SET_GLOBAL_FUN(expr_mk_var, "Var"); - SET_GLOBAL_FUN(expr_mk_app, "mk_app"); - SET_GLOBAL_FUN(expr_mk_lambda, "mk_lambda"); - SET_GLOBAL_FUN(expr_mk_pi, "mk_pi"); - SET_GLOBAL_FUN(expr_mk_arrow, "mk_arrow"); - SET_GLOBAL_FUN(expr_mk_let, "mk_let"); - SET_GLOBAL_FUN(expr_fun, "fun"); - SET_GLOBAL_FUN(expr_fun, "Fun"); - SET_GLOBAL_FUN(expr_pi, "Pi"); - SET_GLOBAL_FUN(expr_let, "Let"); - SET_GLOBAL_FUN(expr_type, "mk_type"); - SET_GLOBAL_FUN(expr_mk_eq, "mk_eq"); - SET_GLOBAL_FUN(expr_type, "Type"); - SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar"); - SET_GLOBAL_FUN(expr_pred, "is_expr"); - - lua_newtable(L); - SET_ENUM("Var", expr_kind::Var); - SET_ENUM("Constant", expr_kind::Constant); - SET_ENUM("Type", expr_kind::Type); - SET_ENUM("Value", expr_kind::Value); - SET_ENUM("Pair", expr_kind::Pair); - SET_ENUM("Proj", expr_kind::Proj); - SET_ENUM("App", expr_kind::App); - SET_ENUM("Sigma", expr_kind::Sigma); - SET_ENUM("Lambda", expr_kind::Lambda); - SET_ENUM("Pi", expr_kind::Pi); - SET_ENUM("Let", expr_kind::Let); - SET_ENUM("HEq", expr_kind::HEq); - SET_ENUM("MetaVar", expr_kind::MetaVar); - lua_setglobal(L, "expr_kind"); -} - DECL_UDATA(object) int push_optional_object(lua_State * L, optional const & o) { @@ -1083,145 +1260,6 @@ static void open_object(lua_State * L) { SET_GLOBAL_FUN(object_pred, "is_kernel_object"); } -DECL_UDATA(justification) - -int push_optional_justification(lua_State * L, optional const & j) { - if (j) - push_justification(L, *j); - else - lua_pushnil(L); - return 1; -} - -static int justification_tostring(lua_State * L) { - std::ostringstream out; - justification & jst = to_justification(L, 1); - if (jst) { - formatter fmt = get_global_formatter(L); - options opts = get_global_options(L); - out << mk_pair(jst.pp(fmt, opts), opts); - } else { - out << ""; - } - lua_pushstring(L, out.str().c_str()); - return 1; -} - -static int justification_has_children(lua_State * L) { - lua_pushboolean(L, to_justification(L, 1).has_children()); - return 1; -} - -static int justification_is_null(lua_State * L) { - lua_pushboolean(L, !to_justification(L, 1)); - return 1; -} - -/** - \brief Iterator (closure base function) for justification children. See \c justification_children -*/ -static int justification_next_child(lua_State * L) { - unsigned i = lua_tointeger(L, lua_upvalueindex(2)); - unsigned num = objlen(L, lua_upvalueindex(1)); - if (i > num) { - lua_pushnil(L); - } else { - lua_pushinteger(L, i + 1); - lua_replace(L, lua_upvalueindex(2)); // update i - lua_rawgeti(L, lua_upvalueindex(1), i); // read children[i] - } - return 1; -} - -static int justification_children(lua_State * L) { - buffer children; - to_justification(L, 1).get_children(children); - lua_newtable(L); - int i = 1; - for (auto jcell : children) { - push_justification(L, justification(jcell)); - lua_rawseti(L, -2, i); - i = i + 1; - } - lua_pushinteger(L, 1); - lua_pushcclosure(L, &safe_function, 2); // create closure with 2 upvalues - return 1; -} - -static int justification_get_main_expr(lua_State * L) { - optional r = to_justification(L, 1).get_main_expr(); - if (r) - push_expr(L, *r); - else - lua_pushnil(L); - return 1; -} - -static int justification_pp(lua_State * L) { - int nargs = lua_gettop(L); - justification & jst = to_justification(L, 1); - formatter fmt = get_global_formatter(L); - options opts = get_global_options(L); - bool display_children = true; - - if (nargs == 2) { - if (lua_isboolean(L, 2)) { - display_children = lua_toboolean(L, 2); - } else { - luaL_checktype(L, 2, LUA_TTABLE); - - lua_pushstring(L, "formatter"); - lua_gettable(L, 2); - if (is_formatter(L, -1)) - fmt = to_formatter(L, -1); - lua_pop(L, 1); - - lua_pushstring(L, "options"); - lua_gettable(L, 2); - if (is_options(L, -1)) - opts = to_options(L, -1); - lua_pop(L, 1); - - lua_pushstring(L, "display_children"); - lua_gettable(L, 2); - if (lua_isboolean(L, -1)) - display_children = lua_toboolean(L, -1); - lua_pop(L, 1); - } - } - return push_format(L, jst.pp(fmt, opts, nullptr, display_children)); -} - -static int justification_depends_on(lua_State * L) { - lua_pushboolean(L, depends_on(to_justification(L, 1), to_justification(L, 2))); - return 1; -} - -static int mk_assumption_justification(lua_State * L) { - return push_justification(L, mk_assumption_justification(luaL_checkinteger(L, 1))); -} - -static const struct luaL_Reg justification_m[] = { - {"__gc", justification_gc}, // never throws - {"__tostring", safe_function}, - {"is_null", safe_function}, - {"has_children", safe_function}, - {"children", safe_function}, - {"get_main_expr", safe_function}, - {"pp", safe_function}, - {"depends_on", safe_function}, - {0, 0} -}; - -static void open_justification(lua_State * L) { - luaL_newmetatable(L, justification_mt); - lua_pushvalue(L, -1); - lua_setfield(L, -2, "__index"); - setfuncs(L, justification_m, 0); - - SET_GLOBAL_FUN(mk_assumption_justification, "mk_assumption_justification"); - SET_GLOBAL_FUN(justification_pred, "is_justification"); -} DECL_UDATA(metavar_env) diff --git a/src/library/kernel_bindings.h b/src/library/kernel_bindings.h index a35adae65..2f4346987 100644 --- a/src/library/kernel_bindings.h +++ b/src/library/kernel_bindings.h @@ -11,10 +11,11 @@ Author: Leonardo de Moura namespace lean { void open_kernel_module(lua_State * L); UDATA_DEFS(level) -UDATA_DEFS(expr); +UDATA_DEFS(expr) UDATA_DEFS(formatter) UDATA_DEFS(definition) UDATA_DEFS(environment) +UDATA_DEFS(substitution) UDATA_DEFS(justification) UDATA_DEFS(constraint) UDATA_DEFS(substitution) diff --git a/tests/lua/subst1.lua b/tests/lua/subst1.lua new file mode 100644 index 000000000..c7383a264 --- /dev/null +++ b/tests/lua/subst1.lua @@ -0,0 +1,35 @@ +local m = mk_metavar("m", Bool) +local s = substitution() +assert(not s:is_assigned(m)) +assert(not s:is_expr_assigned("m")) +assert(not s:is_level_assigned("m")) +local f = Const("f") +local g = Const("g") +local a = Const("a") +local t = f(f(a)) +s = s:assign(m, t) +assert(s:is_assigned(m)) +assert(s:is_expr_assigned("m")) +assert(not s:is_level_assigned("m")) +assert(s:instantiate(g(m)) == g(t)) +s = s:assign("m", a) +assert(s:instantiate(g(m)) == g(a)) +local l = mk_level_one() +local u = mk_meta_univ("u") +s = s:assign(u, l) +assert(s:is_assigned(u)) +assert(s:is_level_assigned("u")) +assert(not s:is_expr_assigned("u")) +assert(s:get_expr("m") == a) +local m2 = mk_metavar("m2", Bool) +s = s:assign(m2, f(m)) +print(s:get_expr("m2")) +assert(s:occurs(m, f(m2))) +assert(s:occurs_expr("m", f(m2))) +print(s:get_level("u")) +print(s:instantiate(mk_sort(u))) +assert(s:instantiate(mk_sort(u)) == mk_sort(l)) +assert(s:get_assignment(m) == a) +assert(s:get_assignment(u) == l) +assert(s:get_expr_assignment("m") == a) +assert(s:get_level_assignment("u") == l)