feat(frontends/lean): allow parser actions to be implemented using Lua

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-17 20:39:42 -07:00
parent 4cbc429192
commit f17e8a853a
7 changed files with 129 additions and 3 deletions

View file

@ -27,6 +27,7 @@ static name g_infixl("infixl");
static name g_infixr("infixr"); static name g_infixr("infixr");
static name g_postfix("postfix"); static name g_postfix("postfix");
static name g_notation("notation"); static name g_notation("notation");
static name g_call("call");
static std::string parse_symbol(parser & p, char const * msg) { static std::string parse_symbol(parser & p, char const * msg) {
name n; name n;
@ -74,6 +75,7 @@ using notation::mk_binders_action;
using notation::mk_exprs_action; using notation::mk_exprs_action;
using notation::mk_scoped_expr_action; using notation::mk_scoped_expr_action;
using notation::mk_skip_action; using notation::mk_skip_action;
using notation::mk_ext_lua_action;
using notation::transition; using notation::transition;
using notation::action; using notation::action;
@ -180,6 +182,10 @@ static action parse_action(parser & p, buffer<expr> & locals, buffer<token_entry
if (p.curr_is_numeral()) { if (p.curr_is_numeral()) {
unsigned prec = parse_precedence(p, "invalid notation declaration, small numeral expected"); unsigned prec = parse_precedence(p, "invalid notation declaration, small numeral expected");
return mk_expr_action(prec); return mk_expr_action(prec);
} else if (p.curr_is_string()) {
std::string fn = p.get_str_val();
p.next();
return mk_ext_lua_action(fn.c_str());
} else if (p.curr_is_token_or_id(g_scoped)) { } else if (p.curr_is_token_or_id(g_scoped)) {
p.next(); p.next();
return mk_scoped_expr_action(mk_var(0)); return mk_scoped_expr_action(mk_var(0));
@ -218,6 +224,11 @@ static action parse_action(parser & p, buffer<expr> & locals, buffer<token_entry
} }
p.check_token_next(g_rparen, "invalid scoped notation argument, ')' expected"); p.check_token_next(g_rparen, "invalid scoped notation argument, ')' expected");
return mk_scoped_expr_action(rec, prec ? *prec : 0); return mk_scoped_expr_action(rec, prec ? *prec : 0);
} else if (p.curr_is_token_or_id(g_call)) {
p.next();
name fn = p.check_id_next("invalid call notation argument, identifier expected");
p.check_token_next(g_rparen, "invalid call notation argument, ')' expected");
return mk_ext_lua_action(fn.to_string().c_str());
} else { } else {
throw parser_error("invalid notation declaration, 'foldl', 'foldr' or 'scoped' expected", p.pos()); throw parser_error("invalid notation declaration, 'foldl', 'foldr' or 'scoped' expected", p.pos());
} }

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#include <string>
#include <utility> #include <utility>
#include "util/rb_map.h" #include "util/rb_map.h"
#include "util/sstream.h" #include "util/sstream.h"
@ -58,6 +59,12 @@ struct ext_action_cell : public action_cell {
action_cell(action_kind::Ext), m_parse_fn(fn) {} action_cell(action_kind::Ext), m_parse_fn(fn) {}
}; };
struct ext_lua_action_cell : public action_cell {
std::string m_lua_fn;
ext_lua_action_cell(char const * fn):
action_cell(action_kind::LuaExt), m_lua_fn(fn) {}
};
action::action(action_cell * ptr):m_ptr(ptr) { lean_assert(ptr); } action::action(action_cell * ptr):m_ptr(ptr) { lean_assert(ptr); }
action::action():action(mk_skip_action()) {} action::action():action(mk_skip_action()) {}
action::action(action const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } action::action(action const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); }
@ -82,6 +89,10 @@ ext_action_cell * to_ext_action(action_cell * c) {
lean_assert(c->m_kind == action_kind::Ext); lean_assert(c->m_kind == action_kind::Ext);
return static_cast<ext_action_cell*>(c); return static_cast<ext_action_cell*>(c);
} }
ext_lua_action_cell * to_ext_lua_action(action_cell * c) {
lean_assert(c->m_kind == action_kind::LuaExt);
return static_cast<ext_lua_action_cell*>(c);
}
unsigned action::rbp() const { return to_expr_action(m_ptr)->m_rbp; } unsigned action::rbp() const { return to_expr_action(m_ptr)->m_rbp; }
name const & action::get_sep() const { return to_exprs_action(m_ptr)->m_token_sep; } name const & action::get_sep() const { return to_exprs_action(m_ptr)->m_token_sep; }
expr const & action::get_rec() const { expr const & action::get_rec() const {
@ -94,6 +105,7 @@ bool action::use_lambda_abstraction() const { return to_scoped_expr_action(m_ptr
expr const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; } expr const & action::get_initial() const { return to_exprs_action(m_ptr)->m_ini; }
bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; } bool action::is_fold_right() const { return to_exprs_action(m_ptr)->m_fold_right; }
parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; } parse_fn const & action::get_parse_fn() const { return to_ext_action(m_ptr)->m_parse_fn; }
std::string const & action::get_lua_fn() const { return to_ext_lua_action(m_ptr)->m_lua_fn; }
bool action::is_compatible(action const & a) const { bool action::is_compatible(action const & a) const {
if (kind() != a.kind()) if (kind() != a.kind())
return false; return false;
@ -102,6 +114,8 @@ bool action::is_compatible(action const & a) const {
return true; return true;
case action_kind::Ext: case action_kind::Ext:
return m_ptr == a.m_ptr; return m_ptr == a.m_ptr;
case action_kind::LuaExt:
return get_lua_fn() == a.get_lua_fn();
case action_kind::Expr: case action_kind::Expr:
return rbp() == a.rbp(); return rbp() == a.rbp();
case action_kind::Exprs: case action_kind::Exprs:
@ -124,6 +138,7 @@ void action_cell::dealloc() {
case action_kind::Exprs: delete(to_exprs_action(this)); break; case action_kind::Exprs: delete(to_exprs_action(this)); break;
case action_kind::ScopedExpr: delete(to_scoped_expr_action(this)); break; case action_kind::ScopedExpr: delete(to_scoped_expr_action(this)); break;
case action_kind::Ext: delete(to_ext_action(this)); break; case action_kind::Ext: delete(to_ext_action(this)); break;
case action_kind::LuaExt: delete(to_ext_lua_action(this)); break;
default: delete this; break; default: delete this; break;
} }
} }
@ -154,6 +169,7 @@ action mk_scoped_expr_action(expr const & rec, unsigned rb, bool lambda) {
return action(new scoped_expr_action_cell(rec, rb, lambda)); return action(new scoped_expr_action_cell(rec, rb, lambda));
} }
action mk_ext_action(parse_fn const & fn) { return action(new ext_action_cell(fn)); } action mk_ext_action(parse_fn const & fn) { return action(new ext_action_cell(fn)); }
action mk_ext_lua_action(char const * fn) { return action(new ext_lua_action_cell(fn)); }
struct parse_table::cell { struct parse_table::cell {
bool m_nud; bool m_nud;
@ -196,7 +212,7 @@ static void validate_transitions(bool nud, unsigned num, transition const * ts,
case action_kind::Binder: case action_kind::Binders: case action_kind::Binder: case action_kind::Binders:
found_binder = true; found_binder = true;
break; break;
case action_kind::Expr: case action_kind::Exprs: case action_kind::Ext: case action_kind::Expr: case action_kind::Exprs: case action_kind::Ext: case action_kind::LuaExt:
nargs++; nargs++;
break; break;
case action_kind::ScopedExpr: case action_kind::ScopedExpr:
@ -296,6 +312,14 @@ static int mk_scoped_expr_action(lua_State * L) {
bool lambda = (nargs <= 2) || lua_toboolean(L, 3); bool lambda = (nargs <= 2) || lua_toboolean(L, 3);
return push_notation_action(L, mk_scoped_expr_action(to_expr(L, 1), rbp, lambda)); return push_notation_action(L, mk_scoped_expr_action(to_expr(L, 1), rbp, lambda));
} }
static int mk_ext_lua_action(lua_State * L) {
char const * fn = lua_tostring(L, 1);
lua_getglobal(L, fn);
if (lua_isnil(L, -1))
throw exception("arg #1 is a unknown function name");
lua_pop(L, 1);
return push_notation_action(L, mk_ext_lua_action(fn));
}
static int is_compatible(lua_State * L) { static int is_compatible(lua_State * L) {
return push_boolean(L, to_notation_action(L, 1).is_compatible(to_notation_action(L, 2))); return push_boolean(L, to_notation_action(L, 1).is_compatible(to_notation_action(L, 2)));
} }
@ -329,6 +353,10 @@ static int use_lambda_abstraction(lua_State * L) {
check_action(L, 1, { action_kind::ScopedExpr }); check_action(L, 1, { action_kind::ScopedExpr });
return push_boolean(L, to_notation_action(L, 1).use_lambda_abstraction()); return push_boolean(L, to_notation_action(L, 1).use_lambda_abstraction());
} }
static int fn(lua_State * L) {
check_action(L, 1, { action_kind::LuaExt });
return push_string(L, to_notation_action(L, 1).get_lua_fn().c_str());
}
static const struct luaL_Reg notation_action_m[] = { static const struct luaL_Reg notation_action_m[] = {
{"__gc", notation_action_gc}, {"__gc", notation_action_gc},
@ -341,6 +369,7 @@ static const struct luaL_Reg notation_action_m[] = {
{"initial", safe_function<initial>}, {"initial", safe_function<initial>},
{"is_fold_right", safe_function<is_fold_right>}, {"is_fold_right", safe_function<is_fold_right>},
{"use_lambda_abstraction", safe_function<use_lambda_abstraction>}, {"use_lambda_abstraction", safe_function<use_lambda_abstraction>},
{"fn", safe_function<fn>},
{0, 0} {0, 0}
}; };
@ -357,6 +386,7 @@ static void open_notation_action(lua_State * L) {
SET_GLOBAL_FUN(mk_expr_action, "expr_notation_action"); SET_GLOBAL_FUN(mk_expr_action, "expr_notation_action");
SET_GLOBAL_FUN(mk_exprs_action, "exprs_notation_action"); SET_GLOBAL_FUN(mk_exprs_action, "exprs_notation_action");
SET_GLOBAL_FUN(mk_scoped_expr_action, "scoped_expr_notation_action"); SET_GLOBAL_FUN(mk_scoped_expr_action, "scoped_expr_notation_action");
SET_GLOBAL_FUN(mk_ext_lua_action, "ext_action");
push_notation_action(L, mk_skip_action()); push_notation_action(L, mk_skip_action());
lua_setglobal(L, "Skip"); lua_setglobal(L, "Skip");
@ -373,6 +403,7 @@ static void open_notation_action(lua_State * L) {
SET_ENUM("Binders", action_kind::Binders); SET_ENUM("Binders", action_kind::Binders);
SET_ENUM("ScopedExpr", action_kind::ScopedExpr); SET_ENUM("ScopedExpr", action_kind::ScopedExpr);
SET_ENUM("Ext", action_kind::Ext); SET_ENUM("Ext", action_kind::Ext);
SET_ENUM("LuaExt", action_kind::LuaExt);
lua_setglobal(L, "notation_action_kind"); lua_setglobal(L, "notation_action_kind");
} }

View file

@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#pragma once #pragma once
#include <string>
#include <utility> #include <utility>
#include "util/buffer.h" #include "util/buffer.h"
#include "util/lua.h" #include "util/lua.h"
@ -17,7 +18,7 @@ class parser;
namespace notation { namespace notation {
typedef std::function<expr(parser &, unsigned, expr const *, pos_info const &)> parse_fn; typedef std::function<expr(parser &, unsigned, expr const *, pos_info const &)> parse_fn;
enum class action_kind { Skip, Expr, Exprs, Binder, Binders, ScopedExpr, Ext }; enum class action_kind { Skip, Expr, Exprs, Binder, Binders, ScopedExpr, Ext, LuaExt };
struct action_cell; struct action_cell;
/** /**
@ -63,6 +64,7 @@ public:
friend action mk_binders_action(); friend action mk_binders_action();
friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda); friend action mk_scoped_expr_action(expr const & rec, unsigned rbp, bool lambda);
friend action mk_ext_action(parse_fn const & fn); friend action mk_ext_action(parse_fn const & fn);
friend action mk_ext_lua_action(char const * lua_fn);
action_kind kind() const; action_kind kind() const;
unsigned rbp() const; unsigned rbp() const;
@ -72,6 +74,7 @@ public:
bool is_fold_right() const; bool is_fold_right() const;
bool use_lambda_abstraction() const; bool use_lambda_abstraction() const;
parse_fn const & get_parse_fn() const; parse_fn const & get_parse_fn() const;
std::string const & get_lua_fn() const;
bool is_compatible(action const & a) const; bool is_compatible(action const & a) const;
}; };
@ -83,6 +86,7 @@ action mk_binder_action();
action mk_binders_action(); action mk_binders_action();
action mk_scoped_expr_action(expr const & rec, unsigned rbp = 0, bool lambda = true); action mk_scoped_expr_action(expr const & rec, unsigned rbp = 0, bool lambda = true);
action mk_ext_action(parse_fn const & fn); action mk_ext_action(parse_fn const & fn);
action mk_ext_lua_action(char const * lua_fn);
class transition { class transition {
name m_token; name m_token;

View file

@ -64,6 +64,42 @@ parser::no_undef_id_error_scope::~no_undef_id_error_scope() {
m_p.m_no_undef_id_error = m_old; m_p.m_no_undef_id_error = m_old;
} }
static char g_parser_key;
void set_global_parser(lua_State * L, parser * p) {
lua_pushlightuserdata(L, static_cast<void *>(&g_parser_key));
lua_pushlightuserdata(L, static_cast<void *>(p));
lua_settable(L, LUA_REGISTRYINDEX);
}
parser * get_global_parser_ptr(lua_State * L) {
lua_pushlightuserdata(L, static_cast<void *>(&g_parser_key));
lua_gettable(L, LUA_REGISTRYINDEX);
if (!lua_islightuserdata(L, -1))
return nullptr;
parser * p = static_cast<parser*>(const_cast<void*>(lua_topointer(L, -1)));
lua_pop(L, 1);
return p;
}
parser & get_global_parser(lua_State * L) {
parser * p = get_global_parser_ptr(L);
if (p == nullptr)
throw exception("there is no Lean parser on the Lua stack");
return *p;
}
struct scoped_set_parser {
lua_State * m_state;
parser * m_old;
scoped_set_parser(lua_State * L, parser & p):m_state(L) {
m_old = get_global_parser_ptr(L);
set_global_parser(L, &p);
}
~scoped_set_parser() {
set_global_parser(m_state, m_old);
}
};
parser::parser(environment const & env, io_state const & ios, parser::parser(environment const & env, io_state const & ios,
std::istream & strm, char const * strm_name, std::istream & strm, char const * strm_name,
script_state * ss, bool use_exceptions, unsigned num_threads, script_state * ss, bool use_exceptions, unsigned num_threads,
@ -694,10 +730,32 @@ expr parser::parse_notation(parse_table t, expr * left) {
args.push_back(r); args.push_back(r);
break; break;
} }
case notation::action_kind::LuaExt:
if (!m_ss)
throw parser_error("failed to use notation implemented in Lua, parser does not contain a Lua state", p);
using_script([&](lua_State * L) {
scoped_set_parser scope(L, *this);
lua_getglobal(L, a.get_lua_fn().c_str());
if (!lua_isfunction(L, -1))
throw parser_error(sstream() << "failed to use notation implemented in Lua, Lua state does not contain function '"
<< a.get_lua_fn() << "'", p);
lua_pushinteger(L, p.first);
lua_pushinteger(L, p.second);
for (unsigned i = 0; i < args.size(); i++)
push_expr(L, args[i]);
pcall(L, args.size() + 2, 1, 0);
if (!is_expr(L, -1))
throw parser_error(sstream() << "failed to use notation implemented in Lua, value returned by function '"
<< a.get_lua_fn() << "' is not an expression", p);
args.push_back(to_expr(L, -1));
lua_pop(L, 1);
});
break;
case notation::action_kind::Ext: case notation::action_kind::Ext:
args.push_back(a.get_parse_fn()(*this, args.size(), args.data(), p)); args.push_back(a.get_parse_fn()(*this, args.size(), args.data(), p));
break; break;
} }
t = r->second; t = r->second;
} }
list<expr> const & as = t.is_accepting(); list<expr> const & as = t.is_accepting();
@ -979,7 +1037,6 @@ bool parser::parse_commands() {
return !m_found_errors; return !m_found_errors;
} }
bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, script_state * S, bool use_exceptions, bool parse_commands(environment & env, io_state & ios, std::istream & in, char const * strm_name, script_state * S, bool use_exceptions,
unsigned num_threads) { unsigned num_threads) {
parser p(env, ios, in, strm_name, S, use_exceptions, num_threads); parser p(env, ios, in, strm_name, S, use_exceptions, num_threads);
@ -995,4 +1052,18 @@ bool parse_commands(environment & env, io_state & ios, char const * fname, scrip
throw exception(sstream() << "failed to open file '" << fname << "'"); throw exception(sstream() << "failed to open file '" << fname << "'");
return parse_commands(env, ios, in, fname, S, use_exceptions, num_threads); return parse_commands(env, ios, in, fname, S, use_exceptions, num_threads);
} }
static int parse_expr(lua_State * L) {
script_state S = to_script_state(L);
int nargs = lua_gettop(L);
expr r;
S.exec_unprotected([&]() {
r = get_global_parser(L).parse_expr(nargs == 0 ? 0 : lua_tointeger(L, 1));
});
return push_expr(L, r);
}
void open_parser(lua_State * L) {
SET_GLOBAL_FUN(parse_expr, "parse_expr");
}
} }

View file

@ -270,4 +270,5 @@ bool parse_commands(environment & env, io_state & ios, std::istream & in, char c
script_state * S, bool use_exceptions, unsigned num_threads); script_state * S, bool use_exceptions, unsigned num_threads);
bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S, bool parse_commands(environment & env, io_state & ios, char const * fname, script_state * S,
bool use_exceptions, unsigned num_threads); bool use_exceptions, unsigned num_threads);
void open_parser(lua_State * L);
} }

View file

@ -75,6 +75,9 @@ serializer & operator<<(serializer & s, action const & a) {
case action_kind::ScopedExpr: case action_kind::ScopedExpr:
s << a.get_rec() << a.rbp() << a.use_lambda_abstraction(); s << a.get_rec() << a.rbp() << a.use_lambda_abstraction();
break; break;
case action_kind::LuaExt:
s << a.get_lua_fn();
break;
case action_kind::Ext: case action_kind::Ext:
lean_unreachable(); lean_unreachable();
} }
@ -104,6 +107,8 @@ action read_action(deserializer & d) {
d >> rec >> rbp >> use_lambda_abstraction; d >> rec >> rbp >> use_lambda_abstraction;
return notation::mk_scoped_expr_action(rec, rbp, use_lambda_abstraction); return notation::mk_scoped_expr_action(rec, rbp, use_lambda_abstraction);
} }
case action_kind::LuaExt:
return notation::mk_ext_lua_action(d.read_string().c_str());
case action_kind::Ext: case action_kind::Ext:
break; break;
} }

View file

@ -9,11 +9,14 @@ Author: Leonardo de Moura
#include "util/script_state.h" #include "util/script_state.h"
#include "frontends/lean/token_table.h" #include "frontends/lean/token_table.h"
#include "frontends/lean/parse_table.h" #include "frontends/lean/parse_table.h"
#include "frontends/lean/parser.h"
namespace lean { namespace lean {
void open_frontend_lean(lua_State * L) { void open_frontend_lean(lua_State * L) {
open_token_table(L); open_token_table(L);
open_parse_table(L); open_parse_table(L);
open_parser(L);
} }
void register_frontend_lean_module() { void register_frontend_lean_module() {
script_state::register_module(open_frontend_lean); script_state::register_module(open_frontend_lean);