diff --git a/src/frontends/lean/class.cpp b/src/frontends/lean/class.cpp index f313d1441..9ce195df9 100644 --- a/src/frontends/lean/class.cpp +++ b/src/frontends/lean/class.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include "util/lbool.h" #include "util/sstream.h" #include "kernel/instantiate.h" #include "library/scoped_ext.h" @@ -160,23 +161,88 @@ void register_class_cmds(cmd_table & r) { add_cmd(r, cmd_info("class", "add a new class", add_class_cmd)); } -/** \brief Return true iff \c type is a class or Pi that produces a class. */ -optional is_ext_class(type_checker & tc, expr type) { +/** \brief If the constant \c e is a class, return its name */ +optional constant_is_ext_class(environment const & env, expr const & e) { + name const & cls_name = const_name(e); + if (is_class(env, cls_name) || !empty(get_tactic_hints(env, cls_name))) { + return optional(cls_name); + } else { + return optional(); + } +} + +/** \brief Partial/Quick test for is_ext_class. Result + l_true: \c type is a class, and the name of the class is stored in \c result. + l_false: \c type is not a class. + l_undef: procedure did not establish whether \c type is a class or not. +*/ +lbool is_quick_ext_class(type_checker const & tc, expr const & type, name & result) { + environment const & env = tc.env(); + expr const * it = &type; + while (true) { + switch (it->kind()) { + case expr_kind::Var: case expr_kind::Sort: case expr_kind::Local: + case expr_kind::Meta: case expr_kind::Lambda: + return l_false; + case expr_kind::Macro: + return l_undef; + case expr_kind::Constant: + if (auto r = constant_is_ext_class(env, *it)) { + result = *r; + return l_true; + } else if (tc.is_opaque(*it)) { + return l_false; + } else { + return l_undef; + } + case expr_kind::App: { + expr const & f = get_app_fn(*it); + if (is_constant(f)) { + if (auto r = constant_is_ext_class(env, f)) { + result = *r; + return l_true; + } else if (tc.is_opaque(f)) { + return l_false; + } else { + return l_undef; + } + } else if (is_lambda(f) || is_macro(f)) { + return l_undef; + } else { + return l_false; + } + } + case expr_kind::Pi: + it = &binding_body(*it); + break; + } + } +} + +/** \brief Full/Expensive test for \c is_ext_class */ +optional is_full_ext_class(type_checker & tc, expr type) { type = tc.whnf(type).first; if (is_pi(type)) { - return is_ext_class(tc, instantiate(binding_body(type), mk_local(tc.mk_fresh_name(), binding_domain(type)))); + return is_full_ext_class(tc, instantiate(binding_body(type), mk_local(tc.mk_fresh_name(), binding_domain(type)))); } else { expr f = get_app_fn(type); if (!is_constant(f)) return optional(); - name const & cls_name = const_name(f); - if (is_class(tc.env(), cls_name) || !empty(get_tactic_hints(tc.env(), cls_name))) - return optional(cls_name); - else - return optional(); + return constant_is_ext_class(tc.env(), f); } } +/** \brief Return true iff \c type is a class or Pi that produces a class. */ +optional is_ext_class(type_checker & tc, expr const & type) { + name result; + switch (is_quick_ext_class(tc, type, result)) { + case l_true: return optional(result); + case l_false: return optional(); + case l_undef: break; + } + return is_full_ext_class(tc, type); +} + /** \brief Return a list of instances of the class \c cls_name that occur in \c ctx */ list get_local_instances(type_checker & tc, list const & ctx, name const & cls_name) { buffer buffer; diff --git a/src/frontends/lean/class.h b/src/frontends/lean/class.h index 70976a8dc..5c3a716eb 100644 --- a/src/frontends/lean/class.h +++ b/src/frontends/lean/class.h @@ -23,7 +23,7 @@ name get_class_name(environment const & env, expr const & e); void register_class_cmds(cmd_table & r); /** \brief Return true iff \c type is a class or Pi that produces a class. */ -optional is_ext_class(type_checker & tc, expr type); +optional is_ext_class(type_checker & tc, expr const & type); /** \brief Return a list of instances of the class \c cls_name that occur in \c ctx */ list get_local_instances(type_checker & tc, list const & ctx, name const & cls_name);