diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index de00e9ce6..4524c0536 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -116,55 +116,58 @@ void expr_cell::dealloc() { } } -namespace expr_eq_ns { -static thread_local expr_cell_pair_set g_eq_visited; -bool apply(expr const & a, expr const & b) { - if (eqp(a, b)) return true; - if (a.hash() != b.hash()) return false; - if (a.kind() != b.kind()) return false; - if (is_var(a)) return get_var_idx(a) == get_var_idx(b); - if (is_prop(a)) return true; - if (is_shared(a) && is_shared(b)) { - auto p = std::make_pair(a.raw(), b.raw()); - if (g_eq_visited.find(p) != g_eq_visited.end()) - return true; - g_eq_visited.insert(p); - } - switch (a.kind()) { - case expr_kind::Var: lean_unreachable(); return true; - case expr_kind::Constant: return get_const_name(a) == get_const_name(b); - case expr_kind::App: - if (get_num_args(a) != get_num_args(b)) - return false; - for (unsigned i = 0; i < get_num_args(a); i++) - if (!apply(get_arg(a, i), get_arg(b, i))) - return false; - return true; - case expr_kind::Lambda: - case expr_kind::Pi: - // Lambda and Pi - // Remark: we ignore get_abs_name because we want alpha-equivalence - return apply(get_abs_type(a), get_abs_type(b)) && apply(get_abs_expr(a), get_abs_expr(b)); - case expr_kind::Prop: lean_unreachable(); return true; - case expr_kind::Type: - if (get_ty_num_vars(a) != get_ty_num_vars(b)) - return false; - for (unsigned i = 0; i < get_ty_num_vars(a); i++) { - uvar v1 = get_ty_var(a, i); - uvar v2 = get_ty_var(b, i); - if (v1.first != v2.first || v1.second != v2.second) - return false; + +class eq_functor { + expr_cell_pair_set m_eq_visited; +public: + bool apply(expr const & a, expr const & b) { + if (eqp(a, b)) return true; + if (a.hash() != b.hash()) return false; + if (a.kind() != b.kind()) return false; + if (is_var(a)) return get_var_idx(a) == get_var_idx(b); + if (is_prop(a)) return true; + if (is_shared(a) && is_shared(b)) { + auto p = std::make_pair(a.raw(), b.raw()); + if (m_eq_visited.find(p) != m_eq_visited.end()) + return true; + m_eq_visited.insert(p); } - return true; - case expr_kind::Numeral: return get_numeral(a) == get_numeral(b); + switch (a.kind()) { + case expr_kind::Var: lean_unreachable(); return true; + case expr_kind::Constant: return get_const_name(a) == get_const_name(b); + case expr_kind::App: + if (get_num_args(a) != get_num_args(b)) + return false; + for (unsigned i = 0; i < get_num_args(a); i++) + if (!apply(get_arg(a, i), get_arg(b, i))) + return false; + return true; + case expr_kind::Lambda: + case expr_kind::Pi: + // Lambda and Pi + // Remark: we ignore get_abs_name because we want alpha-equivalence + return apply(get_abs_type(a), get_abs_type(b)) && apply(get_abs_expr(a), get_abs_expr(b)); + case expr_kind::Prop: lean_unreachable(); return true; + case expr_kind::Type: + if (get_ty_num_vars(a) != get_ty_num_vars(b)) + return false; + for (unsigned i = 0; i < get_ty_num_vars(a); i++) { + uvar v1 = get_ty_var(a, i); + uvar v2 = get_ty_var(b, i); + if (v1.first != v2.first || v1.second != v2.second) + return false; + } + return true; + case expr_kind::Numeral: return get_numeral(a) == get_numeral(b); + } + lean_unreachable(); + return false; } - lean_unreachable(); - return false; -} -} // namespace expr_eq +}; + bool operator==(expr const & a, expr const & b) { - expr_eq_ns::g_eq_visited.clear(); - return expr_eq_ns::apply(a, b); + eq_functor f; + return f.apply(a, b); } // Low-level pretty printer