Refactor expression equality
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
f08c06d582
commit
18a195029b
2 changed files with 72 additions and 45 deletions
|
@ -9,7 +9,7 @@ Author: Leonardo de Moura
|
|||
#include <sstream>
|
||||
#include "expr.h"
|
||||
#include "free_vars.h"
|
||||
#include "expr_sets.h"
|
||||
#include "expr_eq.h"
|
||||
#include "hash.h"
|
||||
|
||||
namespace lean {
|
||||
|
@ -135,51 +135,8 @@ expr mk_type() {
|
|||
return r;
|
||||
}
|
||||
|
||||
class eq_fn {
|
||||
expr_cell_pair_set m_eq_visited;
|
||||
|
||||
bool apply(expr const & a, expr const & b) {
|
||||
if (is_eqp(a, b)) return true;
|
||||
if (a.hash() != b.hash()) return false;
|
||||
if (a.kind() != b.kind()) return false;
|
||||
if (is_var(a)) return var_idx(a) == var_idx(b);
|
||||
if (is_shared(a) && is_shared(b)) {
|
||||
auto p = std::make_pair(a.raw(), b.raw());
|
||||
if (m_eq_visited.find(p) != m_eq_visited.end())
|
||||
return true;
|
||||
m_eq_visited.insert(p);
|
||||
}
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var: lean_unreachable(); return true;
|
||||
case expr_kind::Constant: return const_name(a) == const_name(b);
|
||||
case expr_kind::App:
|
||||
if (num_args(a) != num_args(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < num_args(a); i++)
|
||||
if (!apply(arg(a, i), arg(b, i)))
|
||||
return false;
|
||||
return true;
|
||||
case expr_kind::Eq: return apply(eq_lhs(a), eq_lhs(b)) && apply(eq_rhs(a), eq_rhs(b));
|
||||
case expr_kind::Lambda:
|
||||
case expr_kind::Pi:
|
||||
// Lambda and Pi
|
||||
// Remark: we ignore get_abs_name because we want alpha-equivalence
|
||||
return apply(abst_domain(a), abst_domain(b)) && apply(abst_body(a), abst_body(b));
|
||||
case expr_kind::Type: return ty_level(a) == ty_level(b);
|
||||
case expr_kind::Value: return to_value(a) == to_value(b);
|
||||
case expr_kind::Let: return apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b));
|
||||
}
|
||||
lean_unreachable();
|
||||
return false;
|
||||
}
|
||||
public:
|
||||
bool operator()(expr const & a, expr const & b) {
|
||||
return apply(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
bool operator==(expr const & a, expr const & b) {
|
||||
return eq_fn()(a, b);
|
||||
return expr_eq_fn<>()(a, b);
|
||||
}
|
||||
|
||||
bool is_arrow(expr const & t) {
|
||||
|
|
70
src/kernel/expr_eq.h
Normal file
70
src/kernel/expr_eq.h
Normal file
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "expr.h"
|
||||
#include "expr_sets.h"
|
||||
|
||||
namespace lean {
|
||||
/** \brief Identity function for expressions. */
|
||||
struct id_expr_fn {
|
||||
expr const & operator()(expr const & e) const { return e; }
|
||||
};
|
||||
|
||||
/**
|
||||
\brief Functional object for comparing expressions.
|
||||
The parameter N is a normalization function that can be used
|
||||
to normalize sub-expressions before comparing them.
|
||||
The hashcode of expressions is used to optimize the comparison when
|
||||
parameter UseHash == true. We usually set UseHash to false when N
|
||||
is not the identity function.
|
||||
*/
|
||||
template<typename N = id_expr_fn, bool UseHash = false>
|
||||
class expr_eq_fn {
|
||||
expr_cell_pair_set m_eq_visited;
|
||||
N m_norm;
|
||||
|
||||
bool apply(expr const & a0, expr const & b0) {
|
||||
if (is_eqp(a0, b0)) return true;
|
||||
if (UseHash && a0.hash() != b0.hash()) return false;
|
||||
expr const & a = m_norm(a0);
|
||||
expr const & b = m_norm(b0);
|
||||
if (a.kind() != b.kind()) return false;
|
||||
if (is_var(a)) return var_idx(a) == var_idx(b);
|
||||
if (is_shared(a) && is_shared(b)) {
|
||||
auto p = std::make_pair(a.raw(), b.raw());
|
||||
if (m_eq_visited.find(p) != m_eq_visited.end())
|
||||
return true;
|
||||
m_eq_visited.insert(p);
|
||||
}
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var: lean_unreachable(); return true; // LCOV_EXCL_LINE
|
||||
case expr_kind::Constant: return const_name(a) == const_name(b);
|
||||
case expr_kind::App:
|
||||
if (num_args(a) != num_args(b))
|
||||
return false;
|
||||
for (unsigned i = 0; i < num_args(a); i++)
|
||||
if (!apply(arg(a, i), arg(b, i)))
|
||||
return false;
|
||||
return true;
|
||||
case expr_kind::Eq: return apply(eq_lhs(a), eq_lhs(b)) && apply(eq_rhs(a), eq_rhs(b));
|
||||
case expr_kind::Lambda: // Remark: we ignore get_abs_name because we want alpha-equivalence
|
||||
case expr_kind::Pi: return apply(abst_domain(a), abst_domain(b)) && apply(abst_body(a), abst_body(b));
|
||||
case expr_kind::Type: return ty_level(a) == ty_level(b);
|
||||
case expr_kind::Value: return to_value(a) == to_value(b);
|
||||
case expr_kind::Let: return apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b));
|
||||
}
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
return false; // LCOV_EXCL_LINE
|
||||
}
|
||||
public:
|
||||
expr_eq_fn(N const & norm = N()):m_norm(norm) {}
|
||||
bool operator()(expr const & a, expr const & b) {
|
||||
return apply(a, b);
|
||||
}
|
||||
void clear() { m_eq_visited.clear(); }
|
||||
};
|
||||
}
|
Loading…
Reference in a new issue