diff --git a/src/bindings/lua/expr.cpp b/src/bindings/lua/expr.cpp index ee54747e1..df9c9d3e8 100644 --- a/src/bindings/lua/expr.cpp +++ b/src/bindings/lua/expr.cpp @@ -20,12 +20,16 @@ Author: Leonardo de Moura #include "kernel/occurs.h" #include "kernel/metavar.h" #include "library/expr_lt.h" +#include "library/arith/nat.h" +#include "library/arith/int.h" +#include "library/arith/real.h" #include "bindings/lua/util.h" #include "bindings/lua/name.h" #include "bindings/lua/options.h" #include "bindings/lua/level.h" #include "bindings/lua/local_context.h" #include "bindings/lua/formatter.h" +#include "bindings/lua/numerics.h" namespace lean { constexpr char const * expr_mt = "expr.mt"; @@ -424,6 +428,18 @@ static const struct luaL_Reg expr_m[] = { {0, 0} }; +static int mk_nat_value(lua_State * L) { + return push_expr(L, mk_nat_value(to_mpz_ext(L, 1))); +} + +static int mk_int_value(lua_State * L) { + return push_expr(L, mk_int_value(to_mpz_ext(L, 1))); +} + +static int mk_real_value(lua_State * L) { + return push_expr(L, mk_real_value(to_mpq_ext(L, 1))); +} + void open_expr(lua_State * L) { luaL_newmetatable(L, expr_mt); lua_pushvalue(L, -1); @@ -449,6 +465,12 @@ void open_expr(lua_State * L) { SET_GLOBAL_FUN(expr_type, "Type"); SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar"); SET_GLOBAL_FUN(expr_pred, "is_expr"); + SET_GLOBAL_FUN(mk_nat_value, "mk_nat_value"); + SET_GLOBAL_FUN(mk_nat_value, "nVal"); + SET_GLOBAL_FUN(mk_int_value, "mk_int_value"); + SET_GLOBAL_FUN(mk_int_value, "iVal"); + SET_GLOBAL_FUN(mk_real_value, "mk_real_value"); + SET_GLOBAL_FUN(mk_real_value, "rVal"); lua_newtable(L); SET_ENUM("Var", expr_kind::Var); diff --git a/src/bindings/lua/name.cpp b/src/bindings/lua/name.cpp index 7e13f5da3..82e42774a 100644 --- a/src/bindings/lua/name.cpp +++ b/src/bindings/lua/name.cpp @@ -21,10 +21,28 @@ name & to_name(lua_State * L, int idx) { } name to_name_ext(lua_State * L, int idx) { - if (lua_isstring(L, idx)) + if (lua_isstring(L, idx)) { return luaL_checkstring(L, idx); - else + } else if (lua_istable(L, idx)) { + name r; + int n = objlen(L, idx); + for (int i = 1; i <= n; i++) { + lua_rawgeti(L, idx, i); + if (lua_isnil(L, -1)) { + // skip + } else if (lua_isuserdata(L, -1)) { + r = r + to_name(L, -1); + } else if (lua_isstring(L, -1)) { + r = name(r, luaL_checkstring(L, -1)); + } else { + r = name(r, luaL_checkinteger(L, -1)); + } + lua_pop(L, 1); + } + return r; + } else { return to_name(L, idx); + } } int push_name(lua_State * L, name const & n) { diff --git a/src/bindings/lua/numerics.cpp b/src/bindings/lua/numerics.cpp index fd3793557..857ce8ec6 100644 --- a/src/bindings/lua/numerics.cpp +++ b/src/bindings/lua/numerics.cpp @@ -20,8 +20,7 @@ static mpz const & to_mpz(lua_State * L) { if (lua_isuserdata(L, idx)) { return *static_cast(luaL_checkudata(L, idx, mpz_mt)); } else if (lua_isstring(L, idx)) { - char const * str = luaL_checkstring(L, idx); - arg = mpz(str); + arg = mpz(luaL_checkstring(L, idx)); return arg; } else { arg = static_cast(luaL_checkinteger(L, 1)); @@ -37,6 +36,16 @@ mpz & to_mpz(lua_State * L, int idx) { return *static_cast(luaL_checkudata(L, idx, mpz_mt)); } +mpz to_mpz_ext(lua_State * L, int idx) { + if (lua_isuserdata(L, idx)) { + return *static_cast(luaL_checkudata(L, idx, mpz_mt)); + } else if (lua_isstring(L, idx)) { + return mpz(luaL_checkstring(L, idx)); + } else { + return mpz(static_cast(luaL_checkinteger(L, 1))); + } +} + int push_mpz(lua_State * L, mpz const & val) { void * mem = lua_newuserdata(L, sizeof(mpz)); new (mem) mpz(val); @@ -135,15 +144,17 @@ template static mpq const & to_mpq(lua_State * L) { static thread_local mpq arg; if (lua_isuserdata(L, idx)) { - return *static_cast(luaL_checkudata(L, idx, mpq_mt)); + if (is_mpz(L, idx)) { + arg = mpq(to_mpz(L)); + } else { + return *static_cast(luaL_checkudata(L, idx, mpq_mt)); + } } else if (lua_isstring(L, idx)) { - char const * str = luaL_checkstring(L, idx); - arg = mpq(str); - return arg; + arg = mpq(luaL_checkstring(L, idx)); } else { arg = static_cast(luaL_checkinteger(L, 1)); - return arg; } + return arg; } bool is_mpq(lua_State * L, int idx) { @@ -154,6 +165,20 @@ mpq & to_mpq(lua_State * L, int idx) { return *static_cast(luaL_checkudata(L, idx, mpq_mt)); } +mpq to_mpq_ext(lua_State * L, int idx) { + if (lua_isuserdata(L, idx)) { + if (is_mpz(L, idx)) { + return mpq(to_mpz(L, idx)); + } else { + return *static_cast(luaL_checkudata(L, idx, mpq_mt)); + } + } else if (lua_isstring(L, idx)) { + return mpq(luaL_checkstring(L, idx)); + } else { + return mpq(static_cast(luaL_checkinteger(L, 1))); + } +} + int push_mpq(lua_State * L, mpq const & val) { void * mem = lua_newuserdata(L, sizeof(mpq)); new (mem) mpq(val); diff --git a/src/bindings/lua/numerics.h b/src/bindings/lua/numerics.h index ac63d2bb5..e2596f641 100644 --- a/src/bindings/lua/numerics.h +++ b/src/bindings/lua/numerics.h @@ -11,11 +11,13 @@ class mpz; void open_mpz(lua_State * L); bool is_mpz(lua_State * L, int idx); mpz & to_mpz(lua_State * L, int idx); +mpz to_mpz_ext(lua_State * L, int idx); int push_mpz(lua_State * L, mpz const & val); class mpq; void open_mpq(lua_State * L); bool is_mpq(lua_State * L, int idx); mpq & to_mpq(lua_State * L, int idx); +mpq to_mpq_ext(lua_State * L, int idx); int push_mpq(lua_State * L, mpq const & val); } diff --git a/tests/lean/lua12.lean b/tests/lean/lua12.lean new file mode 100644 index 000000000..d95f01fa4 --- /dev/null +++ b/tests/lean/lua12.lean @@ -0,0 +1,14 @@ +Variables x y z : Int + +(** + + local env = get_environment() + local plus = Const{"Int", "add"} + local x, y = Consts("x y") + local def = plus(plus(x, y), iVal(1000)) + print(def, ":", env:check_type(def)) + env:add_definition("sum", def) + +**) + +Eval sum + 3