feat(lua): add semantic attachments for builtin arithmetical values to Lua API, improve mk_constant
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
05f254f605
commit
09bed4786c
5 changed files with 90 additions and 9 deletions
|
@ -20,12 +20,16 @@ Author: Leonardo de Moura
|
|||
#include "kernel/occurs.h"
|
||||
#include "kernel/metavar.h"
|
||||
#include "library/expr_lt.h"
|
||||
#include "library/arith/nat.h"
|
||||
#include "library/arith/int.h"
|
||||
#include "library/arith/real.h"
|
||||
#include "bindings/lua/util.h"
|
||||
#include "bindings/lua/name.h"
|
||||
#include "bindings/lua/options.h"
|
||||
#include "bindings/lua/level.h"
|
||||
#include "bindings/lua/local_context.h"
|
||||
#include "bindings/lua/formatter.h"
|
||||
#include "bindings/lua/numerics.h"
|
||||
|
||||
namespace lean {
|
||||
constexpr char const * expr_mt = "expr.mt";
|
||||
|
@ -424,6 +428,18 @@ static const struct luaL_Reg expr_m[] = {
|
|||
{0, 0}
|
||||
};
|
||||
|
||||
static int mk_nat_value(lua_State * L) {
|
||||
return push_expr(L, mk_nat_value(to_mpz_ext(L, 1)));
|
||||
}
|
||||
|
||||
static int mk_int_value(lua_State * L) {
|
||||
return push_expr(L, mk_int_value(to_mpz_ext(L, 1)));
|
||||
}
|
||||
|
||||
static int mk_real_value(lua_State * L) {
|
||||
return push_expr(L, mk_real_value(to_mpq_ext(L, 1)));
|
||||
}
|
||||
|
||||
void open_expr(lua_State * L) {
|
||||
luaL_newmetatable(L, expr_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
|
@ -449,6 +465,12 @@ void open_expr(lua_State * L) {
|
|||
SET_GLOBAL_FUN(expr_type, "Type");
|
||||
SET_GLOBAL_FUN(expr_mk_metavar, "mk_metavar");
|
||||
SET_GLOBAL_FUN(expr_pred, "is_expr");
|
||||
SET_GLOBAL_FUN(mk_nat_value, "mk_nat_value");
|
||||
SET_GLOBAL_FUN(mk_nat_value, "nVal");
|
||||
SET_GLOBAL_FUN(mk_int_value, "mk_int_value");
|
||||
SET_GLOBAL_FUN(mk_int_value, "iVal");
|
||||
SET_GLOBAL_FUN(mk_real_value, "mk_real_value");
|
||||
SET_GLOBAL_FUN(mk_real_value, "rVal");
|
||||
|
||||
lua_newtable(L);
|
||||
SET_ENUM("Var", expr_kind::Var);
|
||||
|
|
|
@ -21,10 +21,28 @@ name & to_name(lua_State * L, int idx) {
|
|||
}
|
||||
|
||||
name to_name_ext(lua_State * L, int idx) {
|
||||
if (lua_isstring(L, idx))
|
||||
if (lua_isstring(L, idx)) {
|
||||
return luaL_checkstring(L, idx);
|
||||
else
|
||||
} else if (lua_istable(L, idx)) {
|
||||
name r;
|
||||
int n = objlen(L, idx);
|
||||
for (int i = 1; i <= n; i++) {
|
||||
lua_rawgeti(L, idx, i);
|
||||
if (lua_isnil(L, -1)) {
|
||||
// skip
|
||||
} else if (lua_isuserdata(L, -1)) {
|
||||
r = r + to_name(L, -1);
|
||||
} else if (lua_isstring(L, -1)) {
|
||||
r = name(r, luaL_checkstring(L, -1));
|
||||
} else {
|
||||
r = name(r, luaL_checkinteger(L, -1));
|
||||
}
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
return r;
|
||||
} else {
|
||||
return to_name(L, idx);
|
||||
}
|
||||
}
|
||||
|
||||
int push_name(lua_State * L, name const & n) {
|
||||
|
|
|
@ -20,8 +20,7 @@ static mpz const & to_mpz(lua_State * L) {
|
|||
if (lua_isuserdata(L, idx)) {
|
||||
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
char const * str = luaL_checkstring(L, idx);
|
||||
arg = mpz(str);
|
||||
arg = mpz(luaL_checkstring(L, idx));
|
||||
return arg;
|
||||
} else {
|
||||
arg = static_cast<long int>(luaL_checkinteger(L, 1));
|
||||
|
@ -37,6 +36,16 @@ mpz & to_mpz(lua_State * L, int idx) {
|
|||
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
}
|
||||
|
||||
mpz to_mpz_ext(lua_State * L, int idx) {
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
return *static_cast<mpz*>(luaL_checkudata(L, idx, mpz_mt));
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
return mpz(luaL_checkstring(L, idx));
|
||||
} else {
|
||||
return mpz(static_cast<long int>(luaL_checkinteger(L, 1)));
|
||||
}
|
||||
}
|
||||
|
||||
int push_mpz(lua_State * L, mpz const & val) {
|
||||
void * mem = lua_newuserdata(L, sizeof(mpz));
|
||||
new (mem) mpz(val);
|
||||
|
@ -135,15 +144,17 @@ template<unsigned idx>
|
|||
static mpq const & to_mpq(lua_State * L) {
|
||||
static thread_local mpq arg;
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
|
||||
if (is_mpz(L, idx)) {
|
||||
arg = mpq(to_mpz<idx>(L));
|
||||
} else {
|
||||
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
|
||||
}
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
char const * str = luaL_checkstring(L, idx);
|
||||
arg = mpq(str);
|
||||
return arg;
|
||||
arg = mpq(luaL_checkstring(L, idx));
|
||||
} else {
|
||||
arg = static_cast<long int>(luaL_checkinteger(L, 1));
|
||||
return arg;
|
||||
}
|
||||
return arg;
|
||||
}
|
||||
|
||||
bool is_mpq(lua_State * L, int idx) {
|
||||
|
@ -154,6 +165,20 @@ mpq & to_mpq(lua_State * L, int idx) {
|
|||
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
|
||||
}
|
||||
|
||||
mpq to_mpq_ext(lua_State * L, int idx) {
|
||||
if (lua_isuserdata(L, idx)) {
|
||||
if (is_mpz(L, idx)) {
|
||||
return mpq(to_mpz(L, idx));
|
||||
} else {
|
||||
return *static_cast<mpq*>(luaL_checkudata(L, idx, mpq_mt));
|
||||
}
|
||||
} else if (lua_isstring(L, idx)) {
|
||||
return mpq(luaL_checkstring(L, idx));
|
||||
} else {
|
||||
return mpq(static_cast<long int>(luaL_checkinteger(L, 1)));
|
||||
}
|
||||
}
|
||||
|
||||
int push_mpq(lua_State * L, mpq const & val) {
|
||||
void * mem = lua_newuserdata(L, sizeof(mpq));
|
||||
new (mem) mpq(val);
|
||||
|
|
|
@ -11,11 +11,13 @@ class mpz;
|
|||
void open_mpz(lua_State * L);
|
||||
bool is_mpz(lua_State * L, int idx);
|
||||
mpz & to_mpz(lua_State * L, int idx);
|
||||
mpz to_mpz_ext(lua_State * L, int idx);
|
||||
int push_mpz(lua_State * L, mpz const & val);
|
||||
|
||||
class mpq;
|
||||
void open_mpq(lua_State * L);
|
||||
bool is_mpq(lua_State * L, int idx);
|
||||
mpq & to_mpq(lua_State * L, int idx);
|
||||
mpq to_mpq_ext(lua_State * L, int idx);
|
||||
int push_mpq(lua_State * L, mpq const & val);
|
||||
}
|
||||
|
|
14
tests/lean/lua12.lean
Normal file
14
tests/lean/lua12.lean
Normal file
|
@ -0,0 +1,14 @@
|
|||
Variables x y z : Int
|
||||
|
||||
(**
|
||||
|
||||
local env = get_environment()
|
||||
local plus = Const{"Int", "add"}
|
||||
local x, y = Consts("x y")
|
||||
local def = plus(plus(x, y), iVal(1000))
|
||||
print(def, ":", env:check_type(def))
|
||||
env:add_definition("sum", def)
|
||||
|
||||
**)
|
||||
|
||||
Eval sum + 3
|
Loading…
Reference in a new issue