diff --git a/src/frontends/lean/register_module.cpp b/src/frontends/lean/register_module.cpp index db0a5aad4..3cb528c27 100644 --- a/src/frontends/lean/register_module.cpp +++ b/src/frontends/lean/register_module.cpp @@ -7,10 +7,12 @@ Author: Leonardo de Moura #include #include "util/lua.h" #include "util/script_state.h" +#include "frontends/lean/token_set.h" #include "frontends/lean/parse_table.h" namespace lean { void open_frontend_lean(lua_State * L) { + open_token_set(L); open_parse_table(L); } void register_frontend_lean_module() { diff --git a/src/frontends/lean/token_set.cpp b/src/frontends/lean/token_set.cpp index 9f72121fa..ac5439ad2 100644 --- a/src/frontends/lean/token_set.cpp +++ b/src/frontends/lean/token_set.cpp @@ -20,15 +20,21 @@ token_set add_token(token_set const & s, char const * token, unsigned prec) { token_set add_token(token_set const & s, char const * token, char const * val, unsigned prec) { return insert(s, token, token_info(val, prec)); } -token_set merge(token_set const & s1, token_set const & s2) { - return merge(s1, s2); -} token_set const * find(token_set const & s, char c) { return s.find(c); } token_info const * value_of(token_set const & s) { return s.value(); } +void for_each(token_set const & s, std::function const & fn) { + s.for_each([&](unsigned num, char const * keys, token_info const & info) { + buffer str; + str.append(num, keys); + str.push_back(0); + fn(str.data(), info); + }); +} + static char const * g_lambda_unicode = "\u03BB"; static char const * g_pi_unicode = "\u03A0"; static char const * g_forall_unicode = "\u2200"; @@ -82,5 +88,90 @@ public: } }; static init_token_set_fn g_init; +token_set mk_token_set() { return token_set(); } token_set mk_default_token_set() { return g_init.m_token_set; } + +DECL_UDATA(token_set) +static int mk_token_set(lua_State * L) { return push_token_set(L, mk_token_set()); } +static int mk_default_token_set(lua_State * L) { return push_token_set(L, mk_default_token_set()); } +static int add_command_token(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 2) + return push_token_set(L, add_command_token(to_token_set(L, 1), lua_tostring(L, 2))); + else + return push_token_set(L, add_command_token(to_token_set(L, 1), lua_tostring(L, 2), lua_tostring(L, 3))); +} +static int add_token(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 3) + return push_token_set(L, add_token(to_token_set(L, 1), lua_tostring(L, 2), lua_tonumber(L, 3))); + else + return push_token_set(L, add_token(to_token_set(L, 1), lua_tostring(L, 2), lua_tostring(L, 3), lua_tonumber(L, 4))); +} +static int merge(lua_State * L) { + return push_token_set(L, merge(to_token_set(L, 1), to_token_set(L, 2))); +} +static int find(lua_State * L) { + char k; + if (lua_isnumber(L, 2)) { + k = lua_tonumber(L, 2); + } else { + char const * str = lua_tostring(L, 2); + if (strlen(str) != 1) + throw exception("arg #2 must be a string of length 1"); + k = str[0]; + } + auto it = to_token_set(L, 1).find(k); + if (it) + return push_token_set(L, *it); + else + return push_nil(L); +} +static int value_of(lua_State * L) { + auto it = value_of(to_token_set(L, 1)); + if (it) { + push_boolean(L, it->is_command()); + push_name(L, it->value()); + push_integer(L, it->precedence()); + return 3; + } else { + push_nil(L); + return 1; + } +} +static int for_each(lua_State * L) { + token_set const & t = to_token_set(L, 1); + luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun + for_each(t, [&](char const * k, token_info const & info) { + lua_pushvalue(L, 2); + lua_pushstring(L, k); + lua_pushboolean(L, info.is_command()); + push_name(L, info.value()); + lua_pushinteger(L, info.precedence()); + pcall(L, 4, 0, 0); + }); + return 0; +} + +static const struct luaL_Reg token_set_m[] = { + {"__gc", token_set_gc}, + {"add_command_token", safe_function}, + {"add_token", safe_function}, + {"merge", safe_function}, + {"find", safe_function}, + {"value_of", safe_function}, + {"for_each", safe_function}, + {0, 0} +}; + +void open_token_set(lua_State * L) { + luaL_newmetatable(L, token_set_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, token_set_m, 0); + + SET_GLOBAL_FUN(token_set_pred, "is_token_set"); + SET_GLOBAL_FUN(mk_default_token_set, "default_token_set"); + SET_GLOBAL_FUN(mk_token_set, "token_set"); +} } diff --git a/src/frontends/lean/token_set.h b/src/frontends/lean/token_set.h index 4fa241ef0..c7534bc5c 100644 --- a/src/frontends/lean/token_set.h +++ b/src/frontends/lean/token_set.h @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include #include "util/trie.h" #include "util/name.h" +#include "util/lua.h" namespace lean { class token_info { @@ -25,12 +26,14 @@ public: }; typedef ctrie token_set; +token_set mk_token_set(); token_set mk_default_token_set(); token_set add_command_token(token_set const & s, char const * token); token_set add_command_token(token_set const & s, char const * token, char const * val); token_set add_token(token_set const & s, char const * token, unsigned prec = 0); token_set add_token(token_set const & s, char const * token, char const * val, unsigned prec = 0); -token_set merge(token_set const & s1, token_set const & s2); +void for_each(token_set const & s, std::function const & fn); token_set const * find(token_set const & s, char c); token_info const * value_of(token_set const & s); +void open_token_set(lua_State * L); } diff --git a/tests/lua/token_set.lua b/tests/lua/token_set.lua new file mode 100644 index 000000000..971dda7e9 --- /dev/null +++ b/tests/lua/token_set.lua @@ -0,0 +1,45 @@ +function display_token_set(s) + s:for_each(function(k, cmd, val, prec) + io.write(k) + if cmd then + io.write(" [command]") + end + print(" => " .. tostring(val) .. " " .. tostring(prec)) + end) +end + +function token_set_size(s) + local r = 0 + s:for_each(function() r = r + 1 end) + return r +end + +local s = token_set() +assert(is_token_set(s)) +assert(token_set_size(s) == 0) +s = s:add_command_token("test", "tst1") +s = s:add_command_token("tast", "tst2") +s = s:add_command_token("tests", "tst3") +s = s:add_command_token("fests", "tst4") +s = s:add_command_token("tes", "tst5") +s = s:add_token("++", "++", 65) +s = s:add_token("++-", "plusminus") +assert(token_set_size(s) == 7) +display_token_set(s) + + +print("========") +local s2 = default_token_set() +display_token_set(s2) +assert(token_set_size(s2) > 0) +local sz1 = token_set_size(s) +local sz2 = token_set_size(s2) +s2 = s2:merge(s) +assert(token_set_size(s2) == sz1 + sz2) +s2 = s2:find("t"):find("e") +print("========") +display_token_set(s2) +assert(token_set_size(s2) == 3) +s2 = s2:find("s") +local cmd, val, prec = s2:value_of() +assert(val == name("tst5"))