From 8bccfb947ac40d745a0dfaf75f5e1ae54fa50828 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 27 Jan 2014 15:02:05 -0800 Subject: [PATCH] feat(library/simplifier): expose simplier and simplifier_monitor objects in the Lua API Signed-off-by: Leonardo de Moura --- src/library/simplifier/simplifier.cpp | 357 ++++++++++++++++++++++---- src/library/simplifier/simplifier.h | 71 +++-- tests/lean/simp11.lean | 4 +- tests/lean/simp12.lean | 2 +- tests/lean/simp15.lean | 2 +- tests/lean/simp24.lean | 6 +- tests/lean/simp3.lean | 4 +- tests/lean/simp31.lean | 29 +++ tests/lean/simp31.lean.expected.out | 42 +++ 9 files changed, 441 insertions(+), 76 deletions(-) create mode 100644 tests/lean/simp31.lean create mode 100644 tests/lean/simp31.lean.expected.out diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 48f208225..005a6ec29 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -9,6 +9,8 @@ Author: Leonardo de Moura #include "util/flet.h" #include "util/freset.h" #include "util/interrupt.h" +#include "util/luaref.h" +#include "util/script_state.h" #include "kernel/type_checker.h" #include "kernel/free_vars.h" #include "kernel/instantiate.h" @@ -106,7 +108,9 @@ static name g_H("H"); static name g_x("x"); static name g_unique = name::mk_internal_unique_name(); -class simplifier_fn { +class simplifier_cell::imp { + friend class simplifier_cell; + friend class simplifier; struct result { expr m_expr; // the result of a simplification step optional m_proof; // a proof that the result is equal to the input (when m_proofs_enabled) @@ -126,7 +130,9 @@ class simplifier_fn { typedef expr_map cache; typedef std::vector congr_thms; typedef cache const_map; + std::weak_ptr m_this; ro_environment m_env; + options m_options; type_checker m_tc; bool m_has_heq; bool m_has_cast; @@ -154,7 +160,7 @@ class simplifier_fn { unsigned m_max_steps; struct updt_rule_set { - simplifier_fn & m_fn; + imp & m_fn; rewrite_rule_set m_old; freset m_reset_cache; // must reset the cache whenever we update the rule set. /** @@ -162,7 +168,7 @@ class simplifier_fn { \pre const_type(H) */ - updt_rule_set(simplifier_fn & fn, expr const & H): + updt_rule_set(imp & fn, expr const & H): m_fn(fn), m_old(m_fn.m_rule_sets[0]), m_reset_cache(m_fn.m_cache) { lean_assert(const_type(H)); m_fn.m_rule_sets[0].insert(g_local, *const_type(H), H); @@ -174,9 +180,9 @@ class simplifier_fn { }; struct updt_const_map { - simplifier_fn & m_fn; - expr const & m_old_x; - updt_const_map(simplifier_fn & fn, expr const & old_x, expr const & new_x, expr const & H): + imp & m_fn; + expr const & m_old_x; + updt_const_map(imp & fn, expr const & old_x, expr const & new_x, expr const & H): m_fn(fn), m_old_x(old_x) { m_fn.m_const_map[old_x] = result(new_x, H, true); } @@ -304,7 +310,7 @@ class simplifier_fn { if (is_TypeU(A)) { if (!is_definitionally_equal(f, new_f)) { if (m_monitor) - m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported); + m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported); return none_expr(); // can't handle } // The congruence axiom cannot be used in this case. @@ -324,7 +330,7 @@ class simplifier_fn { expr a_type = infer_type(a); if (!is_convertible(a_type, A)) { if (m_monitor) - m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch); + m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch); return none_expr(); // can't handle } expr a_prime = new_a.m_expr; @@ -347,7 +353,7 @@ class simplifier_fn { expr new_a_type = infer_type(new_a.m_expr); if (!is_convertible(new_a_type, new_A)) { if (m_monitor) - m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch); + m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch); return none_expr(); // failed } expr Heq_a = get_proof(new_a); @@ -359,7 +365,7 @@ class simplifier_fn { is_heq_proof = false; } else { if (m_monitor) - m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported); + m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported); return none_expr(); // we don't know how to handle this case } } @@ -867,7 +873,7 @@ class simplifier_fn { } } if (m_monitor) - m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq()); + m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id()); return true; } else { // Conditional rewriting: we try to fill the missing @@ -906,19 +912,22 @@ class simplifier_fn { // but proof generation is not enabled. // So, we should fail if (m_monitor) - m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::Unsupported); + m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(), + i, simplifier_monitor::failure_kind::Unsupported); return false; } } else { // failed to prove proposition if (m_monitor) - m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::AssumptionNotProved); + m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(), + i, simplifier_monitor::failure_kind::AssumptionNotProved); return false; } } else { // failed, the argument is not a proposition if (m_monitor) - m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::MissingArgument); + m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(), + i, simplifier_monitor::failure_kind::MissingArgument); return false; } } @@ -927,11 +936,12 @@ class simplifier_fn { new_rhs = arg(ceq, num_args(ceq) - 1); if (rule.is_permutation() && !is_lt(new_rhs, target, false)) { if (m_monitor) - m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), 0, simplifier_monitor::failure_kind::PermutationGe); + m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(), + 0, simplifier_monitor::failure_kind::PermutationGe); return false; } if (m_monitor) - m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq()); + m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id()); return true; } } @@ -1099,7 +1109,7 @@ class simplifier_fn { if (occurs(x_old, new_bi)) { // failed, simplifier didn't manage to replace x_old with x_new if (m_monitor) - m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody); + m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody); return rewrite(e, result(e)); } expr new_e = update_lambda(e, new_d, abstract(new_bi, x_new)); @@ -1202,7 +1212,7 @@ class simplifier_fn { expr e_type = infer_type(e); if (is_TypeU(e_type) || !ensure_homogeneous(A, res_A)) { if (m_monitor) - m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch); + m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch); return result(e); // failed, we can't use subst theorem } else { expr H = get_proof(res_A); @@ -1237,7 +1247,7 @@ class simplifier_fn { expr e_type = infer_type(e); if (is_TypeU(e_type) || !ensure_homogeneous(B, res_B)) { if (m_monitor) - m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch); + m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch); return result(e); // failed, we can't use subst theorem } else { expr H = get_proof(res_B); @@ -1318,13 +1328,13 @@ class simplifier_fn { if (occurs(x_old, new_bi)) { // failed, simplifier didn't manage to replace x_old with x_new if (m_monitor) - m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody); + m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody); return rewrite(e, result(e)); } expr new_e = update_pi(e, new_d, abstract(new_bi, x_new)); if (!m_proofs_enabled || is_definitionally_equal(e, new_e)) { if (m_monitor) - m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch); + m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch); return rewrite(e, result(new_e)); } ensure_homogeneous(d, res_d); @@ -1364,7 +1374,7 @@ class simplifier_fn { } else { // We currently do simplify (forall x : A, B x) when it is not a proposition. if (m_monitor) - m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::Unsupported); + m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::Unsupported); return result(e); } } @@ -1374,7 +1384,7 @@ class simplifier_fn { result new_r = r.update_expr(m_max_sharing(r.m_expr)); m_cache.insert(mk_pair(e, new_r)); if (m_monitor) - m_monitor->step_eh(m_depth, e, new_r.m_expr, new_r.m_proof); + m_monitor->step_eh(ro_simplifier(m_this), e, new_r.m_expr, new_r.m_proof); return new_r; } else { return r; @@ -1395,7 +1405,7 @@ class simplifier_fn { } } if (m_monitor) - m_monitor->pre_eh(m_depth, e); + m_monitor->pre_eh(ro_simplifier(m_this), e); switch (e.kind()) { case expr_kind::Var: return result(e); case expr_kind::Constant: return save(e, simplify_constant(e)); @@ -1445,9 +1455,9 @@ class simplifier_fn { } public: - simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, - std::shared_ptr const & monitor): - m_env(env), m_tc(env), m_monitor(monitor) { + imp(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, + std::shared_ptr const & monitor): + m_env(env), m_options(o), m_tc(env), m_monitor(monitor) { m_has_heq = m_env->imported("heq"); m_has_cast = m_env->imported("cast"); set_options(o); @@ -1469,10 +1479,38 @@ public: } }; +simplifier_cell::simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, + std::shared_ptr const & monitor): + m_ptr(new imp(env, o, num_rs, rs, monitor)) { +} + +expr_pair simplifier_cell::operator()(expr const & e, context const & ctx) { return m_ptr->operator()(e, ctx); } +void simplifier_cell::clear() { return m_ptr->m_cache.clear(); } +unsigned simplifier_cell::get_depth() const { return m_ptr->m_depth; } +context const & simplifier_cell::get_context() const { return m_ptr->m_ctx; } +ro_environment const & simplifier_cell::get_environment() const { return m_ptr->m_env; } +options const & simplifier_cell::get_options() const { return m_ptr->m_options; } + +simplifier::simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, + std::shared_ptr const & monitor): + m_ptr(std::make_shared(env, o, num_rs, rs, monitor)) { + m_ptr->m_ptr->m_this = m_ptr; +} + +ro_simplifier::ro_simplifier(simplifier const & env): + m_ptr(env.m_ptr) { +} + +ro_simplifier::ro_simplifier(weak_ref const & r) { + if (r.expired()) + throw exception("weak reference to simplifier object has expired (i.e., the simplifier has been deleted)"); + m_ptr = r.lock(); +} + expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, unsigned num_rs, rewrite_rule_set const * rs, std::shared_ptr const & monitor) { - return simplifier_fn(env, opts, num_rs, rs, monitor)(e, ctx); + return simplifier(env, opts, num_rs, rs, monitor)(e, ctx); } expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, @@ -1484,32 +1522,222 @@ expr_pair simplify(expr const & e, ro_environment const & env, context const & c return simplify(e, env, ctx, opts, num_ns, rules.data(), monitor); } +DECL_UDATA(simplifier) +DECL_UDATA(ro_simplifier) + +/** + \brief Simplifier monitor implemented using Lua functions +*/ +class lua_simplifier_monitor : public simplifier_monitor { + optional m_pre_eh; + optional m_step_eh; + optional m_rewrite_eh; + optional m_failed_app_eh; + optional m_failed_rewrite_eh; + optional m_failed_abstraction_eh; +public: + lua_simplifier_monitor(optional const & pre_eh, optional const & step_eh, optional const & rewrite_eh, + optional const & failed_app_eh, optional const & failed_rewrite_eh, optional const & failed_abstraction_eh): + m_pre_eh(pre_eh), m_step_eh(step_eh), m_rewrite_eh(rewrite_eh), + m_failed_app_eh(failed_app_eh), m_failed_rewrite_eh(failed_rewrite_eh), m_failed_abstraction_eh(failed_abstraction_eh) { + } + virtual ~lua_simplifier_monitor() {} + + virtual void pre_eh(ro_simplifier const & s, expr const & e) { + if (m_pre_eh) { + lua_State * L = m_pre_eh->get_state(); + m_pre_eh->push(); + push_ro_simplifier(L, s); + push_expr(L, e); + pcall(L, 2, 0, 0); + } + } + + virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional const & pr) { + if (m_step_eh) { + lua_State * L = m_step_eh->get_state(); + m_step_eh->push(); + push_ro_simplifier(L, s); + push_expr(L, e); + push_expr(L, new_e); + push_optional_expr(L, pr); + pcall(L, 4, 0, 0); + } + } + + virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) { + if (m_rewrite_eh) { + lua_State * L = m_rewrite_eh->get_state(); + m_rewrite_eh->push(); + push_ro_simplifier(L, s); + push_expr(L, e); + push_expr(L, new_e); + push_expr(L, ceq); + push_name(L, ceq_id); + pcall(L, 5, 0, 0); + } + } + + virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) { + if (m_failed_app_eh) { + lua_State * L = m_failed_app_eh->get_state(); + m_failed_app_eh->push(); + push_ro_simplifier(L, s); + push_expr(L, e); + lua_pushinteger(L, i); + lua_pushinteger(L, static_cast(k)); + pcall(L, 4, 0, 0); + } + } + + virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) { + if (m_failed_rewrite_eh) { + lua_State * L = m_failed_rewrite_eh->get_state(); + m_failed_rewrite_eh->push(); + push_ro_simplifier(L, s); + push_expr(L, e); + push_expr(L, ceq); + push_name(L, ceq_id); + lua_pushinteger(L, i); + lua_pushinteger(L, static_cast(k)); + pcall(L, 6, 0, 0); + } + } + + virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) { + if (m_failed_abstraction_eh) { + lua_State * L = m_failed_abstraction_eh->get_state(); + m_failed_abstraction_eh->push(); + push_ro_simplifier(L, s); + push_expr(L, e); + lua_pushinteger(L, static_cast(k)); + pcall(L, 3, 0, 0); + } + } +}; + +typedef std::shared_ptr simplifier_monitor_ptr; + +DECL_UDATA(simplifier_monitor_ptr) + +static const struct luaL_Reg simplifier_monitor_ptr_m[] = { + {"__gc", simplifier_monitor_ptr_gc}, + {0, 0} +}; + +static optional get_opt_callback(lua_State * L, int i) { + if (i > lua_gettop(L) || lua_isnil(L, i)) { + return optional(); + } else { + luaL_checktype(L, i, LUA_TFUNCTION); // user-fun + return optional(luaref(L, i)); + } +} + +static int mk_simplifier_monitor(lua_State * L) { + simplifier_monitor_ptr r = std::make_shared(get_opt_callback(L, 1), + get_opt_callback(L, 2), + get_opt_callback(L, 3), + get_opt_callback(L, 4), + get_opt_callback(L, 5), + get_opt_callback(L, 6)); + return push_simplifier_monitor_ptr(L, r); +} + +/** + \brief Fill the the rewrite_rule_set \c rs using the object at position \c i in the Lua stack. +*/ +static void get_rewrite_rule_set(lua_State * L, int i, ro_environment const & env, buffer & rs) { + if (i > lua_gettop(L)) { + rs.push_back(get_rewrite_rule_set(env)); + } else if (lua_isstring(L, i)) { + rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, i))); + } else { + luaL_checktype(L, i, LUA_TTABLE); + name r; + int n = objlen(L, i); + for (int j = 1; j <= n; j++) { + lua_rawgeti(L, i, j); + rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1))); + lua_pop(L, 1); + } + } +} + +static int mk_simplifier(lua_State * L, ro_environment const & env) { + int nargs = lua_gettop(L); + buffer rules; + get_rewrite_rule_set(L, 1, env, rules); + options opts; + if (nargs >= 2) + opts = to_options(L, 2); + simplifier_monitor_ptr monitor; + if (nargs >= 3 && !lua_isnil(L, 3)) + monitor = to_simplifier_monitor_ptr(L, 3); + return push_simplifier(L, simplifier(env, opts, rules.size(), rules.data(), monitor)); +} + +static int mk_simplifier(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs <= 3) + return mk_simplifier(L, ro_shared_environment(L)); + else + return mk_simplifier(L, ro_shared_environment(L, 4)); +} + +static int simplifier_apply(lua_State * L) { + int nargs = lua_gettop(L); + expr_pair r; + if (nargs == 2) + r = to_simplifier(L, 1)(to_expr(L, 2), context()); + else + r = to_simplifier(L, 1)(to_expr(L, 2), to_context(L, 3)); + push_expr(L, r.first); + push_expr(L, r.second); + return 2; +} + +static int simplifier_clear(lua_State * L) { to_simplifier(L, 1)->clear(); return 0; } +static int simplifier_depth(lua_State * L) { lua_pushinteger(L, to_simplifier(L, 1)->get_depth()); return 1; } +static int simplifier_context(lua_State * L) { return push_context(L, to_simplifier(L, 1)->get_context()); } +static int simplifier_environment(lua_State * L) { return push_environment(L, to_simplifier(L, 1)->get_environment()); } +static int simplifier_options(lua_State * L) { return push_options(L, to_simplifier(L, 1)->get_options()); } +static int ro_simplifier_depth(lua_State * L) { lua_pushinteger(L, to_ro_simplifier(L, 1)->get_depth()); return 1; } +static int ro_simplifier_context(lua_State * L) { return push_context(L, to_ro_simplifier(L, 1)->get_context()); } +static int ro_simplifier_environment(lua_State * L) { return push_environment(L, to_ro_simplifier(L, 1)->get_environment()); } +static int ro_simplifier_options(lua_State * L) { return push_options(L, to_ro_simplifier(L, 1)->get_options()); } + +static const struct luaL_Reg simplifier_m[] = { + {"__gc", simplifier_gc}, + {"__call", safe_function}, + {"clear", safe_function}, + {"depth", safe_function}, + {"get_environment", safe_function}, + {"get_context", safe_function}, + {"get_options", safe_function}, + {0, 0} +}; + +static const struct luaL_Reg ro_simplifier_m[] = { + {"__gc", ro_simplifier_gc}, + {"depth", safe_function}, + {"get_environment", safe_function}, + {"get_context", safe_function}, + {"get_options", safe_function}, + {0, 0} +}; + static int simplify_core(lua_State * L, ro_shared_environment const & env) { int nargs = lua_gettop(L); expr const & e = to_expr(L, 1); buffer rules; - if (nargs == 1) { - rules.push_back(get_rewrite_rule_set(env)); - } else { - if (lua_isstring(L, 2)) { - rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, 2))); - } else { - luaL_checktype(L, 2, LUA_TTABLE); - name r; - int n = objlen(L, 2); - for (int i = 1; i <= n; i++) { - lua_rawgeti(L, 2, i); - rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1))); - lua_pop(L, 1); - } - } - } + get_rewrite_rule_set(L, 2, env, rules); context ctx; options opts; - if (nargs >= 4) - ctx = to_context(L, 4); + if (nargs >= 3) + opts = to_options(L, 3); if (nargs >= 5) - opts = to_options(L, 5); + ctx = to_context(L, 5); auto r = simplify(e, env, ctx, opts, rules.size(), rules.data()); push_expr(L, r.first); push_expr(L, r.second); @@ -1518,13 +1746,44 @@ static int simplify_core(lua_State * L, ro_shared_environment const & env) { static int simplify(lua_State * L) { int nargs = lua_gettop(L); - if (nargs <= 2) + if (nargs <= 4) return simplify_core(L, ro_shared_environment(L)); else - return simplify_core(L, ro_shared_environment(L, 3)); + return simplify_core(L, ro_shared_environment(L, 4)); } void open_simplifier(lua_State * L) { + luaL_newmetatable(L, simplifier_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, simplifier_m, 0); + SET_GLOBAL_FUN(simplifier_pred, "is_simplifier"); + + SET_GLOBAL_FUN(mk_simplifier, "simplifier"); + + luaL_newmetatable(L, ro_simplifier_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, ro_simplifier_m, 0); + SET_GLOBAL_FUN(ro_simplifier_pred, "is_ro_simplifier"); + + luaL_newmetatable(L, simplifier_monitor_ptr_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, simplifier_monitor_ptr_m, 0); + SET_GLOBAL_FUN(simplifier_monitor_ptr_pred, "is_simplifier_monitor"); + + SET_GLOBAL_FUN(mk_simplifier_monitor, "simplifier_monitor"); + + lua_newtable(L); + SET_ENUM("Unsupported", simplifier_monitor::failure_kind::Unsupported); + SET_ENUM("TypeMismatch", simplifier_monitor::failure_kind::TypeMismatch); + SET_ENUM("AssumptionNotProved", simplifier_monitor::failure_kind::AssumptionNotProved); + SET_ENUM("MissingArgument", simplifier_monitor::failure_kind::MissingArgument); + SET_ENUM("PermutationGe", simplifier_monitor::failure_kind::PermutationGe); + SET_ENUM("AbstractionBody", simplifier_monitor::failure_kind::AbstractionBody); + lua_setglobal(L, "simplifier_failure"); + SET_GLOBAL_FUN(simplify, "simplify"); } } diff --git a/src/library/simplifier/simplifier.h b/src/library/simplifier/simplifier.h index f4813f2f1..6df27aeb5 100644 --- a/src/library/simplifier/simplifier.h +++ b/src/library/simplifier/simplifier.h @@ -5,38 +5,79 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once +#include #include "util/lua.h" #include "kernel/environment.h" #include "library/expr_pair.h" #include "library/simplifier/rewrite_rule_set.h" namespace lean { +class simplifier_monitor; + +/** \brief Simplifier object cell. */ +class simplifier_cell { + friend class simplifier; + struct imp; + std::unique_ptr m_ptr; +public: + simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, + std::shared_ptr const & monitor); + + expr_pair operator()(expr const & e, context const & ctx); + void clear(); + + unsigned get_depth() const; + context const & get_context() const; + ro_environment const & get_environment() const; + options const & get_options() const; +}; + +/** \brief Reference to simplifier object */ +class simplifier { + friend class ro_simplifier; + std::shared_ptr m_ptr; +public: + simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, + std::shared_ptr const & monitor); + simplifier_cell * operator->() const { return m_ptr.get(); } + simplifier_cell & operator*() const { return *(m_ptr.get()); } + expr_pair operator()(expr const & e, context const & ctx) { return (*m_ptr)(e, ctx); } +}; + +/** \brief Read only reference to simplifier object */ +class ro_simplifier { + std::shared_ptr m_ptr; +public: + typedef std::weak_ptr weak_ref; + ro_simplifier(simplifier const & s); + ro_simplifier(weak_ref const & s); + explicit operator weak_ref() const { return weak_ref(m_ptr); } + weak_ref to_weak_ref() const { return weak_ref(m_ptr); } + simplifier_cell const * operator->() const { return m_ptr.get(); } + simplifier_cell const & operator*() const { return *(m_ptr.get()); } +}; + /** \brief Abstract class that specifies the interface for monitoring the behavior of the simplifier. */ class simplifier_monitor { public: + virtual ~simplifier_monitor() {} /** \brief This method is invoked to sign that the simplifier is starting to process the expression \c e. - - \remark \c depth is the recursion depth */ - virtual void pre_eh(unsigned depth, expr const & e) = 0; + virtual void pre_eh(ro_simplifier const & s, expr const & e) = 0; /** \brief This method is invoked to sign that \c e has be rewritten into \c new_e with proof \c pr. The proof is none if proof generation is disabled or if \c e and \c new_e are definitionally equal. - - \remark \c depth is the recursion depth */ - virtual void step_eh(unsigned depth, expr const & e, expr const & new_e, optional const & pr) = 0; + virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional const & pr) = 0; /** \brief This method is invoked to sign that \c e has be rewritten into \c new_e using the conditional equation \c ceq. - - \remark \c depth is the recursion depth */ - virtual void rewrite_eh(unsigned depth, expr const & e, expr const & new_e, expr const & ceq) = 0; + virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) = 0; enum class failure_kind { Unsupported, TypeMismatch, AssumptionNotProved, MissingArgument, PermutationGe, AbstractionBody }; @@ -44,10 +85,8 @@ public: \brief This method is invoked when the simplifier fails to rewrite an application \c e. \c i is the argument where the simplifier gave up, and \c k is the reason for failure. Two possible values are: Unsupported or TypeMismatch (may happen when simplifying terms that use dependent types). - - \remark \c depth is the recursion depth */ - virtual void failed_app_eh(unsigned depth, expr const & e, unsigned i, failure_kind k) = 0; + virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) = 0; /** \brief This method is invoked when the simplifier fails to apply a conditional equation \c ceq to \c e. @@ -55,19 +94,15 @@ public: The possible failure values are: AssumptionNotProved (failed to synthesize a proof for an assumption required by \c ceq), MissingArgument (failed to infer one of the arguments needed by the conditional equation), PermutationGe (the conditional equation is a permutation, and the result is not smaller in the term ordering, \c i is irrelevant in this case). - - \remark \c depth is the recursion depth */ - virtual void failed_rewrite_eh(unsigned depth, expr const & e, expr const & ceq, unsigned i, failure_kind k) = 0; + virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) = 0; /** \brief This method is invoked when the simplifier fails to simplify an abstraction (Pi or Lambda). The possible failure values are: Unsupported, TypeMismatch, and AbstractionBody (failed to rewrite the body of the abstraction, this may happen when we are using dependent types). - - \remark \c depth is the recursion depth */ - virtual void failed_abstraction_eh(unsigned depth, expr const & e, failure_kind k) = 0; + virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) = 0; }; expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts, diff --git a/tests/lean/simp11.lean b/tests/lean/simp11.lean index 41e46a4bb..fef2b0e9b 100644 --- a/tests/lean/simp11.lean +++ b/tests/lean/simp11.lean @@ -4,7 +4,7 @@ add_rewrite Nat::add_assoc Nat::add_comm Nat::add_left_comm Nat::distributer Nat (* local opts = options({"simplifier", "max_steps"}, 100) local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))") -local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +local t2, pr = simplify(t, "simple", opts) *) print "trying again with more steps" @@ -12,7 +12,7 @@ print "trying again with more steps" (* local opts = options({"simplifier", "max_steps"}, 100000) local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))") -local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +local t2, pr = simplify(t, "simple", opts) print(t) print("====>") print(t2) diff --git a/tests/lean/simp12.lean b/tests/lean/simp12.lean index 26becf6d0..f9f476314 100644 --- a/tests/lean/simp12.lean +++ b/tests/lean/simp12.lean @@ -6,7 +6,7 @@ local opts = options({"simplifier", "single_pass"}, true) local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))") print(t) for i = 1, 10 do - local t2, pr = simplify(t, "simple", get_environment(), context(), opts) + local t2, pr = simplify(t, "simple", opts) print("i: " .. i .. " ====>") print(t2) if t == t2 then diff --git a/tests/lean/simp15.lean b/tests/lean/simp15.lean index ef70d07e3..020401f6e 100644 --- a/tests/lean/simp15.lean +++ b/tests/lean/simp15.lean @@ -10,7 +10,7 @@ variables a b c : Nat (* local opts = options({"simplifier", "contextual"}, false) local t = parse_lean([[a = 1 ∧ (¬ b = 0 ∨ c ≠ 0 ∨ b + c > a)]]) -local s, pr = simplify(t, "simple", get_environment(), context(), opts) +local s, pr = simplify(t, "simple", opts) print(s) print(pr) print(get_environment():type_check(pr)) diff --git a/tests/lean/simp24.lean b/tests/lean/simp24.lean index d74dbbd6a..3962e035e 100644 --- a/tests/lean/simp24.lean +++ b/tests/lean/simp24.lean @@ -4,7 +4,7 @@ variables a b : Nat (* local opts = options({"simplifier", "contextual"}, false) local t = parse_lean('λ x, a = a → x = a') -local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +local t2, pr = simplify(t, "simple", opts) print(t2) print(pr) get_environment():type_check(pr) @@ -13,7 +13,7 @@ get_environment():type_check(pr) (* local opts = options({"simplifier", "contextual"}, false) local t = parse_lean('λ x, x = a → x = x') -local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +local t2, pr = simplify(t, "simple", opts) print(t2) print(pr) get_environment():type_check(pr) @@ -22,7 +22,7 @@ get_environment():type_check(pr) (* local opts = options({"simplifier", "contextual"}, false) local t = parse_lean('λ x, x = a + 0 → a = a') -local t2, pr = simplify(t, "simple", get_environment(), context(), opts) +local t2, pr = simplify(t, "simple", opts) print(t2) print(pr) *) diff --git a/tests/lean/simp3.lean b/tests/lean/simp3.lean index 089071fe9..991e48587 100644 --- a/tests/lean/simp3.lean +++ b/tests/lean/simp3.lean @@ -15,7 +15,7 @@ show(simplify(t4)) (* local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false) local t1 = parse_lean("double (double 2) + 1 ≥ 3") -show(simplify(t1, 'default', get_environment(), context(), opt)) +show(simplify(t1, 'default', opt)) *) set_opaque Nat::ge false @@ -26,7 +26,7 @@ add_rewrite Nat::distributel (* local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false) local t1 = parse_lean("2 * double (double 2) + 1 ≥ 3") -show(simplify(t1, 'default', get_environment(), context(), opt)) +show(simplify(t1, 'default', opt)) *) variables a b c d : Nat diff --git a/tests/lean/simp31.lean b/tests/lean/simp31.lean new file mode 100644 index 000000000..f3fc883ae --- /dev/null +++ b/tests/lean/simp31.lean @@ -0,0 +1,29 @@ +rewrite_set simple +add_rewrite Nat::add_comm Nat::add_left_comm Nat::add_assoc Nat::add_zeror : simple +variables a b c : Nat +(* + function indent(s) + for i = 1, s:depth()-1 do + io.write(" ") + end + end + local m = simplifier_monitor(function(s, e) + print("Visit, depth: " .. s:depth() .. ", " .. tostring(e)) + end, + function(s, e, new_e, pr) + print("Step: " .. tostring(e) .. " ===> " .. tostring(new_e)) + end, + function(s, e, new_e, ceq, ceq_id) + print("Rewrite using: " .. tostring(ceq_id)) + print(" " .. tostring(e) .. " ===> " .. tostring(new_e)) + end + ) + local s = simplifier("simple", options(), m) + local t = parse_lean('a + (b + 0) + a') + print(t) + print("=====>") + local t2, pr = s(t) + print(t2) + print(pr) + get_environment():type_check(pr) +*) diff --git a/tests/lean/simp31.lean.expected.out b/tests/lean/simp31.lean.expected.out new file mode 100644 index 000000000..0cf34d763 --- /dev/null +++ b/tests/lean/simp31.lean.expected.out @@ -0,0 +1,42 @@ + Set: pp::colors + Set: pp::unicode + Assumed: a + Assumed: b + Assumed: c +a + (b + 0) + a +=====> +Visit, depth: 1, a + (b + 0) + a +Visit, depth: 2, Nat::add +Visit, depth: 2, a + (b + 0) +Visit, depth: 3, Nat::add +Visit, depth: 3, a +Step: a ===> a +Visit, depth: 3, b + 0 +Visit, depth: 4, Nat::add +Visit, depth: 4, b +Step: b ===> b +Visit, depth: 4, 0 +Rewrite using: Nat::add_zeror + b + 0 ===> b +Step: b + 0 ===> b +Visit, depth: 3, a + b +Visit, depth: 4, Nat::add +Step: a + b ===> a + b +Step: a + (b + 0) ===> a + b +Rewrite using: Nat::add_assoc + a + b + a ===> a + (b + a) +Visit, depth: 2, a + (b + a) +Visit, depth: 3, Nat::add +Visit, depth: 3, b + a +Visit, depth: 4, Nat::add +Rewrite using: Nat::add_comm + b + a ===> a + b +Step: b + a ===> a + b +Visit, depth: 3, a + (a + b) +Visit, depth: 4, Nat::add +Step: a + (a + b) ===> a + (a + b) +Step: a + (b + a) ===> a + (a + b) +Step: a + (b + 0) + a ===> a + (a + b) +a + (a + b) +trans (trans (congr1 a (congr2 Nat::add (congr2 (Nat::add a) (Nat::add_zeror b)))) (Nat::add_assoc a b a)) + (congr2 (Nat::add a) (Nat::add_comm b a))