feat(lua): add splay_maps to the Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-11-14 13:32:33 -08:00
parent 8e56726116
commit 64cce595a5
9 changed files with 342 additions and 5 deletions

View file

@ -1,5 +1,6 @@
add_library(lua util.cpp lua_exception.cpp name.cpp numerics.cpp
options.cpp sexpr.cpp format.cpp level.cpp local_context.cpp expr.cpp
context.cpp object.cpp environment.cpp formatter.cpp state.cpp leanlua_state.cpp)
splay_map.cpp options.cpp sexpr.cpp format.cpp level.cpp
local_context.cpp expr.cpp context.cpp object.cpp environment.cpp
formatter.cpp state.cpp leanlua_state.cpp)
target_link_libraries(lua ${LEAN_LIBS})

View file

@ -377,6 +377,16 @@ static int expr_occurs(lua_State * L) {
return 1;
}
static int expr_is_eqp(lua_State * L) {
lua_pushboolean(L, is_eqp(to_expr(L, 1), to_expr(L, 2)));
return 1;
}
static int expr_hash(lua_State * L) {
lua_pushinteger(L, to_expr(L, 1).hash());
return 1;
}
static const struct luaL_Reg expr_m[] = {
{"__gc", expr_gc}, // never throws
{"__tostring", safe_function<expr_tostring>},
@ -389,7 +399,7 @@ static const struct luaL_Reg expr_m[] = {
{"is_constant", safe_function<expr_is_constant>},
{"is_app", safe_function<expr_is_app>},
{"is_eq", safe_function<expr_is_eq>},
{"is_lambda", safe_function<expr_is_lambda>},
{"is_lambda", safe_function<expr_is_lambda>},
{"is_pi", safe_function<expr_is_pi>},
{"is_abstraction", safe_function<expr_is_abstraction>},
{"is_let", safe_function<expr_is_let>},
@ -409,6 +419,8 @@ static const struct luaL_Reg expr_m[] = {
{"abstract", safe_function<expr_abstract>},
{"occurs", safe_function<expr_occurs>},
{"has_metavar", safe_function<expr_has_metavar>},
{"is_eqp", safe_function<expr_is_eqp>},
{"hash", safe_function<expr_hash>},
{0, 0}
};

View file

@ -19,6 +19,7 @@ Author: Leonardo de Moura
#include "bindings/lua/leanlua_state.h"
#include "bindings/lua/util.h"
#include "bindings/lua/name.h"
#include "bindings/lua/splay_map.h"
#include "bindings/lua/numerics.h"
#include "bindings/lua/options.h"
#include "bindings/lua/sexpr.h"
@ -145,6 +146,7 @@ struct leanlua_state::imp {
luaL_openlibs(m_state);
open_patch(m_state);
open_name(m_state);
open_splay_map(m_state);
open_mpz(m_state);
open_mpq(m_state);
open_options(m_state);

View file

@ -0,0 +1,205 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <lua.hpp>
#include "util/splay_map.h"
#include "bindings/lua/util.h"
#include "bindings/lua/expr.h"
#include "library/expr_lt.h"
namespace lean {
/**
\brief Reference to Lua object.
*/
class lua_ref {
lua_State * m_state;
int m_ref;
public:
lua_ref():m_state(nullptr) {}
lua_ref(lua_State * L, int i) {
lean_assert(L);
m_state = L;
lua_pushvalue(m_state, i);
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
}
lua_ref(lua_ref const & r) {
m_state = r.m_state;
if (m_state) {
r.push();
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
}
}
lua_ref(lua_ref && r) {
m_state = r.m_state;
m_ref = r.m_ref;
r.m_state = nullptr;
}
~lua_ref() {
if (m_state)
luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref);
}
lua_ref & operator=(lua_ref const & r) {
if (m_ref == r.m_ref)
return *this;
if (m_state)
luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref);
m_state = r.m_state;
if (m_state) {
r.push();
m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX);
}
return *this;
}
void push() const {
lean_assert(m_state);
lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_ref);
}
lua_State * get_state() const {
return m_state;
}
};
struct lua_lt_proc {
int operator()(lua_ref const & r1, lua_ref const & r2) const {
lean_assert(r1.get_state() == r2.get_state());
lua_State * L = r1.get_state();
r1.push();
r2.push();
int r;
if (lessthan(L, -2, -1)) {
r = -1;
} else if (lessthan(L, -1, -2)) {
r = 1;
} else if (equal(L, -2, -1)) {
r = 0;
} else {
throw exception("'<' is not a total order for the elements inserted on the table");
}
lua_pop(L, 2);
return r;
}
};
typedef splay_map<lua_ref, lua_ref, lua_lt_proc> lua_splay_map;
constexpr char const * splay_map_mt = "splay_map.mt";
bool is_splay_map(lua_State * L, int idx) {
return testudata(L, idx, splay_map_mt);
}
lua_splay_map & to_splay_map(lua_State * L, int idx) {
return *static_cast<lua_splay_map*>(luaL_checkudata(L, idx, splay_map_mt));
}
int push_splay_map(lua_State * L, lua_splay_map const & o) {
void * mem = lua_newuserdata(L, sizeof(lua_splay_map));
new (mem) lua_splay_map(o);
luaL_getmetatable(L, splay_map_mt);
lua_setmetatable(L, -2);
return 1;
}
static int mk_splay_map(lua_State * L) {
lua_splay_map r;
return push_splay_map(L, r);
}
static int splay_map_gc(lua_State * L) {
to_splay_map(L, 1).~lua_splay_map();
return 0;
}
static int splay_map_size(lua_State * L) {
lua_pushinteger(L, to_splay_map(L, 1).size());
return 1;
}
static int splay_map_contains(lua_State * L) {
lua_pushboolean(L, to_splay_map(L, 1).contains(lua_ref(L, 2)));
return 1;
}
static int splay_map_empty(lua_State * L) {
lua_pushboolean(L, to_splay_map(L, 1).empty());
return 1;
}
static int splay_map_insert(lua_State * L) {
to_splay_map(L, 1).insert(lua_ref(L, 2), lua_ref(L, 3));
return 0;
}
static int splay_map_erase(lua_State * L) {
to_splay_map(L, 1).erase(lua_ref(L, 2));
return 0;
}
static int splay_map_find(lua_State * L) {
lua_splay_map & m = to_splay_map(L, 1);
lua_ref * val = m.splay_find(lua_ref(L, 2));
if (val) {
lean_assert(val->get_state() == L);
val->push();
} else {
lua_pushnil(L);
}
return 1;
}
static int splay_map_copy(lua_State * L) {
return push_splay_map(L, to_splay_map(L, 1));
}
static int splay_map_pred(lua_State * L) {
lua_pushboolean(L, is_splay_map(L, 1));
return 1;
}
static int splay_map_for_each(lua_State * L) {
lua_splay_map & m = to_splay_map(L, 1); // map
luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun
m.for_each([&](lua_ref const & k, lua_ref const & v) {
lua_pushvalue(L, 2); // push user-fun
k.push();
v.push();
pcall(L, 2, 0, 0);
});
return 0;
}
static const struct luaL_Reg splay_map_m[] = {
{"__gc", splay_map_gc}, // never throws
{"__len", safe_function<splay_map_size> },
{"contains", safe_function<splay_map_contains>},
{"size", safe_function<splay_map_size>},
{"empty", safe_function<splay_map_empty>},
{"insert", safe_function<splay_map_insert>},
{"erase", safe_function<splay_map_erase>},
{"find", safe_function<splay_map_find>},
{"copy", safe_function<splay_map_copy>},
{"for_each", safe_function<splay_map_for_each>},
{0, 0}
};
void open_splay_map(lua_State * L) {
luaL_newmetatable(L, splay_map_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, splay_map_m, 0);
SET_GLOBAL_FUN(mk_splay_map, "splay_map");
SET_GLOBAL_FUN(splay_map_pred, "is_splay_map");
}
}

View file

@ -0,0 +1,11 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <lua.hpp>
namespace lean {
void open_splay_map(lua_State * L);
}

View file

@ -58,6 +58,22 @@ size_t objlen(lua_State * L, int idx) {
#endif
}
int lessthan(lua_State * L, int idx1, int idx2) {
#if LUA_VERSION_NUM < 502
return lua_lessthan(L, idx1, idx2);
#else
return lua_compare(L, idx1, idx2, LUA_OPLT);
#endif
}
int equal(lua_State * L, int idx1, int idx2) {
#if LUA_VERSION_NUM < 502
return lua_equal(L, idx1, idx2);
#else
return lua_compare(L, idx1, idx2, LUA_OPEQ);
#endif
}
static void exec(lua_State * L) {
pcall(L, 0, LUA_MULTRET, 0);
}

View file

@ -15,6 +15,8 @@ size_t objlen(lua_State * L, int idx);
void dofile(lua_State * L, char const * fname);
void dostring(lua_State * L, char const * str);
void pcall(lua_State * L, int nargs, int nresults, int errorfun);
int lessthan(lua_State * L, int idx1, int idx2);
int equal(lua_State * L, int idx1, int idx2);
/**
\brief Wrapper for invoking function f, and catching Lean exceptions.
*/

View file

@ -183,11 +183,11 @@ class splay_tree : public CMP {
if (n) {
if (n->m_left) {
check_invariant(n->m_left);
lean_assert(cmp(n->m_left->m_value, n->m_value) < 0);
lean_assert_lt(cmp(n->m_left->m_value, n->m_value), 0);
}
if (n->m_right) {
check_invariant(n->m_right);
lean_assert(cmp(n->m_value, n->m_right->m_value) < 0);
lean_assert_lt(cmp(n->m_value, n->m_right->m_value), 0);
}
}
return true;

88
tests/lua/map.lua Normal file
View file

@ -0,0 +1,88 @@
-- This examples demonstrates that Lean objects are not very useful as Lua table keys.
local f = Const("f")
local m = {}
local env = environment()
env:add_var("T", Type())
env:add_var("f", mk_arrow(Const("T"), Const("T")))
for i = 1, 100 do
env:add_var("a" .. i, Const("T"))
local t = f(Const("a" .. i))
-- Any object can be a key of a Lua table.
-- But, Lua does not use the method __eq for comparing keys.
-- The problem is that Lua uses its own hashcode that may not
-- be compatible with the __eq implemantion.
-- By non-compatible, we mean two objects my be equal by __eq, but
-- the hashcodes may be different.
m[t] = i
end
for t1, i in pairs(m) do
local t2 = f(Const("a" .. i))
-- print(t1, i, t2)
assert(m[t1] == i)
-- t1 and t2 are structurally equal
assert(t1 == t2)
-- t1 and t2 are different objects
assert(not t1:is_eqp(t2))
-- t2 is not a key of map
assert(m[t2] == nil)
assert(env:normalize(t1) == t1)
assert(t1:instantiate(Const("a")) == t1)
local t1_prime = t1:instantiate(Const("a"))
-- t1 and t1_prime are structurally equal
assert(t1 == t1_prime)
-- Moreover, they are references to the same Lean object
assert(t1:is_eqp(t1_prime))
-- But, they are wrapped by different Lua userdata
assert(m[t1_prime] == nil)
end
-- We can store elements that implement __lt and __eq metamethods in splay_maps.
-- The implementation assumes that the elements stored in the splay map can be totally ordered by __lt
m = splay_map()
m:insert(Const("a"), 10)
assert(m:contains(Const("a")))
assert(not m:contains(Const("b")))
assert(m:find(Const("a")) == 10)
local a, b = Consts("a, b")
m:insert(f(a, b), 20)
assert(m:find(f(a, b)) == 20)
assert(m:find(f(a, a)) == nil)
assert(m:size() == 2)
assert(#m == 2)
m:erase(f(a, a))
assert(m:size() == 2)
m:erase(f(a, b))
assert(m:size() == 1)
for i = 1, 100 do
local t = f(Const("a" .. i))
m:insert(t, i)
assert(m:contains(t))
assert(m:find(t) == i)
end
assert(m:size() == 101)
for i = 1, 100 do
local t = f(Const("a" .. i))
assert(m:find(t) == i)
end
-- The following call fails because integers cannot be compared with Lean expressions
assert(not pcall(function() m:insert(10, 20) end))
-- Splay maps copy operation is O(1)
local m2 = m:copy()
m2:insert(b, 20)
assert(m:size() == 101)
assert(m2:size() == 102)
-- We can also traverse all elements in the map
local num = 0
m:for_each(
function(k, v)
print(tostring(k) .. " -> " .. v)
num = num + 1
end
)
assert(num == 101)