diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 0645516e5..4a32662f2 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -137,6 +137,10 @@ void expr_app::dealloc(buffer & todelete) { static unsigned dec(unsigned k) { return k == 0 ? 0 : k - 1; } +bool operator==(expr_binder_info const & i1, expr_binder_info const & i2) { + return i1.is_implicit() == i2.is_implicit() && i1.is_cast() == i2.is_cast() && i1.is_contextual() == i2.is_contextual(); +} + // Expr binders (Lambda, Pi) expr_binder::expr_binder(expr_kind k, name const & n, expr const & t, expr const & b, expr_binder_info const & i): expr_composite(k, ::lean::hash(t.hash(), b.hash()), @@ -390,6 +394,7 @@ unsigned get_free_var_range(expr const & e) { } bool operator==(expr const & a, expr const & b) { return expr_eq_fn()(a, b); } +bool is_bi_equal(expr const & a, expr const & b) { return expr_eq_fn(true)(a, b); } static expr copy_tag(expr const & e, expr && new_e) { tag t = e.get_tag(); diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 7de639ee1..f55c3daa7 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -150,8 +150,11 @@ public: // ======================================= // Structural equality - bool operator==(expr const & a, expr const & b); +/** \brief Binder information is ignored in the following predicate */ +bool operator==(expr const & a, expr const & b); inline bool operator!=(expr const & a, expr const & b) { return !operator==(a, b); } +/** \brief Similar to ==, but it also compares binder information */ +bool is_bi_equal(expr const & a, expr const & b); // ======================================= SPECIALIZE_OPTIONAL_FOR_SMART_PTR(expr) @@ -232,6 +235,9 @@ public: bool is_contextual() const { return m_contextual; } }; +bool operator==(expr_binder_info const & i1, expr_binder_info const & i2); +inline bool operator!=(expr_binder_info const & i1, expr_binder_info const & i2) { return !(i1 == i2); } + /** \brief Super class for lambda and pi */ class expr_binder : public expr_composite { name m_name; diff --git a/src/kernel/expr_eq_fn.cpp b/src/kernel/expr_eq_fn.cpp index f5a81d303..5f4828024 100644 --- a/src/kernel/expr_eq_fn.cpp +++ b/src/kernel/expr_eq_fn.cpp @@ -39,7 +39,8 @@ bool expr_eq_fn::apply(expr const & a, expr const & b) { case expr_kind::Lambda: case expr_kind::Pi: return apply(binder_domain(a), binder_domain(b)) && - apply(binder_body(a), binder_body(b)); + apply(binder_body(a), binder_body(b)) && + (!m_compare_binder_info || binder_info(a) == binder_info(b)); case expr_kind::Sort: return sort_level(a) == sort_level(b); case expr_kind::Macro: diff --git a/src/kernel/expr_eq_fn.h b/src/kernel/expr_eq_fn.h index 1ecb1e828..e37eefcf4 100644 --- a/src/kernel/expr_eq_fn.h +++ b/src/kernel/expr_eq_fn.h @@ -15,9 +15,12 @@ namespace lean { \brief Functional object for comparing expressions. */ class expr_eq_fn { + bool m_compare_binder_info; std::unique_ptr m_eq_visited; bool apply(expr const & a, expr const & b); public: + /** \brief If \c is true, then functional object will also compare binder information attached to lambda and Pi expressions */ + expr_eq_fn(bool c = false):m_compare_binder_info(c) {} bool operator()(expr const & a, expr const & b) { return apply(a, b); } void clear() { m_eq_visited.reset(); } }; diff --git a/src/kernel/expr_maps.h b/src/kernel/expr_maps.h index 2050eed32..74b04b14e 100644 --- a/src/kernel/expr_maps.h +++ b/src/kernel/expr_maps.h @@ -26,4 +26,8 @@ using expr_cell_offset_map = typename std::unordered_map using expr_struct_map = typename std::unordered_map>; +// The following map also takes into account binder information +struct is_bi_equal_proc { bool operator()(expr const & e1, expr const & e2) const { return is_bi_equal(e1, e2); } }; +template +using expr_bi_struct_map = typename std::unordered_map; }; diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 3c37faa07..beb680f79 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -60,7 +60,11 @@ struct type_checker::imp { name_generator m_gen; constraint_handler & m_chandler; std::unique_ptr m_conv; - expr_struct_map m_infer_type_cache; + // In the type checker cache, we must take into account binder information. + // Examples: + // The type of (lambda x : A, t) is (Pi x : A, typeof(t)) + // The type of (lambda {x : A}, t) is (Pi {x : A}, typeof(t)) + expr_bi_struct_map m_infer_type_cache; converter_context m_conv_ctx; type_checker_context m_tc_ctx; bool m_memoize; diff --git a/tests/lua/tc3.lua b/tests/lua/tc3.lua new file mode 100644 index 000000000..d9c8ba8c2 --- /dev/null +++ b/tests/lua/tc3.lua @@ -0,0 +1,14 @@ +local env = empty_environment() +local t1 = mk_lambda("A", Type, mk_lambda("a", Var(0), Var(0)), binder_info(true)) +local t2 = mk_lambda("A", Type, mk_lambda("a", Var(0), Var(0))) +print(t1) +print(t2) +local tc = type_checker(env) +local T1 = mk_pi("A", Type, mk_arrow(Var(0), Var(1)), binder_info(true)) +local T2 = mk_pi("A", Type, mk_arrow(Var(0), Var(1))) +print(T1) +print(T2) +print(tc:check(t1)) +print(tc:check(t2)) +assert(tc:check(t1):binder_info():is_implicit()) +assert(not tc:check(t2):binder_info():is_implicit())