diff --git a/src/bindings/lua/CMakeLists.txt b/src/bindings/lua/CMakeLists.txt index 3886bf8dd..0ec167bf5 100644 --- a/src/bindings/lua/CMakeLists.txt +++ b/src/bindings/lua/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(lua util.cpp lua_exception.cpp name.cpp numerics.cpp -options.cpp sexpr.cpp format.cpp level.cpp local_context.cpp expr.cpp -context.cpp object.cpp environment.cpp formatter.cpp state.cpp leanlua_state.cpp) +splay_map.cpp options.cpp sexpr.cpp format.cpp level.cpp +local_context.cpp expr.cpp context.cpp object.cpp environment.cpp +formatter.cpp state.cpp leanlua_state.cpp) target_link_libraries(lua ${LEAN_LIBS}) diff --git a/src/bindings/lua/expr.cpp b/src/bindings/lua/expr.cpp index 781674480..ee54747e1 100644 --- a/src/bindings/lua/expr.cpp +++ b/src/bindings/lua/expr.cpp @@ -377,6 +377,16 @@ static int expr_occurs(lua_State * L) { return 1; } +static int expr_is_eqp(lua_State * L) { + lua_pushboolean(L, is_eqp(to_expr(L, 1), to_expr(L, 2))); + return 1; +} + +static int expr_hash(lua_State * L) { + lua_pushinteger(L, to_expr(L, 1).hash()); + return 1; +} + static const struct luaL_Reg expr_m[] = { {"__gc", expr_gc}, // never throws {"__tostring", safe_function}, @@ -389,7 +399,7 @@ static const struct luaL_Reg expr_m[] = { {"is_constant", safe_function}, {"is_app", safe_function}, {"is_eq", safe_function}, - {"is_lambda", safe_function}, + {"is_lambda", safe_function}, {"is_pi", safe_function}, {"is_abstraction", safe_function}, {"is_let", safe_function}, @@ -409,6 +419,8 @@ static const struct luaL_Reg expr_m[] = { {"abstract", safe_function}, {"occurs", safe_function}, {"has_metavar", safe_function}, + {"is_eqp", safe_function}, + {"hash", safe_function}, {0, 0} }; diff --git a/src/bindings/lua/leanlua_state.cpp b/src/bindings/lua/leanlua_state.cpp index b60d8615c..38ecf4ce4 100644 --- a/src/bindings/lua/leanlua_state.cpp +++ b/src/bindings/lua/leanlua_state.cpp @@ -19,6 +19,7 @@ Author: Leonardo de Moura #include "bindings/lua/leanlua_state.h" #include "bindings/lua/util.h" #include "bindings/lua/name.h" +#include "bindings/lua/splay_map.h" #include "bindings/lua/numerics.h" #include "bindings/lua/options.h" #include "bindings/lua/sexpr.h" @@ -145,6 +146,7 @@ struct leanlua_state::imp { luaL_openlibs(m_state); open_patch(m_state); open_name(m_state); + open_splay_map(m_state); open_mpz(m_state); open_mpq(m_state); open_options(m_state); diff --git a/src/bindings/lua/splay_map.cpp b/src/bindings/lua/splay_map.cpp new file mode 100644 index 000000000..9b75c94ab --- /dev/null +++ b/src/bindings/lua/splay_map.cpp @@ -0,0 +1,205 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include "util/splay_map.h" +#include "bindings/lua/util.h" + +#include "bindings/lua/expr.h" +#include "library/expr_lt.h" + +namespace lean { +/** + \brief Reference to Lua object. +*/ +class lua_ref { + lua_State * m_state; + int m_ref; +public: + lua_ref():m_state(nullptr) {} + + lua_ref(lua_State * L, int i) { + lean_assert(L); + m_state = L; + lua_pushvalue(m_state, i); + m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX); + } + + lua_ref(lua_ref const & r) { + m_state = r.m_state; + if (m_state) { + r.push(); + m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX); + } + } + + lua_ref(lua_ref && r) { + m_state = r.m_state; + m_ref = r.m_ref; + r.m_state = nullptr; + } + + ~lua_ref() { + if (m_state) + luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref); + } + + lua_ref & operator=(lua_ref const & r) { + if (m_ref == r.m_ref) + return *this; + if (m_state) + luaL_unref(m_state, LUA_REGISTRYINDEX, m_ref); + m_state = r.m_state; + if (m_state) { + r.push(); + m_ref = luaL_ref(m_state, LUA_REGISTRYINDEX); + } + return *this; + } + + void push() const { + lean_assert(m_state); + lua_rawgeti(m_state, LUA_REGISTRYINDEX, m_ref); + } + + lua_State * get_state() const { + return m_state; + } +}; + +struct lua_lt_proc { + int operator()(lua_ref const & r1, lua_ref const & r2) const { + lean_assert(r1.get_state() == r2.get_state()); + lua_State * L = r1.get_state(); + r1.push(); + r2.push(); + int r; + if (lessthan(L, -2, -1)) { + r = -1; + } else if (lessthan(L, -1, -2)) { + r = 1; + } else if (equal(L, -2, -1)) { + r = 0; + } else { + throw exception("'<' is not a total order for the elements inserted on the table"); + } + lua_pop(L, 2); + return r; + } +}; + +typedef splay_map lua_splay_map; + +constexpr char const * splay_map_mt = "splay_map.mt"; + +bool is_splay_map(lua_State * L, int idx) { + return testudata(L, idx, splay_map_mt); +} + +lua_splay_map & to_splay_map(lua_State * L, int idx) { + return *static_cast(luaL_checkudata(L, idx, splay_map_mt)); +} + +int push_splay_map(lua_State * L, lua_splay_map const & o) { + void * mem = lua_newuserdata(L, sizeof(lua_splay_map)); + new (mem) lua_splay_map(o); + luaL_getmetatable(L, splay_map_mt); + lua_setmetatable(L, -2); + return 1; +} + +static int mk_splay_map(lua_State * L) { + lua_splay_map r; + return push_splay_map(L, r); +} + +static int splay_map_gc(lua_State * L) { + to_splay_map(L, 1).~lua_splay_map(); + return 0; +} + +static int splay_map_size(lua_State * L) { + lua_pushinteger(L, to_splay_map(L, 1).size()); + return 1; +} + +static int splay_map_contains(lua_State * L) { + lua_pushboolean(L, to_splay_map(L, 1).contains(lua_ref(L, 2))); + return 1; +} + +static int splay_map_empty(lua_State * L) { + lua_pushboolean(L, to_splay_map(L, 1).empty()); + return 1; +} + +static int splay_map_insert(lua_State * L) { + to_splay_map(L, 1).insert(lua_ref(L, 2), lua_ref(L, 3)); + return 0; +} + +static int splay_map_erase(lua_State * L) { + to_splay_map(L, 1).erase(lua_ref(L, 2)); + return 0; +} + +static int splay_map_find(lua_State * L) { + lua_splay_map & m = to_splay_map(L, 1); + lua_ref * val = m.splay_find(lua_ref(L, 2)); + if (val) { + lean_assert(val->get_state() == L); + val->push(); + } else { + lua_pushnil(L); + } + return 1; +} + +static int splay_map_copy(lua_State * L) { + return push_splay_map(L, to_splay_map(L, 1)); +} + +static int splay_map_pred(lua_State * L) { + lua_pushboolean(L, is_splay_map(L, 1)); + return 1; +} + +static int splay_map_for_each(lua_State * L) { + lua_splay_map & m = to_splay_map(L, 1); // map + luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun + m.for_each([&](lua_ref const & k, lua_ref const & v) { + lua_pushvalue(L, 2); // push user-fun + k.push(); + v.push(); + pcall(L, 2, 0, 0); + }); + return 0; +} + +static const struct luaL_Reg splay_map_m[] = { + {"__gc", splay_map_gc}, // never throws + {"__len", safe_function }, + {"contains", safe_function}, + {"size", safe_function}, + {"empty", safe_function}, + {"insert", safe_function}, + {"erase", safe_function}, + {"find", safe_function}, + {"copy", safe_function}, + {"for_each", safe_function}, + {0, 0} +}; + +void open_splay_map(lua_State * L) { + luaL_newmetatable(L, splay_map_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, splay_map_m, 0); + + SET_GLOBAL_FUN(mk_splay_map, "splay_map"); + SET_GLOBAL_FUN(splay_map_pred, "is_splay_map"); +} +} diff --git a/src/bindings/lua/splay_map.h b/src/bindings/lua/splay_map.h new file mode 100644 index 000000000..7fb5b69a3 --- /dev/null +++ b/src/bindings/lua/splay_map.h @@ -0,0 +1,11 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +namespace lean { +void open_splay_map(lua_State * L); +} diff --git a/src/bindings/lua/util.cpp b/src/bindings/lua/util.cpp index 311d3c668..065675c57 100644 --- a/src/bindings/lua/util.cpp +++ b/src/bindings/lua/util.cpp @@ -58,6 +58,22 @@ size_t objlen(lua_State * L, int idx) { #endif } +int lessthan(lua_State * L, int idx1, int idx2) { + #if LUA_VERSION_NUM < 502 + return lua_lessthan(L, idx1, idx2); + #else + return lua_compare(L, idx1, idx2, LUA_OPLT); + #endif +} + +int equal(lua_State * L, int idx1, int idx2) { + #if LUA_VERSION_NUM < 502 + return lua_equal(L, idx1, idx2); + #else + return lua_compare(L, idx1, idx2, LUA_OPEQ); + #endif +} + static void exec(lua_State * L) { pcall(L, 0, LUA_MULTRET, 0); } diff --git a/src/bindings/lua/util.h b/src/bindings/lua/util.h index df7abc81a..b25dd2495 100644 --- a/src/bindings/lua/util.h +++ b/src/bindings/lua/util.h @@ -15,6 +15,8 @@ size_t objlen(lua_State * L, int idx); void dofile(lua_State * L, char const * fname); void dostring(lua_State * L, char const * str); void pcall(lua_State * L, int nargs, int nresults, int errorfun); +int lessthan(lua_State * L, int idx1, int idx2); +int equal(lua_State * L, int idx1, int idx2); /** \brief Wrapper for invoking function f, and catching Lean exceptions. */ diff --git a/src/util/splay_tree.h b/src/util/splay_tree.h index 9e90d7362..0814f6ae9 100644 --- a/src/util/splay_tree.h +++ b/src/util/splay_tree.h @@ -183,11 +183,11 @@ class splay_tree : public CMP { if (n) { if (n->m_left) { check_invariant(n->m_left); - lean_assert(cmp(n->m_left->m_value, n->m_value) < 0); + lean_assert_lt(cmp(n->m_left->m_value, n->m_value), 0); } if (n->m_right) { check_invariant(n->m_right); - lean_assert(cmp(n->m_value, n->m_right->m_value) < 0); + lean_assert_lt(cmp(n->m_value, n->m_right->m_value), 0); } } return true; diff --git a/tests/lua/map.lua b/tests/lua/map.lua new file mode 100644 index 000000000..826b411dd --- /dev/null +++ b/tests/lua/map.lua @@ -0,0 +1,88 @@ +-- This examples demonstrates that Lean objects are not very useful as Lua table keys. +local f = Const("f") +local m = {} +local env = environment() +env:add_var("T", Type()) +env:add_var("f", mk_arrow(Const("T"), Const("T"))) +for i = 1, 100 do + env:add_var("a" .. i, Const("T")) + local t = f(Const("a" .. i)) + -- Any object can be a key of a Lua table. + -- But, Lua does not use the method __eq for comparing keys. + -- The problem is that Lua uses its own hashcode that may not + -- be compatible with the __eq implemantion. + -- By non-compatible, we mean two objects my be equal by __eq, but + -- the hashcodes may be different. + m[t] = i +end + +for t1, i in pairs(m) do + local t2 = f(Const("a" .. i)) + -- print(t1, i, t2) + assert(m[t1] == i) + -- t1 and t2 are structurally equal + assert(t1 == t2) + -- t1 and t2 are different objects + assert(not t1:is_eqp(t2)) + -- t2 is not a key of map + assert(m[t2] == nil) + assert(env:normalize(t1) == t1) + assert(t1:instantiate(Const("a")) == t1) + local t1_prime = t1:instantiate(Const("a")) + -- t1 and t1_prime are structurally equal + assert(t1 == t1_prime) + -- Moreover, they are references to the same Lean object + assert(t1:is_eqp(t1_prime)) + -- But, they are wrapped by different Lua userdata + assert(m[t1_prime] == nil) +end + +-- We can store elements that implement __lt and __eq metamethods in splay_maps. +-- The implementation assumes that the elements stored in the splay map can be totally ordered by __lt +m = splay_map() +m:insert(Const("a"), 10) +assert(m:contains(Const("a"))) +assert(not m:contains(Const("b"))) +assert(m:find(Const("a")) == 10) +local a, b = Consts("a, b") +m:insert(f(a, b), 20) +assert(m:find(f(a, b)) == 20) +assert(m:find(f(a, a)) == nil) +assert(m:size() == 2) +assert(#m == 2) +m:erase(f(a, a)) +assert(m:size() == 2) +m:erase(f(a, b)) +assert(m:size() == 1) +for i = 1, 100 do + local t = f(Const("a" .. i)) + m:insert(t, i) + assert(m:contains(t)) + assert(m:find(t) == i) +end + +assert(m:size() == 101) + +for i = 1, 100 do + local t = f(Const("a" .. i)) + assert(m:find(t) == i) +end + +-- The following call fails because integers cannot be compared with Lean expressions +assert(not pcall(function() m:insert(10, 20) end)) + +-- Splay maps copy operation is O(1) +local m2 = m:copy() +m2:insert(b, 20) +assert(m:size() == 101) +assert(m2:size() == 102) + +-- We can also traverse all elements in the map +local num = 0 +m:for_each( + function(k, v) + print(tostring(k) .. " -> " .. v) + num = num + 1 + end +) +assert(num == 101)