feat(lua): add splay_maps to the Lua API
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
8e56726116
commit
64cce595a5
9 changed files with 342 additions and 5 deletions
|
@ -1,5 +1,6 @@
|
||||||
add_library(lua util.cpp lua_exception.cpp name.cpp numerics.cpp
|
add_library(lua util.cpp lua_exception.cpp name.cpp numerics.cpp
|
||||||
options.cpp sexpr.cpp format.cpp level.cpp local_context.cpp expr.cpp
|
splay_map.cpp options.cpp sexpr.cpp format.cpp level.cpp
|
||||||
context.cpp object.cpp environment.cpp formatter.cpp state.cpp leanlua_state.cpp)
|
local_context.cpp expr.cpp context.cpp object.cpp environment.cpp
|
||||||
|
formatter.cpp state.cpp leanlua_state.cpp)
|
||||||
|
|
||||||
target_link_libraries(lua ${LEAN_LIBS})
|
target_link_libraries(lua ${LEAN_LIBS})
|
||||||
|
|
|
@ -377,6 +377,16 @@ static int expr_occurs(lua_State * L) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int expr_is_eqp(lua_State * L) {
|
||||||
|
lua_pushboolean(L, is_eqp(to_expr(L, 1), to_expr(L, 2)));
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int expr_hash(lua_State * L) {
|
||||||
|
lua_pushinteger(L, to_expr(L, 1).hash());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
static const struct luaL_Reg expr_m[] = {
|
static const struct luaL_Reg expr_m[] = {
|
||||||
{"__gc", expr_gc}, // never throws
|
{"__gc", expr_gc}, // never throws
|
||||||
{"__tostring", safe_function<expr_tostring>},
|
{"__tostring", safe_function<expr_tostring>},
|
||||||
|
@ -389,7 +399,7 @@ static const struct luaL_Reg expr_m[] = {
|
||||||
{"is_constant", safe_function<expr_is_constant>},
|
{"is_constant", safe_function<expr_is_constant>},
|
||||||
{"is_app", safe_function<expr_is_app>},
|
{"is_app", safe_function<expr_is_app>},
|
||||||
{"is_eq", safe_function<expr_is_eq>},
|
{"is_eq", safe_function<expr_is_eq>},
|
||||||
{"is_lambda", safe_function<expr_is_lambda>},
|
{"is_lambda", safe_function<expr_is_lambda>},
|
||||||
{"is_pi", safe_function<expr_is_pi>},
|
{"is_pi", safe_function<expr_is_pi>},
|
||||||
{"is_abstraction", safe_function<expr_is_abstraction>},
|
{"is_abstraction", safe_function<expr_is_abstraction>},
|
||||||
{"is_let", safe_function<expr_is_let>},
|
{"is_let", safe_function<expr_is_let>},
|
||||||
|
@ -409,6 +419,8 @@ static const struct luaL_Reg expr_m[] = {
|
||||||
{"abstract", safe_function<expr_abstract>},
|
{"abstract", safe_function<expr_abstract>},
|
||||||
{"occurs", safe_function<expr_occurs>},
|
{"occurs", safe_function<expr_occurs>},
|
||||||
{"has_metavar", safe_function<expr_has_metavar>},
|
{"has_metavar", safe_function<expr_has_metavar>},
|
||||||
|
{"is_eqp", safe_function<expr_is_eqp>},
|
||||||
|
{"hash", safe_function<expr_hash>},
|
||||||
{0, 0}
|
{0, 0}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ Author: Leonardo de Moura
|
||||||
#include "bindings/lua/leanlua_state.h"
|
#include "bindings/lua/leanlua_state.h"
|
||||||
#include "bindings/lua/util.h"
|
#include "bindings/lua/util.h"
|
||||||
#include "bindings/lua/name.h"
|
#include "bindings/lua/name.h"
|
||||||
|
#include "bindings/lua/splay_map.h"
|
||||||
#include "bindings/lua/numerics.h"
|
#include "bindings/lua/numerics.h"
|
||||||
#include "bindings/lua/options.h"
|
#include "bindings/lua/options.h"
|
||||||
#include "bindings/lua/sexpr.h"
|
#include "bindings/lua/sexpr.h"
|
||||||
|
@ -145,6 +146,7 @@ struct leanlua_state::imp {
|
||||||
luaL_openlibs(m_state);
|
luaL_openlibs(m_state);
|
||||||
open_patch(m_state);
|
open_patch(m_state);
|
||||||
open_name(m_state);
|
open_name(m_state);
|
||||||
|
open_splay_map(m_state);
|
||||||
open_mpz(m_state);
|
open_mpz(m_state);
|
||||||
open_mpq(m_state);
|
open_mpq(m_state);
|
||||||
open_options(m_state);
|
open_options(m_state);
|
||||||
|
|
205
src/bindings/lua/splay_map.cpp
Normal file
205
src/bindings/lua/splay_map.cpp
Normal file
|
@ -0,0 +1,205 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||||
|
Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
|
Author: Leonardo de Moura
|
||||||
|
*/
|
||||||
|
#include <lua.hpp>
|
||||||
|
#include "util/splay_map.h"
|
||||||
|
#include "bindings/lua/util.h"
|
||||||
|
|
||||||
|
#include "bindings/lua/expr.h"
|
||||||
|
#include "library/expr_lt.h"
|
||||||
|
|
||||||
|
namespace lean {
|
||||||
|
/**
|
||||||
|
\brief Reference to Lua object.
|
||||||
|
*/
|
||||||
|
class lua_ref {
|
||||||
|
lua_State * m_state;
|
||||||
|
int m_ref;
|
||||||
|
public:
|
||||||
|
lua_ref():m_state(nullptr) {}
|
||||||
|
|
||||||
|
lua_ref(lua_State * L, int i) {
|
||||||
|
lean_assert(L);
|
||||||
|
m_state = L;
|
||||||
|
lua_pushvalue(m_state, i);
|
||||||
|
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
lua_ref(lua_ref const & r) {
|
||||||
|
m_state = r.m_state;
|
||||||
|
if (m_state) {
|
||||||
|
r.push();
|
||||||
|
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lua_ref(lua_ref && r) {
|
||||||
|
m_state = r.m_state;
|
||||||
|
m_ref = r.m_ref;
|
||||||
|
r.m_state = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
~lua_ref() {
|
||||||
|
if (m_state)
|
||||||
|
luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
lua_ref & operator=(lua_ref const & r) {
|
||||||
|
if (m_ref == r.m_ref)
|
||||||
|
return *this;
|
||||||
|
if (m_state)
|
||||||
|
luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref);
|
||||||
|
m_state = r.m_state;
|
||||||
|
if (m_state) {
|
||||||
|
r.push();
|
||||||
|
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void push() const {
|
||||||
|
lean_assert(m_state);
|
||||||
|
lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
lua_State * get_state() const {
|
||||||
|
return m_state;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct lua_lt_proc {
|
||||||
|
int operator()(lua_ref const & r1, lua_ref const & r2) const {
|
||||||
|
lean_assert(r1.get_state() == r2.get_state());
|
||||||
|
lua_State * L = r1.get_state();
|
||||||
|
r1.push();
|
||||||
|
r2.push();
|
||||||
|
int r;
|
||||||
|
if (lessthan(L, -2, -1)) {
|
||||||
|
r = -1;
|
||||||
|
} else if (lessthan(L, -1, -2)) {
|
||||||
|
r = 1;
|
||||||
|
} else if (equal(L, -2, -1)) {
|
||||||
|
r = 0;
|
||||||
|
} else {
|
||||||
|
throw exception("'<' is not a total order for the elements inserted on the table");
|
||||||
|
}
|
||||||
|
lua_pop(L, 2);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef splay_map<lua_ref, lua_ref, lua_lt_proc> lua_splay_map;
|
||||||
|
|
||||||
|
constexpr char const * splay_map_mt = "splay_map.mt";
|
||||||
|
|
||||||
|
bool is_splay_map(lua_State * L, int idx) {
|
||||||
|
return testudata(L, idx, splay_map_mt);
|
||||||
|
}
|
||||||
|
|
||||||
|
lua_splay_map & to_splay_map(lua_State * L, int idx) {
|
||||||
|
return *static_cast<lua_splay_map*>(luaL_checkudata(L, idx, splay_map_mt));
|
||||||
|
}
|
||||||
|
|
||||||
|
int push_splay_map(lua_State * L, lua_splay_map const & o) {
|
||||||
|
void * mem = lua_newuserdata(L, sizeof(lua_splay_map));
|
||||||
|
new (mem) lua_splay_map(o);
|
||||||
|
luaL_getmetatable(L, splay_map_mt);
|
||||||
|
lua_setmetatable(L, -2);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int mk_splay_map(lua_State * L) {
|
||||||
|
lua_splay_map r;
|
||||||
|
return push_splay_map(L, r);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_gc(lua_State * L) {
|
||||||
|
to_splay_map(L, 1).~lua_splay_map();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_size(lua_State * L) {
|
||||||
|
lua_pushinteger(L, to_splay_map(L, 1).size());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_contains(lua_State * L) {
|
||||||
|
lua_pushboolean(L, to_splay_map(L, 1).contains(lua_ref(L, 2)));
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_empty(lua_State * L) {
|
||||||
|
lua_pushboolean(L, to_splay_map(L, 1).empty());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_insert(lua_State * L) {
|
||||||
|
to_splay_map(L, 1).insert(lua_ref(L, 2), lua_ref(L, 3));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_erase(lua_State * L) {
|
||||||
|
to_splay_map(L, 1).erase(lua_ref(L, 2));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_find(lua_State * L) {
|
||||||
|
lua_splay_map & m = to_splay_map(L, 1);
|
||||||
|
lua_ref * val = m.splay_find(lua_ref(L, 2));
|
||||||
|
if (val) {
|
||||||
|
lean_assert(val->get_state() == L);
|
||||||
|
val->push();
|
||||||
|
} else {
|
||||||
|
lua_pushnil(L);
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_copy(lua_State * L) {
|
||||||
|
return push_splay_map(L, to_splay_map(L, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_pred(lua_State * L) {
|
||||||
|
lua_pushboolean(L, is_splay_map(L, 1));
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int splay_map_for_each(lua_State * L) {
|
||||||
|
lua_splay_map & m = to_splay_map(L, 1); // map
|
||||||
|
luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun
|
||||||
|
m.for_each([&](lua_ref const & k, lua_ref const & v) {
|
||||||
|
lua_pushvalue(L, 2); // push user-fun
|
||||||
|
k.push();
|
||||||
|
v.push();
|
||||||
|
pcall(L, 2, 0, 0);
|
||||||
|
});
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const struct luaL_Reg splay_map_m[] = {
|
||||||
|
{"__gc", splay_map_gc}, // never throws
|
||||||
|
{"__len", safe_function<splay_map_size> },
|
||||||
|
{"contains", safe_function<splay_map_contains>},
|
||||||
|
{"size", safe_function<splay_map_size>},
|
||||||
|
{"empty", safe_function<splay_map_empty>},
|
||||||
|
{"insert", safe_function<splay_map_insert>},
|
||||||
|
{"erase", safe_function<splay_map_erase>},
|
||||||
|
{"find", safe_function<splay_map_find>},
|
||||||
|
{"copy", safe_function<splay_map_copy>},
|
||||||
|
{"for_each", safe_function<splay_map_for_each>},
|
||||||
|
{0, 0}
|
||||||
|
};
|
||||||
|
|
||||||
|
void open_splay_map(lua_State * L) {
|
||||||
|
luaL_newmetatable(L, splay_map_mt);
|
||||||
|
lua_pushvalue(L, -1);
|
||||||
|
lua_setfield(L, -2, "__index");
|
||||||
|
setfuncs(L, splay_map_m, 0);
|
||||||
|
|
||||||
|
SET_GLOBAL_FUN(mk_splay_map, "splay_map");
|
||||||
|
SET_GLOBAL_FUN(splay_map_pred, "is_splay_map");
|
||||||
|
}
|
||||||
|
}
|
11
src/bindings/lua/splay_map.h
Normal file
11
src/bindings/lua/splay_map.h
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||||
|
Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
|
Author: Leonardo de Moura
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <lua.hpp>
|
||||||
|
namespace lean {
|
||||||
|
void open_splay_map(lua_State * L);
|
||||||
|
}
|
|
@ -58,6 +58,22 @@ size_t objlen(lua_State * L, int idx) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int lessthan(lua_State * L, int idx1, int idx2) {
|
||||||
|
#if LUA_VERSION_NUM < 502
|
||||||
|
return lua_lessthan(L, idx1, idx2);
|
||||||
|
#else
|
||||||
|
return lua_compare(L, idx1, idx2, LUA_OPLT);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int equal(lua_State * L, int idx1, int idx2) {
|
||||||
|
#if LUA_VERSION_NUM < 502
|
||||||
|
return lua_equal(L, idx1, idx2);
|
||||||
|
#else
|
||||||
|
return lua_compare(L, idx1, idx2, LUA_OPEQ);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
static void exec(lua_State * L) {
|
static void exec(lua_State * L) {
|
||||||
pcall(L, 0, LUA_MULTRET, 0);
|
pcall(L, 0, LUA_MULTRET, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@ size_t objlen(lua_State * L, int idx);
|
||||||
void dofile(lua_State * L, char const * fname);
|
void dofile(lua_State * L, char const * fname);
|
||||||
void dostring(lua_State * L, char const * str);
|
void dostring(lua_State * L, char const * str);
|
||||||
void pcall(lua_State * L, int nargs, int nresults, int errorfun);
|
void pcall(lua_State * L, int nargs, int nresults, int errorfun);
|
||||||
|
int lessthan(lua_State * L, int idx1, int idx2);
|
||||||
|
int equal(lua_State * L, int idx1, int idx2);
|
||||||
/**
|
/**
|
||||||
\brief Wrapper for invoking function f, and catching Lean exceptions.
|
\brief Wrapper for invoking function f, and catching Lean exceptions.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -183,11 +183,11 @@ class splay_tree : public CMP {
|
||||||
if (n) {
|
if (n) {
|
||||||
if (n->m_left) {
|
if (n->m_left) {
|
||||||
check_invariant(n->m_left);
|
check_invariant(n->m_left);
|
||||||
lean_assert(cmp(n->m_left->m_value, n->m_value) < 0);
|
lean_assert_lt(cmp(n->m_left->m_value, n->m_value), 0);
|
||||||
}
|
}
|
||||||
if (n->m_right) {
|
if (n->m_right) {
|
||||||
check_invariant(n->m_right);
|
check_invariant(n->m_right);
|
||||||
lean_assert(cmp(n->m_value, n->m_right->m_value) < 0);
|
lean_assert_lt(cmp(n->m_value, n->m_right->m_value), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
88
tests/lua/map.lua
Normal file
88
tests/lua/map.lua
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
-- This examples demonstrates that Lean objects are not very useful as Lua table keys.
|
||||||
|
local f = Const("f")
|
||||||
|
local m = {}
|
||||||
|
local env = environment()
|
||||||
|
env:add_var("T", Type())
|
||||||
|
env:add_var("f", mk_arrow(Const("T"), Const("T")))
|
||||||
|
for i = 1, 100 do
|
||||||
|
env:add_var("a" .. i, Const("T"))
|
||||||
|
local t = f(Const("a" .. i))
|
||||||
|
-- Any object can be a key of a Lua table.
|
||||||
|
-- But, Lua does not use the method __eq for comparing keys.
|
||||||
|
-- The problem is that Lua uses its own hashcode that may not
|
||||||
|
-- be compatible with the __eq implemantion.
|
||||||
|
-- By non-compatible, we mean two objects my be equal by __eq, but
|
||||||
|
-- the hashcodes may be different.
|
||||||
|
m[t] = i
|
||||||
|
end
|
||||||
|
|
||||||
|
for t1, i in pairs(m) do
|
||||||
|
local t2 = f(Const("a" .. i))
|
||||||
|
-- print(t1, i, t2)
|
||||||
|
assert(m[t1] == i)
|
||||||
|
-- t1 and t2 are structurally equal
|
||||||
|
assert(t1 == t2)
|
||||||
|
-- t1 and t2 are different objects
|
||||||
|
assert(not t1:is_eqp(t2))
|
||||||
|
-- t2 is not a key of map
|
||||||
|
assert(m[t2] == nil)
|
||||||
|
assert(env:normalize(t1) == t1)
|
||||||
|
assert(t1:instantiate(Const("a")) == t1)
|
||||||
|
local t1_prime = t1:instantiate(Const("a"))
|
||||||
|
-- t1 and t1_prime are structurally equal
|
||||||
|
assert(t1 == t1_prime)
|
||||||
|
-- Moreover, they are references to the same Lean object
|
||||||
|
assert(t1:is_eqp(t1_prime))
|
||||||
|
-- But, they are wrapped by different Lua userdata
|
||||||
|
assert(m[t1_prime] == nil)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- We can store elements that implement __lt and __eq metamethods in splay_maps.
|
||||||
|
-- The implementation assumes that the elements stored in the splay map can be totally ordered by __lt
|
||||||
|
m = splay_map()
|
||||||
|
m:insert(Const("a"), 10)
|
||||||
|
assert(m:contains(Const("a")))
|
||||||
|
assert(not m:contains(Const("b")))
|
||||||
|
assert(m:find(Const("a")) == 10)
|
||||||
|
local a, b = Consts("a, b")
|
||||||
|
m:insert(f(a, b), 20)
|
||||||
|
assert(m:find(f(a, b)) == 20)
|
||||||
|
assert(m:find(f(a, a)) == nil)
|
||||||
|
assert(m:size() == 2)
|
||||||
|
assert(#m == 2)
|
||||||
|
m:erase(f(a, a))
|
||||||
|
assert(m:size() == 2)
|
||||||
|
m:erase(f(a, b))
|
||||||
|
assert(m:size() == 1)
|
||||||
|
for i = 1, 100 do
|
||||||
|
local t = f(Const("a" .. i))
|
||||||
|
m:insert(t, i)
|
||||||
|
assert(m:contains(t))
|
||||||
|
assert(m:find(t) == i)
|
||||||
|
end
|
||||||
|
|
||||||
|
assert(m:size() == 101)
|
||||||
|
|
||||||
|
for i = 1, 100 do
|
||||||
|
local t = f(Const("a" .. i))
|
||||||
|
assert(m:find(t) == i)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- The following call fails because integers cannot be compared with Lean expressions
|
||||||
|
assert(not pcall(function() m:insert(10, 20) end))
|
||||||
|
|
||||||
|
-- Splay maps copy operation is O(1)
|
||||||
|
local m2 = m:copy()
|
||||||
|
m2:insert(b, 20)
|
||||||
|
assert(m:size() == 101)
|
||||||
|
assert(m2:size() == 102)
|
||||||
|
|
||||||
|
-- We can also traverse all elements in the map
|
||||||
|
local num = 0
|
||||||
|
m:for_each(
|
||||||
|
function(k, v)
|
||||||
|
print(tostring(k) .. " -> " .. v)
|
||||||
|
num = num + 1
|
||||||
|
end
|
||||||
|
)
|
||||||
|
assert(num == 101)
|
Loading…
Reference in a new issue