feat(library/simplifier): expose simplier and simplifier_monitor objects in the Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-27 15:02:05 -08:00
parent c088825ef0
commit 8bccfb947a
9 changed files with 441 additions and 76 deletions

View file

@ -9,6 +9,8 @@ Author: Leonardo de Moura
#include "util/flet.h"
#include "util/freset.h"
#include "util/interrupt.h"
#include "util/luaref.h"
#include "util/script_state.h"
#include "kernel/type_checker.h"
#include "kernel/free_vars.h"
#include "kernel/instantiate.h"
@ -106,7 +108,9 @@ static name g_H("H");
static name g_x("x");
static name g_unique = name::mk_internal_unique_name();
class simplifier_fn {
class simplifier_cell::imp {
friend class simplifier_cell;
friend class simplifier;
struct result {
expr m_expr; // the result of a simplification step
optional<expr> m_proof; // a proof that the result is equal to the input (when m_proofs_enabled)
@ -126,7 +130,9 @@ class simplifier_fn {
typedef expr_map<result> cache;
typedef std::vector<congr_theorem_info const *> congr_thms;
typedef cache const_map;
std::weak_ptr<simplifier_cell> m_this;
ro_environment m_env;
options m_options;
type_checker m_tc;
bool m_has_heq;
bool m_has_cast;
@ -154,7 +160,7 @@ class simplifier_fn {
unsigned m_max_steps;
struct updt_rule_set {
simplifier_fn & m_fn;
imp & m_fn;
rewrite_rule_set m_old;
freset<cache> m_reset_cache; // must reset the cache whenever we update the rule set.
/**
@ -162,7 +168,7 @@ class simplifier_fn {
\pre const_type(H)
*/
updt_rule_set(simplifier_fn & fn, expr const & H):
updt_rule_set(imp & fn, expr const & H):
m_fn(fn), m_old(m_fn.m_rule_sets[0]), m_reset_cache(m_fn.m_cache) {
lean_assert(const_type(H));
m_fn.m_rule_sets[0].insert(g_local, *const_type(H), H);
@ -174,9 +180,9 @@ class simplifier_fn {
};
struct updt_const_map {
simplifier_fn & m_fn;
expr const & m_old_x;
updt_const_map(simplifier_fn & fn, expr const & old_x, expr const & new_x, expr const & H):
imp & m_fn;
expr const & m_old_x;
updt_const_map(imp & fn, expr const & old_x, expr const & new_x, expr const & H):
m_fn(fn), m_old_x(old_x) {
m_fn.m_const_map[old_x] = result(new_x, H, true);
}
@ -304,7 +310,7 @@ class simplifier_fn {
if (is_TypeU(A)) {
if (!is_definitionally_equal(f, new_f)) {
if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported);
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported);
return none_expr(); // can't handle
}
// The congruence axiom cannot be used in this case.
@ -324,7 +330,7 @@ class simplifier_fn {
expr a_type = infer_type(a);
if (!is_convertible(a_type, A)) {
if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch);
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch);
return none_expr(); // can't handle
}
expr a_prime = new_a.m_expr;
@ -347,7 +353,7 @@ class simplifier_fn {
expr new_a_type = infer_type(new_a.m_expr);
if (!is_convertible(new_a_type, new_A)) {
if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch);
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch);
return none_expr(); // failed
}
expr Heq_a = get_proof(new_a);
@ -359,7 +365,7 @@ class simplifier_fn {
is_heq_proof = false;
} else {
if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported);
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported);
return none_expr(); // we don't know how to handle this case
}
}
@ -867,7 +873,7 @@ class simplifier_fn {
}
}
if (m_monitor)
m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq());
m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id());
return true;
} else {
// Conditional rewriting: we try to fill the missing
@ -906,19 +912,22 @@ class simplifier_fn {
// but proof generation is not enabled.
// So, we should fail
if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::Unsupported);
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
i, simplifier_monitor::failure_kind::Unsupported);
return false;
}
} else {
// failed to prove proposition
if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::AssumptionNotProved);
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
i, simplifier_monitor::failure_kind::AssumptionNotProved);
return false;
}
} else {
// failed, the argument is not a proposition
if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::MissingArgument);
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
i, simplifier_monitor::failure_kind::MissingArgument);
return false;
}
}
@ -927,11 +936,12 @@ class simplifier_fn {
new_rhs = arg(ceq, num_args(ceq) - 1);
if (rule.is_permutation() && !is_lt(new_rhs, target, false)) {
if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), 0, simplifier_monitor::failure_kind::PermutationGe);
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
0, simplifier_monitor::failure_kind::PermutationGe);
return false;
}
if (m_monitor)
m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq());
m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id());
return true;
}
}
@ -1099,7 +1109,7 @@ class simplifier_fn {
if (occurs(x_old, new_bi)) {
// failed, simplifier didn't manage to replace x_old with x_new
if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody);
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody);
return rewrite(e, result(e));
}
expr new_e = update_lambda(e, new_d, abstract(new_bi, x_new));
@ -1202,7 +1212,7 @@ class simplifier_fn {
expr e_type = infer_type(e);
if (is_TypeU(e_type) || !ensure_homogeneous(A, res_A)) {
if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch);
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
return result(e); // failed, we can't use subst theorem
} else {
expr H = get_proof(res_A);
@ -1237,7 +1247,7 @@ class simplifier_fn {
expr e_type = infer_type(e);
if (is_TypeU(e_type) || !ensure_homogeneous(B, res_B)) {
if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch);
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
return result(e); // failed, we can't use subst theorem
} else {
expr H = get_proof(res_B);
@ -1318,13 +1328,13 @@ class simplifier_fn {
if (occurs(x_old, new_bi)) {
// failed, simplifier didn't manage to replace x_old with x_new
if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody);
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody);
return rewrite(e, result(e));
}
expr new_e = update_pi(e, new_d, abstract(new_bi, x_new));
if (!m_proofs_enabled || is_definitionally_equal(e, new_e)) {
if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch);
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
return rewrite(e, result(new_e));
}
ensure_homogeneous(d, res_d);
@ -1364,7 +1374,7 @@ class simplifier_fn {
} else {
// We currently do simplify (forall x : A, B x) when it is not a proposition.
if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::Unsupported);
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::Unsupported);
return result(e);
}
}
@ -1374,7 +1384,7 @@ class simplifier_fn {
result new_r = r.update_expr(m_max_sharing(r.m_expr));
m_cache.insert(mk_pair(e, new_r));
if (m_monitor)
m_monitor->step_eh(m_depth, e, new_r.m_expr, new_r.m_proof);
m_monitor->step_eh(ro_simplifier(m_this), e, new_r.m_expr, new_r.m_proof);
return new_r;
} else {
return r;
@ -1395,7 +1405,7 @@ class simplifier_fn {
}
}
if (m_monitor)
m_monitor->pre_eh(m_depth, e);
m_monitor->pre_eh(ro_simplifier(m_this), e);
switch (e.kind()) {
case expr_kind::Var: return result(e);
case expr_kind::Constant: return save(e, simplify_constant(e));
@ -1445,9 +1455,9 @@ class simplifier_fn {
}
public:
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor):
m_env(env), m_tc(env), m_monitor(monitor) {
imp(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor):
m_env(env), m_options(o), m_tc(env), m_monitor(monitor) {
m_has_heq = m_env->imported("heq");
m_has_cast = m_env->imported("cast");
set_options(o);
@ -1469,10 +1479,38 @@ public:
}
};
simplifier_cell::simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor):
m_ptr(new imp(env, o, num_rs, rs, monitor)) {
}
expr_pair simplifier_cell::operator()(expr const & e, context const & ctx) { return m_ptr->operator()(e, ctx); }
void simplifier_cell::clear() { return m_ptr->m_cache.clear(); }
unsigned simplifier_cell::get_depth() const { return m_ptr->m_depth; }
context const & simplifier_cell::get_context() const { return m_ptr->m_ctx; }
ro_environment const & simplifier_cell::get_environment() const { return m_ptr->m_env; }
options const & simplifier_cell::get_options() const { return m_ptr->m_options; }
simplifier::simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor):
m_ptr(std::make_shared<simplifier_cell>(env, o, num_rs, rs, monitor)) {
m_ptr->m_ptr->m_this = m_ptr;
}
ro_simplifier::ro_simplifier(simplifier const & env):
m_ptr(env.m_ptr) {
}
ro_simplifier::ro_simplifier(weak_ref const & r) {
if (r.expired())
throw exception("weak reference to simplifier object has expired (i.e., the simplifier has been deleted)");
m_ptr = r.lock();
}
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor) {
return simplifier_fn(env, opts, num_rs, rs, monitor)(e, ctx);
return simplifier(env, opts, num_rs, rs, monitor)(e, ctx);
}
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
@ -1484,32 +1522,222 @@ expr_pair simplify(expr const & e, ro_environment const & env, context const & c
return simplify(e, env, ctx, opts, num_ns, rules.data(), monitor);
}
DECL_UDATA(simplifier)
DECL_UDATA(ro_simplifier)
/**
\brief Simplifier monitor implemented using Lua functions
*/
class lua_simplifier_monitor : public simplifier_monitor {
optional<luaref> m_pre_eh;
optional<luaref> m_step_eh;
optional<luaref> m_rewrite_eh;
optional<luaref> m_failed_app_eh;
optional<luaref> m_failed_rewrite_eh;
optional<luaref> m_failed_abstraction_eh;
public:
lua_simplifier_monitor(optional<luaref> const & pre_eh, optional<luaref> const & step_eh, optional<luaref> const & rewrite_eh,
optional<luaref> const & failed_app_eh, optional<luaref> const & failed_rewrite_eh, optional<luaref> const & failed_abstraction_eh):
m_pre_eh(pre_eh), m_step_eh(step_eh), m_rewrite_eh(rewrite_eh),
m_failed_app_eh(failed_app_eh), m_failed_rewrite_eh(failed_rewrite_eh), m_failed_abstraction_eh(failed_abstraction_eh) {
}
virtual ~lua_simplifier_monitor() {}
virtual void pre_eh(ro_simplifier const & s, expr const & e) {
if (m_pre_eh) {
lua_State * L = m_pre_eh->get_state();
m_pre_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
pcall(L, 2, 0, 0);
}
}
virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional<expr> const & pr) {
if (m_step_eh) {
lua_State * L = m_step_eh->get_state();
m_step_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
push_expr(L, new_e);
push_optional_expr(L, pr);
pcall(L, 4, 0, 0);
}
}
virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) {
if (m_rewrite_eh) {
lua_State * L = m_rewrite_eh->get_state();
m_rewrite_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
push_expr(L, new_e);
push_expr(L, ceq);
push_name(L, ceq_id);
pcall(L, 5, 0, 0);
}
}
virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) {
if (m_failed_app_eh) {
lua_State * L = m_failed_app_eh->get_state();
m_failed_app_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
lua_pushinteger(L, i);
lua_pushinteger(L, static_cast<unsigned>(k));
pcall(L, 4, 0, 0);
}
}
virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) {
if (m_failed_rewrite_eh) {
lua_State * L = m_failed_rewrite_eh->get_state();
m_failed_rewrite_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
push_expr(L, ceq);
push_name(L, ceq_id);
lua_pushinteger(L, i);
lua_pushinteger(L, static_cast<unsigned>(k));
pcall(L, 6, 0, 0);
}
}
virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) {
if (m_failed_abstraction_eh) {
lua_State * L = m_failed_abstraction_eh->get_state();
m_failed_abstraction_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
lua_pushinteger(L, static_cast<unsigned>(k));
pcall(L, 3, 0, 0);
}
}
};
typedef std::shared_ptr<simplifier_monitor> simplifier_monitor_ptr;
DECL_UDATA(simplifier_monitor_ptr)
static const struct luaL_Reg simplifier_monitor_ptr_m[] = {
{"__gc", simplifier_monitor_ptr_gc},
{0, 0}
};
static optional<luaref> get_opt_callback(lua_State * L, int i) {
if (i > lua_gettop(L) || lua_isnil(L, i)) {
return optional<luaref>();
} else {
luaL_checktype(L, i, LUA_TFUNCTION); // user-fun
return optional<luaref>(luaref(L, i));
}
}
static int mk_simplifier_monitor(lua_State * L) {
simplifier_monitor_ptr r = std::make_shared<lua_simplifier_monitor>(get_opt_callback(L, 1),
get_opt_callback(L, 2),
get_opt_callback(L, 3),
get_opt_callback(L, 4),
get_opt_callback(L, 5),
get_opt_callback(L, 6));
return push_simplifier_monitor_ptr(L, r);
}
/**
\brief Fill the the rewrite_rule_set \c rs using the object at position \c i in the Lua stack.
*/
static void get_rewrite_rule_set(lua_State * L, int i, ro_environment const & env, buffer<rewrite_rule_set> & rs) {
if (i > lua_gettop(L)) {
rs.push_back(get_rewrite_rule_set(env));
} else if (lua_isstring(L, i)) {
rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, i)));
} else {
luaL_checktype(L, i, LUA_TTABLE);
name r;
int n = objlen(L, i);
for (int j = 1; j <= n; j++) {
lua_rawgeti(L, i, j);
rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
lua_pop(L, 1);
}
}
}
static int mk_simplifier(lua_State * L, ro_environment const & env) {
int nargs = lua_gettop(L);
buffer<rewrite_rule_set> rules;
get_rewrite_rule_set(L, 1, env, rules);
options opts;
if (nargs >= 2)
opts = to_options(L, 2);
simplifier_monitor_ptr monitor;
if (nargs >= 3 && !lua_isnil(L, 3))
monitor = to_simplifier_monitor_ptr(L, 3);
return push_simplifier(L, simplifier(env, opts, rules.size(), rules.data(), monitor));
}
static int mk_simplifier(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs <= 3)
return mk_simplifier(L, ro_shared_environment(L));
else
return mk_simplifier(L, ro_shared_environment(L, 4));
}
static int simplifier_apply(lua_State * L) {
int nargs = lua_gettop(L);
expr_pair r;
if (nargs == 2)
r = to_simplifier(L, 1)(to_expr(L, 2), context());
else
r = to_simplifier(L, 1)(to_expr(L, 2), to_context(L, 3));
push_expr(L, r.first);
push_expr(L, r.second);
return 2;
}
static int simplifier_clear(lua_State * L) { to_simplifier(L, 1)->clear(); return 0; }
static int simplifier_depth(lua_State * L) { lua_pushinteger(L, to_simplifier(L, 1)->get_depth()); return 1; }
static int simplifier_context(lua_State * L) { return push_context(L, to_simplifier(L, 1)->get_context()); }
static int simplifier_environment(lua_State * L) { return push_environment(L, to_simplifier(L, 1)->get_environment()); }
static int simplifier_options(lua_State * L) { return push_options(L, to_simplifier(L, 1)->get_options()); }
static int ro_simplifier_depth(lua_State * L) { lua_pushinteger(L, to_ro_simplifier(L, 1)->get_depth()); return 1; }
static int ro_simplifier_context(lua_State * L) { return push_context(L, to_ro_simplifier(L, 1)->get_context()); }
static int ro_simplifier_environment(lua_State * L) { return push_environment(L, to_ro_simplifier(L, 1)->get_environment()); }
static int ro_simplifier_options(lua_State * L) { return push_options(L, to_ro_simplifier(L, 1)->get_options()); }
static const struct luaL_Reg simplifier_m[] = {
{"__gc", simplifier_gc},
{"__call", safe_function<simplifier_apply>},
{"clear", safe_function<simplifier_clear>},
{"depth", safe_function<simplifier_depth>},
{"get_environment", safe_function<simplifier_environment>},
{"get_context", safe_function<simplifier_context>},
{"get_options", safe_function<simplifier_options>},
{0, 0}
};
static const struct luaL_Reg ro_simplifier_m[] = {
{"__gc", ro_simplifier_gc},
{"depth", safe_function<ro_simplifier_depth>},
{"get_environment", safe_function<ro_simplifier_environment>},
{"get_context", safe_function<ro_simplifier_context>},
{"get_options", safe_function<ro_simplifier_options>},
{0, 0}
};
static int simplify_core(lua_State * L, ro_shared_environment const & env) {
int nargs = lua_gettop(L);
expr const & e = to_expr(L, 1);
buffer<rewrite_rule_set> rules;
if (nargs == 1) {
rules.push_back(get_rewrite_rule_set(env));
} else {
if (lua_isstring(L, 2)) {
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, 2)));
} else {
luaL_checktype(L, 2, LUA_TTABLE);
name r;
int n = objlen(L, 2);
for (int i = 1; i <= n; i++) {
lua_rawgeti(L, 2, i);
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
lua_pop(L, 1);
}
}
}
get_rewrite_rule_set(L, 2, env, rules);
context ctx;
options opts;
if (nargs >= 4)
ctx = to_context(L, 4);
if (nargs >= 3)
opts = to_options(L, 3);
if (nargs >= 5)
opts = to_options(L, 5);
ctx = to_context(L, 5);
auto r = simplify(e, env, ctx, opts, rules.size(), rules.data());
push_expr(L, r.first);
push_expr(L, r.second);
@ -1518,13 +1746,44 @@ static int simplify_core(lua_State * L, ro_shared_environment const & env) {
static int simplify(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs <= 2)
if (nargs <= 4)
return simplify_core(L, ro_shared_environment(L));
else
return simplify_core(L, ro_shared_environment(L, 3));
return simplify_core(L, ro_shared_environment(L, 4));
}
void open_simplifier(lua_State * L) {
luaL_newmetatable(L, simplifier_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, simplifier_m, 0);
SET_GLOBAL_FUN(simplifier_pred, "is_simplifier");
SET_GLOBAL_FUN(mk_simplifier, "simplifier");
luaL_newmetatable(L, ro_simplifier_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, ro_simplifier_m, 0);
SET_GLOBAL_FUN(ro_simplifier_pred, "is_ro_simplifier");
luaL_newmetatable(L, simplifier_monitor_ptr_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, simplifier_monitor_ptr_m, 0);
SET_GLOBAL_FUN(simplifier_monitor_ptr_pred, "is_simplifier_monitor");
SET_GLOBAL_FUN(mk_simplifier_monitor, "simplifier_monitor");
lua_newtable(L);
SET_ENUM("Unsupported", simplifier_monitor::failure_kind::Unsupported);
SET_ENUM("TypeMismatch", simplifier_monitor::failure_kind::TypeMismatch);
SET_ENUM("AssumptionNotProved", simplifier_monitor::failure_kind::AssumptionNotProved);
SET_ENUM("MissingArgument", simplifier_monitor::failure_kind::MissingArgument);
SET_ENUM("PermutationGe", simplifier_monitor::failure_kind::PermutationGe);
SET_ENUM("AbstractionBody", simplifier_monitor::failure_kind::AbstractionBody);
lua_setglobal(L, "simplifier_failure");
SET_GLOBAL_FUN(simplify, "simplify");
}
}

View file

@ -5,38 +5,79 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <memory>
#include "util/lua.h"
#include "kernel/environment.h"
#include "library/expr_pair.h"
#include "library/simplifier/rewrite_rule_set.h"
namespace lean {
class simplifier_monitor;
/** \brief Simplifier object cell. */
class simplifier_cell {
friend class simplifier;
struct imp;
std::unique_ptr<imp> m_ptr;
public:
simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor);
expr_pair operator()(expr const & e, context const & ctx);
void clear();
unsigned get_depth() const;
context const & get_context() const;
ro_environment const & get_environment() const;
options const & get_options() const;
};
/** \brief Reference to simplifier object */
class simplifier {
friend class ro_simplifier;
std::shared_ptr<simplifier_cell> m_ptr;
public:
simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor);
simplifier_cell * operator->() const { return m_ptr.get(); }
simplifier_cell & operator*() const { return *(m_ptr.get()); }
expr_pair operator()(expr const & e, context const & ctx) { return (*m_ptr)(e, ctx); }
};
/** \brief Read only reference to simplifier object */
class ro_simplifier {
std::shared_ptr<simplifier_cell const> m_ptr;
public:
typedef std::weak_ptr<simplifier_cell const> weak_ref;
ro_simplifier(simplifier const & s);
ro_simplifier(weak_ref const & s);
explicit operator weak_ref() const { return weak_ref(m_ptr); }
weak_ref to_weak_ref() const { return weak_ref(m_ptr); }
simplifier_cell const * operator->() const { return m_ptr.get(); }
simplifier_cell const & operator*() const { return *(m_ptr.get()); }
};
/**
\brief Abstract class that specifies the interface for monitoring
the behavior of the simplifier.
*/
class simplifier_monitor {
public:
virtual ~simplifier_monitor() {}
/**
\brief This method is invoked to sign that the simplifier is starting to process the expression \c e.
\remark \c depth is the recursion depth
*/
virtual void pre_eh(unsigned depth, expr const & e) = 0;
virtual void pre_eh(ro_simplifier const & s, expr const & e) = 0;
/**
\brief This method is invoked to sign that \c e has be rewritten into \c new_e with proof \c pr.
The proof is none if proof generation is disabled or if \c e and \c new_e are definitionally equal.
\remark \c depth is the recursion depth
*/
virtual void step_eh(unsigned depth, expr const & e, expr const & new_e, optional<expr> const & pr) = 0;
virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional<expr> const & pr) = 0;
/**
\brief This method is invoked to sign that \c e has be rewritten into \c new_e using the conditional equation \c ceq.
\remark \c depth is the recursion depth
*/
virtual void rewrite_eh(unsigned depth, expr const & e, expr const & new_e, expr const & ceq) = 0;
virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) = 0;
enum class failure_kind { Unsupported, TypeMismatch, AssumptionNotProved, MissingArgument, PermutationGe, AbstractionBody };
@ -44,10 +85,8 @@ public:
\brief This method is invoked when the simplifier fails to rewrite an application \c e.
\c i is the argument where the simplifier gave up, and \c k is the reason for failure.
Two possible values are: Unsupported or TypeMismatch (may happen when simplifying terms that use dependent types).
\remark \c depth is the recursion depth
*/
virtual void failed_app_eh(unsigned depth, expr const & e, unsigned i, failure_kind k) = 0;
virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) = 0;
/**
\brief This method is invoked when the simplifier fails to apply a conditional equation \c ceq to \c e.
@ -55,19 +94,15 @@ public:
The possible failure values are: AssumptionNotProved (failed to synthesize a proof for an assumption required by \c ceq),
MissingArgument (failed to infer one of the arguments needed by the conditional equation), PermutationGe
(the conditional equation is a permutation, and the result is not smaller in the term ordering, \c i is irrelevant in this case).
\remark \c depth is the recursion depth
*/
virtual void failed_rewrite_eh(unsigned depth, expr const & e, expr const & ceq, unsigned i, failure_kind k) = 0;
virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) = 0;
/**
\brief This method is invoked when the simplifier fails to simplify an abstraction (Pi or Lambda).
The possible failure values are: Unsupported, TypeMismatch, and AbstractionBody (failed to rewrite the body of the abstraction,
this may happen when we are using dependent types).
\remark \c depth is the recursion depth
*/
virtual void failed_abstraction_eh(unsigned depth, expr const & e, failure_kind k) = 0;
virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) = 0;
};
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts,

View file

@ -4,7 +4,7 @@ add_rewrite Nat::add_assoc Nat::add_comm Nat::add_left_comm Nat::distributer Nat
(*
local opts = options({"simplifier", "max_steps"}, 100)
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
local t2, pr = simplify(t, "simple", opts)
*)
print "trying again with more steps"
@ -12,7 +12,7 @@ print "trying again with more steps"
(*
local opts = options({"simplifier", "max_steps"}, 100000)
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
local t2, pr = simplify(t, "simple", opts)
print(t)
print("====>")
print(t2)

View file

@ -6,7 +6,7 @@ local opts = options({"simplifier", "single_pass"}, true)
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
print(t)
for i = 1, 10 do
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
local t2, pr = simplify(t, "simple", opts)
print("i: " .. i .. " ====>")
print(t2)
if t == t2 then

View file

@ -10,7 +10,7 @@ variables a b c : Nat
(*
local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean([[a = 1 ∧ (¬ b = 0 c ≠ 0 b + c > a)]])
local s, pr = simplify(t, "simple", get_environment(), context(), opts)
local s, pr = simplify(t, "simple", opts)
print(s)
print(pr)
print(get_environment():type_check(pr))

View file

@ -4,7 +4,7 @@ variables a b : Nat
(*
local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean('λ x, a = a → x = a')
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
local t2, pr = simplify(t, "simple", opts)
print(t2)
print(pr)
get_environment():type_check(pr)
@ -13,7 +13,7 @@ get_environment():type_check(pr)
(*
local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean('λ x, x = a → x = x')
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
local t2, pr = simplify(t, "simple", opts)
print(t2)
print(pr)
get_environment():type_check(pr)
@ -22,7 +22,7 @@ get_environment():type_check(pr)
(*
local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean('λ x, x = a + 0 → a = a')
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
local t2, pr = simplify(t, "simple", opts)
print(t2)
print(pr)
*)

View file

@ -15,7 +15,7 @@ show(simplify(t4))
(*
local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false)
local t1 = parse_lean("double (double 2) + 1 ≥ 3")
show(simplify(t1, 'default', get_environment(), context(), opt))
show(simplify(t1, 'default', opt))
*)
set_opaque Nat::ge false
@ -26,7 +26,7 @@ add_rewrite Nat::distributel
(*
local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false)
local t1 = parse_lean("2 * double (double 2) + 1 ≥ 3")
show(simplify(t1, 'default', get_environment(), context(), opt))
show(simplify(t1, 'default', opt))
*)
variables a b c d : Nat

29
tests/lean/simp31.lean Normal file
View file

@ -0,0 +1,29 @@
rewrite_set simple
add_rewrite Nat::add_comm Nat::add_left_comm Nat::add_assoc Nat::add_zeror : simple
variables a b c : Nat
(*
function indent(s)
for i = 1, s:depth()-1 do
io.write(" ")
end
end
local m = simplifier_monitor(function(s, e)
print("Visit, depth: " .. s:depth() .. ", " .. tostring(e))
end,
function(s, e, new_e, pr)
print("Step: " .. tostring(e) .. " ===> " .. tostring(new_e))
end,
function(s, e, new_e, ceq, ceq_id)
print("Rewrite using: " .. tostring(ceq_id))
print(" " .. tostring(e) .. " ===> " .. tostring(new_e))
end
)
local s = simplifier("simple", options(), m)
local t = parse_lean('a + (b + 0) + a')
print(t)
print("=====>")
local t2, pr = s(t)
print(t2)
print(pr)
get_environment():type_check(pr)
*)

View file

@ -0,0 +1,42 @@
Set: pp::colors
Set: pp::unicode
Assumed: a
Assumed: b
Assumed: c
a + (b + 0) + a
=====>
Visit, depth: 1, a + (b + 0) + a
Visit, depth: 2, Nat::add
Visit, depth: 2, a + (b + 0)
Visit, depth: 3, Nat::add
Visit, depth: 3, a
Step: a ===> a
Visit, depth: 3, b + 0
Visit, depth: 4, Nat::add
Visit, depth: 4, b
Step: b ===> b
Visit, depth: 4, 0
Rewrite using: Nat::add_zeror
b + 0 ===> b
Step: b + 0 ===> b
Visit, depth: 3, a + b
Visit, depth: 4, Nat::add
Step: a + b ===> a + b
Step: a + (b + 0) ===> a + b
Rewrite using: Nat::add_assoc
a + b + a ===> a + (b + a)
Visit, depth: 2, a + (b + a)
Visit, depth: 3, Nat::add
Visit, depth: 3, b + a
Visit, depth: 4, Nat::add
Rewrite using: Nat::add_comm
b + a ===> a + b
Step: b + a ===> a + b
Visit, depth: 3, a + (a + b)
Visit, depth: 4, Nat::add
Step: a + (a + b) ===> a + (a + b)
Step: a + (b + a) ===> a + (a + b)
Step: a + (b + 0) + a ===> a + (a + b)
a + (a + b)
trans (trans (congr1 a (congr2 Nat::add (congr2 (Nat::add a) (Nat::add_zeror b)))) (Nat::add_assoc a b a))
(congr2 (Nat::add a) (Nat::add_comm b a))