diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index bb782ac91..8a9be14ca 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -482,9 +482,6 @@ unsigned get_weight(expr const & e) { lean_unreachable(); // LCOV_EXCL_LINE } -bool operator==(expr const & a, expr const & b) { return expr_eq_fn()(a, b); } -bool is_bi_equal(expr const & a, expr const & b) { return expr_eq_fn(true)(a, b); } - expr copy_tag(expr const & e, expr && new_e) { tag t = e.get_tag(); if (t != nulltag) diff --git a/src/kernel/expr.h b/src/kernel/expr.h index f0908aa71..38474d086 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -25,6 +25,7 @@ Author: Leonardo de Moura #include "kernel/level.h" #include "kernel/formatter.h" #include "kernel/extension_context.h" +#include "kernel/expr_eq_fn.h" namespace lean { // Tags are used by frontends to mark expressions. They are automatically propagated by @@ -148,16 +149,6 @@ public: expr copy_tag(expr const & e, expr && new_e); -// ======================================= -// Structural equality -/** \brief Binder information is ignored in the following predicate */ -bool operator==(expr const & a, expr const & b); -inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b); } -/** \brief Similar to ==, but it also compares binder information */ -bool is_bi_equal(expr const & a, expr const & b); -struct is_bi_equal_proc { bool operator()(expr const & e1, expr const & e2) const { return is_bi_equal(e1, e2); } }; -// ======================================= - SPECIALIZE_OPTIONAL_FOR_SMART_PTR(expr) inline optional none_expr() { return optional(); } diff --git a/src/kernel/expr_eq_fn.cpp b/src/kernel/expr_eq_fn.cpp index ecfe444cb..6d091b4dd 100644 --- a/src/kernel/expr_eq_fn.cpp +++ b/src/kernel/expr_eq_fn.cpp @@ -4,67 +4,116 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ -#include "kernel/expr_eq_fn.h" +#include +#include +#include "util/interrupt.h" +#include "util/thread.h" +#include "kernel/expr.h" +#include "kernel/expr_sets.h" -#ifndef LEAN_EQ_CACHE_THRESHOLD -#define LEAN_EQ_CACHE_THRESHOLD 4 +#ifndef LEAN_EQ_CACHE_CAPACITY +#define LEAN_EQ_CACHE_CAPACITY 1024*8 #endif namespace lean { -bool expr_eq_fn::apply(expr const & a, expr const & b) { - if (is_eqp(a, b)) return true; - if (a.hash() != b.hash()) return false; - if (a.kind() != b.kind()) return false; - if (is_var(a)) return var_idx(a) == var_idx(b); - if (m_counter >= LEAN_EQ_CACHE_THRESHOLD && is_shared(a) && is_shared(b)) { - auto p = std::make_pair(a.raw(), b.raw()); - if (!m_eq_visited) - m_eq_visited.reset(new expr_cell_pair_set); - if (m_eq_visited->find(p) != m_eq_visited->end()) +struct eq_cache { + struct entry { + expr_ptr m_a; + expr_ptr m_b; + entry():m_a(nullptr), m_b(nullptr) {} + }; + unsigned m_capacity; + std::vector m_cache; + std::vector m_used; + eq_cache():m_capacity(LEAN_EQ_CACHE_CAPACITY), m_cache(LEAN_EQ_CACHE_CAPACITY) {} + + bool check(expr const & a, expr const & b) { + unsigned i = hash(a.hash_alloc(), b.hash_alloc()) % m_capacity; + if (m_cache[i].m_a == a.raw() && m_cache[i].m_b == b.raw()) { return true; - m_eq_visited->insert(p); - } - check_system("expression equality test"); - switch (a.kind()) { - case expr_kind::Var: - lean_unreachable(); // LCOV_EXCL_LINE - case expr_kind::Constant: - return - const_name(a) == const_name(b) && - compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; }); - case expr_kind::Meta: - return - mlocal_name(a) == mlocal_name(b) && - apply(mlocal_type(a), mlocal_type(b)); - case expr_kind::Local: - return - mlocal_name(a) == mlocal_name(b) && - apply(mlocal_type(a), mlocal_type(b)) && - (!m_compare_binder_info || local_pp_name(a) == local_pp_name(b)) && - (!m_compare_binder_info || local_info(a) == local_info(b)); - case expr_kind::App: - m_counter++; - return - apply(app_fn(a), app_fn(b)) && - apply(app_arg(a), app_arg(b)); - case expr_kind::Lambda: case expr_kind::Pi: - m_counter++; - return - apply(binding_domain(a), binding_domain(b)) && - apply(binding_body(a), binding_body(b)) && - (!m_compare_binder_info || binding_info(a) == binding_info(b)); - case expr_kind::Sort: - return sort_level(a) == sort_level(b); - case expr_kind::Macro: - m_counter++; - if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b)) + } else { + if (m_cache[i].m_a == nullptr) + m_used.push_back(i); + m_cache[i].m_a = a.raw(); + m_cache[i].m_b = b.raw(); return false; - for (unsigned i = 0; i < macro_num_args(a); i++) { - if (!apply(macro_arg(a, i), macro_arg(b, i))) - return false; } - return true; } - lean_unreachable(); // LCOV_EXCL_LINE + + void clear() { + for (unsigned i : m_used) + m_cache[i].m_a = nullptr; + m_used.clear(); + } +}; + +MK_THREAD_LOCAL_GET_DEF(eq_cache, get_eq_cache); + +/** \brief Functional object for comparing expressions. + + Remark if CompareBinderInfo is true, then functional object will also compare + binder information attached to lambda and Pi expressions */ +template +class expr_eq_fn { + eq_cache & m_cache; + + bool apply(expr const & a, expr const & b) { + if (is_eqp(a, b)) return true; + if (a.hash() != b.hash()) return false; + if (a.kind() != b.kind()) return false; + if (is_var(a)) return var_idx(a) == var_idx(b); + if (m_cache.check(a, b)) + return true; + check_system("expression equality test"); + switch (a.kind()) { + case expr_kind::Var: + lean_unreachable(); // LCOV_EXCL_LINE + case expr_kind::Constant: + return + const_name(a) == const_name(b) && + compare(const_levels(a), const_levels(b), [](level const & l1, level const & l2) { return l1 == l2; }); + case expr_kind::Meta: + return + mlocal_name(a) == mlocal_name(b) && + apply(mlocal_type(a), mlocal_type(b)); + case expr_kind::Local: + return + mlocal_name(a) == mlocal_name(b) && + apply(mlocal_type(a), mlocal_type(b)) && + (!CompareBinderInfo || local_pp_name(a) == local_pp_name(b)) && + (!CompareBinderInfo || local_info(a) == local_info(b)); + case expr_kind::App: + return + apply(app_fn(a), app_fn(b)) && + apply(app_arg(a), app_arg(b)); + case expr_kind::Lambda: case expr_kind::Pi: + return + apply(binding_domain(a), binding_domain(b)) && + apply(binding_body(a), binding_body(b)) && + (!CompareBinderInfo || binding_info(a) == binding_info(b)); + case expr_kind::Sort: + return sort_level(a) == sort_level(b); + case expr_kind::Macro: + if (macro_def(a) != macro_def(b) || macro_num_args(a) != macro_num_args(b)) + return false; + for (unsigned i = 0; i < macro_num_args(a); i++) { + if (!apply(macro_arg(a, i), macro_arg(b, i))) + return false; + } + return true; + } + lean_unreachable(); // LCOV_EXCL_LINE + } +public: + expr_eq_fn():m_cache(get_eq_cache()) {} + ~expr_eq_fn() { m_cache.clear(); } + bool operator()(expr const & a, expr const & b) { return apply(a, b); } +}; + +bool is_equal(expr const & a, expr const & b) { + return expr_eq_fn()(a, b); +} +bool is_bi_equal(expr const & a, expr const & b) { + return expr_eq_fn()(a, b); } } diff --git a/src/kernel/expr_eq_fn.h b/src/kernel/expr_eq_fn.h index 48a0fea8a..2724d27eb 100644 --- a/src/kernel/expr_eq_fn.h +++ b/src/kernel/expr_eq_fn.h @@ -1,31 +1,22 @@ /* -Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Copyright (c) 2014 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once -#include -#include "util/interrupt.h" -#include "kernel/expr.h" -#include "kernel/expr_sets.h" namespace lean { -/** - \brief Functional object for comparing expressions. -*/ -class expr_eq_fn { - bool m_compare_binder_info; - // We only use the cache m_eq_visited when m_counter > LEAN_EQ_CACHE_THRESHOLD. - // The idea is that most queries fail quickly, and it is a wast of time - // to create the cache. - unsigned m_counter; - std::unique_ptr m_eq_visited; - bool apply(expr const & a, expr const & b); -public: - /** \brief If \c is true, then functional object will also compare binder information attached to lambda and Pi expressions */ - expr_eq_fn(bool c = false):m_compare_binder_info(c), m_counter(0) {} - bool operator()(expr const & a, expr const & b) { m_counter = 0; return apply(a, b); } - void clear() { m_eq_visited.reset(); } -}; +class expr; +// ======================================= +// Structural equality +/** \brief Binder information is ignored in the following predicate */ +bool is_equal(expr const & a, expr const & b); +inline bool operator==(expr const & a, expr const & b) { return is_equal(a, b); } +inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b); } +// ======================================= + +/** \brief Similar to ==, but it also compares binder information */ +bool is_bi_equal(expr const & a, expr const & b); +struct is_bi_equal_proc { bool operator()(expr const & e1, expr const & e2) const { return is_bi_equal(e1, e2); } }; } diff --git a/src/kernel/expr_sets.h b/src/kernel/expr_sets.h index 8418cf478..edd690574 100644 --- a/src/kernel/expr_sets.h +++ b/src/kernel/expr_sets.h @@ -12,32 +12,12 @@ Author: Leonardo de Moura #include "kernel/expr.h" namespace lean { - // ======================================= // Expression Set // Remark: to expressions are assumed to be equal if they are "pointer-equal" typedef std::unordered_set expr_set; // ======================================= -// ======================================= -// (low level) Expression Cell pair Set -// Remark: to expressions are assumed to be equal if they are "pointer-equal" -// -// WARNING: use with care, this kind of set -// does not prevent an expression from being -// garbage collected. -typedef pair expr_cell_pair; -struct expr_cell_pair_hash { - unsigned operator()(expr_cell_pair const & p) const { return hash(p.first->hash_alloc(), p.second->hash_alloc()); } -}; -struct expr_cell_pair_eqp { - bool operator()(expr_cell_pair const & p1, expr_cell_pair const & p2) const { - return p1.first == p2.first && p1.second == p2.second; - } -}; -typedef std::unordered_set expr_cell_pair_set; -// ======================================= - // Similar to expr_set, but using structural equality typedef std::unordered_set> expr_struct_set; }