feat(library/simplifier): expose simplier and simplifier_monitor objects in the Lua API
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
c088825ef0
commit
8bccfb947a
9 changed files with 441 additions and 76 deletions
|
@ -9,6 +9,8 @@ Author: Leonardo de Moura
|
|||
#include "util/flet.h"
|
||||
#include "util/freset.h"
|
||||
#include "util/interrupt.h"
|
||||
#include "util/luaref.h"
|
||||
#include "util/script_state.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/free_vars.h"
|
||||
#include "kernel/instantiate.h"
|
||||
|
@ -106,7 +108,9 @@ static name g_H("H");
|
|||
static name g_x("x");
|
||||
static name g_unique = name::mk_internal_unique_name();
|
||||
|
||||
class simplifier_fn {
|
||||
class simplifier_cell::imp {
|
||||
friend class simplifier_cell;
|
||||
friend class simplifier;
|
||||
struct result {
|
||||
expr m_expr; // the result of a simplification step
|
||||
optional<expr> m_proof; // a proof that the result is equal to the input (when m_proofs_enabled)
|
||||
|
@ -126,7 +130,9 @@ class simplifier_fn {
|
|||
typedef expr_map<result> cache;
|
||||
typedef std::vector<congr_theorem_info const *> congr_thms;
|
||||
typedef cache const_map;
|
||||
std::weak_ptr<simplifier_cell> m_this;
|
||||
ro_environment m_env;
|
||||
options m_options;
|
||||
type_checker m_tc;
|
||||
bool m_has_heq;
|
||||
bool m_has_cast;
|
||||
|
@ -154,7 +160,7 @@ class simplifier_fn {
|
|||
unsigned m_max_steps;
|
||||
|
||||
struct updt_rule_set {
|
||||
simplifier_fn & m_fn;
|
||||
imp & m_fn;
|
||||
rewrite_rule_set m_old;
|
||||
freset<cache> m_reset_cache; // must reset the cache whenever we update the rule set.
|
||||
/**
|
||||
|
@ -162,7 +168,7 @@ class simplifier_fn {
|
|||
|
||||
\pre const_type(H)
|
||||
*/
|
||||
updt_rule_set(simplifier_fn & fn, expr const & H):
|
||||
updt_rule_set(imp & fn, expr const & H):
|
||||
m_fn(fn), m_old(m_fn.m_rule_sets[0]), m_reset_cache(m_fn.m_cache) {
|
||||
lean_assert(const_type(H));
|
||||
m_fn.m_rule_sets[0].insert(g_local, *const_type(H), H);
|
||||
|
@ -174,9 +180,9 @@ class simplifier_fn {
|
|||
};
|
||||
|
||||
struct updt_const_map {
|
||||
simplifier_fn & m_fn;
|
||||
imp & m_fn;
|
||||
expr const & m_old_x;
|
||||
updt_const_map(simplifier_fn & fn, expr const & old_x, expr const & new_x, expr const & H):
|
||||
updt_const_map(imp & fn, expr const & old_x, expr const & new_x, expr const & H):
|
||||
m_fn(fn), m_old_x(old_x) {
|
||||
m_fn.m_const_map[old_x] = result(new_x, H, true);
|
||||
}
|
||||
|
@ -304,7 +310,7 @@ class simplifier_fn {
|
|||
if (is_TypeU(A)) {
|
||||
if (!is_definitionally_equal(f, new_f)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported);
|
||||
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported);
|
||||
return none_expr(); // can't handle
|
||||
}
|
||||
// The congruence axiom cannot be used in this case.
|
||||
|
@ -324,7 +330,7 @@ class simplifier_fn {
|
|||
expr a_type = infer_type(a);
|
||||
if (!is_convertible(a_type, A)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
return none_expr(); // can't handle
|
||||
}
|
||||
expr a_prime = new_a.m_expr;
|
||||
|
@ -347,7 +353,7 @@ class simplifier_fn {
|
|||
expr new_a_type = infer_type(new_a.m_expr);
|
||||
if (!is_convertible(new_a_type, new_A)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
return none_expr(); // failed
|
||||
}
|
||||
expr Heq_a = get_proof(new_a);
|
||||
|
@ -359,7 +365,7 @@ class simplifier_fn {
|
|||
is_heq_proof = false;
|
||||
} else {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_app_eh(m_depth, e, i, simplifier_monitor::failure_kind::Unsupported);
|
||||
m_monitor->failed_app_eh(ro_simplifier(m_this), e, i, simplifier_monitor::failure_kind::Unsupported);
|
||||
return none_expr(); // we don't know how to handle this case
|
||||
}
|
||||
}
|
||||
|
@ -867,7 +873,7 @@ class simplifier_fn {
|
|||
}
|
||||
}
|
||||
if (m_monitor)
|
||||
m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq());
|
||||
m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id());
|
||||
return true;
|
||||
} else {
|
||||
// Conditional rewriting: we try to fill the missing
|
||||
|
@ -906,19 +912,22 @@ class simplifier_fn {
|
|||
// but proof generation is not enabled.
|
||||
// So, we should fail
|
||||
if (m_monitor)
|
||||
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::Unsupported);
|
||||
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
|
||||
i, simplifier_monitor::failure_kind::Unsupported);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// failed to prove proposition
|
||||
if (m_monitor)
|
||||
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::AssumptionNotProved);
|
||||
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
|
||||
i, simplifier_monitor::failure_kind::AssumptionNotProved);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// failed, the argument is not a proposition
|
||||
if (m_monitor)
|
||||
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), i, simplifier_monitor::failure_kind::MissingArgument);
|
||||
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
|
||||
i, simplifier_monitor::failure_kind::MissingArgument);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -927,11 +936,12 @@ class simplifier_fn {
|
|||
new_rhs = arg(ceq, num_args(ceq) - 1);
|
||||
if (rule.is_permutation() && !is_lt(new_rhs, target, false)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_rewrite_eh(m_depth, target, rule.get_ceq(), 0, simplifier_monitor::failure_kind::PermutationGe);
|
||||
m_monitor->failed_rewrite_eh(ro_simplifier(m_this), target, rule.get_ceq(), rule.get_id(),
|
||||
0, simplifier_monitor::failure_kind::PermutationGe);
|
||||
return false;
|
||||
}
|
||||
if (m_monitor)
|
||||
m_monitor->rewrite_eh(m_depth, target, new_rhs, rule.get_ceq());
|
||||
m_monitor->rewrite_eh(ro_simplifier(m_this), target, new_rhs, rule.get_ceq(), rule.get_id());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -1099,7 +1109,7 @@ class simplifier_fn {
|
|||
if (occurs(x_old, new_bi)) {
|
||||
// failed, simplifier didn't manage to replace x_old with x_new
|
||||
if (m_monitor)
|
||||
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody);
|
||||
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody);
|
||||
return rewrite(e, result(e));
|
||||
}
|
||||
expr new_e = update_lambda(e, new_d, abstract(new_bi, x_new));
|
||||
|
@ -1202,7 +1212,7 @@ class simplifier_fn {
|
|||
expr e_type = infer_type(e);
|
||||
if (is_TypeU(e_type) || !ensure_homogeneous(A, res_A)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
return result(e); // failed, we can't use subst theorem
|
||||
} else {
|
||||
expr H = get_proof(res_A);
|
||||
|
@ -1237,7 +1247,7 @@ class simplifier_fn {
|
|||
expr e_type = infer_type(e);
|
||||
if (is_TypeU(e_type) || !ensure_homogeneous(B, res_B)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
return result(e); // failed, we can't use subst theorem
|
||||
} else {
|
||||
expr H = get_proof(res_B);
|
||||
|
@ -1318,13 +1328,13 @@ class simplifier_fn {
|
|||
if (occurs(x_old, new_bi)) {
|
||||
// failed, simplifier didn't manage to replace x_old with x_new
|
||||
if (m_monitor)
|
||||
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::AbstractionBody);
|
||||
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::AbstractionBody);
|
||||
return rewrite(e, result(e));
|
||||
}
|
||||
expr new_e = update_pi(e, new_d, abstract(new_bi, x_new));
|
||||
if (!m_proofs_enabled || is_definitionally_equal(e, new_e)) {
|
||||
if (m_monitor)
|
||||
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::TypeMismatch);
|
||||
return rewrite(e, result(new_e));
|
||||
}
|
||||
ensure_homogeneous(d, res_d);
|
||||
|
@ -1364,7 +1374,7 @@ class simplifier_fn {
|
|||
} else {
|
||||
// We currently do simplify (forall x : A, B x) when it is not a proposition.
|
||||
if (m_monitor)
|
||||
m_monitor->failed_abstraction_eh(m_depth, e, simplifier_monitor::failure_kind::Unsupported);
|
||||
m_monitor->failed_abstraction_eh(ro_simplifier(m_this), e, simplifier_monitor::failure_kind::Unsupported);
|
||||
return result(e);
|
||||
}
|
||||
}
|
||||
|
@ -1374,7 +1384,7 @@ class simplifier_fn {
|
|||
result new_r = r.update_expr(m_max_sharing(r.m_expr));
|
||||
m_cache.insert(mk_pair(e, new_r));
|
||||
if (m_monitor)
|
||||
m_monitor->step_eh(m_depth, e, new_r.m_expr, new_r.m_proof);
|
||||
m_monitor->step_eh(ro_simplifier(m_this), e, new_r.m_expr, new_r.m_proof);
|
||||
return new_r;
|
||||
} else {
|
||||
return r;
|
||||
|
@ -1395,7 +1405,7 @@ class simplifier_fn {
|
|||
}
|
||||
}
|
||||
if (m_monitor)
|
||||
m_monitor->pre_eh(m_depth, e);
|
||||
m_monitor->pre_eh(ro_simplifier(m_this), e);
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: return result(e);
|
||||
case expr_kind::Constant: return save(e, simplify_constant(e));
|
||||
|
@ -1445,9 +1455,9 @@ class simplifier_fn {
|
|||
}
|
||||
|
||||
public:
|
||||
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
|
||||
imp(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
|
||||
std::shared_ptr<simplifier_monitor> const & monitor):
|
||||
m_env(env), m_tc(env), m_monitor(monitor) {
|
||||
m_env(env), m_options(o), m_tc(env), m_monitor(monitor) {
|
||||
m_has_heq = m_env->imported("heq");
|
||||
m_has_cast = m_env->imported("cast");
|
||||
set_options(o);
|
||||
|
@ -1469,10 +1479,38 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
simplifier_cell::simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
|
||||
std::shared_ptr<simplifier_monitor> const & monitor):
|
||||
m_ptr(new imp(env, o, num_rs, rs, monitor)) {
|
||||
}
|
||||
|
||||
expr_pair simplifier_cell::operator()(expr const & e, context const & ctx) { return m_ptr->operator()(e, ctx); }
|
||||
void simplifier_cell::clear() { return m_ptr->m_cache.clear(); }
|
||||
unsigned simplifier_cell::get_depth() const { return m_ptr->m_depth; }
|
||||
context const & simplifier_cell::get_context() const { return m_ptr->m_ctx; }
|
||||
ro_environment const & simplifier_cell::get_environment() const { return m_ptr->m_env; }
|
||||
options const & simplifier_cell::get_options() const { return m_ptr->m_options; }
|
||||
|
||||
simplifier::simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
|
||||
std::shared_ptr<simplifier_monitor> const & monitor):
|
||||
m_ptr(std::make_shared<simplifier_cell>(env, o, num_rs, rs, monitor)) {
|
||||
m_ptr->m_ptr->m_this = m_ptr;
|
||||
}
|
||||
|
||||
ro_simplifier::ro_simplifier(simplifier const & env):
|
||||
m_ptr(env.m_ptr) {
|
||||
}
|
||||
|
||||
ro_simplifier::ro_simplifier(weak_ref const & r) {
|
||||
if (r.expired())
|
||||
throw exception("weak reference to simplifier object has expired (i.e., the simplifier has been deleted)");
|
||||
m_ptr = r.lock();
|
||||
}
|
||||
|
||||
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
|
||||
unsigned num_rs, rewrite_rule_set const * rs,
|
||||
std::shared_ptr<simplifier_monitor> const & monitor) {
|
||||
return simplifier_fn(env, opts, num_rs, rs, monitor)(e, ctx);
|
||||
return simplifier(env, opts, num_rs, rs, monitor)(e, ctx);
|
||||
}
|
||||
|
||||
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts,
|
||||
|
@ -1484,32 +1522,222 @@ expr_pair simplify(expr const & e, ro_environment const & env, context const & c
|
|||
return simplify(e, env, ctx, opts, num_ns, rules.data(), monitor);
|
||||
}
|
||||
|
||||
static int simplify_core(lua_State * L, ro_shared_environment const & env) {
|
||||
int nargs = lua_gettop(L);
|
||||
expr const & e = to_expr(L, 1);
|
||||
buffer<rewrite_rule_set> rules;
|
||||
if (nargs == 1) {
|
||||
rules.push_back(get_rewrite_rule_set(env));
|
||||
DECL_UDATA(simplifier)
|
||||
DECL_UDATA(ro_simplifier)
|
||||
|
||||
/**
|
||||
\brief Simplifier monitor implemented using Lua functions
|
||||
*/
|
||||
class lua_simplifier_monitor : public simplifier_monitor {
|
||||
optional<luaref> m_pre_eh;
|
||||
optional<luaref> m_step_eh;
|
||||
optional<luaref> m_rewrite_eh;
|
||||
optional<luaref> m_failed_app_eh;
|
||||
optional<luaref> m_failed_rewrite_eh;
|
||||
optional<luaref> m_failed_abstraction_eh;
|
||||
public:
|
||||
lua_simplifier_monitor(optional<luaref> const & pre_eh, optional<luaref> const & step_eh, optional<luaref> const & rewrite_eh,
|
||||
optional<luaref> const & failed_app_eh, optional<luaref> const & failed_rewrite_eh, optional<luaref> const & failed_abstraction_eh):
|
||||
m_pre_eh(pre_eh), m_step_eh(step_eh), m_rewrite_eh(rewrite_eh),
|
||||
m_failed_app_eh(failed_app_eh), m_failed_rewrite_eh(failed_rewrite_eh), m_failed_abstraction_eh(failed_abstraction_eh) {
|
||||
}
|
||||
virtual ~lua_simplifier_monitor() {}
|
||||
|
||||
virtual void pre_eh(ro_simplifier const & s, expr const & e) {
|
||||
if (m_pre_eh) {
|
||||
lua_State * L = m_pre_eh->get_state();
|
||||
m_pre_eh->push();
|
||||
push_ro_simplifier(L, s);
|
||||
push_expr(L, e);
|
||||
pcall(L, 2, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional<expr> const & pr) {
|
||||
if (m_step_eh) {
|
||||
lua_State * L = m_step_eh->get_state();
|
||||
m_step_eh->push();
|
||||
push_ro_simplifier(L, s);
|
||||
push_expr(L, e);
|
||||
push_expr(L, new_e);
|
||||
push_optional_expr(L, pr);
|
||||
pcall(L, 4, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) {
|
||||
if (m_rewrite_eh) {
|
||||
lua_State * L = m_rewrite_eh->get_state();
|
||||
m_rewrite_eh->push();
|
||||
push_ro_simplifier(L, s);
|
||||
push_expr(L, e);
|
||||
push_expr(L, new_e);
|
||||
push_expr(L, ceq);
|
||||
push_name(L, ceq_id);
|
||||
pcall(L, 5, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) {
|
||||
if (m_failed_app_eh) {
|
||||
lua_State * L = m_failed_app_eh->get_state();
|
||||
m_failed_app_eh->push();
|
||||
push_ro_simplifier(L, s);
|
||||
push_expr(L, e);
|
||||
lua_pushinteger(L, i);
|
||||
lua_pushinteger(L, static_cast<unsigned>(k));
|
||||
pcall(L, 4, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) {
|
||||
if (m_failed_rewrite_eh) {
|
||||
lua_State * L = m_failed_rewrite_eh->get_state();
|
||||
m_failed_rewrite_eh->push();
|
||||
push_ro_simplifier(L, s);
|
||||
push_expr(L, e);
|
||||
push_expr(L, ceq);
|
||||
push_name(L, ceq_id);
|
||||
lua_pushinteger(L, i);
|
||||
lua_pushinteger(L, static_cast<unsigned>(k));
|
||||
pcall(L, 6, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) {
|
||||
if (m_failed_abstraction_eh) {
|
||||
lua_State * L = m_failed_abstraction_eh->get_state();
|
||||
m_failed_abstraction_eh->push();
|
||||
push_ro_simplifier(L, s);
|
||||
push_expr(L, e);
|
||||
lua_pushinteger(L, static_cast<unsigned>(k));
|
||||
pcall(L, 3, 0, 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<simplifier_monitor> simplifier_monitor_ptr;
|
||||
|
||||
DECL_UDATA(simplifier_monitor_ptr)
|
||||
|
||||
static const struct luaL_Reg simplifier_monitor_ptr_m[] = {
|
||||
{"__gc", simplifier_monitor_ptr_gc},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
static optional<luaref> get_opt_callback(lua_State * L, int i) {
|
||||
if (i > lua_gettop(L) || lua_isnil(L, i)) {
|
||||
return optional<luaref>();
|
||||
} else {
|
||||
if (lua_isstring(L, 2)) {
|
||||
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, 2)));
|
||||
luaL_checktype(L, i, LUA_TFUNCTION); // user-fun
|
||||
return optional<luaref>(luaref(L, i));
|
||||
}
|
||||
}
|
||||
|
||||
static int mk_simplifier_monitor(lua_State * L) {
|
||||
simplifier_monitor_ptr r = std::make_shared<lua_simplifier_monitor>(get_opt_callback(L, 1),
|
||||
get_opt_callback(L, 2),
|
||||
get_opt_callback(L, 3),
|
||||
get_opt_callback(L, 4),
|
||||
get_opt_callback(L, 5),
|
||||
get_opt_callback(L, 6));
|
||||
return push_simplifier_monitor_ptr(L, r);
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Fill the the rewrite_rule_set \c rs using the object at position \c i in the Lua stack.
|
||||
*/
|
||||
static void get_rewrite_rule_set(lua_State * L, int i, ro_environment const & env, buffer<rewrite_rule_set> & rs) {
|
||||
if (i > lua_gettop(L)) {
|
||||
rs.push_back(get_rewrite_rule_set(env));
|
||||
} else if (lua_isstring(L, i)) {
|
||||
rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, i)));
|
||||
} else {
|
||||
luaL_checktype(L, 2, LUA_TTABLE);
|
||||
luaL_checktype(L, i, LUA_TTABLE);
|
||||
name r;
|
||||
int n = objlen(L, 2);
|
||||
for (int i = 1; i <= n; i++) {
|
||||
lua_rawgeti(L, 2, i);
|
||||
rules.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
|
||||
int n = objlen(L, i);
|
||||
for (int j = 1; j <= n; j++) {
|
||||
lua_rawgeti(L, i, j);
|
||||
rs.push_back(get_rewrite_rule_set(env, to_name_ext(L, -1)));
|
||||
lua_pop(L, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static int mk_simplifier(lua_State * L, ro_environment const & env) {
|
||||
int nargs = lua_gettop(L);
|
||||
buffer<rewrite_rule_set> rules;
|
||||
get_rewrite_rule_set(L, 1, env, rules);
|
||||
options opts;
|
||||
if (nargs >= 2)
|
||||
opts = to_options(L, 2);
|
||||
simplifier_monitor_ptr monitor;
|
||||
if (nargs >= 3 && !lua_isnil(L, 3))
|
||||
monitor = to_simplifier_monitor_ptr(L, 3);
|
||||
return push_simplifier(L, simplifier(env, opts, rules.size(), rules.data(), monitor));
|
||||
}
|
||||
|
||||
static int mk_simplifier(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs <= 3)
|
||||
return mk_simplifier(L, ro_shared_environment(L));
|
||||
else
|
||||
return mk_simplifier(L, ro_shared_environment(L, 4));
|
||||
}
|
||||
|
||||
static int simplifier_apply(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
expr_pair r;
|
||||
if (nargs == 2)
|
||||
r = to_simplifier(L, 1)(to_expr(L, 2), context());
|
||||
else
|
||||
r = to_simplifier(L, 1)(to_expr(L, 2), to_context(L, 3));
|
||||
push_expr(L, r.first);
|
||||
push_expr(L, r.second);
|
||||
return 2;
|
||||
}
|
||||
|
||||
static int simplifier_clear(lua_State * L) { to_simplifier(L, 1)->clear(); return 0; }
|
||||
static int simplifier_depth(lua_State * L) { lua_pushinteger(L, to_simplifier(L, 1)->get_depth()); return 1; }
|
||||
static int simplifier_context(lua_State * L) { return push_context(L, to_simplifier(L, 1)->get_context()); }
|
||||
static int simplifier_environment(lua_State * L) { return push_environment(L, to_simplifier(L, 1)->get_environment()); }
|
||||
static int simplifier_options(lua_State * L) { return push_options(L, to_simplifier(L, 1)->get_options()); }
|
||||
static int ro_simplifier_depth(lua_State * L) { lua_pushinteger(L, to_ro_simplifier(L, 1)->get_depth()); return 1; }
|
||||
static int ro_simplifier_context(lua_State * L) { return push_context(L, to_ro_simplifier(L, 1)->get_context()); }
|
||||
static int ro_simplifier_environment(lua_State * L) { return push_environment(L, to_ro_simplifier(L, 1)->get_environment()); }
|
||||
static int ro_simplifier_options(lua_State * L) { return push_options(L, to_ro_simplifier(L, 1)->get_options()); }
|
||||
|
||||
static const struct luaL_Reg simplifier_m[] = {
|
||||
{"__gc", simplifier_gc},
|
||||
{"__call", safe_function<simplifier_apply>},
|
||||
{"clear", safe_function<simplifier_clear>},
|
||||
{"depth", safe_function<simplifier_depth>},
|
||||
{"get_environment", safe_function<simplifier_environment>},
|
||||
{"get_context", safe_function<simplifier_context>},
|
||||
{"get_options", safe_function<simplifier_options>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
static const struct luaL_Reg ro_simplifier_m[] = {
|
||||
{"__gc", ro_simplifier_gc},
|
||||
{"depth", safe_function<ro_simplifier_depth>},
|
||||
{"get_environment", safe_function<ro_simplifier_environment>},
|
||||
{"get_context", safe_function<ro_simplifier_context>},
|
||||
{"get_options", safe_function<ro_simplifier_options>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
static int simplify_core(lua_State * L, ro_shared_environment const & env) {
|
||||
int nargs = lua_gettop(L);
|
||||
expr const & e = to_expr(L, 1);
|
||||
buffer<rewrite_rule_set> rules;
|
||||
get_rewrite_rule_set(L, 2, env, rules);
|
||||
context ctx;
|
||||
options opts;
|
||||
if (nargs >= 4)
|
||||
ctx = to_context(L, 4);
|
||||
if (nargs >= 3)
|
||||
opts = to_options(L, 3);
|
||||
if (nargs >= 5)
|
||||
opts = to_options(L, 5);
|
||||
ctx = to_context(L, 5);
|
||||
auto r = simplify(e, env, ctx, opts, rules.size(), rules.data());
|
||||
push_expr(L, r.first);
|
||||
push_expr(L, r.second);
|
||||
|
@ -1518,13 +1746,44 @@ static int simplify_core(lua_State * L, ro_shared_environment const & env) {
|
|||
|
||||
static int simplify(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs <= 2)
|
||||
if (nargs <= 4)
|
||||
return simplify_core(L, ro_shared_environment(L));
|
||||
else
|
||||
return simplify_core(L, ro_shared_environment(L, 3));
|
||||
return simplify_core(L, ro_shared_environment(L, 4));
|
||||
}
|
||||
|
||||
void open_simplifier(lua_State * L) {
|
||||
luaL_newmetatable(L, simplifier_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, simplifier_m, 0);
|
||||
SET_GLOBAL_FUN(simplifier_pred, "is_simplifier");
|
||||
|
||||
SET_GLOBAL_FUN(mk_simplifier, "simplifier");
|
||||
|
||||
luaL_newmetatable(L, ro_simplifier_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, ro_simplifier_m, 0);
|
||||
SET_GLOBAL_FUN(ro_simplifier_pred, "is_ro_simplifier");
|
||||
|
||||
luaL_newmetatable(L, simplifier_monitor_ptr_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, simplifier_monitor_ptr_m, 0);
|
||||
SET_GLOBAL_FUN(simplifier_monitor_ptr_pred, "is_simplifier_monitor");
|
||||
|
||||
SET_GLOBAL_FUN(mk_simplifier_monitor, "simplifier_monitor");
|
||||
|
||||
lua_newtable(L);
|
||||
SET_ENUM("Unsupported", simplifier_monitor::failure_kind::Unsupported);
|
||||
SET_ENUM("TypeMismatch", simplifier_monitor::failure_kind::TypeMismatch);
|
||||
SET_ENUM("AssumptionNotProved", simplifier_monitor::failure_kind::AssumptionNotProved);
|
||||
SET_ENUM("MissingArgument", simplifier_monitor::failure_kind::MissingArgument);
|
||||
SET_ENUM("PermutationGe", simplifier_monitor::failure_kind::PermutationGe);
|
||||
SET_ENUM("AbstractionBody", simplifier_monitor::failure_kind::AbstractionBody);
|
||||
lua_setglobal(L, "simplifier_failure");
|
||||
|
||||
SET_GLOBAL_FUN(simplify, "simplify");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,38 +5,79 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include "util/lua.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "library/expr_pair.h"
|
||||
#include "library/simplifier/rewrite_rule_set.h"
|
||||
|
||||
namespace lean {
|
||||
class simplifier_monitor;
|
||||
|
||||
/** \brief Simplifier object cell. */
|
||||
class simplifier_cell {
|
||||
friend class simplifier;
|
||||
struct imp;
|
||||
std::unique_ptr<imp> m_ptr;
|
||||
public:
|
||||
simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
|
||||
std::shared_ptr<simplifier_monitor> const & monitor);
|
||||
|
||||
expr_pair operator()(expr const & e, context const & ctx);
|
||||
void clear();
|
||||
|
||||
unsigned get_depth() const;
|
||||
context const & get_context() const;
|
||||
ro_environment const & get_environment() const;
|
||||
options const & get_options() const;
|
||||
};
|
||||
|
||||
/** \brief Reference to simplifier object */
|
||||
class simplifier {
|
||||
friend class ro_simplifier;
|
||||
std::shared_ptr<simplifier_cell> m_ptr;
|
||||
public:
|
||||
simplifier(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs,
|
||||
std::shared_ptr<simplifier_monitor> const & monitor);
|
||||
simplifier_cell * operator->() const { return m_ptr.get(); }
|
||||
simplifier_cell & operator*() const { return *(m_ptr.get()); }
|
||||
expr_pair operator()(expr const & e, context const & ctx) { return (*m_ptr)(e, ctx); }
|
||||
};
|
||||
|
||||
/** \brief Read only reference to simplifier object */
|
||||
class ro_simplifier {
|
||||
std::shared_ptr<simplifier_cell const> m_ptr;
|
||||
public:
|
||||
typedef std::weak_ptr<simplifier_cell const> weak_ref;
|
||||
ro_simplifier(simplifier const & s);
|
||||
ro_simplifier(weak_ref const & s);
|
||||
explicit operator weak_ref() const { return weak_ref(m_ptr); }
|
||||
weak_ref to_weak_ref() const { return weak_ref(m_ptr); }
|
||||
simplifier_cell const * operator->() const { return m_ptr.get(); }
|
||||
simplifier_cell const & operator*() const { return *(m_ptr.get()); }
|
||||
};
|
||||
|
||||
/**
|
||||
\brief Abstract class that specifies the interface for monitoring
|
||||
the behavior of the simplifier.
|
||||
*/
|
||||
class simplifier_monitor {
|
||||
public:
|
||||
virtual ~simplifier_monitor() {}
|
||||
/**
|
||||
\brief This method is invoked to sign that the simplifier is starting to process the expression \c e.
|
||||
|
||||
\remark \c depth is the recursion depth
|
||||
*/
|
||||
virtual void pre_eh(unsigned depth, expr const & e) = 0;
|
||||
virtual void pre_eh(ro_simplifier const & s, expr const & e) = 0;
|
||||
|
||||
/**
|
||||
\brief This method is invoked to sign that \c e has be rewritten into \c new_e with proof \c pr.
|
||||
The proof is none if proof generation is disabled or if \c e and \c new_e are definitionally equal.
|
||||
|
||||
\remark \c depth is the recursion depth
|
||||
*/
|
||||
virtual void step_eh(unsigned depth, expr const & e, expr const & new_e, optional<expr> const & pr) = 0;
|
||||
virtual void step_eh(ro_simplifier const & s, expr const & e, expr const & new_e, optional<expr> const & pr) = 0;
|
||||
/**
|
||||
\brief This method is invoked to sign that \c e has be rewritten into \c new_e using the conditional equation \c ceq.
|
||||
|
||||
\remark \c depth is the recursion depth
|
||||
*/
|
||||
virtual void rewrite_eh(unsigned depth, expr const & e, expr const & new_e, expr const & ceq) = 0;
|
||||
virtual void rewrite_eh(ro_simplifier const & s, expr const & e, expr const & new_e, expr const & ceq, name const & ceq_id) = 0;
|
||||
|
||||
enum class failure_kind { Unsupported, TypeMismatch, AssumptionNotProved, MissingArgument, PermutationGe, AbstractionBody };
|
||||
|
||||
|
@ -44,10 +85,8 @@ public:
|
|||
\brief This method is invoked when the simplifier fails to rewrite an application \c e.
|
||||
\c i is the argument where the simplifier gave up, and \c k is the reason for failure.
|
||||
Two possible values are: Unsupported or TypeMismatch (may happen when simplifying terms that use dependent types).
|
||||
|
||||
\remark \c depth is the recursion depth
|
||||
*/
|
||||
virtual void failed_app_eh(unsigned depth, expr const & e, unsigned i, failure_kind k) = 0;
|
||||
virtual void failed_app_eh(ro_simplifier const & s, expr const & e, unsigned i, failure_kind k) = 0;
|
||||
|
||||
/**
|
||||
\brief This method is invoked when the simplifier fails to apply a conditional equation \c ceq to \c e.
|
||||
|
@ -55,19 +94,15 @@ public:
|
|||
The possible failure values are: AssumptionNotProved (failed to synthesize a proof for an assumption required by \c ceq),
|
||||
MissingArgument (failed to infer one of the arguments needed by the conditional equation), PermutationGe
|
||||
(the conditional equation is a permutation, and the result is not smaller in the term ordering, \c i is irrelevant in this case).
|
||||
|
||||
\remark \c depth is the recursion depth
|
||||
*/
|
||||
virtual void failed_rewrite_eh(unsigned depth, expr const & e, expr const & ceq, unsigned i, failure_kind k) = 0;
|
||||
virtual void failed_rewrite_eh(ro_simplifier const & s, expr const & e, expr const & ceq, name const & ceq_id, unsigned i, failure_kind k) = 0;
|
||||
|
||||
/**
|
||||
\brief This method is invoked when the simplifier fails to simplify an abstraction (Pi or Lambda).
|
||||
The possible failure values are: Unsupported, TypeMismatch, and AbstractionBody (failed to rewrite the body of the abstraction,
|
||||
this may happen when we are using dependent types).
|
||||
|
||||
\remark \c depth is the recursion depth
|
||||
*/
|
||||
virtual void failed_abstraction_eh(unsigned depth, expr const & e, failure_kind k) = 0;
|
||||
virtual void failed_abstraction_eh(ro_simplifier const & s, expr const & e, failure_kind k) = 0;
|
||||
};
|
||||
|
||||
expr_pair simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts,
|
||||
|
|
|
@ -4,7 +4,7 @@ add_rewrite Nat::add_assoc Nat::add_comm Nat::add_left_comm Nat::distributer Nat
|
|||
(*
|
||||
local opts = options({"simplifier", "max_steps"}, 100)
|
||||
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
|
||||
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local t2, pr = simplify(t, "simple", opts)
|
||||
*)
|
||||
|
||||
print "trying again with more steps"
|
||||
|
@ -12,7 +12,7 @@ print "trying again with more steps"
|
|||
(*
|
||||
local opts = options({"simplifier", "max_steps"}, 100000)
|
||||
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
|
||||
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local t2, pr = simplify(t, "simple", opts)
|
||||
print(t)
|
||||
print("====>")
|
||||
print(t2)
|
||||
|
|
|
@ -6,7 +6,7 @@ local opts = options({"simplifier", "single_pass"}, true)
|
|||
local t = parse_lean("f + (c + f + d) + (e * (a + c) + (d + a))")
|
||||
print(t)
|
||||
for i = 1, 10 do
|
||||
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local t2, pr = simplify(t, "simple", opts)
|
||||
print("i: " .. i .. " ====>")
|
||||
print(t2)
|
||||
if t == t2 then
|
||||
|
|
|
@ -10,7 +10,7 @@ variables a b c : Nat
|
|||
(*
|
||||
local opts = options({"simplifier", "contextual"}, false)
|
||||
local t = parse_lean([[a = 1 ∧ (¬ b = 0 ∨ c ≠ 0 ∨ b + c > a)]])
|
||||
local s, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local s, pr = simplify(t, "simple", opts)
|
||||
print(s)
|
||||
print(pr)
|
||||
print(get_environment():type_check(pr))
|
||||
|
|
|
@ -4,7 +4,7 @@ variables a b : Nat
|
|||
(*
|
||||
local opts = options({"simplifier", "contextual"}, false)
|
||||
local t = parse_lean('λ x, a = a → x = a')
|
||||
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local t2, pr = simplify(t, "simple", opts)
|
||||
print(t2)
|
||||
print(pr)
|
||||
get_environment():type_check(pr)
|
||||
|
@ -13,7 +13,7 @@ get_environment():type_check(pr)
|
|||
(*
|
||||
local opts = options({"simplifier", "contextual"}, false)
|
||||
local t = parse_lean('λ x, x = a → x = x')
|
||||
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local t2, pr = simplify(t, "simple", opts)
|
||||
print(t2)
|
||||
print(pr)
|
||||
get_environment():type_check(pr)
|
||||
|
@ -22,7 +22,7 @@ get_environment():type_check(pr)
|
|||
(*
|
||||
local opts = options({"simplifier", "contextual"}, false)
|
||||
local t = parse_lean('λ x, x = a + 0 → a = a')
|
||||
local t2, pr = simplify(t, "simple", get_environment(), context(), opts)
|
||||
local t2, pr = simplify(t, "simple", opts)
|
||||
print(t2)
|
||||
print(pr)
|
||||
*)
|
||||
|
|
|
@ -15,7 +15,7 @@ show(simplify(t4))
|
|||
(*
|
||||
local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false)
|
||||
local t1 = parse_lean("double (double 2) + 1 ≥ 3")
|
||||
show(simplify(t1, 'default', get_environment(), context(), opt))
|
||||
show(simplify(t1, 'default', opt))
|
||||
*)
|
||||
|
||||
set_opaque Nat::ge false
|
||||
|
@ -26,7 +26,7 @@ add_rewrite Nat::distributel
|
|||
(*
|
||||
local opt = options({"simplifier", "unfold"}, true, {"simplifier", "eval"}, false)
|
||||
local t1 = parse_lean("2 * double (double 2) + 1 ≥ 3")
|
||||
show(simplify(t1, 'default', get_environment(), context(), opt))
|
||||
show(simplify(t1, 'default', opt))
|
||||
*)
|
||||
|
||||
variables a b c d : Nat
|
||||
|
|
29
tests/lean/simp31.lean
Normal file
29
tests/lean/simp31.lean
Normal file
|
@ -0,0 +1,29 @@
|
|||
rewrite_set simple
|
||||
add_rewrite Nat::add_comm Nat::add_left_comm Nat::add_assoc Nat::add_zeror : simple
|
||||
variables a b c : Nat
|
||||
(*
|
||||
function indent(s)
|
||||
for i = 1, s:depth()-1 do
|
||||
io.write(" ")
|
||||
end
|
||||
end
|
||||
local m = simplifier_monitor(function(s, e)
|
||||
print("Visit, depth: " .. s:depth() .. ", " .. tostring(e))
|
||||
end,
|
||||
function(s, e, new_e, pr)
|
||||
print("Step: " .. tostring(e) .. " ===> " .. tostring(new_e))
|
||||
end,
|
||||
function(s, e, new_e, ceq, ceq_id)
|
||||
print("Rewrite using: " .. tostring(ceq_id))
|
||||
print(" " .. tostring(e) .. " ===> " .. tostring(new_e))
|
||||
end
|
||||
)
|
||||
local s = simplifier("simple", options(), m)
|
||||
local t = parse_lean('a + (b + 0) + a')
|
||||
print(t)
|
||||
print("=====>")
|
||||
local t2, pr = s(t)
|
||||
print(t2)
|
||||
print(pr)
|
||||
get_environment():type_check(pr)
|
||||
*)
|
42
tests/lean/simp31.lean.expected.out
Normal file
42
tests/lean/simp31.lean.expected.out
Normal file
|
@ -0,0 +1,42 @@
|
|||
Set: pp::colors
|
||||
Set: pp::unicode
|
||||
Assumed: a
|
||||
Assumed: b
|
||||
Assumed: c
|
||||
a + (b + 0) + a
|
||||
=====>
|
||||
Visit, depth: 1, a + (b + 0) + a
|
||||
Visit, depth: 2, Nat::add
|
||||
Visit, depth: 2, a + (b + 0)
|
||||
Visit, depth: 3, Nat::add
|
||||
Visit, depth: 3, a
|
||||
Step: a ===> a
|
||||
Visit, depth: 3, b + 0
|
||||
Visit, depth: 4, Nat::add
|
||||
Visit, depth: 4, b
|
||||
Step: b ===> b
|
||||
Visit, depth: 4, 0
|
||||
Rewrite using: Nat::add_zeror
|
||||
b + 0 ===> b
|
||||
Step: b + 0 ===> b
|
||||
Visit, depth: 3, a + b
|
||||
Visit, depth: 4, Nat::add
|
||||
Step: a + b ===> a + b
|
||||
Step: a + (b + 0) ===> a + b
|
||||
Rewrite using: Nat::add_assoc
|
||||
a + b + a ===> a + (b + a)
|
||||
Visit, depth: 2, a + (b + a)
|
||||
Visit, depth: 3, Nat::add
|
||||
Visit, depth: 3, b + a
|
||||
Visit, depth: 4, Nat::add
|
||||
Rewrite using: Nat::add_comm
|
||||
b + a ===> a + b
|
||||
Step: b + a ===> a + b
|
||||
Visit, depth: 3, a + (a + b)
|
||||
Visit, depth: 4, Nat::add
|
||||
Step: a + (a + b) ===> a + (a + b)
|
||||
Step: a + (b + a) ===> a + (a + b)
|
||||
Step: a + (b + 0) + a ===> a + (a + b)
|
||||
a + (a + b)
|
||||
trans (trans (congr1 a (congr2 Nat::add (congr2 (Nat::add a) (Nat::add_zeror b)))) (Nat::add_assoc a b a))
|
||||
(congr2 (Nat::add a) (Nat::add_comm b a))
|
Loading…
Reference in a new issue