feat(library/simplifier): expose simplier and simplifier_monitor objects in the Lua API

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-27 15:02:05 -08:00
parent c088825ef0
commit 8bccfb947a
9 changed files with 441 additions and 76 deletions

View file

@ -9,6 +9,8 @@ Author: Leonardo de Moura
#include "util/flet.h" #include "util/flet.h"
#include "util/freset.h" #include "util/freset.h"
#include "util/interrupt.h" #include "util/interrupt.h"
#include "util/luaref.h"
#include "util/script_state.h"
#include "kernel/type_checker.h" #include "kernel/type_checker.h"
#include "kernel/free_vars.h" #include "kernel/free_vars.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
@ -106,7 +108,9 @@ static name g_H("H");
static name g_x("x"); static name g_x("x");
static name g_unique = name::mk_internal_unique_name(); static name g_unique = name::mk_internal_unique_name();
class simplifier_fn { class simplifier_cell::imp {
friend class simplifier_cell;
friend class simplifier;
struct result { struct result {
expr m_expr; // the result of a simplification step expr m_expr; // the result of a simplification step
optional<expr> m_proof; // a proof that the result is equal to the input (when m_proofs_enabled) optional<expr> m_proof; // a proof that the result is equal to the input (when m_proofs_enabled)
@ -126,7 +130,9 @@ class simplifier_fn {
typedef expr_map<result> cache; typedef expr_map<result> cache;
typedef std::vector<congr_theorem_info const *> congr_thms; typedef std::vector<congr_theorem_info const *> congr_thms;
typedef cache const_map; typedef cache const_map;
std::weak_ptr<simplifier_cell> m_this;
ro_environment m_env; ro_environment m_env;
options m_options;
type_checker m_tc; type_checker m_tc;
bool m_has_heq; bool m_has_heq;
bool m_has_cast; bool m_has_cast;
@ -154,7 +160,7 @@ class simplifier_fn {
unsigned m_max_steps; unsigned m_max_steps;
struct updt_rule_set { struct updt_rule_set {
simplifier_fn & m_fn; imp & m_fn;
rewrite_rule_set m_old; rewrite_rule_set m_old;
freset<cache> m_reset_cache; // must reset the cache whenever we update the rule set. freset<cache> m_reset_cache; // must reset the cache whenever we update the rule set.
/** /**
@ -162,7 +168,7 @@ class simplifier_fn {
\pre const_type(H) \pre const_type(H)
*/ */
updt_rule_set(simplifier_fn & fn, expr const & H): updt_rule_set(imp & fn, expr const & H):
m_fn(fn), m_old(m_fn.m_rule_sets[0]), m_reset_cache(m_fn.m_cache) { m_fn(fn), m_old(m_fn.m_rule_sets[0]), m_reset_cache(m_fn.m_cache) {
lean_assert(const_type(H)); lean_assert(const_type(H));
m_fn.m_rule_sets[0].insert(g_local, *const_type(H), H); m_fn.m_rule_sets[0].insert(g_local, *const_type(H), H);
@ -174,9 +180,9 @@ class simplifier_fn {
}; };
struct updt_const_map { struct updt_const_map {
simplifier_fn & m_fn; imp & m_fn;
expr const & m_old_x; expr const & m_old_x;
updt_const_map(simplifier_fn & fn, expr const & old_x, expr const & new_x, expr const & H): updt_const_map(imp & fn, expr const & old_x, expr const & new_x, expr const & H):
m_fn(fn), m_old_x(old_x) { m_fn(fn), m_old_x(old_x) {
m_fn.m_const_map[old_x] = result(new_x, H, true); m_fn.m_const_map[old_x] = result(new_x, H, true);
} }
@ -304,7 +310,7 @@ class simplifier_fn {
if (is_TypeU(A)) { if (is_TypeU(A)) {
if (!is_definitionally_equal(f, new_f)) { if (!is_definitionally_equal(f, new_f)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported); m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported);
return none_expr(); // can't handle return none_expr(); // can't handle
} }
// The congruence axiom cannot be used in this case. // The congruence axiom cannot be used in this case.
@ -324,7 +330,7 @@ class simplifier_fn {
expr a_type = infer_type(a); expr a_type = infer_type(a);
if (!is_convertible(a_type, A)) { if (!is_convertible(a_type, A)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch); m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch);
return none_expr(); // can't handle return none_expr(); // can't handle
} }
expr a_prime = new_a.m_expr; expr a_prime = new_a.m_expr;
@ -347,7 +353,7 @@ class simplifier_fn {
expr new_a_type = infer_type(new_a.m_expr); expr new_a_type = infer_type(new_a.m_expr);
if (!is_convertible(new_a_type, new_A)) { if (!is_convertible(new_a_type, new_A)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch); m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch);
return none_expr(); // failed return none_expr(); // failed
} }
expr Heq_a = get_proof(new_a); expr Heq_a = get_proof(new_a);
@ -359,7 +365,7 @@ class simplifier_fn {
is_heq_proof = false; is_heq_proof = false;
} else { } else {
if (m_monitor) if (m_monitor)
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported); m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported);
return none_expr(); // we don't know how to handle this case return none_expr(); // we don't know how to handle this case
} }
} }
@ -867,7 +873,7 @@ class simplifier_fn {
} }
} }
if (m_monitor) if (m_monitor)
m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq()); m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id());
return true; return true;
} else { } else {
// Conditional rewriting: we try to fill the missing // Conditional rewriting: we try to fill the missing
@ -906,19 +912,22 @@ class simplifier_fn {
// but proof generation is not enabled. // but proof generation is not enabled.
// So, we should fail // So, we should fail
if (m_monitor) if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::Unsupported); m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
i, simplifier_monitor::failure_kind::Unsupported);
return false; return false;
} }
} else { } else {
// failed to prove proposition // failed to prove proposition
if (m_monitor) if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::AssumptionNotProved); m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
i, simplifier_monitor::failure_kind::AssumptionNotProved);
return false; return false;
} }
} else { } else {
// failed, the argument is not a proposition // failed, the argument is not a proposition
if (m_monitor) if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::MissingArgument); m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
i, simplifier_monitor::failure_kind::MissingArgument);
return false; return false;
} }
} }
@ -927,11 +936,12 @@ class simplifier_fn {
new_rhs = arg(ceq, num_args(ceq) - 1); new_rhs = arg(ceq, num_args(ceq) - 1);
if (rule.is_permutation() && !is_lt(new_rhs, target, false)) { if (rule.is_permutation() && !is_lt(new_rhs, target, false)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), 0, simplifier_monitor::failure_kind::PermutationGe); m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
0, simplifier_monitor::failure_kind::PermutationGe);
return false; return false;
} }
if (m_monitor) if (m_monitor)
m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq()); m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id());
return true; return true;
} }
} }
@ -1099,7 +1109,7 @@ class simplifier_fn {
if (occurs(x_old, new_bi)) { if (occurs(x_old, new_bi)) {
// failed, simplifier didn't manage to replace x_old with x_new // failed, simplifier didn't manage to replace x_old with x_new
if (m_monitor) if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody); m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody);
return rewrite(e, result(e)); return rewrite(e, result(e));
} }
expr new_e = update_lambda(e, new_d, abstract(new_bi, x_new)); expr new_e = update_lambda(e, new_d, abstract(new_bi, x_new));
@ -1202,7 +1212,7 @@ class simplifier_fn {
expr e_type = infer_type(e); expr e_type = infer_type(e);
if (is_TypeU(e_type) || !ensure_homogeneous(A, res_A)) { if (is_TypeU(e_type) || !ensure_homogeneous(A, res_A)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch); m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
return result(e); // failed, we can't use subst theorem return result(e); // failed, we can't use subst theorem
} else { } else {
expr H = get_proof(res_A); expr H = get_proof(res_A);
@ -1237,7 +1247,7 @@ class simplifier_fn {
expr e_type = infer_type(e); expr e_type = infer_type(e);
if (is_TypeU(e_type) || !ensure_homogeneous(B, res_B)) { if (is_TypeU(e_type) || !ensure_homogeneous(B, res_B)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch); m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
return result(e); // failed, we can't use subst theorem return result(e); // failed, we can't use subst theorem
} else { } else {
expr H = get_proof(res_B); expr H = get_proof(res_B);
@ -1318,13 +1328,13 @@ class simplifier_fn {
if (occurs(x_old, new_bi)) { if (occurs(x_old, new_bi)) {
// failed, simplifier didn't manage to replace x_old with x_new // failed, simplifier didn't manage to replace x_old with x_new
if (m_monitor) if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody); m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody);
return rewrite(e, result(e)); return rewrite(e, result(e));
} }
expr new_e = update_pi(e, new_d, abstract(new_bi, x_new)); expr new_e = update_pi(e, new_d, abstract(new_bi, x_new));
if (!m_proofs_enabled || is_definitionally_equal(e, new_e)) { if (!m_proofs_enabled || is_definitionally_equal(e, new_e)) {
if (m_monitor) if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch); m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
return rewrite(e, result(new_e)); return rewrite(e, result(new_e));
} }
ensure_homogeneous(d, res_d); ensure_homogeneous(d, res_d);
@ -1364,7 +1374,7 @@ class simplifier_fn {
} else { } else {
// We currently do simplify (forall x : A, B x) when it is not a proposition. // We currently do simplify (forall x : A, B x) when it is not a proposition.
if (m_monitor) if (m_monitor)
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::Unsupported); m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::Unsupported);
return result(e); return result(e);
} }
} }
@ -1374,7 +1384,7 @@ class simplifier_fn {
result new_r = r.update_expr(m_max_sharing(r.m_expr)); result new_r = r.update_expr(m_max_sharing(r.m_expr));
m_cache.insert(mk_pair(e, new_r)); m_cache.insert(mk_pair(e, new_r));
if (m_monitor) if (m_monitor)
m_monitor->step_eh(m_depth, e, new_r.m_expr, new_r.m_proof); m_monitor->step_eh(ro_simplifier(m_this), e, new_r.m_expr, new_r.m_proof);
return new_r; return new_r;
} else { } else {
return r; return r;
@ -1395,7 +1405,7 @@ class simplifier_fn {
} }
} }
if (m_monitor) if (m_monitor)
m_monitor->pre_eh(m_depth, e); m_monitor->pre_eh(ro_simplifier(m_this), e);
switch (e.kind()) { switch (e.kind()) {
case expr_kind::Var: return result(e); case expr_kind::Var: return result(e);
case expr_kind::Constant: return save(e, simplify_constant(e)); case expr_kind::Constant: return save(e, simplify_constant(e));
@ -1445,9 +1455,9 @@ class simplifier_fn {
} }
public: public:
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, imp(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor): std::shared_ptr<simplifier_monitor> const & monitor):
m_env(env), m_tc(env), m_monitor(monitor) { m_env(env), m_options(o), m_tc(env), m_monitor(monitor) {
m_has_heq = m_env->imported("heq"); m_has_heq = m_env->imported("heq");
m_has_cast = m_env->imported("cast"); m_has_cast = m_env->imported("cast");
set_options(o); set_options(o);
@ -1469,10 +1479,38 @@ public:
} }
}; };
simplifier_cell::simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor):
m_ptr(new imp(env, o, num_rs, rs, monitor)) {
}
expr_pair simplifier_cell::operator()(expr const & e, context const & ctx) { return m_ptr->operator()(e, ctx); }
void simplifier_cell::clear() { return m_ptr->m_cache.clear(); }
unsigned simplifier_cell::get_depth() const { return m_ptr->m_depth; }
context const & simplifier_cell::get_context() const { return m_ptr->m_ctx; }
ro_environment const & simplifier_cell::get_environment() const { return m_ptr->m_env; }
options const & simplifier_cell::get_options() const { return m_ptr->m_options; }
simplifier::simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor):
m_ptr(std::make_shared<simplifier_cell>(env, o, num_rs, rs, monitor)) {
m_ptr->m_ptr->m_this = m_ptr;
}
ro_simplifier::ro_simplifier(simplifier const & env):
m_ptr(env.m_ptr) {
}
ro_simplifier::ro_simplifier(weak_ref const & r) {
if (r.expired())
throw exception("weak reference to simplifier object has expired (i.e., the simplifier has been deleted)");
m_ptr = r.lock();
}
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
unsigned num_rs, rewrite_rule_set const * rs, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor) { std::shared_ptr<simplifier_monitor> const & monitor) {
return simplifier_fn(env, opts, num_rs, rs, monitor)(e, ctx); return simplifier(env, opts, num_rs, rs, monitor)(e, ctx);
} }
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
@ -1484,32 +1522,222 @@ expr_pair simplify(expr const & e, ro_environment const & env, context const & c
return simplify(e, env, ctx, opts, num_ns, rules.data(), monitor); return simplify(e, env, ctx, opts, num_ns, rules.data(), monitor);
} }
DECL_UDATA(simplifier)
DECL_UDATA(ro_simplifier)
/**
\brief Simplifier monitor implemented using Lua functions
*/
class lua_simplifier_monitor : public simplifier_monitor {
optional<luaref> m_pre_eh;
optional<luaref> m_step_eh;
optional<luaref> m_rewrite_eh;
optional<luaref> m_failed_app_eh;
optional<luaref> m_failed_rewrite_eh;
optional<luaref> m_failed_abstraction_eh;
public:
lua_simplifier_monitor(optional<luaref> const & pre_eh, optional<luaref> const & step_eh, optional<luaref> const & rewrite_eh,
optional<luaref> const & failed_app_eh, optional<luaref> const & failed_rewrite_eh, optional<luaref> const & failed_abstraction_eh):
m_pre_eh(pre_eh), m_step_eh(step_eh), m_rewrite_eh(rewrite_eh),
m_failed_app_eh(failed_app_eh), m_failed_rewrite_eh(failed_rewrite_eh), m_failed_abstraction_eh(failed_abstraction_eh) {
}
virtual ~lua_simplifier_monitor() {}
virtual void pre_eh(ro_simplifier const & s, expr const & e) {
if (m_pre_eh) {
lua_State * L = m_pre_eh->get_state();
m_pre_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
pcall(L, 2, 0, 0);
}
}
virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional<expr> const & pr) {
if (m_step_eh) {
lua_State * L = m_step_eh->get_state();
m_step_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
push_expr(L, new_e);
push_optional_expr(L, pr);
pcall(L, 4, 0, 0);
}
}
virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) {
if (m_rewrite_eh) {
lua_State * L = m_rewrite_eh->get_state();
m_rewrite_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
push_expr(L, new_e);
push_expr(L, ceq);
push_name(L, ceq_id);
pcall(L, 5, 0, 0);
}
}
virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) {
if (m_failed_app_eh) {
lua_State * L = m_failed_app_eh->get_state();
m_failed_app_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
lua_pushinteger(L, i);
lua_pushinteger(L, static_cast<unsigned>(k));
pcall(L, 4, 0, 0);
}
}
virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) {
if (m_failed_rewrite_eh) {
lua_State * L = m_failed_rewrite_eh->get_state();
m_failed_rewrite_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
push_expr(L, ceq);
push_name(L, ceq_id);
lua_pushinteger(L, i);
lua_pushinteger(L, static_cast<unsigned>(k));
pcall(L, 6, 0, 0);
}
}
virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) {
if (m_failed_abstraction_eh) {
lua_State * L = m_failed_abstraction_eh->get_state();
m_failed_abstraction_eh->push();
push_ro_simplifier(L, s);
push_expr(L, e);
lua_pushinteger(L, static_cast<unsigned>(k));
pcall(L, 3, 0, 0);
}
}
};
typedef std::shared_ptr<simplifier_monitor> simplifier_monitor_ptr;
DECL_UDATA(simplifier_monitor_ptr)
static const struct luaL_Reg simplifier_monitor_ptr_m[] = {
{"__gc", simplifier_monitor_ptr_gc},
{0, 0}
};
static optional<luaref> get_opt_callback(lua_State * L, int i) {
if (i > lua_gettop(L) || lua_isnil(L, i)) {
return optional<luaref>();
} else {
luaL_checktype(L, i, LUA_TFUNCTION); // user-fun
return optional<luaref>(luaref(L, i));
}
}
static int mk_simplifier_monitor(lua_State * L) {
simplifier_monitor_ptr r = std::make_shared<lua_simplifier_monitor>(get_opt_callback(L, 1),
get_opt_callback(L, 2),
get_opt_callback(L, 3),
get_opt_callback(L, 4),
get_opt_callback(L, 5),
get_opt_callback(L, 6));
return push_simplifier_monitor_ptr(L, r);
}
/**
\brief Fill the the rewrite_rule_set \c rs using the object at position \c i in the Lua stack.
*/
static void get_rewrite_rule_set(lua_State * L, int i, ro_environment const & env, buffer<rewrite_rule_set> & rs) {
if (i > lua_gettop(L)) {
rs.push_back(get_rewrite_rule_set(env));
} else if (lua_isstring(L, i)) {
rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, i)));
} else {
luaL_checktype(L, i, LUA_TTABLE);
name r;
int n = objlen(L, i);
for (int j = 1; j <= n; j++) {
lua_rawgeti(L, i, j);
rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
lua_pop(L, 1);
}
}
}
static int mk_simplifier(lua_State * L, ro_environment const & env) {
int nargs = lua_gettop(L);
buffer<rewrite_rule_set> rules;
get_rewrite_rule_set(L, 1, env, rules);
options opts;
if (nargs >= 2)
opts = to_options(L, 2);
simplifier_monitor_ptr monitor;
if (nargs >= 3 && !lua_isnil(L, 3))
monitor = to_simplifier_monitor_ptr(L, 3);
return push_simplifier(L, simplifier(env, opts, rules.size(), rules.data(), monitor));
}
static int mk_simplifier(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs <= 3)
return mk_simplifier(L, ro_shared_environment(L));
else
return mk_simplifier(L, ro_shared_environment(L, 4));
}
static int simplifier_apply(lua_State * L) {
int nargs = lua_gettop(L);
expr_pair r;
if (nargs == 2)
r = to_simplifier(L, 1)(to_expr(L, 2), context());
else
r = to_simplifier(L, 1)(to_expr(L, 2), to_context(L, 3));
push_expr(L, r.first);
push_expr(L, r.second);
return 2;
}
static int simplifier_clear(lua_State * L) { to_simplifier(L, 1)->clear(); return 0; }
static int simplifier_depth(lua_State * L) { lua_pushinteger(L, to_simplifier(L, 1)->get_depth()); return 1; }
static int simplifier_context(lua_State * L) { return push_context(L, to_simplifier(L, 1)->get_context()); }
static int simplifier_environment(lua_State * L) { return push_environment(L, to_simplifier(L, 1)->get_environment()); }
static int simplifier_options(lua_State * L) { return push_options(L, to_simplifier(L, 1)->get_options()); }
static int ro_simplifier_depth(lua_State * L) { lua_pushinteger(L, to_ro_simplifier(L, 1)->get_depth()); return 1; }
static int ro_simplifier_context(lua_State * L) { return push_context(L, to_ro_simplifier(L, 1)->get_context()); }
static int ro_simplifier_environment(lua_State * L) { return push_environment(L, to_ro_simplifier(L, 1)->get_environment()); }
static int ro_simplifier_options(lua_State * L) { return push_options(L, to_ro_simplifier(L, 1)->get_options()); }
static const struct luaL_Reg simplifier_m[] = {
{"__gc", simplifier_gc},
{"__call", safe_function<simplifier_apply>},
{"clear", safe_function<simplifier_clear>},
{"depth", safe_function<simplifier_depth>},
{"get_environment", safe_function<simplifier_environment>},
{"get_context", safe_function<simplifier_context>},
{"get_options", safe_function<simplifier_options>},
{0, 0}
};
static const struct luaL_Reg ro_simplifier_m[] = {
{"__gc", ro_simplifier_gc},
{"depth", safe_function<ro_simplifier_depth>},
{"get_environment", safe_function<ro_simplifier_environment>},
{"get_context", safe_function<ro_simplifier_context>},
{"get_options", safe_function<ro_simplifier_options>},
{0, 0}
};
static int simplify_core(lua_State * L, ro_shared_environment const & env) { static int simplify_core(lua_State * L, ro_shared_environment const & env) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
expr const & e = to_expr(L, 1); expr const & e = to_expr(L, 1);
buffer<rewrite_rule_set> rules; buffer<rewrite_rule_set> rules;
if (nargs == 1) { get_rewrite_rule_set(L, 2, env, rules);
rules.push_back(get_rewrite_rule_set(env));
} else {
if (lua_isstring(L, 2)) {
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, 2)));
} else {
luaL_checktype(L, 2, LUA_TTABLE);
name r;
int n = objlen(L, 2);
for (int i = 1; i <= n; i++) {
lua_rawgeti(L, 2, i);
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
lua_pop(L, 1);
}
}
}
context ctx; context ctx;
options opts; options opts;
if (nargs >= 4) if (nargs >= 3)
ctx = to_context(L, 4); opts = to_options(L, 3);
if (nargs >= 5) if (nargs >= 5)
opts = to_options(L, 5); ctx = to_context(L, 5);
auto r = simplify(e, env, ctx, opts, rules.size(), rules.data()); auto r = simplify(e, env, ctx, opts, rules.size(), rules.data());
push_expr(L, r.first); push_expr(L, r.first);
push_expr(L, r.second); push_expr(L, r.second);
@ -1518,13 +1746,44 @@ static int simplify_core(lua_State * L, ro_shared_environment const & env) {
static int simplify(lua_State * L) { static int simplify(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
if (nargs <= 2) if (nargs <= 4)
return simplify_core(L, ro_shared_environment(L)); return simplify_core(L, ro_shared_environment(L));
else else
return simplify_core(L, ro_shared_environment(L, 3)); return simplify_core(L, ro_shared_environment(L, 4));
} }
void open_simplifier(lua_State * L) { void open_simplifier(lua_State * L) {
luaL_newmetatable(L, simplifier_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, simplifier_m, 0);
SET_GLOBAL_FUN(simplifier_pred, "is_simplifier");
SET_GLOBAL_FUN(mk_simplifier, "simplifier");
luaL_newmetatable(L, ro_simplifier_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, ro_simplifier_m, 0);
SET_GLOBAL_FUN(ro_simplifier_pred, "is_ro_simplifier");
luaL_newmetatable(L, simplifier_monitor_ptr_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, simplifier_monitor_ptr_m, 0);
SET_GLOBAL_FUN(simplifier_monitor_ptr_pred, "is_simplifier_monitor");
SET_GLOBAL_FUN(mk_simplifier_monitor, "simplifier_monitor");
lua_newtable(L);
SET_ENUM("Unsupported", simplifier_monitor::failure_kind::Unsupported);
SET_ENUM("TypeMismatch", simplifier_monitor::failure_kind::TypeMismatch);
SET_ENUM("AssumptionNotProved", simplifier_monitor::failure_kind::AssumptionNotProved);
SET_ENUM("MissingArgument", simplifier_monitor::failure_kind::MissingArgument);
SET_ENUM("PermutationGe", simplifier_monitor::failure_kind::PermutationGe);
SET_ENUM("AbstractionBody", simplifier_monitor::failure_kind::AbstractionBody);
lua_setglobal(L, "simplifier_failure");
SET_GLOBAL_FUN(simplify, "simplify"); SET_GLOBAL_FUN(simplify, "simplify");
} }
} }

View file

@ -5,38 +5,79 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura Author: Leonardo de Moura
*/ */
#pragma once #pragma once
#include <memory>
#include "util/lua.h" #include "util/lua.h"
#include "kernel/environment.h" #include "kernel/environment.h"
#include "library/expr_pair.h" #include "library/expr_pair.h"
#include "library/simplifier/rewrite_rule_set.h" #include "library/simplifier/rewrite_rule_set.h"
namespace lean { namespace lean {
class simplifier_monitor;
/** \brief Simplifier object cell. */
class simplifier_cell {
friend class simplifier;
struct imp;
std::unique_ptr<imp> m_ptr;
public:
simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor);
expr_pair operator()(expr const & e, context const & ctx);
void clear();
unsigned get_depth() const;
context const & get_context() const;
ro_environment const & get_environment() const;
options const & get_options() const;
};
/** \brief Reference to simplifier object */
class simplifier {
friend class ro_simplifier;
std::shared_ptr<simplifier_cell> m_ptr;
public:
simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
std::shared_ptr<simplifier_monitor> const & monitor);
simplifier_cell * operator->() const { return m_ptr.get(); }
simplifier_cell & operator*() const { return *(m_ptr.get()); }
expr_pair operator()(expr const & e, context const & ctx) { return (*m_ptr)(e, ctx); }
};
/** \brief Read only reference to simplifier object */
class ro_simplifier {
std::shared_ptr<simplifier_cell const> m_ptr;
public:
typedef std::weak_ptr<simplifier_cell const> weak_ref;
ro_simplifier(simplifier const & s);
ro_simplifier(weak_ref const & s);
explicit operator weak_ref() const { return weak_ref(m_ptr); }
weak_ref to_weak_ref() const { return weak_ref(m_ptr); }
simplifier_cell const * operator->() const { return m_ptr.get(); }
simplifier_cell const & operator*() const { return *(m_ptr.get()); }
};
/** /**
\brief Abstract class that specifies the interface for monitoring \brief Abstract class that specifies the interface for monitoring
the behavior of the simplifier. the behavior of the simplifier.
*/ */
class simplifier_monitor { class simplifier_monitor {
public: public:
virtual ~simplifier_monitor() {}
/** /**
\brief This method is invoked to sign that the simplifier is starting to process the expression \c e. \brief This method is invoked to sign that the simplifier is starting to process the expression \c e.
\remark \c depth is the recursion depth
*/ */
virtual void pre_eh(unsigned depth, expr const & e) = 0; virtual void pre_eh(ro_simplifier const & s, expr const & e) = 0;
/** /**
\brief This method is invoked to sign that \c e has be rewritten into \c new_e with proof \c pr. \brief This method is invoked to sign that \c e has be rewritten into \c new_e with proof \c pr.
The proof is none if proof generation is disabled or if \c e and \c new_e are definitionally equal. The proof is none if proof generation is disabled or if \c e and \c new_e are definitionally equal.
\remark \c depth is the recursion depth
*/ */
virtual void step_eh(unsigned depth, expr const & e, expr const & new_e, optional<expr> const & pr) = 0; virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional<expr> const & pr) = 0;
/** /**
\brief This method is invoked to sign that \c e has be rewritten into \c new_e using the conditional equation \c ceq. \brief This method is invoked to sign that \c e has be rewritten into \c new_e using the conditional equation \c ceq.
\remark \c depth is the recursion depth
*/ */
virtual void rewrite_eh(unsigned depth, expr const & e, expr const & new_e, expr const & ceq) = 0; virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) = 0;
enum class failure_kind { Unsupported, TypeMismatch, AssumptionNotProved, MissingArgument, PermutationGe, AbstractionBody }; enum class failure_kind { Unsupported, TypeMismatch, AssumptionNotProved, MissingArgument, PermutationGe, AbstractionBody };
@ -44,10 +85,8 @@ public:
\brief This method is invoked when the simplifier fails to rewrite an application \c e. \brief This method is invoked when the simplifier fails to rewrite an application \c e.
\c i is the argument where the simplifier gave up, and \c k is the reason for failure. \c i is the argument where the simplifier gave up, and \c k is the reason for failure.
Two possible values are: Unsupported or TypeMismatch (may happen when simplifying terms that use dependent types). Two possible values are: Unsupported or TypeMismatch (may happen when simplifying terms that use dependent types).
\remark \c depth is the recursion depth
*/ */
virtual void failed_app_eh(unsigned depth, expr const & e, unsigned i, failure_kind k) = 0; virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) = 0;
/** /**
\brief This method is invoked when the simplifier fails to apply a conditional equation \c ceq to \c e. \brief This method is invoked when the simplifier fails to apply a conditional equation \c ceq to \c e.
@ -55,19 +94,15 @@ public:
The possible failure values are: AssumptionNotProved (failed to synthesize a proof for an assumption required by \c ceq), The possible failure values are: AssumptionNotProved (failed to synthesize a proof for an assumption required by \c ceq),
MissingArgument (failed to infer one of the arguments needed by the conditional equation), PermutationGe MissingArgument (failed to infer one of the arguments needed by the conditional equation), PermutationGe
(the conditional equation is a permutation, and the result is not smaller in the term ordering, \c i is irrelevant in this case). (the conditional equation is a permutation, and the result is not smaller in the term ordering, \c i is irrelevant in this case).
\remark \c depth is the recursion depth
*/ */
virtual void failed_rewrite_eh(unsigned depth, expr const & e, expr const & ceq, unsigned i, failure_kind k) = 0; virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) = 0;
/** /**
\brief This method is invoked when the simplifier fails to simplify an abstraction (Pi or Lambda). \brief This method is invoked when the simplifier fails to simplify an abstraction (Pi or Lambda).
The possible failure values are: Unsupported, TypeMismatch, and AbstractionBody (failed to rewrite the body of the abstraction, The possible failure values are: Unsupported, TypeMismatch, and AbstractionBody (failed to rewrite the body of the abstraction,
this may happen when we are using dependent types). this may happen when we are using dependent types).
\remark \c depth is the recursion depth
*/ */
virtual void failed_abstraction_eh(unsigned depth, expr const & e, failure_kind k) = 0; virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) = 0;
}; };
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts, expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts,

View file

@ -4,7 +4,7 @@ add_rewrite Nat::add_assoc Nat::add_comm Nat::add_left_comm Nat::distributer Nat
(* (*
local opts = options({"simplifier", "max_steps"}, 100) local opts = options({"simplifier", "max_steps"}, 100)
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))") local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
local t2, pr = simplify(t, "simple", get_environment(), context(), opts) local t2, pr = simplify(t, "simple", opts)
*) *)
print "trying again with more steps" print "trying again with more steps"
@ -12,7 +12,7 @@ print "trying again with more steps"
(* (*
local opts = options({"simplifier", "max_steps"}, 100000) local opts = options({"simplifier", "max_steps"}, 100000)
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))") local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
local t2, pr = simplify(t, "simple", get_environment(), context(), opts) local t2, pr = simplify(t, "simple", opts)
print(t) print(t)
print("====>") print("====>")
print(t2) print(t2)

View file

@ -6,7 +6,7 @@ local opts = options({"simplifier", "single_pass"}, true)
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))") local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
print(t) print(t)
for i = 1, 10 do for i = 1, 10 do
local t2, pr = simplify(t, "simple", get_environment(), context(), opts) local t2, pr = simplify(t, "simple", opts)
print("i: " .. i .. " ====>") print("i: " .. i .. " ====>")
print(t2) print(t2)
if t == t2 then if t == t2 then

View file

@ -10,7 +10,7 @@ variables a b c : Nat
(* (*
local opts = options({"simplifier", "contextual"}, false) local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean([[a = 1 ∧ (¬ b = 0 c ≠ 0 b + c > a)]]) local t = parse_lean([[a = 1 ∧ (¬ b = 0 c ≠ 0 b + c > a)]])
local s, pr = simplify(t, "simple", get_environment(), context(), opts) local s, pr = simplify(t, "simple", opts)
print(s) print(s)
print(pr) print(pr)
print(get_environment():type_check(pr)) print(get_environment():type_check(pr))

View file

@ -4,7 +4,7 @@ variables a b : Nat
(* (*
local opts = options({"simplifier", "contextual"}, false) local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean('λ x, a = a → x = a') local t = parse_lean('λ x, a = a → x = a')
local t2, pr = simplify(t, "simple", get_environment(), context(), opts) local t2, pr = simplify(t, "simple", opts)
print(t2) print(t2)
print(pr) print(pr)
get_environment():type_check(pr) get_environment():type_check(pr)
@ -13,7 +13,7 @@ get_environment():type_check(pr)
(* (*
local opts = options({"simplifier", "contextual"}, false) local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean('λ x, x = a → x = x') local t = parse_lean('λ x, x = a → x = x')
local t2, pr = simplify(t, "simple", get_environment(), context(), opts) local t2, pr = simplify(t, "simple", opts)
print(t2) print(t2)
print(pr) print(pr)
get_environment():type_check(pr) get_environment():type_check(pr)
@ -22,7 +22,7 @@ get_environment():type_check(pr)
(* (*
local opts = options({"simplifier", "contextual"}, false) local opts = options({"simplifier", "contextual"}, false)
local t = parse_lean('λ x, x = a + 0 → a = a') local t = parse_lean('λ x, x = a + 0 → a = a')
local t2, pr = simplify(t, "simple", get_environment(), context(), opts) local t2, pr = simplify(t, "simple", opts)
print(t2) print(t2)
print(pr) print(pr)
*) *)

View file

@ -15,7 +15,7 @@ show(simplify(t4))
(* (*
local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false) local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false)
local t1 = parse_lean("double (double 2) + 1 ≥ 3") local t1 = parse_lean("double (double 2) + 1 ≥ 3")
show(simplify(t1, 'default', get_environment(), context(), opt)) show(simplify(t1, 'default', opt))
*) *)
set_opaque Nat::ge false set_opaque Nat::ge false
@ -26,7 +26,7 @@ add_rewrite Nat::distributel
(* (*
local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false) local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false)
local t1 = parse_lean("2 * double (double 2) + 1 ≥ 3") local t1 = parse_lean("2 * double (double 2) + 1 ≥ 3")
show(simplify(t1, 'default', get_environment(), context(), opt)) show(simplify(t1, 'default', opt))
*) *)
variables a b c d : Nat variables a b c d : Nat

29
tests/lean/simp31.lean Normal file
View file

@ -0,0 +1,29 @@
rewrite_set simple
add_rewrite Nat::add_comm Nat::add_left_comm Nat::add_assoc Nat::add_zeror : simple
variables a b c : Nat
(*
function indent(s)
for i = 1, s:depth()-1 do
io.write(" ")
end
end
local m = simplifier_monitor(function(s, e)
print("Visit, depth: " .. s:depth() .. ", " .. tostring(e))
end,
function(s, e, new_e, pr)
print("Step: " .. tostring(e) .. " ===> " .. tostring(new_e))
end,
function(s, e, new_e, ceq, ceq_id)
print("Rewrite using: " .. tostring(ceq_id))
print(" " .. tostring(e) .. " ===> " .. tostring(new_e))
end
)
local s = simplifier("simple", options(), m)
local t = parse_lean('a + (b + 0) + a')
print(t)
print("=====>")
local t2, pr = s(t)
print(t2)
print(pr)
get_environment():type_check(pr)
*)

View file

@ -0,0 +1,42 @@
Set: pp::colors
Set: pp::unicode
Assumed: a
Assumed: b
Assumed: c
a + (b + 0) + a
=====>
Visit, depth: 1, a + (b + 0) + a
Visit, depth: 2, Nat::add
Visit, depth: 2, a + (b + 0)
Visit, depth: 3, Nat::add
Visit, depth: 3, a
Step: a ===> a
Visit, depth: 3, b + 0
Visit, depth: 4, Nat::add
Visit, depth: 4, b
Step: b ===> b
Visit, depth: 4, 0
Rewrite using: Nat::add_zeror
b + 0 ===> b
Step: b + 0 ===> b
Visit, depth: 3, a + b
Visit, depth: 4, Nat::add
Step: a + b ===> a + b
Step: a + (b + 0) ===> a + b
Rewrite using: Nat::add_assoc
a + b + a ===> a + (b + a)
Visit, depth: 2, a + (b + a)
Visit, depth: 3, Nat::add
Visit, depth: 3, b + a
Visit, depth: 4, Nat::add
Rewrite using: Nat::add_comm
b + a ===> a + b
Step: b + a ===> a + b
Visit, depth: 3, a + (a + b)
Visit, depth: 4, Nat::add
Step: a + (a + b) ===> a + (a + b)
Step: a + (b + a) ===> a + (a + b)
Step: a + (b + 0) + a ===> a + (a + b)
a + (a + b)
trans (trans (congr1 a (congr2 Nat::add (congr2 (Nat::add a) (Nat::add_zeror b)))) (Nat::add_assoc a b a))
(congr2 (Nat::add a) (Nat::add_comm b a))