feat(lua): add splay_maps to the Lua API
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
8e56726116
commit
64cce595a5
9 changed files with 342 additions and 5 deletions
|
@ -1,5 +1,6 @@
|
|||
add_library(lua util.cpp lua_exception.cpp name.cpp numerics.cpp
|
||||
options.cpp sexpr.cpp format.cpp level.cpp local_context.cpp expr.cpp
|
||||
context.cpp object.cpp environment.cpp formatter.cpp state.cpp leanlua_state.cpp)
|
||||
splay_map.cpp options.cpp sexpr.cpp format.cpp level.cpp
|
||||
local_context.cpp expr.cpp context.cpp object.cpp environment.cpp
|
||||
formatter.cpp state.cpp leanlua_state.cpp)
|
||||
|
||||
target_link_libraries(lua ${LEAN_LIBS})
|
||||
|
|
|
@ -377,6 +377,16 @@ static int expr_occurs(lua_State * L) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
static int expr_is_eqp(lua_State * L) {
|
||||
lua_pushboolean(L, is_eqp(to_expr(L, 1), to_expr(L, 2)));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int expr_hash(lua_State * L) {
|
||||
lua_pushinteger(L, to_expr(L, 1).hash());
|
||||
return 1;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg expr_m[] = {
|
||||
{"__gc", expr_gc}, // never throws
|
||||
{"__tostring", safe_function<expr_tostring>},
|
||||
|
@ -389,7 +399,7 @@ static const struct luaL_Reg expr_m[] = {
|
|||
{"is_constant", safe_function<expr_is_constant>},
|
||||
{"is_app", safe_function<expr_is_app>},
|
||||
{"is_eq", safe_function<expr_is_eq>},
|
||||
{"is_lambda", safe_function<expr_is_lambda>},
|
||||
{"is_lambda", safe_function<expr_is_lambda>},
|
||||
{"is_pi", safe_function<expr_is_pi>},
|
||||
{"is_abstraction", safe_function<expr_is_abstraction>},
|
||||
{"is_let", safe_function<expr_is_let>},
|
||||
|
@ -409,6 +419,8 @@ static const struct luaL_Reg expr_m[] = {
|
|||
{"abstract", safe_function<expr_abstract>},
|
||||
{"occurs", safe_function<expr_occurs>},
|
||||
{"has_metavar", safe_function<expr_has_metavar>},
|
||||
{"is_eqp", safe_function<expr_is_eqp>},
|
||||
{"hash", safe_function<expr_hash>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ Author: Leonardo de Moura
|
|||
#include "bindings/lua/leanlua_state.h"
|
||||
#include "bindings/lua/util.h"
|
||||
#include "bindings/lua/name.h"
|
||||
#include "bindings/lua/splay_map.h"
|
||||
#include "bindings/lua/numerics.h"
|
||||
#include "bindings/lua/options.h"
|
||||
#include "bindings/lua/sexpr.h"
|
||||
|
@ -145,6 +146,7 @@ struct leanlua_state::imp {
|
|||
luaL_openlibs(m_state);
|
||||
open_patch(m_state);
|
||||
open_name(m_state);
|
||||
open_splay_map(m_state);
|
||||
open_mpz(m_state);
|
||||
open_mpq(m_state);
|
||||
open_options(m_state);
|
||||
|
|
205
src/bindings/lua/splay_map.cpp
Normal file
205
src/bindings/lua/splay_map.cpp
Normal file
|
@ -0,0 +1,205 @@
|
|||
/*
|
||||
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <lua.hpp>
|
||||
#include "util/splay_map.h"
|
||||
#include "bindings/lua/util.h"
|
||||
|
||||
#include "bindings/lua/expr.h"
|
||||
#include "library/expr_lt.h"
|
||||
|
||||
namespace lean {
|
||||
/**
|
||||
\brief Reference to Lua object.
|
||||
*/
|
||||
class lua_ref {
|
||||
lua_State * m_state;
|
||||
int m_ref;
|
||||
public:
|
||||
lua_ref():m_state(nullptr) {}
|
||||
|
||||
lua_ref(lua_State * L, int i) {
|
||||
lean_assert(L);
|
||||
m_state = L;
|
||||
lua_pushvalue(m_state, i);
|
||||
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
|
||||
}
|
||||
|
||||
lua_ref(lua_ref const & r) {
|
||||
m_state = r.m_state;
|
||||
if (m_state) {
|
||||
r.push();
|
||||
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
|
||||
}
|
||||
}
|
||||
|
||||
lua_ref(lua_ref && r) {
|
||||
m_state = r.m_state;
|
||||
m_ref = r.m_ref;
|
||||
r.m_state = nullptr;
|
||||
}
|
||||
|
||||
~lua_ref() {
|
||||
if (m_state)
|
||||
luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref);
|
||||
}
|
||||
|
||||
lua_ref & operator=(lua_ref const & r) {
|
||||
if (m_ref == r.m_ref)
|
||||
return *this;
|
||||
if (m_state)
|
||||
luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref);
|
||||
m_state = r.m_state;
|
||||
if (m_state) {
|
||||
r.push();
|
||||
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void push() const {
|
||||
lean_assert(m_state);
|
||||
lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_ref);
|
||||
}
|
||||
|
||||
lua_State * get_state() const {
|
||||
return m_state;
|
||||
}
|
||||
};
|
||||
|
||||
struct lua_lt_proc {
|
||||
int operator()(lua_ref const & r1, lua_ref const & r2) const {
|
||||
lean_assert(r1.get_state() == r2.get_state());
|
||||
lua_State * L = r1.get_state();
|
||||
r1.push();
|
||||
r2.push();
|
||||
int r;
|
||||
if (lessthan(L, -2, -1)) {
|
||||
r = -1;
|
||||
} else if (lessthan(L, -1, -2)) {
|
||||
r = 1;
|
||||
} else if (equal(L, -2, -1)) {
|
||||
r = 0;
|
||||
} else {
|
||||
throw exception("'<' is not a total order for the elements inserted on the table");
|
||||
}
|
||||
lua_pop(L, 2);
|
||||
return r;
|
||||
}
|
||||
};
|
||||
|
||||
typedef splay_map<lua_ref, lua_ref, lua_lt_proc> lua_splay_map;
|
||||
|
||||
constexpr char const * splay_map_mt = "splay_map.mt";
|
||||
|
||||
bool is_splay_map(lua_State * L, int idx) {
|
||||
return testudata(L, idx, splay_map_mt);
|
||||
}
|
||||
|
||||
lua_splay_map & to_splay_map(lua_State * L, int idx) {
|
||||
return *static_cast<lua_splay_map*>(luaL_checkudata(L, idx, splay_map_mt));
|
||||
}
|
||||
|
||||
int push_splay_map(lua_State * L, lua_splay_map const & o) {
|
||||
void * mem = lua_newuserdata(L, sizeof(lua_splay_map));
|
||||
new (mem) lua_splay_map(o);
|
||||
luaL_getmetatable(L, splay_map_mt);
|
||||
lua_setmetatable(L, -2);
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int mk_splay_map(lua_State * L) {
|
||||
lua_splay_map r;
|
||||
return push_splay_map(L, r);
|
||||
}
|
||||
|
||||
static int splay_map_gc(lua_State * L) {
|
||||
to_splay_map(L, 1).~lua_splay_map();
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int splay_map_size(lua_State * L) {
|
||||
lua_pushinteger(L, to_splay_map(L, 1).size());
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int splay_map_contains(lua_State * L) {
|
||||
lua_pushboolean(L, to_splay_map(L, 1).contains(lua_ref(L, 2)));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int splay_map_empty(lua_State * L) {
|
||||
lua_pushboolean(L, to_splay_map(L, 1).empty());
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int splay_map_insert(lua_State * L) {
|
||||
to_splay_map(L, 1).insert(lua_ref(L, 2), lua_ref(L, 3));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int splay_map_erase(lua_State * L) {
|
||||
to_splay_map(L, 1).erase(lua_ref(L, 2));
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int splay_map_find(lua_State * L) {
|
||||
lua_splay_map & m = to_splay_map(L, 1);
|
||||
lua_ref * val = m.splay_find(lua_ref(L, 2));
|
||||
if (val) {
|
||||
lean_assert(val->get_state() == L);
|
||||
val->push();
|
||||
} else {
|
||||
lua_pushnil(L);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int splay_map_copy(lua_State * L) {
|
||||
return push_splay_map(L, to_splay_map(L, 1));
|
||||
}
|
||||
|
||||
static int splay_map_pred(lua_State * L) {
|
||||
lua_pushboolean(L, is_splay_map(L, 1));
|
||||
return 1;
|
||||
}
|
||||
|
||||
static int splay_map_for_each(lua_State * L) {
|
||||
lua_splay_map & m = to_splay_map(L, 1); // map
|
||||
luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun
|
||||
m.for_each([&](lua_ref const & k, lua_ref const & v) {
|
||||
lua_pushvalue(L, 2); // push user-fun
|
||||
k.push();
|
||||
v.push();
|
||||
pcall(L, 2, 0, 0);
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const struct luaL_Reg splay_map_m[] = {
|
||||
{"__gc", splay_map_gc}, // never throws
|
||||
{"__len", safe_function<splay_map_size> },
|
||||
{"contains", safe_function<splay_map_contains>},
|
||||
{"size", safe_function<splay_map_size>},
|
||||
{"empty", safe_function<splay_map_empty>},
|
||||
{"insert", safe_function<splay_map_insert>},
|
||||
{"erase", safe_function<splay_map_erase>},
|
||||
{"find", safe_function<splay_map_find>},
|
||||
{"copy", safe_function<splay_map_copy>},
|
||||
{"for_each", safe_function<splay_map_for_each>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
void open_splay_map(lua_State * L) {
|
||||
luaL_newmetatable(L, splay_map_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, splay_map_m, 0);
|
||||
|
||||
SET_GLOBAL_FUN(mk_splay_map, "splay_map");
|
||||
SET_GLOBAL_FUN(splay_map_pred, "is_splay_map");
|
||||
}
|
||||
}
|
11
src/bindings/lua/splay_map.h
Normal file
11
src/bindings/lua/splay_map.h
Normal file
|
@ -0,0 +1,11 @@
|
|||
/*
|
||||
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include <lua.hpp>
|
||||
namespace lean {
|
||||
void open_splay_map(lua_State * L);
|
||||
}
|
|
@ -58,6 +58,22 @@ size_t objlen(lua_State * L, int idx) {
|
|||
#endif
|
||||
}
|
||||
|
||||
int lessthan(lua_State * L, int idx1, int idx2) {
|
||||
#if LUA_VERSION_NUM < 502
|
||||
return lua_lessthan(L, idx1, idx2);
|
||||
#else
|
||||
return lua_compare(L, idx1, idx2, LUA_OPLT);
|
||||
#endif
|
||||
}
|
||||
|
||||
int equal(lua_State * L, int idx1, int idx2) {
|
||||
#if LUA_VERSION_NUM < 502
|
||||
return lua_equal(L, idx1, idx2);
|
||||
#else
|
||||
return lua_compare(L, idx1, idx2, LUA_OPEQ);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void exec(lua_State * L) {
|
||||
pcall(L, 0, LUA_MULTRET, 0);
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ size_t objlen(lua_State * L, int idx);
|
|||
void dofile(lua_State * L, char const * fname);
|
||||
void dostring(lua_State * L, char const * str);
|
||||
void pcall(lua_State * L, int nargs, int nresults, int errorfun);
|
||||
int lessthan(lua_State * L, int idx1, int idx2);
|
||||
int equal(lua_State * L, int idx1, int idx2);
|
||||
/**
|
||||
\brief Wrapper for invoking function f, and catching Lean exceptions.
|
||||
*/
|
||||
|
|
|
@ -183,11 +183,11 @@ class splay_tree : public CMP {
|
|||
if (n) {
|
||||
if (n->m_left) {
|
||||
check_invariant(n->m_left);
|
||||
lean_assert(cmp(n->m_left->m_value, n->m_value) < 0);
|
||||
lean_assert_lt(cmp(n->m_left->m_value, n->m_value), 0);
|
||||
}
|
||||
if (n->m_right) {
|
||||
check_invariant(n->m_right);
|
||||
lean_assert(cmp(n->m_value, n->m_right->m_value) < 0);
|
||||
lean_assert_lt(cmp(n->m_value, n->m_right->m_value), 0);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
88
tests/lua/map.lua
Normal file
88
tests/lua/map.lua
Normal file
|
@ -0,0 +1,88 @@
|
|||
-- This examples demonstrates that Lean objects are not very useful as Lua table keys.
|
||||
local f = Const("f")
|
||||
local m = {}
|
||||
local env = environment()
|
||||
env:add_var("T", Type())
|
||||
env:add_var("f", mk_arrow(Const("T"), Const("T")))
|
||||
for i = 1, 100 do
|
||||
env:add_var("a" .. i, Const("T"))
|
||||
local t = f(Const("a" .. i))
|
||||
-- Any object can be a key of a Lua table.
|
||||
-- But, Lua does not use the method __eq for comparing keys.
|
||||
-- The problem is that Lua uses its own hashcode that may not
|
||||
-- be compatible with the __eq implemantion.
|
||||
-- By non-compatible, we mean two objects my be equal by __eq, but
|
||||
-- the hashcodes may be different.
|
||||
m[t] = i
|
||||
end
|
||||
|
||||
for t1, i in pairs(m) do
|
||||
local t2 = f(Const("a" .. i))
|
||||
-- print(t1, i, t2)
|
||||
assert(m[t1] == i)
|
||||
-- t1 and t2 are structurally equal
|
||||
assert(t1 == t2)
|
||||
-- t1 and t2 are different objects
|
||||
assert(not t1:is_eqp(t2))
|
||||
-- t2 is not a key of map
|
||||
assert(m[t2] == nil)
|
||||
assert(env:normalize(t1) == t1)
|
||||
assert(t1:instantiate(Const("a")) == t1)
|
||||
local t1_prime = t1:instantiate(Const("a"))
|
||||
-- t1 and t1_prime are structurally equal
|
||||
assert(t1 == t1_prime)
|
||||
-- Moreover, they are references to the same Lean object
|
||||
assert(t1:is_eqp(t1_prime))
|
||||
-- But, they are wrapped by different Lua userdata
|
||||
assert(m[t1_prime] == nil)
|
||||
end
|
||||
|
||||
-- We can store elements that implement __lt and __eq metamethods in splay_maps.
|
||||
-- The implementation assumes that the elements stored in the splay map can be totally ordered by __lt
|
||||
m = splay_map()
|
||||
m:insert(Const("a"), 10)
|
||||
assert(m:contains(Const("a")))
|
||||
assert(not m:contains(Const("b")))
|
||||
assert(m:find(Const("a")) == 10)
|
||||
local a, b = Consts("a, b")
|
||||
m:insert(f(a, b), 20)
|
||||
assert(m:find(f(a, b)) == 20)
|
||||
assert(m:find(f(a, a)) == nil)
|
||||
assert(m:size() == 2)
|
||||
assert(#m == 2)
|
||||
m:erase(f(a, a))
|
||||
assert(m:size() == 2)
|
||||
m:erase(f(a, b))
|
||||
assert(m:size() == 1)
|
||||
for i = 1, 100 do
|
||||
local t = f(Const("a" .. i))
|
||||
m:insert(t, i)
|
||||
assert(m:contains(t))
|
||||
assert(m:find(t) == i)
|
||||
end
|
||||
|
||||
assert(m:size() == 101)
|
||||
|
||||
for i = 1, 100 do
|
||||
local t = f(Const("a" .. i))
|
||||
assert(m:find(t) == i)
|
||||
end
|
||||
|
||||
-- The following call fails because integers cannot be compared with Lean expressions
|
||||
assert(not pcall(function() m:insert(10, 20) end))
|
||||
|
||||
-- Splay maps copy operation is O(1)
|
||||
local m2 = m:copy()
|
||||
m2:insert(b, 20)
|
||||
assert(m:size() == 101)
|
||||
assert(m2:size() == 102)
|
||||
|
||||
-- We can also traverse all elements in the map
|
||||
local num = 0
|
||||
m:for_each(
|
||||
function(k, v)
|
||||
print(tostring(k) .. " -> " .. v)
|
||||
num = num + 1
|
||||
end
|
||||
)
|
||||
assert(num == 101)
|
Loading…
Reference in a new issue