feat(library/simplifier): memoize intermediate results

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-20 17:03:44 -08:00
parent 97ead50a3e
commit cd19d4da01
3 changed files with 89 additions and 37 deletions

View file

@ -7,12 +7,14 @@ Author: Leonardo de Moura
#include <utility>
#include <vector>
#include "util/flet.h"
#include "util/freset.h"
#include "util/interrupt.h"
#include "kernel/type_checker.h"
#include "kernel/free_vars.h"
#include "kernel/instantiate.h"
#include "kernel/normalizer.h"
#include "kernel/kernel.h"
#include "kernel/max_sharing.h"
#include "library/heq_decls.h"
#include "library/kernel_bindings.h"
#include "library/expr_pair.h"
@ -52,6 +54,10 @@ Author: Leonardo de Moura
#define LEAN_SIMPLIFIER_CONDITIONAL true
#endif
#ifndef LEAN_SIMPLIFIER_MEMOIZE
#define LEAN_SIMPLIFIER_MEMOIZE true
#endif
#ifndef LEAN_SIMPLIFIER_MAX_STEPS
#define LEAN_SIMPLIFIER_MAX_STEPS std::numeric_limits<unsigned>::max()
#endif
@ -65,6 +71,7 @@ static name g_simplifier_eta {"simplifier", "eta"};
static name g_simplifier_eval {"simplifier", "eval"};
static name g_simplifier_unfold {"simplifier", "unfold"};
static name g_simplifier_conditional {"simplifier", "conditional"};
static name g_simplifier_memoize {"simplifier", "memoize"};
static name g_simplifier_max_steps {"simplifier", "max_steps"};
RegisterBoolOption(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS, "(simplifier) generate proofs");
@ -75,6 +82,7 @@ RegisterBoolOption(g_simplifier_eta, LEAN_SIMPLIFIER_ETA, "(simplifier) use eta-
RegisterBoolOption(g_simplifier_eval, LEAN_SIMPLIFIER_EVAL, "(simplifier) apply reductions based on computation");
RegisterBoolOption(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD, "(simplifier) unfolds non-opaque definitions");
RegisterBoolOption(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL, "(simplifier) conditional rewriting");
RegisterBoolOption(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE, "(simplifier) memoize/cache intermediate results");
RegisterUnsignedOption(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS, "(simplifier) maximum number of steps");
bool get_simplifier_proofs(options const & opts) {
@ -101,39 +109,14 @@ bool get_simplifier_unfold(options const & opts) {
bool get_simplifier_conditional(options const & opts) {
return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL);
}
bool get_simplifier_memoize(options const & opts) {
return opts.get_bool(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE);
}
unsigned get_simplifier_max_steps(options const & opts) {
return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS);
}
class simplifier_fn {
typedef std::vector<rewrite_rule_set> rule_sets;
ro_environment m_env;
type_checker m_tc;
bool m_has_heq;
context m_ctx;
rule_sets m_rule_sets;
// Configuration
bool m_proofs_enabled;
bool m_contextual;
bool m_single_pass;
bool m_beta;
bool m_eta;
bool m_eval;
bool m_unfold;
bool m_conditional;
unsigned m_max_steps;
struct match_fn {
simplifier_fn & m_simp;
match_fn(simplifier_fn & s):m_simp(s) {}
bool operator()(rewrite_rule const & rule) const {
return m_simp.match(rule);
}
};
match_fn m_match_fn;
struct result {
expr m_out; // the result of a simplification step
optional<expr> m_proof; // a proof that the result is equal to the input (when m_proofs_enabled)
@ -146,9 +129,42 @@ class simplifier_fn {
m_out(out), m_proof(pr), m_heq_proof(heq_proof) {}
};
typedef std::vector<rewrite_rule_set> rule_sets;
typedef expr_map<result> cache;
ro_environment m_env;
type_checker m_tc;
bool m_has_heq;
context m_ctx;
rule_sets m_rule_sets;
cache m_cache;
max_sharing_fn m_max_sharing;
// Configuration
bool m_proofs_enabled;
bool m_contextual;
bool m_single_pass;
bool m_beta;
bool m_eta;
bool m_eval;
bool m_unfold;
bool m_conditional;
bool m_memoize;
unsigned m_max_steps;
struct match_fn {
simplifier_fn & m_simp;
match_fn(simplifier_fn & s):m_simp(s) {}
bool operator()(rewrite_rule const & rule) const {
return m_simp.match(rule);
}
};
match_fn m_match_fn;
struct set_context {
flet<context> m_set;
set_context(simplifier_fn & s, context const & new_ctx):m_set(s.m_ctx, new_ctx) {}
freset<cache> m_reset_cache;
set_context(simplifier_fn & s, context const & new_ctx):m_set(s.m_ctx, new_ctx), m_reset_cache(s.m_cache) {}
};
/**
@ -676,18 +692,35 @@ class simplifier_fn {
}
}
result simplify(expr const & e) {
result save(expr const & e, result const & r) {
if (m_memoize) {
result new_r(m_max_sharing(r.m_out), r.m_proof, r.m_heq_proof);
m_cache.insert(mk_pair(e, new_r));
return new_r;
} else {
return r;
}
}
result simplify(expr e) {
check_system("simplifier");
if (m_memoize) {
e = m_max_sharing(e);
auto it = m_cache.find(e);
if (it != m_cache.end()) {
return it->second;
}
}
switch (e.kind()) {
case expr_kind::Var: return simplify_var(e);
case expr_kind::Constant: return simplify_constant(e);
case expr_kind::Var: return save(e, simplify_var(e));
case expr_kind::Constant: return save(e, simplify_constant(e));
case expr_kind::Type:
case expr_kind::MetaVar:
case expr_kind::Value: return result(e);
case expr_kind::App: return simplify_app(e);
case expr_kind::Lambda: return simplify_lambda(e);
case expr_kind::Pi: return simplify_pi(e);
case expr_kind::Let: return simplify(instantiate(let_body(e), let_value(e)));
case expr_kind::Value: return save(e, result(e));
case expr_kind::App: return save(e, simplify_app(e));
case expr_kind::Lambda: return save(e, simplify_lambda(e));
case expr_kind::Pi: return save(e, simplify_pi(e));
case expr_kind::Let: return save(e, simplify(instantiate(let_body(e), let_value(e))));
}
lean_unreachable();
}
@ -701,6 +734,7 @@ class simplifier_fn {
m_eval = get_simplifier_eval(o);
m_unfold = get_simplifier_unfold(o);
m_conditional = get_simplifier_conditional(o);
m_memoize = get_simplifier_memoize(o);
m_max_steps = get_simplifier_max_steps(o);
}

9
tests/lean/simp9.lean Normal file
View file

@ -0,0 +1,9 @@
variables a b c d e f : Nat
rewrite_set simple
add_rewrite Nat::mul_assoc Nat::mul_comm Nat::mul_left_comm Nat::add_assoc Nat::add_comm Nat::add_left_comm
Nat::distributer Nat::distributel : simple
(*
local t = parse_lean("(a + b) * (c + d) * (e + f) * (a + b) * (c + d) * (e + f)")
local t2, pr = simplify(t, "simple")
print(t)
*)

View file

@ -0,0 +1,9 @@
Set: pp::colors
Set: pp::unicode
Assumed: a
Assumed: b
Assumed: c
Assumed: d
Assumed: e
Assumed: f
(a + b) * (c + d) * (e + f) * (a + b) * (c + d) * (e + f)