diff --git a/src/kernel/max_sharing.cpp b/src/kernel/max_sharing.cpp index 7b3f8eb2f..c49d15062 100644 --- a/src/kernel/max_sharing.cpp +++ b/src/kernel/max_sharing.cpp @@ -25,13 +25,15 @@ class max_sharing_functor { public: expr apply(expr const & a) { - if (a.raw()->max_shared()) - return a; auto r = m_cache.find(a); if (r != m_cache.end()) { lean_assert((*r).raw()->max_shared()); return *r; } + if (a.raw()->max_shared()) { + m_cache.insert(a); + return a; + } switch (a.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: cache(a); diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index c6d7a1e6e..bdfe92760 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -187,6 +187,16 @@ void tst6() { } } +void tst7() { + expr f = constant(name("f")); + expr v = var(0); + expr a1 = max_sharing(app(f, v, v)); + expr a2 = max_sharing(app(f, v, v)); + lean_assert(!eqp(a1, a2)); + expr b = max_sharing(app(f, a1, a2)); + lean_assert(eqp(get_arg(b, 1), get_arg(b, 2))); +} + int main() { continue_on_violation(true); std::cout << "sizeof(expr): " << sizeof(expr) << "\n"; @@ -197,6 +207,7 @@ int main() { tst4(); tst5(); tst6(); + tst7(); std::cout << "done" << "\n"; return has_violations() ? 1 : 0; }