feat(lua): add semantic attachments for builtin arithmetical values to Lua API, improve mk_constant

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-14 15:15:04 -08:00
parent 05f254f605
commit 09bed4786c
5 changed files with 90 additions and 9 deletions

View file

@ -20,12 +20,16 @@ Author: Leonardo de Moura
#include "kernel/occurs.h" #include "kernel/occurs.h"
#include "kernel/metavar.h" #include "kernel/metavar.h"
#include "library/expr_lt.h" #include "library/expr_lt.h"
#include "library/arith/nat.h"
#include "library/arith/int.h"
#include "library/arith/real.h"
#include "bindings/lua/util.h" #include "bindings/lua/util.h"
#include "bindings/lua/name.h" #include "bindings/lua/name.h"
#include "bindings/lua/options.h" #include "bindings/lua/options.h"
#include "bindings/lua/level.h" #include "bindings/lua/level.h"
#include "bindings/lua/local_context.h" #include "bindings/lua/local_context.h"
#include "bindings/lua/formatter.h" #include "bindings/lua/formatter.h"
#include "bindings/lua/numerics.h"
namespace lean { namespace lean {
constexpr char const * expr_mt = "expr.mt"; constexpr char const * expr_mt = "expr.mt";
@ -424,6 +428,18 @@ static const struct luaL_Reg expr_m[] = {
{0, 0} {0, 0}
}; };
static int mk_nat_value(lua_State * L) {
return push_expr(L, mk_nat_value(to_mpz_ext(L, 1)));
}
static int mk_int_value(lua_State * L) {
return push_expr(L, mk_int_value(to_mpz_ext(L, 1)));
}
static int mk_real_value(lua_State * L) {
return push_expr(L, mk_real_value(to_mpq_ext(L, 1)));
}
void open_expr(lua_State * L) { void open_expr(lua_State * L) {
luaL_newmetatable(L, expr_mt); luaL_newmetatable(L, expr_mt);
lua_pushvalue(L, -1); lua_pushvalue(L, -1);
@ -449,6 +465,12 @@ void open_expr(lua_State * L) {
SET_GLOBAL_FUN(expr_type, "Type"); SET_GLOBAL_FUN(expr_type, "Type");
SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar"); SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar");
SET_GLOBAL_FUN(expr_pred, "is_expr"); SET_GLOBAL_FUN(expr_pred, "is_expr");
SET_GLOBAL_FUN(mk_nat_value, "mk_nat_value");
SET_GLOBAL_FUN(mk_nat_value, "nVal");
SET_GLOBAL_FUN(mk_int_value, "mk_int_value");
SET_GLOBAL_FUN(mk_int_value, "iVal");
SET_GLOBAL_FUN(mk_real_value, "mk_real_value");
SET_GLOBAL_FUN(mk_real_value, "rVal");
lua_newtable(L); lua_newtable(L);
SET_ENUM("Var", expr_kind::Var); SET_ENUM("Var", expr_kind::Var);

View file

@ -21,11 +21,29 @@ name & to_name(lua_State * L, int idx) {
} }
name to_name_ext(lua_State * L, int idx) { name to_name_ext(lua_State * L, int idx) {
if (lua_isstring(L, idx)) if (lua_isstring(L, idx)) {
return luaL_checkstring(L, idx); return luaL_checkstring(L, idx);
else } else if (lua_istable(L, idx)) {
name r;
int n = objlen(L, idx);
for (int i = 1; i <= n; i++) {
lua_rawgeti(L, idx, i);
if (lua_isnil(L, -1)) {
// skip
} else if (lua_isuserdata(L, -1)) {
r = r + to_name(L, -1);
} else if (lua_isstring(L, -1)) {
r = name(r, luaL_checkstring(L, -1));
} else {
r = name(r, luaL_checkinteger(L, -1));
}
lua_pop(L, 1);
}
return r;
} else {
return to_name(L, idx); return to_name(L, idx);
} }
}
int push_name(lua_State * L, name const & n) { int push_name(lua_State * L, name const & n) {
void * mem = lua_newuserdata(L, sizeof(name)); void * mem = lua_newuserdata(L, sizeof(name));

View file

@ -20,8 +20,7 @@ static mpz const & to_mpz(lua_State * L) {
if (lua_isuserdata(L, idx)) { if (lua_isuserdata(L, idx)) {
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt)); return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
} else if (lua_isstring(L, idx)) { } else if (lua_isstring(L, idx)) {
char const * str = luaL_checkstring(L, idx); arg = mpz(luaL_checkstring(L, idx));
arg = mpz(str);
return arg; return arg;
} else { } else {
arg = static_cast<long int>(luaL_checkinteger(L, 1)); arg = static_cast<long int>(luaL_checkinteger(L, 1));
@ -37,6 +36,16 @@ mpz & to_mpz(lua_State * L, int idx) {
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt)); return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
} }
mpz to_mpz_ext(lua_State * L, int idx) {
if (lua_isuserdata(L, idx)) {
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
} else if (lua_isstring(L, idx)) {
return mpz(luaL_checkstring(L, idx));
} else {
return mpz(static_cast<long int>(luaL_checkinteger(L, 1)));
}
}
int push_mpz(lua_State * L, mpz const & val) { int push_mpz(lua_State * L, mpz const & val) {
void * mem = lua_newuserdata(L, sizeof(mpz)); void * mem = lua_newuserdata(L, sizeof(mpz));
new (mem) mpz(val); new (mem) mpz(val);
@ -135,15 +144,17 @@ template<unsigned idx>
static mpq const & to_mpq(lua_State * L) { static mpq const & to_mpq(lua_State * L) {
static thread_local mpq arg; static thread_local mpq arg;
if (lua_isuserdata(L, idx)) { if (lua_isuserdata(L, idx)) {
if (is_mpz(L, idx)) {
arg = mpq(to_mpz<idx>(L));
} else {
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt)); return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
}
} else if (lua_isstring(L, idx)) { } else if (lua_isstring(L, idx)) {
char const * str = luaL_checkstring(L, idx); arg = mpq(luaL_checkstring(L, idx));
arg = mpq(str);
return arg;
} else { } else {
arg = static_cast<long int>(luaL_checkinteger(L, 1)); arg = static_cast<long int>(luaL_checkinteger(L, 1));
return arg;
} }
return arg;
} }
bool is_mpq(lua_State * L, int idx) { bool is_mpq(lua_State * L, int idx) {
@ -154,6 +165,20 @@ mpq & to_mpq(lua_State * L, int idx) {
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt)); return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
} }
mpq to_mpq_ext(lua_State * L, int idx) {
if (lua_isuserdata(L, idx)) {
if (is_mpz(L, idx)) {
return mpq(to_mpz(L, idx));
} else {
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
}
} else if (lua_isstring(L, idx)) {
return mpq(luaL_checkstring(L, idx));
} else {
return mpq(static_cast<long int>(luaL_checkinteger(L, 1)));
}
}
int push_mpq(lua_State * L, mpq const & val) { int push_mpq(lua_State * L, mpq const & val) {
void * mem = lua_newuserdata(L, sizeof(mpq)); void * mem = lua_newuserdata(L, sizeof(mpq));
new (mem) mpq(val); new (mem) mpq(val);

View file

@ -11,11 +11,13 @@ class mpz;
void open_mpz(lua_State * L); void open_mpz(lua_State * L);
bool is_mpz(lua_State * L, int idx); bool is_mpz(lua_State * L, int idx);
mpz & to_mpz(lua_State * L, int idx); mpz & to_mpz(lua_State * L, int idx);
mpz to_mpz_ext(lua_State * L, int idx);
int push_mpz(lua_State * L, mpz const & val); int push_mpz(lua_State * L, mpz const & val);
class mpq; class mpq;
void open_mpq(lua_State * L); void open_mpq(lua_State * L);
bool is_mpq(lua_State * L, int idx); bool is_mpq(lua_State * L, int idx);
mpq & to_mpq(lua_State * L, int idx); mpq & to_mpq(lua_State * L, int idx);
mpq to_mpq_ext(lua_State * L, int idx);
int push_mpq(lua_State * L, mpq const & val); int push_mpq(lua_State * L, mpq const & val);
} }

14
tests/lean/lua12.lean Normal file
View file

@ -0,0 +1,14 @@
Variables x y z : Int
(**
local env = get_environment()
local plus = Const{"Int", "add"}
local x, y = Consts("x y")
local def = plus(plus(x, y), iVal(1000))
print(def, ":", env:check_type(def))
env:add_definition("sum", def)
**)
Eval sum + 3