From 1c96373c1a9c7d7f450f30929f7f5d091b9854d8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 4 Jun 2014 15:26:55 -0700 Subject: [PATCH] feat(library/kernel_bindings): expose replace_fn in the Lua API Signed-off-by: Leonardo de Moura --- src/library/kernel_bindings.cpp | 24 +++++++++++++++++++++++- tests/lua/replace1.lua | 10 ++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/lua/replace1.lua diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index cc606bc53..1d25a8de3 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -19,6 +19,7 @@ Author: Leonardo de Moura #include "kernel/metavar.h" #include "kernel/error_msgs.h" #include "kernel/type_checker.h" +#include "kernel/replace_fn.h" #include "kernel/inductive/inductive.h" #include "kernel/standard/standard.h" #include "kernel/hott/hott.h" @@ -593,7 +594,7 @@ static int expr_fn(lua_State * L) { return push_expr(L, app_fn(to_app(L, 1))); } static int expr_arg(lua_State * L) { return push_expr(L, app_arg(to_app(L, 1))); } static int expr_for_each(lua_State * L) { - expr & e = to_expr(L, 1); // expr + expr const & e = to_expr(L, 1); // expr luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun for_each(e, [&](expr const & a, unsigned offset) { lua_pushvalue(L, 2); // push user-fun @@ -609,6 +610,26 @@ static int expr_for_each(lua_State * L) { return 0; } +static int expr_replace(lua_State * L) { + expr const & e = to_expr(L, 1); + luaL_checktype(L, 2, LUA_TFUNCTION); + expr r = replace(e, [&](expr const & a, unsigned offset) { + lua_pushvalue(L, 2); + push_expr(L, a); + lua_pushinteger(L, offset); + pcall(L, 2, 1, 0); + if (is_expr(L, -1)) { + expr r = to_expr(L, -1); + lua_pop(L, 1); + return some_expr(r); + } else { + lua_pop(L, 1); + return none_expr(); + } + }); + return push_expr(L, r); +} + static int expr_has_free_var(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 2) @@ -753,6 +774,7 @@ static const struct luaL_Reg expr_m[] = { {"macro_num_args", safe_function}, {"macro_arg", safe_function}, {"for_each", safe_function}, + {"replace", safe_function}, {"has_free_var", safe_function}, {"lift_free_vars", safe_function}, {"lower_free_vars", safe_function}, diff --git a/tests/lua/replace1.lua b/tests/lua/replace1.lua new file mode 100644 index 000000000..1d01377bf --- /dev/null +++ b/tests/lua/replace1.lua @@ -0,0 +1,10 @@ +local f = Const("f") +local a = Const("a") +local b = Const("b") +local t = f(a, f(a)) +local new_t = t:replace(function(e) + if e == a then + return b + end + end) +print(new_t)