diff --git a/src/frontends/lean/frontend.cpp b/src/frontends/lean/frontend.cpp index c97a7d225..b17de6ab5 100644 --- a/src/frontends/lean/frontend.cpp +++ b/src/frontends/lean/frontend.cpp @@ -33,6 +33,7 @@ struct frontend::imp { typedef std::unordered_map implicit_table; typedef std::unordered_map, expr_hash, std::equal_to> expr_to_operators; typedef std::unordered_map coercion_map; + typedef std::unordered_map, expr_hash, std::equal_to> expr_to_coercions; typedef std::unordered_set> coercion_set; std::atomic m_num_children; @@ -44,6 +45,7 @@ struct frontend::imp { implicit_table m_implicit_table; // track the number of implicit arguments for a symbol. coercion_map m_coercion_map; // mapping from (given_type, expected_type) -> coercion coercion_set m_coercion_set; // Set of coercions + expr_to_coercions m_type_coercions; // mapping type -> list (to-type, function) state m_state; bool has_children() const { return m_num_children > 0; } @@ -314,28 +316,34 @@ struct frontend::imp { expr to = abst_body(norm_type); if (from == to) throw exception("invalid coercion declaration, 'from' and 'to' types are the same"); - if (get_coercion_core(from, to)) + if (get_coercion(from, to)) throw exception("invalid coercion declaration, frontend already has a coercion for the given types"); m_coercion_map[expr_pair(from, to)] = f; m_coercion_set.insert(f); + list l = get_coercions(from); + insert(m_type_coercions, from, cons(expr_pair(to, f), l)); m_env.add_neutral_object(new coercion_declaration(f)); } - expr get_coercion_core(expr const & given_type, expr const & expected_type) { - expr_pair p(given_type, expected_type); + expr get_coercion(expr const & from_type, expr const & to_type) { + expr_pair p(from_type, to_type); auto it = m_coercion_map.find(p); if (it != m_coercion_map.end()) return it->second; else if (has_parent()) - return m_parent->get_coercion_core(given_type, expected_type); + return m_parent->get_coercion(from_type, to_type); else return expr(); } - expr get_coercion(expr const & given_type, expr const & expected_type, context const & ctx) { - expr norm_given_type = m_env.normalize(given_type, ctx); - expr norm_expected_type = m_env.normalize(expected_type, ctx); - return get_coercion_core(norm_given_type, norm_expected_type); + list get_coercions(expr const & from_type) const { + auto r = m_type_coercions.find(from_type); + if (r != m_type_coercions.end()) + return r->second; + else if (has_parent()) + return m_parent->get_coercions(from_type); + else + return list(); } bool is_coercion(expr const & f) { @@ -423,9 +431,8 @@ name const & frontend::get_explicit_version(name const & n) const { return m_ptr bool frontend::is_explicit(name const & n) const { return m_ptr->is_explicit(n); } void frontend::add_coercion(expr const & f) { m_ptr->add_coercion(f); } -expr frontend::get_coercion(expr const & given_type, expr const & expected_type, context const & ctx) const { - return m_ptr->get_coercion(given_type, expected_type, ctx); -} +expr frontend::get_coercion(expr const & from_type, expr const & to_type) const { return m_ptr->get_coercion(from_type, to_type); } +list frontend::get_coercions(expr const & from_type) const { return m_ptr->get_coercions(from_type); } bool frontend::is_coercion(expr const & f) const { return m_ptr->is_coercion(f); } state const & frontend::get_state() const { return m_ptr->m_state; } diff --git a/src/frontends/lean/frontend.h b/src/frontends/lean/frontend.h index 0ed208d31..9def83427 100644 --- a/src/frontends/lean/frontend.h +++ b/src/frontends/lean/frontend.h @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include #include "kernel/environment.h" #include "library/state.h" +#include "library/expr_pair.h" #include "frontends/lean/operator_info.h" namespace lean { @@ -169,12 +170,20 @@ public: coercion from T1 to T2. */ void add_coercion(expr const & f); + /** - \brief Return a coercion from given_type to expected_type if it exists. - Return the null expression if there is no coercion from \c given_type to - \c expected_type. + \brief Return a coercion from given_type to expected_type if it exists. + Return the null expression if there is no coercion from \c from_type to + \c to_type. */ - expr get_coercion(expr const & given_type, expr const & expected_type, context const & ctx) const; + expr get_coercion(expr const & from_type, expr const & to_type) const; + + /** + \brief Return the list of coercions for the given type. + The result is a list of pairs (to_type, function). + */ + list get_coercions(expr const & from_type) const; + /** \brief Return true iff the given expression is a coercion. That is, it was added using \c add_coercion.