From ce14ced08e4f8b2e276f5fb22e9a797689b556a2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 7 Jul 2014 10:45:19 -0700 Subject: [PATCH] feat(util/sexpr): allow Lua objects to be embedded in Lean s-expressions Signed-off-by: Leonardo de Moura --- src/util/sexpr/format.cpp | 5 ++ src/util/sexpr/sexpr.cpp | 154 +++++++++++++++++++++++++++++++------- src/util/sexpr/sexpr.h | 36 ++++++--- tests/lua/sexpr_bug1.lua | 21 +++++- 4 files changed, 176 insertions(+), 40 deletions(-) diff --git a/src/util/sexpr/format.cpp b/src/util/sexpr/format.cpp index b5f8af802..6a7e1a5fc 100644 --- a/src/util/sexpr/format.cpp +++ b/src/util/sexpr/format.cpp @@ -369,6 +369,11 @@ struct sexpr_pp_fn { case sexpr_kind::Name: return pp(to_name(s)); case sexpr_kind::MPZ: return format(to_mpz(s)); case sexpr_kind::MPQ: return format(to_mpq(s)); + case sexpr_kind::Ext: { + std::ostringstream ss; + to_ext(s).display(ss); + return format(ss.str()); + } case sexpr_kind::Cons: { sexpr const * curr = &s; format r; diff --git a/src/util/sexpr/sexpr.cpp b/src/util/sexpr/sexpr.cpp index c4854d5a2..bd4f06cd8 100644 --- a/src/util/sexpr/sexpr.cpp +++ b/src/util/sexpr/sexpr.cpp @@ -14,6 +14,7 @@ Author: Leonardo de Moura #include "util/buffer.h" #include "util/sstream.h" #include "util/object_serializer.h" +#include "util/luaref.h" #include "util/numerics/mpz.h" #include "util/numerics/mpq.h" #include "util/sexpr/sexpr.h" @@ -88,6 +89,16 @@ struct sexpr_mpq : public sexpr_cell { m_value(v) {} }; +/** \brief S-expression cell: external atom */ +struct sexpr_ext : public sexpr_cell { + std::unique_ptr m_value; + sexpr_ext(std::unique_ptr && v): + sexpr_cell(sexpr_kind::Ext, v->hash()), + m_value(std::move(v)) { + lean_assert(m_value); + } +}; + /** \brief S-expression cell: cons cell (aka pair) */ struct sexpr_cons : public sexpr_cell { sexpr m_head; @@ -137,6 +148,7 @@ void sexpr_cell::dealloc() { case sexpr_kind::Name: delete static_cast(this); break; case sexpr_kind::MPZ: delete static_cast(this); break; case sexpr_kind::MPQ: delete static_cast(this); break; + case sexpr_kind::Ext: delete static_cast(this); break; case sexpr_kind::Cons: static_cast(this)->dealloc_cons(); break; } } @@ -149,6 +161,7 @@ sexpr::sexpr(double v):m_ptr(new sexpr_double(v)) {} sexpr::sexpr(name const & v):m_ptr(new sexpr_name(v)) {} sexpr::sexpr(mpz const & v):m_ptr(new sexpr_mpz(v)) {} sexpr::sexpr(mpq const & v):m_ptr(new sexpr_mpq(v)) {} +sexpr::sexpr(std::unique_ptr && v):m_ptr(new sexpr_ext(std::move(v))) {} sexpr::sexpr(sexpr const & h, sexpr const & t):m_ptr(new sexpr_cons(h, t)) {} sexpr::sexpr(sexpr const & s):m_ptr(s.m_ptr) { if (m_ptr) @@ -172,6 +185,7 @@ double sexpr::get_double() const { return static_cast(m_ptr)->m_v name const & sexpr::get_name() const { return static_cast(m_ptr)->m_value; } mpz const & sexpr::get_mpz() const { return static_cast(m_ptr)->m_value; } mpq const & sexpr::get_mpq() const { return static_cast(m_ptr)->m_value; } +sexpr_ext_atom const & sexpr::get_ext() const { return *static_cast(m_ptr)->m_value; } unsigned sexpr::hash() const { return m_ptr == nullptr ? 23 : m_ptr->m_hash; } @@ -224,6 +238,7 @@ bool operator==(sexpr const & a, sexpr const & b) { case sexpr_kind::Name: return to_name(a) == to_name(b); case sexpr_kind::MPZ: return to_mpz(a) == to_mpz(b); case sexpr_kind::MPQ: return to_mpq(a) == to_mpq(b); + case sexpr_kind::Ext: return to_ext(a).cmp(to_ext(b)) == 0; case sexpr_kind::Cons: return head(a) == head(b) && tail(a) == tail(b); } lean_unreachable(); // LCOV_EXCL_LINE @@ -249,6 +264,7 @@ int cmp(sexpr const & a, sexpr const & b) { case sexpr_kind::Name: return cmp(to_name(a), to_name(b)); case sexpr_kind::MPZ: return cmp(to_mpz(a), to_mpz(b)); case sexpr_kind::MPQ: return cmp(to_mpq(a), to_mpq(b)); + case sexpr_kind::Ext: return to_ext(a).cmp(to_ext(b)); case sexpr_kind::Cons: { int r = cmp(head(a), head(b)); if (r != 0) @@ -268,6 +284,7 @@ std::ostream & operator<<(std::ostream & out, sexpr const & s) { case sexpr_kind::Name: out << to_name(s); break; case sexpr_kind::MPZ: out << to_mpz(s); break; case sexpr_kind::MPQ: out << to_mpq(s); break; + case sexpr_kind::Ext: to_ext(s).display(out); break; case sexpr_kind::Cons: { out << "("; sexpr const * curr = &s; @@ -311,6 +328,8 @@ public: case sexpr_kind::MPZ: s << to_mpz(a); break; case sexpr_kind::MPQ: s << to_mpq(a); break; case sexpr_kind::Cons: write(car(a)); write(cdr(a)); break; + case sexpr_kind::Ext: + throw exception("s-expressions constaining external atoms cannot be serialized"); } }); } @@ -332,6 +351,7 @@ public: case sexpr_kind::Name: return sexpr(read_name(d)); case sexpr_kind::MPZ: return sexpr(read_mpz(d)); case sexpr_kind::MPQ: return sexpr(read_mpq(d)); + case sexpr_kind::Ext: lean_unreachable(); // LCOV_EXCL_LINE case sexpr_kind::Cons: { sexpr h = read(); sexpr t = read(); @@ -363,6 +383,70 @@ sexpr read_sexpr(deserializer & d) { DECL_UDATA(sexpr) +class lua_sexpr_atom : public sexpr_ext_atom { + luaref m_ref; +public: + lua_sexpr_atom(luaref && r):m_ref(r) {} + virtual ~lua_sexpr_atom() {} + virtual int cmp(sexpr_ext_atom const & e) const { + if (dynamic_cast(&e) == nullptr) { + return strcmp(typeid(*this).name(), typeid(e).name()); + } else { + luaref other = static_cast(e).m_ref; + lua_State * L = m_ref.get_state(); + if (other.get_state() != L) + throw exception("missing Lua objects from different Lua states"); + m_ref.push(); + other.push(); + int r; + if (equal(L, -2, -1)) + r = 0; + else if (lessthan(L, -2, -1)) + r = -1; + else + r = 0; + lua_pop(L, 2); + return r; + } + } + + virtual unsigned hash() const { + lua_State * L = m_ref.get_state(); + m_ref.push(); + lua_getfield(L, -1, "hash"); + if (lua_isnil(L, -1)) { + lua_pop(L, 2); + return 0; + } else { + m_ref.push(); + pcall(L, 1, 1, 0); + if (lua_isnumber(L, -1)) { + unsigned r = lua_tointeger(L, -1); + lua_pop(L, 1); + return r; + } else { + lua_pop(L, 1); + return 0; + } + } + } + + virtual int push_lua(lua_State * L) const { + if (m_ref.get_state() != L) + throw exception("using Lua object in a different Lua state"); + m_ref.push(); + return 1; + } + + virtual void display(std::ostream & out) const { + lua_State * L = m_ref.get_state(); + m_ref.push(); + out << luaL_tolstring(L, -1, nullptr); + lua_pop(L, 1); + } +}; + + static int sexpr_tostring(lua_State * L) { std::ostringstream out; out << to_sexpr(L, 1); @@ -389,7 +473,7 @@ static sexpr to_sexpr_elem(lua_State * L, int idx) { std::string str = lua_tostring(L, idx); return sexpr(str); } else { - throw exception(sstream() << "arg #" << idx << " cannot be casted into an s-expression"); + return sexpr(std::unique_ptr(new lua_sexpr_atom(luaref(L, idx)))); } } @@ -425,6 +509,7 @@ SEXPR_PRED(is_double) SEXPR_PRED(is_name) SEXPR_PRED(is_mpz) SEXPR_PRED(is_mpq) +SEXPR_PRED(is_external) static int sexpr_length(lua_State * L) { sexpr const & e = to_sexpr(L, 1); @@ -496,6 +581,13 @@ static int sexpr_to_mpq(lua_State * L) { return push_mpq(L, to_mpq(e)); } +static int sexpr_to_external(lua_State * L) { + sexpr const & e = to_sexpr(L, 1); + if (!is_external(e)) + throw exception("s-expression is not an external atom"); + return to_ext(e).push_lua(L); +} + static int sexpr_get_kind(lua_State * L) { return push_integer(L, static_cast(to_sexpr(L, 1).kind())); } @@ -511,6 +603,7 @@ static int sexpr_fields(lua_State * L) { case sexpr_kind::Name: return sexpr_to_name(L); case sexpr_kind::MPZ: return sexpr_to_mpz(L); case sexpr_kind::MPQ: return sexpr_to_mpq(L); + case sexpr_kind::Ext: return sexpr_to_external(L); case sexpr_kind::Cons: sexpr_head(L); sexpr_tail(L); return 2; } lean_unreachable(); // LCOV_EXCL_LINE @@ -518,34 +611,36 @@ static int sexpr_fields(lua_State * L) { } static const struct luaL_Reg sexpr_m[] = { - {"__gc", sexpr_gc}, // never throws - {"__tostring", safe_function}, - {"__eq", safe_function}, - {"__lt", safe_function}, - {"kind", safe_function}, - {"is_nil", safe_function}, - {"is_cons", safe_function}, - {"is_pair", safe_function}, - {"is_list", safe_function}, - {"is_atom", safe_function}, - {"is_string", safe_function}, - {"is_bool", safe_function}, - {"is_int", safe_function}, - {"is_double", safe_function}, - {"is_name", safe_function}, - {"is_mpz", safe_function}, - {"is_mpq", safe_function}, - {"head", safe_function}, - {"tail", safe_function}, - {"length", safe_function}, - {"to_bool", safe_function}, - {"to_string", safe_function}, - {"to_int", safe_function}, - {"to_double", safe_function}, - {"to_name", safe_function}, - {"to_mpz", safe_function}, - {"to_mpq", safe_function}, - {"fields", safe_function}, + {"__gc", sexpr_gc}, // never throws + {"__tostring", safe_function}, + {"__eq", safe_function}, + {"__lt", safe_function}, + {"kind", safe_function}, + {"is_nil", safe_function}, + {"is_cons", safe_function}, + {"is_pair", safe_function}, + {"is_list", safe_function}, + {"is_atom", safe_function}, + {"is_string", safe_function}, + {"is_bool", safe_function}, + {"is_int", safe_function}, + {"is_double", safe_function}, + {"is_name", safe_function}, + {"is_mpz", safe_function}, + {"is_mpq", safe_function}, + {"is_external", safe_function}, + {"head", safe_function}, + {"tail", safe_function}, + {"length", safe_function}, + {"to_bool", safe_function}, + {"to_string", safe_function}, + {"to_int", safe_function}, + {"to_double", safe_function}, + {"to_name", safe_function}, + {"to_mpz", safe_function}, + {"to_mpq", safe_function}, + {"to_external", safe_function}, + {"fields", safe_function}, {0, 0} }; @@ -568,6 +663,7 @@ void open_sexpr(lua_State * L) { SET_ENUM("MPZ", sexpr_kind::MPZ); SET_ENUM("MPQ", sexpr_kind::MPQ); SET_ENUM("Cons", sexpr_kind::Cons); + SET_ENUM("Ext", sexpr_kind::Ext); lua_setglobal(L, "sexpr_kind"); } } diff --git a/src/util/sexpr/sexpr.h b/src/util/sexpr/sexpr.h index 0002f6771..9164a0c85 100644 --- a/src/util/sexpr/sexpr.h +++ b/src/util/sexpr/sexpr.h @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include #include #include +#include #include "util/lua.h" #include "util/serializer.h" @@ -18,7 +19,9 @@ class mpq; class mpz; struct sexpr_cell; -enum class sexpr_kind { Nil, String, Bool, Int, Double, Name, MPZ, MPQ, Cons }; +enum class sexpr_kind { Nil, String, Bool, Int, Double, Name, MPZ, MPQ, Cons, Ext }; + +class sexpr_ext_atom; /** \brief Simple LISP-like S-expressions. @@ -43,6 +46,7 @@ public: explicit sexpr(name const & v); explicit sexpr(mpz const & v); explicit sexpr(mpq const & v); + explicit sexpr(std::unique_ptr && v); sexpr(sexpr const & h, sexpr const & t); template sexpr(T const & h, sexpr const & t):sexpr(sexpr(h), t) {} @@ -75,6 +79,7 @@ public: name const & get_name() const; mpz const & get_mpz() const; mpq const & get_mpq() const; + sexpr_ext_atom const & get_ext() const; /** \brief Hash code for this S-expression*/ unsigned hash() const; @@ -95,6 +100,15 @@ public: friend std::ostream & operator<<(std::ostream & out, sexpr const & s); }; +class sexpr_ext_atom { +public: + virtual ~sexpr_ext_atom() {} + virtual int cmp(sexpr_ext_atom const & e) const = 0; + virtual unsigned hash() const = 0; + virtual int push_lua(lua_State * L) const = 0; + virtual void display(std::ostream & out) const = 0; +}; + /** \brief Return the nil S-expression */ inline sexpr nil() { return sexpr(); } /** \brief Return a cons-cell (aka pair) composed of \c head and \c tail */ @@ -110,16 +124,17 @@ inline sexpr const & car(sexpr const & s) { return head(s); } */ inline sexpr const & cdr(sexpr const & s) { return tail(s); } /** \brief Return true iff \c s is not an atom (i.e., it is not a cons cell). */ -inline bool is_atom(sexpr const & s) { return s.kind() != sexpr_kind::Cons; } +inline bool is_atom(sexpr const & s) { return s.kind() != sexpr_kind::Cons; } /** \brief Return true iff \c s is not a cons cell. */ -inline bool is_cons(sexpr const & s) { return s.kind() == sexpr_kind::Cons; } -inline bool is_string(sexpr const & s) { return s.kind() == sexpr_kind::String; } -inline bool is_bool(sexpr const & s) { return s.kind() == sexpr_kind::Bool; } -inline bool is_int(sexpr const & s) { return s.kind() == sexpr_kind::Int; } -inline bool is_double(sexpr const & s) { return s.kind() == sexpr_kind::Double; } -inline bool is_name(sexpr const & s) { return s.kind() == sexpr_kind::Name; } -inline bool is_mpz(sexpr const & s) { return s.kind() == sexpr_kind::MPZ; } -inline bool is_mpq(sexpr const & s) { return s.kind() == sexpr_kind::MPQ; } +inline bool is_cons(sexpr const & s) { return s.kind() == sexpr_kind::Cons; } +inline bool is_string(sexpr const & s) { return s.kind() == sexpr_kind::String; } +inline bool is_bool(sexpr const & s) { return s.kind() == sexpr_kind::Bool; } +inline bool is_int(sexpr const & s) { return s.kind() == sexpr_kind::Int; } +inline bool is_double(sexpr const & s) { return s.kind() == sexpr_kind::Double; } +inline bool is_name(sexpr const & s) { return s.kind() == sexpr_kind::Name; } +inline bool is_mpz(sexpr const & s) { return s.kind() == sexpr_kind::MPZ; } +inline bool is_mpq(sexpr const & s) { return s.kind() == sexpr_kind::MPQ; } +inline bool is_external(sexpr const & s) { return s.kind() == sexpr_kind::Ext; } inline std::string const & to_string(sexpr const & s) { return s.get_string(); } inline bool to_bool(sexpr const & s) { return s.get_bool(); } @@ -128,6 +143,7 @@ inline double to_double(sexpr const & s) { return s.get_double(); } inline name const & to_name(sexpr const & s) { return s.get_name(); } inline mpz const & to_mpz(sexpr const & s) { return s.get_mpz(); } inline mpq const & to_mpq(sexpr const & s) { return s.get_mpq(); } +inline sexpr_ext_atom const & to_ext(sexpr const & s) { return s.get_ext(); } /** \brief Return true iff \c s is nil or \c s is a cons cell where \c is_list(tail(s)). */ bool is_list(sexpr const & s); diff --git a/tests/lua/sexpr_bug1.lua b/tests/lua/sexpr_bug1.lua index c42bcc853..025cbee64 100644 --- a/tests/lua/sexpr_bug1.lua +++ b/tests/lua/sexpr_bug1.lua @@ -1 +1,20 @@ -check_error(function() print(sexpr(Local("a", Bool))) end) +local s = sexpr(Local("a", Bool), Local("b", Bool)) +print(s) +local a, b = s:fields() +print(a) +print(b) +assert(a ~= Local("a", Bool)) +assert(a:to_external() == Local("a", Bool)) +assert(a:fields() == Local("a", Bool)) +assert(is_expr(a:to_external())) + +local s = sexpr(Local("a", Bool), Local("b", Bool)) +local s = sexpr({}) + +local s1 = sexpr(Local("a", Bool), Local("b", Bool)) +local s2 = sexpr(Local("a", Bool), Local("c", Bool)) +assert(Local("b", Bool) > Local("c", Bool)) +assert(s1 > s2) +assert(s2 < s1) +assert(s2 == sexpr(Local("a", Bool), Local("c", Bool))) +