diff --git a/src/frontends/lean/frontend.cpp b/src/frontends/lean/frontend.cpp index ac053af78..301807bd9 100644 --- a/src/frontends/lean/frontend.cpp +++ b/src/frontends/lean/frontend.cpp @@ -379,47 +379,75 @@ struct lean_extension : public environment_extension { return !n.is_atomic() && get_explicit_version(n.get_prefix()) == n; } + /** + \brief It is too expensive to normalize \c t when checking if there is a coercion for it. + So, we just do a 'quick' normalization following a chain of definitions. + */ + expr coercion_type_normalization(expr t, ro_environment const & env) const { + while (true) { + if (is_constant(t)) { + auto obj = env->find_object(const_name(t)); + if (obj && obj->is_definition()) { + t = obj->get_value(); + } else { + return t; + } + } else { + return t; + } + } + } + void add_coercion(expr const & f, environment const & env) { expr type = env->infer_type(f); expr norm_type = env->normalize(type); if (!is_arrow(norm_type)) throw exception("invalid coercion declaration, a coercion must have an arrow type (i.e., a non-dependent functional type)"); - expr from = abst_domain(norm_type); - expr to = abst_body(norm_type); + expr from = coercion_type_normalization(abst_domain(norm_type), env); + expr to = coercion_type_normalization(abst_body(norm_type), env); if (from == to) throw exception("invalid coercion declaration, 'from' and 'to' types are the same"); - if (get_coercion(from, to)) + if (get_coercion_core(from, to)) throw exception("invalid coercion declaration, frontend already has a coercion for the given types"); m_coercion_map[expr_pair(from, to)] = f; m_coercion_set.insert(f); - list l = get_coercions(from); + list l = get_coercions_core(from); insert(m_type_coercions, from, cons(expr_pair(to, f), l)); env->add_neutral_object(new coercion_declaration(f)); } - optional get_coercion(expr const & from_type, expr const & to_type) const { + optional get_coercion_core(expr const & from_type, expr const & to_type) const { expr_pair p(from_type, to_type); auto it = m_coercion_map.find(p); if (it != m_coercion_map.end()) return some_expr(it->second); lean_extension const * parent = get_parent(); if (parent) - return parent->get_coercion(from_type, to_type); + return parent->get_coercion_core(from_type, to_type); else return none_expr(); } - list get_coercions(expr const & from_type) const { + optional get_coercion(expr const & from_type, expr const & to_type, ro_environment const & env) const { + return get_coercion_core(coercion_type_normalization(from_type, env), + coercion_type_normalization(to_type, env)); + } + + list get_coercions_core(expr const & from_type) const { auto r = m_type_coercions.find(from_type); if (r != m_type_coercions.end()) return r->second; lean_extension const * parent = get_parent(); if (parent) - return parent->get_coercions(from_type); + return parent->get_coercions_core(from_type); else return list(); } + list get_coercions(expr const & from_type, ro_environment const & env) const { + return get_coercions_core(coercion_type_normalization(from_type, env)); + } + bool is_coercion(expr const & f) const { if (m_coercion_set.find(f) != m_coercion_set.end()) return true; @@ -555,10 +583,10 @@ void add_coercion(environment const & env, expr const & f) { to_ext(env).add_coercion(f, env); } optional get_coercion(ro_environment const & env, expr const & from_type, expr const & to_type) { - return to_ext(env).get_coercion(from_type, to_type); + return to_ext(env).get_coercion(from_type, to_type, env); } list get_coercions(ro_environment const & env, expr const & from_type) { - return to_ext(env).get_coercions(from_type); + return to_ext(env).get_coercions(from_type, env); } bool is_coercion(ro_environment const & env, expr const & f) { return to_ext(env).is_coercion(f); diff --git a/tests/lean/bad8.lean b/tests/lean/bad8.lean new file mode 100644 index 000000000..b0179a29b --- /dev/null +++ b/tests/lean/bad8.lean @@ -0,0 +1,9 @@ +Variable list : Type → Type +Variable nil {A : Type} : list A +Variable cons {A : Type} (head : A) (tail : list A) : list A +Variable a : ℤ +Variable b : ℤ +Variable n : ℕ +Variable m : ℕ +Definition l1 : list ℤ := cons a (cons b (cons n nil)) +Definition l2 : list ℤ := cons a (cons n (cons b nil)) diff --git a/tests/lean/bad8.lean.expected.out b/tests/lean/bad8.lean.expected.out new file mode 100644 index 000000000..ac31c7569 --- /dev/null +++ b/tests/lean/bad8.lean.expected.out @@ -0,0 +1,11 @@ + Set: pp::colors + Set: pp::unicode + Assumed: list + Assumed: nil + Assumed: cons + Assumed: a + Assumed: b + Assumed: n + Assumed: m + Defined: l1 + Defined: l2