From b4a8418d38ddc8d7916000132a1c8ab4e3b75c61 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 27 Nov 2013 17:47:29 -0800 Subject: [PATCH] feat(library/tactic): expose tactics in the Lua API Signed-off-by: Leonardo de Moura --- src/library/tactic/register_module.h | 2 + src/library/tactic/tactic.cpp | 227 ++++++++++++++++++++++++++- src/library/tactic/tactic.h | 12 +- src/util/script_state.h | 6 + tests/lua/tactic1.lua | 37 +++++ 5 files changed, 281 insertions(+), 3 deletions(-) create mode 100644 tests/lua/tactic1.lua diff --git a/src/library/tactic/register_module.h b/src/library/tactic/register_module.h index 9e1eb0e33..03e084c79 100644 --- a/src/library/tactic/register_module.h +++ b/src/library/tactic/register_module.h @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "library/tactic/proof_builder.h" #include "library/tactic/cex_builder.h" #include "library/tactic/proof_state.h" +#include "library/tactic/tactic.h" namespace lean { inline void open_tactic_module(lua_State * L) { @@ -17,6 +18,7 @@ inline void open_tactic_module(lua_State * L) { open_proof_builder(L); open_cex_builder(L); open_proof_state(L); + open_tactic(L); } inline void register_tactic_module() { script_state::register_module(open_tactic_module); diff --git a/src/library/tactic/tactic.cpp b/src/library/tactic/tactic.cpp index f0dcf544c..ebf92fd53 100644 --- a/src/library/tactic/tactic.cpp +++ b/src/library/tactic/tactic.cpp @@ -7,29 +7,54 @@ Author: Leonardo de Moura #include #include #include +#include "util/luaref.h" +#include "util/script_state.h" #include "util/sstream.h" #include "util/interrupt.h" #include "util/lazy_list_fn.h" +#include "library/kernel_bindings.h" #include "library/tactic/tactic.h" namespace lean { solve_result::solve_result(expr const & pr):m_kind(solve_result_kind::Proof) { new (&m_proof) expr(pr); } solve_result::solve_result(counterexample const & cex):m_kind(solve_result_kind::Counterexample) { new (&m_cex) counterexample(cex); } solve_result::solve_result(list const & fs):m_kind(solve_result_kind::Failure) { new (&m_failures) list(fs); } -solve_result::solve_result(solve_result const & r):m_kind(r.m_kind) { +void solve_result::init(solve_result const & r) { + m_kind = r.m_kind; switch (m_kind) { + case solve_result_kind::None: break; case solve_result_kind::Proof: new (&m_proof) expr(r.m_proof); break; case solve_result_kind::Counterexample: new (&m_cex) counterexample(r.m_cex); break; case solve_result_kind::Failure: new (&m_failures) list(r.m_failures); break; } } -solve_result::~solve_result() { +void solve_result::destroy() { switch (m_kind) { + case solve_result_kind::None: break; case solve_result_kind::Proof: m_proof.~expr(); break; case solve_result_kind::Counterexample: m_cex.~counterexample(); break; case solve_result_kind::Failure: m_failures.~list(); break; } } +solve_result::solve_result(solve_result const & r) { + init(r); +} +solve_result::~solve_result() { + destroy(); +} +solve_result & solve_result::operator=(solve_result & other) { + if (this == &other) + return *this; + destroy(); + init(other); + return *this; +} +solve_result & solve_result::operator=(solve_result && other) { + lean_assert(this != &other); + destroy(); + init(other); + return *this; +} tactic & tactic::operator=(tactic const & s) { LEAN_COPY_REF(tactic, s); @@ -227,4 +252,202 @@ tactic take(tactic const & t, unsigned k) { return take(k, t(env, io, s)); }); } + +DECL_UDATA(proof_state_seq) + +static const struct luaL_Reg proof_state_seq_m[] = { + {"__gc", proof_state_seq_gc}, // never throws + {0, 0} +}; + +static int proof_state_seq_next(lua_State * L) { + proof_state_seq seq = to_proof_state_seq(L, lua_upvalueindex(1)); + script_state S = to_script_state(L); + proof_state_seq::maybe_pair p; + S.exec_unprotected([&]() { + p = seq.pull(); + }); + if (p) { + push_proof_state_seq(L, p->second); + lua_replace(L, lua_upvalueindex(1)); + push_proof_state(L, p->first); + } else { + lua_pushnil(L); + } + return 1; +} + +static int push_proof_state_seq_it(lua_State * L, proof_state_seq const & seq) { + push_proof_state_seq(L, seq); + lua_pushcclosure(L, &safe_function, 1); // create closure with 1 upvalue + return 1; +} + +DECL_UDATA(tactic) + +static void check_ios(io_state * ios) { + if (!ios) + throw exception("failed to invoke tactic, io_state is not available"); +} + +static int tactic_call_core(lua_State * L, tactic t, environment env, io_state ios, proof_state s) { + script_state S = to_script_state(L); + proof_state_seq seq; + S.exec_unprotected([&]() { + seq = t(env, ios, s); + }); + return push_proof_state_seq_it(L, seq); +} + +static int tactic_call(lua_State * L) { + int nargs = lua_gettop(L); + tactic & t = to_tactic(L, 1); + ro_environment env(L, 2); + if (nargs == 3) { + io_state * ios = get_io_state(L); + check_ios(ios); + return tactic_call_core(L, t, env, *ios, to_proof_state(L, 3)); + } else { + return tactic_call_core(L, t, env, to_io_state(L, 3), to_proof_state(L, 4)); + } +} + +static int tactic_then(lua_State * L) { return push_tactic(L, then(to_tactic(L, 1), to_tactic(L, 2))); } +static int tactic_orelse(lua_State * L) { return push_tactic(L, orelse(to_tactic(L, 1), to_tactic(L, 2))); } +static int tactic_append(lua_State * L) { return push_tactic(L, append(to_tactic(L, 1), to_tactic(L, 2))); } +static int tactic_par(lua_State * L) { return push_tactic(L, par(to_tactic(L, 1), to_tactic(L, 2))); } +static int tactic_repeat(lua_State * L) { return push_tactic(L, repeat(to_tactic(L, 1))); } +static int tactic_repeat1(lua_State * L) { return push_tactic(L, repeat1(to_tactic(L, 1))); } +static int tactic_repeat_at_most(lua_State * L) { return push_tactic(L, repeat_at_most(to_tactic(L, 1), luaL_checkinteger(L, 2))); } +static int tactic_take(lua_State * L) { return push_tactic(L, take(to_tactic(L, 1), luaL_checkinteger(L, 2))); } +static int tactic_determ(lua_State * L) { return push_tactic(L, determ(to_tactic(L, 1))); } +static int tactic_suppress_trace(lua_State * L) { return push_tactic(L, suppress_trace(to_tactic(L, 1))); } +static int tactic_try_for(lua_State * L) { return push_tactic(L, try_for(to_tactic(L, 1), luaL_checkinteger(L, 2))); } + +static int push_solve_result(lua_State * L, solve_result const & r) { + switch (r.kind()) { + case solve_result_kind::None: lua_pushnil(L); break; + case solve_result_kind::Proof: push_expr(L, r.get_proof()); break; + case solve_result_kind::Counterexample: push_environment(L, r.get_cex()); break; + case solve_result_kind::Failure: + lua_newtable(L); + int i = 1; + for (auto s : r.get_failures()) { + push_proof_state(L, s); + lua_rawseti(L, -2, i); + i++; + } + } + return 1; +} + +static int tactic_solve_core(lua_State * L, tactic t, environment env, io_state ios, proof_state s) { + script_state S = to_script_state(L); + solve_result result; + S.exec_unprotected([&]() { + result = t.solve(env, ios, s);; + }); + return push_solve_result(L, result); +} + +static int tactic_solve_core(lua_State * L, tactic t, environment env, io_state ios, context ctx, expr e) { + script_state S = to_script_state(L); + solve_result result; + S.exec_unprotected([&]() { + result = t.solve(env, ios, ctx, e); + }); + return push_solve_result(L, result); +} + +static int tactic_solve(lua_State * L) { + int nargs = lua_gettop(L); + tactic & t = to_tactic(L, 1); + ro_environment env(L, 2); + if (nargs == 3) { + io_state * ios = get_io_state(L); + check_ios(ios); + return tactic_solve_core(L, t, env, *ios, to_proof_state(L, 3)); + } else if (nargs == 4) { + if (is_io_state(L, 3)) { + return tactic_solve_core(L, t, env, to_io_state(L, 3), to_proof_state(L, 4)); + } else { + io_state * ios = get_io_state(L); + check_ios(ios); + return tactic_solve_core(L, t, env, *ios, to_context(L, 3), to_expr(L, 4)); + } + } else { + return tactic_solve_core(L, t, env, to_io_state(L, 3), to_context(L, 4), to_expr(L, 5)); + } +} + +static int mk_lua_tactic01(lua_State * L) { + luaL_checktype(L, 1, LUA_TFUNCTION); // user-fun + luaref ref(L, 1); + return push_tactic(L, + mk_tactic01([=](environment const & env, io_state const & ios, proof_state const & s) -> optional { + script_state S = to_script_state(L); + optional r; + S.exec_protected([&]() { + ref.push(); // push user-fun on the stack + push_environment(L, env); + push_io_state(L, ios); + push_proof_state(L, s); + pcall(L, 3, 1, 0); + if (is_proof_state(L, -1)) { + r = to_proof_state(L, -1); + } + lua_pop(L, 1); + }); + return r; + })); +} + +static int mk_id_tactic(lua_State * L) { return push_tactic(L, id_tactic()); } +static int mk_now_tactic(lua_State * L) { return push_tactic(L, now_tactic()); } +static int mk_fail_tactic(lua_State * L) { return push_tactic(L, fail_tactic()); } +static int mk_trace_tactic(lua_State * L) { return push_tactic(L, trace_tactic(luaL_checkstring(L, 1))); } +static int mk_assumption_tactic(lua_State * L) { return push_tactic(L, assumption_tactic()); } + +static const struct luaL_Reg tactic_m[] = { + {"__gc", tactic_gc}, // never throws + {"__call", safe_function}, + {"__concat", safe_function}, + {"__pow", safe_function}, + {"__add", safe_function}, + {"then", safe_function}, + {"orelse", safe_function}, + {"append", safe_function}, + {"solve", safe_function}, + {"par", safe_function}, + {"determ", safe_function}, + {"repeat", safe_function}, + {"repeat1", safe_function}, + {"repeat_at_most", safe_function}, + {"take", safe_function}, + {"suppress_trace", safe_function}, + {"try_for", safe_function}, + {0, 0} +}; + +void open_tactic(lua_State * L) { + luaL_newmetatable(L, proof_state_seq_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, proof_state_seq_m, 0); + SET_GLOBAL_FUN(proof_state_seq_pred, "is_proof_state_seq"); + + luaL_newmetatable(L, tactic_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, tactic_m, 0); + + SET_GLOBAL_FUN(tactic_pred, "is_tactic"); + SET_GLOBAL_FUN(mk_trace_tactic, "trace_tactic"); + SET_GLOBAL_FUN(mk_id_tactic, "id_tactic"); + SET_GLOBAL_FUN(mk_now_tactic, "now_tactic"); + SET_GLOBAL_FUN(mk_fail_tactic, "fail_tactic"); + SET_GLOBAL_FUN(mk_assumption_tactic, "assumption_tactic"); + SET_GLOBAL_FUN(mk_assumption_tactic, "assump_tactic"); + SET_GLOBAL_FUN(mk_lua_tactic01, "tactic"); +} } diff --git a/src/library/tactic/tactic.h b/src/library/tactic/tactic.h index 027c73aa3..0c4d9b118 100644 --- a/src/library/tactic/tactic.h +++ b/src/library/tactic/tactic.h @@ -36,7 +36,7 @@ public: } }; -enum class solve_result_kind { Proof, Counterexample, Failure }; +enum class solve_result_kind { None, Proof, Counterexample, Failure }; /** \brief Result for the solve method in the tactic class. The result may be a proof, a counterexample, or a list of unsolved proof_states. @@ -48,12 +48,17 @@ class solve_result { counterexample m_cex; list m_failures; }; + void init(solve_result const & r); + void destroy(); public: + solve_result():m_kind(solve_result_kind::None) {} solve_result(expr const & pr); solve_result(counterexample const & cex); solve_result(list const & fs); solve_result(solve_result const & r); ~solve_result(); + solve_result & operator=(solve_result & other); + solve_result & operator=(solve_result && other); solve_result_kind kind() const { return m_kind; } expr get_proof() const { lean_assert(kind() == solve_result_kind::Proof); return m_proof; } counterexample get_cex() const { lean_assert(kind() == solve_result_kind::Counterexample); return m_cex; } @@ -218,6 +223,7 @@ tactic interleave(tactic const & t1, tactic const & t2); threads finished. */ tactic par(tactic const & t1, tactic const & t2, unsigned check_ms = 1); + /** \brief Return a tactic that keeps applying \c t until it fails. */ @@ -264,4 +270,8 @@ tactic cond(P && p, tactic const & t1, tactic const & t2) { */ template tactic when(P && p, tactic const & t) { return cond(std::forward

(p), t, id_tactic()); } + +UDATA_DEFS_CORE(proof_state_seq) +UDATA_DEFS(tactic); +void open_tactic(lua_State * L); } diff --git a/src/util/script_state.h b/src/util/script_state.h index 8bd14a47e..d77afdcfd 100644 --- a/src/util/script_state.h +++ b/src/util/script_state.h @@ -63,6 +63,12 @@ public: unlock_guard unlock(get_mutex()); f(); } + + template + void exec_protected(F && f) { + std::lock_guard lock(get_mutex()); + f(); + } }; /** \brief Return a reference to the script_state object that is wrapping \c L. diff --git a/tests/lua/tactic1.lua b/tests/lua/tactic1.lua new file mode 100644 index 000000000..832195bb8 --- /dev/null +++ b/tests/lua/tactic1.lua @@ -0,0 +1,37 @@ +local ps = proof_state() +local env = environment() +local Bool = Const("Bool") +env:add_var("p", Bool) +env:add_var("q", Bool) +local p, q = Consts("p, q") +local ctx = context() +ctx = ctx:extend("H1", p) +ctx = ctx:extend("H2", q) +ps = to_proof_state(env, ctx, p) +local ios = io_state() +print(ps) +local ltac = tactic(function(env, ios, s) + print("FIRST tactic in Lua, current state: " .. tostring(s)); + return s +end) +local t = (trace_tactic("hello") .. trace_tactic("world")) + (trace_tactic("again") .. ltac .. assumption_tactic()) +for s in t(env, ios, ps) do + if s:is_proof_final_state() then + local m = proof_map() + local a = assignment(s:get_menv()) + print(s:proof_builder()(m, a)) + else + print(s) + end +end +print("-------------------") +print(t:solve(env, ios, ps)) +print(t:solve(env, ios, ctx, p)) +assert(t:solve(env, ios, ps) == Var(0)) +assert(t:solve(env, ios, ctx, q) == Var(1)) +local t2 = id_tactic() + id_tactic() + id_tactic() +local r = t2:solve(env, ios, ps) +assert(#r == 3) +for i, out_state in ipairs(r) do + print(i, out_state) +end