feat(library/kernel_bindings): expose replace_fn in the Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-04 15:26:55 -07:00
parent a522398194
commit 1c96373c1a
2 changed files with 33 additions and 1 deletions

View file

@ -19,6 +19,7 @@ Author: Leonardo de Moura
#include "kernel/metavar.h" #include "kernel/metavar.h"
#include "kernel/error_msgs.h" #include "kernel/error_msgs.h"
#include "kernel/type_checker.h" #include "kernel/type_checker.h"
#include "kernel/replace_fn.h"
#include "kernel/inductive/inductive.h" #include "kernel/inductive/inductive.h"
#include "kernel/standard/standard.h" #include "kernel/standard/standard.h"
#include "kernel/hott/hott.h" #include "kernel/hott/hott.h"
@ -593,7 +594,7 @@ static int expr_fn(lua_State * L) { return push_expr(L, app_fn(to_app(L, 1))); }
static int expr_arg(lua_State * L) { return push_expr(L, app_arg(to_app(L, 1))); } static int expr_arg(lua_State * L) { return push_expr(L, app_arg(to_app(L, 1))); }
static int expr_for_each(lua_State * L) { static int expr_for_each(lua_State * L) {
expr & e = to_expr(L, 1); // expr expr const & e = to_expr(L, 1); // expr
luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun
for_each(e, [&](expr const & a, unsigned offset) { for_each(e, [&](expr const & a, unsigned offset) {
lua_pushvalue(L, 2); // push user-fun lua_pushvalue(L, 2); // push user-fun
@ -609,6 +610,26 @@ static int expr_for_each(lua_State * L) {
return 0; return 0;
} }
static int expr_replace(lua_State * L) {
expr const & e = to_expr(L, 1);
luaL_checktype(L, 2, LUA_TFUNCTION);
expr r = replace(e, [&](expr const & a, unsigned offset) {
lua_pushvalue(L, 2);
push_expr(L, a);
lua_pushinteger(L, offset);
pcall(L, 2, 1, 0);
if (is_expr(L, -1)) {
expr r = to_expr(L, -1);
lua_pop(L, 1);
return some_expr(r);
} else {
lua_pop(L, 1);
return none_expr();
}
});
return push_expr(L, r);
}
static int expr_has_free_var(lua_State * L) { static int expr_has_free_var(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
if (nargs == 2) if (nargs == 2)
@ -753,6 +774,7 @@ static const struct luaL_Reg expr_m[] = {
{"macro_num_args", safe_function<macro_num_args>}, {"macro_num_args", safe_function<macro_num_args>},
{"macro_arg", safe_function<macro_arg>}, {"macro_arg", safe_function<macro_arg>},
{"for_each", safe_function<expr_for_each>}, {"for_each", safe_function<expr_for_each>},
{"replace", safe_function<expr_replace>},
{"has_free_var", safe_function<expr_has_free_var>}, {"has_free_var", safe_function<expr_has_free_var>},
{"lift_free_vars", safe_function<expr_lift_free_vars>}, {"lift_free_vars", safe_function<expr_lift_free_vars>},
{"lower_free_vars", safe_function<expr_lower_free_vars>}, {"lower_free_vars", safe_function<expr_lower_free_vars>},

10
tests/lua/replace1.lua Normal file
View file

@ -0,0 +1,10 @@
local f = Const("f")
local a = Const("a")
local b = Const("b")
local t = f(a, f(a))
local new_t = t:replace(function(e)
if e == a then
return b
end
end)
print(new_t)