feat(library/simplifier): memoize intermediate results
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
97ead50a3e
commit
cd19d4da01
3 changed files with 89 additions and 37 deletions
|
@ -7,12 +7,14 @@ Author: Leonardo de Moura
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include "util/flet.h"
|
||||
#include "util/freset.h"
|
||||
#include "util/interrupt.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/free_vars.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/normalizer.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "kernel/max_sharing.h"
|
||||
#include "library/heq_decls.h"
|
||||
#include "library/kernel_bindings.h"
|
||||
#include "library/expr_pair.h"
|
||||
|
@ -52,6 +54,10 @@ Author: Leonardo de Moura
|
|||
#define LEAN_SIMPLIFIER_CONDITIONAL true
|
||||
#endif
|
||||
|
||||
#ifndef LEAN_SIMPLIFIER_MEMOIZE
|
||||
#define LEAN_SIMPLIFIER_MEMOIZE true
|
||||
#endif
|
||||
|
||||
#ifndef LEAN_SIMPLIFIER_MAX_STEPS
|
||||
#define LEAN_SIMPLIFIER_MAX_STEPS std::numeric_limits<unsigned>::max()
|
||||
#endif
|
||||
|
@ -65,6 +71,7 @@ static name g_simplifier_eta {"simplifier", "eta"};
|
|||
static name g_simplifier_eval {"simplifier", "eval"};
|
||||
static name g_simplifier_unfold {"simplifier", "unfold"};
|
||||
static name g_simplifier_conditional {"simplifier", "conditional"};
|
||||
static name g_simplifier_memoize {"simplifier", "memoize"};
|
||||
static name g_simplifier_max_steps {"simplifier", "max_steps"};
|
||||
|
||||
RegisterBoolOption(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS, "(simplifier) generate proofs");
|
||||
|
@ -75,6 +82,7 @@ RegisterBoolOption(g_simplifier_eta, LEAN_SIMPLIFIER_ETA, "(simplifier) use eta-
|
|||
RegisterBoolOption(g_simplifier_eval, LEAN_SIMPLIFIER_EVAL, "(simplifier) apply reductions based on computation");
|
||||
RegisterBoolOption(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD, "(simplifier) unfolds non-opaque definitions");
|
||||
RegisterBoolOption(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL, "(simplifier) conditional rewriting");
|
||||
RegisterBoolOption(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE, "(simplifier) memoize/cache intermediate results");
|
||||
RegisterUnsignedOption(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS, "(simplifier) maximum number of steps");
|
||||
|
||||
bool get_simplifier_proofs(options const & opts) {
|
||||
|
@ -101,39 +109,14 @@ bool get_simplifier_unfold(options const & opts) {
|
|||
bool get_simplifier_conditional(options const & opts) {
|
||||
return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL);
|
||||
}
|
||||
bool get_simplifier_memoize(options const & opts) {
|
||||
return opts.get_bool(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE);
|
||||
}
|
||||
unsigned get_simplifier_max_steps(options const & opts) {
|
||||
return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS);
|
||||
}
|
||||
|
||||
class simplifier_fn {
|
||||
typedef std::vector<rewrite_rule_set> rule_sets;
|
||||
ro_environment m_env;
|
||||
type_checker m_tc;
|
||||
bool m_has_heq;
|
||||
context m_ctx;
|
||||
rule_sets m_rule_sets;
|
||||
|
||||
// Configuration
|
||||
bool m_proofs_enabled;
|
||||
bool m_contextual;
|
||||
bool m_single_pass;
|
||||
bool m_beta;
|
||||
bool m_eta;
|
||||
bool m_eval;
|
||||
bool m_unfold;
|
||||
bool m_conditional;
|
||||
unsigned m_max_steps;
|
||||
|
||||
struct match_fn {
|
||||
simplifier_fn & m_simp;
|
||||
match_fn(simplifier_fn & s):m_simp(s) {}
|
||||
bool operator()(rewrite_rule const & rule) const {
|
||||
return m_simp.match(rule);
|
||||
}
|
||||
};
|
||||
|
||||
match_fn m_match_fn;
|
||||
|
||||
struct result {
|
||||
expr m_out; // the result of a simplification step
|
||||
optional<expr> m_proof; // a proof that the result is equal to the input (when m_proofs_enabled)
|
||||
|
@ -146,9 +129,42 @@ class simplifier_fn {
|
|||
m_out(out), m_proof(pr), m_heq_proof(heq_proof) {}
|
||||
};
|
||||
|
||||
typedef std::vector<rewrite_rule_set> rule_sets;
|
||||
typedef expr_map<result> cache;
|
||||
ro_environment m_env;
|
||||
type_checker m_tc;
|
||||
bool m_has_heq;
|
||||
context m_ctx;
|
||||
rule_sets m_rule_sets;
|
||||
cache m_cache;
|
||||
max_sharing_fn m_max_sharing;
|
||||
|
||||
// Configuration
|
||||
bool m_proofs_enabled;
|
||||
bool m_contextual;
|
||||
bool m_single_pass;
|
||||
bool m_beta;
|
||||
bool m_eta;
|
||||
bool m_eval;
|
||||
bool m_unfold;
|
||||
bool m_conditional;
|
||||
bool m_memoize;
|
||||
unsigned m_max_steps;
|
||||
|
||||
struct match_fn {
|
||||
simplifier_fn & m_simp;
|
||||
match_fn(simplifier_fn & s):m_simp(s) {}
|
||||
bool operator()(rewrite_rule const & rule) const {
|
||||
return m_simp.match(rule);
|
||||
}
|
||||
};
|
||||
|
||||
match_fn m_match_fn;
|
||||
|
||||
struct set_context {
|
||||
flet<context> m_set;
|
||||
set_context(simplifier_fn & s, context const & new_ctx):m_set(s.m_ctx, new_ctx) {}
|
||||
freset<cache> m_reset_cache;
|
||||
set_context(simplifier_fn & s, context const & new_ctx):m_set(s.m_ctx, new_ctx), m_reset_cache(s.m_cache) {}
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -676,18 +692,35 @@ class simplifier_fn {
|
|||
}
|
||||
}
|
||||
|
||||
result simplify(expr const & e) {
|
||||
result save(expr const & e, result const & r) {
|
||||
if (m_memoize) {
|
||||
result new_r(m_max_sharing(r.m_out), r.m_proof, r.m_heq_proof);
|
||||
m_cache.insert(mk_pair(e, new_r));
|
||||
return new_r;
|
||||
} else {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
result simplify(expr e) {
|
||||
check_system("simplifier");
|
||||
if (m_memoize) {
|
||||
e = m_max_sharing(e);
|
||||
auto it = m_cache.find(e);
|
||||
if (it != m_cache.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: return simplify_var(e);
|
||||
case expr_kind::Constant: return simplify_constant(e);
|
||||
case expr_kind::Var: return save(e, simplify_var(e));
|
||||
case expr_kind::Constant: return save(e, simplify_constant(e));
|
||||
case expr_kind::Type:
|
||||
case expr_kind::MetaVar:
|
||||
case expr_kind::Value: return result(e);
|
||||
case expr_kind::App: return simplify_app(e);
|
||||
case expr_kind::Lambda: return simplify_lambda(e);
|
||||
case expr_kind::Pi: return simplify_pi(e);
|
||||
case expr_kind::Let: return simplify(instantiate(let_body(e), let_value(e)));
|
||||
case expr_kind::Value: return save(e, result(e));
|
||||
case expr_kind::App: return save(e, simplify_app(e));
|
||||
case expr_kind::Lambda: return save(e, simplify_lambda(e));
|
||||
case expr_kind::Pi: return save(e, simplify_pi(e));
|
||||
case expr_kind::Let: return save(e, simplify(instantiate(let_body(e), let_value(e))));
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
@ -701,6 +734,7 @@ class simplifier_fn {
|
|||
m_eval = get_simplifier_eval(o);
|
||||
m_unfold = get_simplifier_unfold(o);
|
||||
m_conditional = get_simplifier_conditional(o);
|
||||
m_memoize = get_simplifier_memoize(o);
|
||||
m_max_steps = get_simplifier_max_steps(o);
|
||||
}
|
||||
|
||||
|
|
9
tests/lean/simp9.lean
Normal file
9
tests/lean/simp9.lean
Normal file
|
@ -0,0 +1,9 @@
|
|||
variables a b c d e f : Nat
|
||||
rewrite_set simple
|
||||
add_rewrite Nat::mul_assoc Nat::mul_comm Nat::mul_left_comm Nat::add_assoc Nat::add_comm Nat::add_left_comm
|
||||
Nat::distributer Nat::distributel : simple
|
||||
(*
|
||||
local t = parse_lean("(a + b) * (c + d) * (e + f) * (a + b) * (c + d) * (e + f)")
|
||||
local t2, pr = simplify(t, "simple")
|
||||
print(t)
|
||||
*)
|
9
tests/lean/simp9.lean.expected.out
Normal file
9
tests/lean/simp9.lean.expected.out
Normal file
|
@ -0,0 +1,9 @@
|
|||
Set: pp::colors
|
||||
Set: pp::unicode
|
||||
Assumed: a
|
||||
Assumed: b
|
||||
Assumed: c
|
||||
Assumed: d
|
||||
Assumed: e
|
||||
Assumed: f
|
||||
(a + b) * (c + d) * (e + f) * (a + b) * (c + d) * (e + f)
|
Loading…
Add table
Reference in a new issue