diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index ae39998a5..0a5188558 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -1076,15 +1076,11 @@ class parser::imp { name tac_name = check_identifier_next("invalid apply command, identifier or 'script-block' expected"); m_script_state->apply([&](lua_State * L) { lua_getglobal(L, tac_name.to_string().c_str()); - if (lua_type(L, -1) != LUA_TFUNCTION && !is_tactic(L, -1)) + try { + t = to_tactic_ext(L, -1); + } catch (...) { throw parser_error(sstream() << "unknown tactic '" << tac_name << "'", tac_pos); - if (lua_type(L, -1) == LUA_TFUNCTION) { - pcall(L, 0, 1, 0); - if (!is_tactic(L, -1)) - throw parser_error(sstream() << "invalid function '" << tac_name << "', it does not return a tactic", - tac_pos); } - t = to_tactic(L, -1); lua_pop(L, 1); }); } diff --git a/src/library/tactic/tactic.cpp b/src/library/tactic/tactic.cpp index f8a8033e9..c4e83ff23 100644 --- a/src/library/tactic/tactic.cpp +++ b/src/library/tactic/tactic.cpp @@ -290,6 +290,39 @@ static int push_proof_state_seq_it(lua_State * L, proof_state_seq const & seq) { DECL_UDATA(tactic) +[[ noreturn ]] void throw_tactic_expected(int i) { + throw exception(sstream() << "arg #" << i << " must be a tactic or a function that returns a tactic"); +} + +/** + \brief We allow functions (that return tactics) to be used where a tactic + is expected. The idea is to be able to write + ORELSE(assumption_tactic, conj_tactic) + instead of + ORELSE(assumption_tactic(), conj_tactic()) +*/ +tactic to_tactic_ext(lua_State * L, int i) { + if (is_tactic(L, i)) { + return to_tactic(L, i); + } else if (lua_isfunction(L, i)) { + try { + lua_pushvalue(L, i); + pcall(L, 0, 1, 0); + } catch (...) { + throw_tactic_expected(i); + } + if (is_tactic(L, -1)) { + tactic t = to_tactic(L, -1); + lua_pop(L, 1); + return t; + } else { + throw_tactic_expected(i); + } + } else { + throw_tactic_expected(i); + } +} + static void check_ios(io_state * ios) { if (!ios) throw exception("failed to invoke tactic, io_state is not available"); @@ -306,7 +339,7 @@ static int tactic_call_core(lua_State * L, tactic t, environment env, io_state i static int tactic_call(lua_State * L) { int nargs = lua_gettop(L); - tactic & t = to_tactic(L, 1); + tactic t = to_tactic_ext(L, 1); ro_environment env(L, 2); if (nargs == 3) { io_state * ios = get_io_state(L); @@ -324,27 +357,27 @@ static int nary_tactic(lua_State * L) { int nargs = lua_gettop(L); if (nargs < 2) throw exception("tactical expects at least two arguments"); - tactic r = F(to_tactic(L, 1), to_tactic(L, 2)); + tactic r = F(to_tactic_ext(L, 1), to_tactic_ext(L, 2)); for (int i = 3; i <= nargs; i++) - r = F(r, to_tactic(L, i)); + r = F(r, to_tactic_ext(L, i)); return push_tactic(L, r); } -static int tactic_then(lua_State * L) { return push_tactic(L, then(to_tactic(L, 1), to_tactic(L, 2))); } -static int tactic_orelse(lua_State * L) { return push_tactic(L, orelse(to_tactic(L, 1), to_tactic(L, 2))); } -static int tactic_append(lua_State * L) { return push_tactic(L, append(to_tactic(L, 1), to_tactic(L, 2))); } -static int tactic_interleave(lua_State * L) { return push_tactic(L, interleave(to_tactic(L, 1), to_tactic(L, 2))); } -static int tactic_par(lua_State * L) { return push_tactic(L, par(to_tactic(L, 1), to_tactic(L, 2))); } +static int tactic_then(lua_State * L) { return push_tactic(L, then(to_tactic_ext(L, 1), to_tactic_ext(L, 2))); } +static int tactic_orelse(lua_State * L) { return push_tactic(L, orelse(to_tactic_ext(L, 1), to_tactic_ext(L, 2))); } +static int tactic_append(lua_State * L) { return push_tactic(L, append(to_tactic_ext(L, 1), to_tactic_ext(L, 2))); } +static int tactic_interleave(lua_State * L) { return push_tactic(L, interleave(to_tactic_ext(L, 1), to_tactic_ext(L, 2))); } +static int tactic_par(lua_State * L) { return push_tactic(L, par(to_tactic_ext(L, 1), to_tactic_ext(L, 2))); } -static int tactic_repeat(lua_State * L) { return push_tactic(L, repeat(to_tactic(L, 1))); } -static int tactic_repeat1(lua_State * L) { return push_tactic(L, repeat1(to_tactic(L, 1))); } -static int tactic_repeat_at_most(lua_State * L) { return push_tactic(L, repeat_at_most(to_tactic(L, 1), luaL_checkinteger(L, 2))); } -static int tactic_take(lua_State * L) { return push_tactic(L, take(to_tactic(L, 1), luaL_checkinteger(L, 2))); } -static int tactic_determ(lua_State * L) { return push_tactic(L, determ(to_tactic(L, 1))); } -static int tactic_suppress_trace(lua_State * L) { return push_tactic(L, suppress_trace(to_tactic(L, 1))); } -static int tactic_try_for(lua_State * L) { return push_tactic(L, try_for(to_tactic(L, 1), luaL_checkinteger(L, 2))); } -static int tactic_using_params(lua_State * L) { return push_tactic(L, using_params(to_tactic(L, 1), to_options(L, 2))); } -static int tactic_try(lua_State * L) { return push_tactic(L, orelse(to_tactic(L, 1), id_tactic())); } +static int tactic_repeat(lua_State * L) { return push_tactic(L, repeat(to_tactic_ext(L, 1))); } +static int tactic_repeat1(lua_State * L) { return push_tactic(L, repeat1(to_tactic_ext(L, 1))); } +static int tactic_repeat_at_most(lua_State * L) { return push_tactic(L, repeat_at_most(to_tactic_ext(L, 1), luaL_checkinteger(L, 2))); } +static int tactic_take(lua_State * L) { return push_tactic(L, take(to_tactic_ext(L, 1), luaL_checkinteger(L, 2))); } +static int tactic_determ(lua_State * L) { return push_tactic(L, determ(to_tactic_ext(L, 1))); } +static int tactic_suppress_trace(lua_State * L) { return push_tactic(L, suppress_trace(to_tactic_ext(L, 1))); } +static int tactic_try_for(lua_State * L) { return push_tactic(L, try_for(to_tactic_ext(L, 1), luaL_checkinteger(L, 2))); } +static int tactic_using_params(lua_State * L) { return push_tactic(L, using_params(to_tactic_ext(L, 1), to_options(L, 2))); } +static int tactic_try(lua_State * L) { return push_tactic(L, orelse(to_tactic_ext(L, 1), id_tactic())); } static int push_solve_result(lua_State * L, solve_result const & r) { switch (r.kind()) { @@ -383,7 +416,7 @@ static int tactic_solve_core(lua_State * L, tactic t, environment env, io_state static int tactic_solve(lua_State * L) { int nargs = lua_gettop(L); - tactic & t = to_tactic(L, 1); + tactic t = to_tactic_ext(L, 1); ro_environment env(L, 2); if (nargs == 3) { io_state * ios = get_io_state(L); @@ -473,11 +506,11 @@ static int mk_lua_cond_tactic(lua_State * L, tactic t1, tactic t2) { } static int mk_lua_cond_tactic(lua_State * L) { - return mk_lua_cond_tactic(L, to_tactic(L, 2), to_tactic(L, 3)); + return mk_lua_cond_tactic(L, to_tactic_ext(L, 2), to_tactic_ext(L, 3)); } static int mk_lua_when_tactic(lua_State * L) { - return mk_lua_cond_tactic(L, to_tactic(L, 2), id_tactic()); + return mk_lua_cond_tactic(L, to_tactic_ext(L, 2), id_tactic()); } static int mk_id_tactic(lua_State * L) { return push_tactic(L, id_tactic()); } diff --git a/src/library/tactic/tactic.h b/src/library/tactic/tactic.h index d167b5252..3e3fb5200 100644 --- a/src/library/tactic/tactic.h +++ b/src/library/tactic/tactic.h @@ -274,5 +274,6 @@ tactic when(P && p, tactic const & t) { return cond(std::forward
(p), t, id_ta UDATA_DEFS_CORE(proof_state_seq) UDATA_DEFS(tactic); +tactic to_tactic_ext(lua_State * L, int i); void open_tactic(lua_State * L); } diff --git a/tests/lean/tactic5.lean b/tests/lean/tactic5.lean new file mode 100644 index 000000000..5e0538385 --- /dev/null +++ b/tests/lean/tactic5.lean @@ -0,0 +1,9 @@ +(** +simple_tac = REPEAT(ORELSE(imp_tactic, conj_tactic)) .. assumption_tactic +**) + +Theorem T4 (a b : Bool) : a => b => a /\ b /\ a := _. + apply simple_tac + done + +Show Environment 1. \ No newline at end of file diff --git a/tests/lean/tactic5.lean.expected.out b/tests/lean/tactic5.lean.expected.out new file mode 100644 index 000000000..da29a8912 --- /dev/null +++ b/tests/lean/tactic5.lean.expected.out @@ -0,0 +1,5 @@ + Set: pp::colors + Set: pp::unicode + Proved: T4 +Theorem T4 (a b : Bool) : a ⇒ b ⇒ a ∧ b ∧ a := + Discharge (λ H : a, Discharge (λ H::1 : b, Conj H (Conj H::1 H)))