Cache results of the normalizer. Add example that demonstrates the exponential performance improvement.
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
984c4149fa
commit
2d74ff5fe0
4 changed files with 90 additions and 28 deletions
|
@ -9,6 +9,7 @@ Author: Leonardo de Moura
|
|||
#include "expr.h"
|
||||
#include "context.h"
|
||||
#include "environment.h"
|
||||
#include "scoped_map.h"
|
||||
#include "builtin.h"
|
||||
#include "free_vars.h"
|
||||
#include "list.h"
|
||||
|
@ -28,6 +29,7 @@ class svalue {
|
|||
expr m_expr;
|
||||
value_stack m_ctx;
|
||||
public:
|
||||
svalue() {}
|
||||
explicit svalue(expr const & e): m_kind(svalue_kind::Expr), m_expr(e) {}
|
||||
explicit svalue(unsigned k): m_kind(svalue_kind::BoundedVar), m_bvar(k) {}
|
||||
svalue(expr const & e, value_stack const & c):m_kind(svalue_kind::Closure), m_expr(e), m_ctx(c) { lean_assert(is_lambda(e)); }
|
||||
|
@ -52,8 +54,11 @@ value_stack extend(value_stack const & s, svalue const & v) { return cons(v, s);
|
|||
|
||||
/** \brief Expression normalizer. */
|
||||
class normalize_fn {
|
||||
typedef scoped_map<expr, svalue, expr_hash, expr_eqp> cache;
|
||||
|
||||
environment const & m_env;
|
||||
context const & m_ctx;
|
||||
cache m_cache;
|
||||
|
||||
svalue lookup(value_stack const & s, unsigned i, unsigned k) {
|
||||
unsigned j = i;
|
||||
|
@ -127,20 +132,33 @@ class normalize_fn {
|
|||
/** \brief Normalize the expression \c a in a context composed of stack \c s and \c k binders. */
|
||||
svalue normalize(expr const & a, value_stack const & s, unsigned k) {
|
||||
lean_trace("normalize", tout << "Normalize, k: " << k << "\n" << a << "\n";);
|
||||
|
||||
bool shared = false;
|
||||
if (is_shared(a)) {
|
||||
shared = true;
|
||||
auto it = m_cache.find(a);
|
||||
if (it != m_cache.end())
|
||||
return it->second;
|
||||
}
|
||||
|
||||
svalue r;
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var:
|
||||
return lookup(s, var_idx(a), k);
|
||||
r = lookup(s, var_idx(a), k);
|
||||
break;
|
||||
case expr_kind::Constant: {
|
||||
named_object const & obj = m_env.get_object(const_name(a));
|
||||
if (obj.is_definition() && !obj.is_opaque()) {
|
||||
return normalize(obj.get_value(), value_stack(), 0);
|
||||
r = normalize(obj.get_value(), value_stack(), 0);
|
||||
}
|
||||
else {
|
||||
return svalue(a);
|
||||
r = svalue(a);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case expr_kind::Type: case expr_kind::Value:
|
||||
return svalue(a);
|
||||
r = svalue(a);
|
||||
break;
|
||||
case expr_kind::App: {
|
||||
svalue f = normalize(arg(a, 0), s, k);
|
||||
unsigned i = 1;
|
||||
|
@ -150,10 +168,15 @@ class normalize_fn {
|
|||
// beta reduction
|
||||
expr const & fv = to_expr(f);
|
||||
lean_trace("normalize", tout << "beta reduction...\n" << fv << "\n";);
|
||||
{
|
||||
cache::mk_scope sc(m_cache);
|
||||
value_stack new_s = extend(stack_of(f), normalize(arg(a, i), s, k));
|
||||
f = normalize(abst_body(fv), new_s, k);
|
||||
if (i == n - 1)
|
||||
return f;
|
||||
}
|
||||
if (i == n - 1) {
|
||||
r = f;
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
} else {
|
||||
buffer<expr> new_args;
|
||||
|
@ -162,37 +185,55 @@ class normalize_fn {
|
|||
for (; i < n; i++)
|
||||
new_args.push_back(reify(normalize(arg(a, i), s, k), k));
|
||||
if (is_value(new_f)) {
|
||||
expr r;
|
||||
if (to_value(new_f).normalize(new_args.size(), new_args.data(), r))
|
||||
return svalue(r);
|
||||
}
|
||||
return svalue(mk_app(new_args.size(), new_args.data()));
|
||||
expr m;
|
||||
if (to_value(new_f).normalize(new_args.size(), new_args.data(), m)) {
|
||||
r = svalue(m);
|
||||
break;
|
||||
}
|
||||
}
|
||||
r = svalue(mk_app(new_args.size(), new_args.data()));
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case expr_kind::Eq: {
|
||||
expr new_l = reify(normalize(eq_lhs(a), s, k), k);
|
||||
expr new_r = reify(normalize(eq_rhs(a), s, k), k);
|
||||
if (new_l == new_r) {
|
||||
return svalue(mk_bool_value(true));
|
||||
} else if (is_value(new_l) && is_value(new_r)) {
|
||||
return svalue(mk_bool_value(false));
|
||||
expr new_lhs = reify(normalize(eq_lhs(a), s, k), k);
|
||||
expr new_rhs = reify(normalize(eq_rhs(a), s, k), k);
|
||||
if (new_lhs == new_rhs) {
|
||||
r = svalue(mk_bool_value(true));
|
||||
} else if (is_value(new_lhs) && is_value(new_rhs)) {
|
||||
r = svalue(mk_bool_value(false));
|
||||
} else {
|
||||
return svalue(mk_eq(new_l, new_r));
|
||||
r = svalue(mk_eq(new_lhs, new_rhs));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case expr_kind::Lambda:
|
||||
return svalue(a, s);
|
||||
r = svalue(a, s);
|
||||
break;
|
||||
case expr_kind::Pi: {
|
||||
expr new_t = reify(normalize(abst_domain(a), s, k), k);
|
||||
expr new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1);
|
||||
return svalue(mk_pi(abst_name(a), new_t, new_b));
|
||||
expr new_b;
|
||||
{
|
||||
cache::mk_scope sc(m_cache);
|
||||
new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1);
|
||||
}
|
||||
case expr_kind::Let:
|
||||
return normalize(let_body(a), extend(s, normalize(let_value(a), s, k)), k+1);
|
||||
r = svalue(mk_pi(abst_name(a), new_t, new_b));
|
||||
break;
|
||||
}
|
||||
lean_unreachable();
|
||||
return svalue(a);
|
||||
case expr_kind::Let: {
|
||||
svalue v = normalize(let_value(a), s, k);
|
||||
{
|
||||
cache::mk_scope sc(m_cache);
|
||||
r = normalize(let_body(a), extend(s, v), k+1);
|
||||
}
|
||||
break;
|
||||
}}
|
||||
if (shared) {
|
||||
m_cache.insert(a, r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
|
@ -125,7 +125,7 @@ struct infer_type_fn {
|
|||
lean_trace("type_check", tout << "infer type\n" << e << "\n" << ctx << "\n";);
|
||||
|
||||
bool shared = false;
|
||||
if (true && is_shared(e)) {
|
||||
if (is_shared(e)) {
|
||||
shared = true;
|
||||
auto it = m_cache.find(e);
|
||||
if (it != m_cache.end())
|
||||
|
|
|
@ -184,6 +184,27 @@ static void tst10() {
|
|||
std::cout << env.get_object("simp_eq").pp(env) << "\n";
|
||||
}
|
||||
|
||||
static void tst11() {
|
||||
environment env = mk_toplevel();
|
||||
env.add_var("f", Int >> (Int >> Int));
|
||||
env.add_var("a", Int);
|
||||
unsigned n = 1000;
|
||||
expr f = Const("f");
|
||||
expr a = Const("a");
|
||||
expr t1 = f(a,a);
|
||||
expr b = Const("a");
|
||||
expr t2 = f(a,a);
|
||||
expr t3 = f(b,b);
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
t1 = f(t1,t1);
|
||||
t2 = mk_let("x", t2, f(Var(0), Var(0)));
|
||||
t3 = f(t3,t3);
|
||||
}
|
||||
lean_assert(t1 != t2);
|
||||
env.add_theorem("eqs1", Eq(t1,t2), Refl(Int, t1));
|
||||
env.add_theorem("eqs2", Eq(t1,t3), Refl(Int, t1));
|
||||
}
|
||||
|
||||
int main() {
|
||||
tst1();
|
||||
tst2();
|
||||
|
@ -195,5 +216,6 @@ int main() {
|
|||
tst8();
|
||||
tst9();
|
||||
tst10();
|
||||
tst11();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
|
|
@ -98,7 +98,6 @@ public:
|
|||
m_actions.push_back(std::make_pair(action_kind::Replace, *it));
|
||||
it->second = v;
|
||||
}
|
||||
lean_assert(m_map.find(k)->second == v);
|
||||
}
|
||||
|
||||
void insert(value_type const & p) {
|
||||
|
|
Loading…
Reference in a new issue