From 0446c43ebfa61fe89cbfa95570c4a51a61fa3ffa Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 20 Oct 2015 10:03:26 -0700 Subject: [PATCH] refactor(library/class_instance_resolution): use new generic type_inference module to implement type class resolution --- src/library/class_instance_resolution.cpp | 646 ++++------------------ src/library/class_instance_resolution.h | 31 +- src/library/init_module.cpp | 3 + 3 files changed, 139 insertions(+), 541 deletions(-) diff --git a/src/library/class_instance_resolution.cpp b/src/library/class_instance_resolution.cpp index 61b9441ec..fcffe974e 100644 --- a/src/library/class_instance_resolution.cpp +++ b/src/library/class_instance_resolution.cpp @@ -22,6 +22,7 @@ Author: Leonardo de Moura #include "library/constants.h" #include "library/pp_options.h" #include "library/choice_iterator.h" +#include "library/type_inference.h" #include "library/class_instance_resolution.h" #ifndef LEAN_DEFAULT_CLASS_TRACE_INSTANCES @@ -40,18 +41,13 @@ namespace lean { [[ noreturn ]] void throw_class_exception(char const * msg, expr const & m) { throw_generic_exception(msg, m); } [[ noreturn ]] void throw_class_exception(expr const & m, pp_fn const & fn) { throw_generic_exception(m, fn); } -typedef std::shared_ptr ci_type_inference_ptr; - static name * g_class_trace_instances = nullptr; static name * g_class_instance_max_depth = nullptr; static name * g_class_trans_instances = nullptr; -static name * g_prefix1 = nullptr; -static name * g_prefix2 = nullptr; +static name * g_prefix = nullptr; -static ci_type_inference_factory * g_default_factory = nullptr; - -LEAN_THREAD_PTR(ci_type_inference_factory, g_factory); -LEAN_THREAD_PTR(io_state, g_ios); +LEAN_THREAD_PTR(ci_local_metavar_types, g_lm_types); +LEAN_THREAD_PTR(io_state, g_ios); bool get_class_trace_instances(options const & o) { return o.get_bool(*g_class_trace_instances, LEAN_DEFAULT_CLASS_TRACE_INSTANCES); @@ -65,31 +61,18 @@ bool get_class_trans_instances(options const & o) { return o.get_bool(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES); } -class default_ci_type_inference : public ci_type_inference { - type_checker_ptr m_tc; +class default_ci_local_metavar_types : public ci_local_metavar_types { public: - default_ci_type_inference(environment const & env): - m_tc(mk_type_checker(env, name_generator(*g_prefix1), UnfoldReducible)) {} - - virtual ~default_ci_type_inference() {} - - virtual expr whnf(expr const & e) { - return m_tc->whnf(e).first; - } - - virtual expr infer_type(expr const & e) { - return m_tc->infer(e).first; - } + virtual expr infer_local(expr const & e) { return mlocal_type(e); } + virtual expr infer_metavar(expr const & e) { return mlocal_type(e); } }; -ci_type_inference_factory::~ci_type_inference_factory() {} - -std::shared_ptr ci_type_inference_factory::operator()(environment const & env) const { - return std::shared_ptr(new default_ci_type_inference(env)); +static expr ci_infer_local(expr const & e) { + return g_lm_types->infer_local(e); } -static ci_type_inference_factory & get_ci_factory() { - return g_factory ? *g_factory : *g_default_factory; +static expr ci_infer_metavar(expr const & e) { + return g_lm_types->infer_metavar(e); } /** \brief The following global thread local constant is a big hack for mk_subsingleton_instance. @@ -114,11 +97,11 @@ LEAN_THREAD_VALUE(bool, g_subsingleton_hack, false); struct cienv { typedef rb_map uassignment; typedef rb_map eassignment; - + typedef std::unique_ptr ti_ptr; environment m_env; pos_info_provider const * m_pip; + ti_ptr m_ti_ptr; optional m_pos; - ci_type_inference_ptr m_tc_ptr; expr_struct_map m_cache; name_predicate m_not_reducible_pred; @@ -157,6 +140,69 @@ struct cienv { bool m_trans_instances; bool m_trace_instances; + class ti : public type_inference { + cienv & m_cienv; + std::vector m_stack; + public: + ti(cienv & e):type_inference(e.m_env), m_cienv(e) {} + virtual bool is_extra_opaque(name const & n) const { return m_cienv.is_not_reducible(n); } + virtual expr mk_tmp_local(expr const & type, binder_info const & bi) { return m_cienv.mk_local(type, bi); } + virtual bool is_tmp_local(expr const & e) const { return m_cienv.is_internal_local(e); } + virtual bool is_uvar(level const & l) const { return cienv::is_uvar(l); } + virtual bool is_mvar(expr const & e) const { return m_cienv.is_mvar(e); } + virtual level const * get_assignment(level const & u) const { return m_cienv.get_assignment(u); } + virtual expr const * get_assignment(expr const & m) const { return m_cienv.get_assignment(m); } + virtual void update_assignment(level const & u, level const & v) { return m_cienv.update_assignment(u, v); } + virtual void update_assignment(expr const & m, expr const & v) { return m_cienv.update_assignment(m, v); } + virtual expr infer_local(expr const & e) const { return ci_infer_local(e); } + virtual expr infer_metavar(expr const & e) const { return ci_infer_metavar(e); } + virtual void push() { m_stack.push_back(m_cienv.m_state); } + virtual void pop() { m_cienv.m_state = m_stack.back(); m_stack.pop_back(); } + virtual void commit() { m_stack.pop_back(); } + + virtual bool ignore_universe_def_eq(level const & l1, level const & l2) const { + if (is_meta(l1) || is_meta(l2)) { + // The unifier may invoke this module before universe metavariables in the class + // have been instantiated. So, we just ignore and assume they will be solved by + // the unifier. + + // See comment at g_subsingleton_hack declaration. + if (g_subsingleton_hack && (is_zero(l1) || is_zero(l2))) + return false; + return true; // we ignore + } else { + return false; + } + } + + + virtual bool validate_assignment(expr const & m, buffer const & locals, expr const & v) { + // We must check + // 1. Any (internal) local constant occurring in v occurs in locals + // 2. m does not occur in v + bool ok = true; + for_each(v, [&](expr const & e, unsigned) { + if (!ok) + return false; // stop search + if (is_tmp_local(e)) { + if (std::all_of(locals.begin(), locals.end(), [&](expr const & a) { + return mlocal_name(a) != mlocal_name(e); })) { + ok = false; // failed 1 + return false; + } + } else if (is_mvar(e)) { + if (m == e) { + ok = false; // failed 2 + return false; + } + return false; + } + return true; + }); + return ok; + } + }; + cienv(bool multiple_instances = false): m_next_local_idx(0), m_next_uvar_idx(0), @@ -167,18 +213,20 @@ struct cienv { return m_not_reducible_pred(n); } - void reset_cache() { + void clear_cache() { expr_struct_map fresh; fresh.swap(m_cache); + if (m_ti_ptr) + m_ti_ptr->clear_cache(); } - void reset_cache_and_ctx() { + void clear_cache_and_ctx() { m_next_local_idx = 0; m_next_uvar_idx = 0; m_next_mvar_idx = 0; m_ctx = list(); m_local_instances.clear(); - reset_cache(); + clear_cache(); } optional check_cache(expr const & type) const { @@ -214,7 +262,7 @@ struct cienv { if (m_max_depth != max_depth || m_trans_instances != trans_instances || m_trace_instances != trace_instances) { - reset_cache_and_ctx(); + clear_cache_and_ctx(); } m_max_depth = max_depth; m_trans_instances = trans_instances; @@ -228,44 +276,44 @@ struct cienv { if (!m_env.is_descendant(m_env) || !m_env.is_descendant(env)) { m_env = env; m_not_reducible_pred = mk_not_reducible_pred(m_env); - m_tc_ptr = nullptr; - reset_cache_and_ctx(); + m_ti_ptr = nullptr; + clear_cache_and_ctx(); } - if (!m_tc_ptr) { - ci_type_inference_factory & factory = get_ci_factory(); - m_tc_ptr = factory(m_env); + if (!m_ti_ptr) { + m_ti_ptr.reset(new ti(*this)); + clear_cache_and_ctx(); } } expr whnf(expr const & e) { - return m_tc_ptr->whnf(e); + return m_ti_ptr->whnf(e); } expr infer_type(expr const & e) { - return m_tc_ptr->infer_type(e); + return m_ti_ptr->infer(e); } - bool is_prop(expr const & e) { - if (m_env.impredicative()) { - expr t = whnf(infer_type(e)); - return t == mk_Prop(); - } else { - return false; - } + bool is_def_eq(expr const & e1, expr const & e2) { + return m_ti_ptr->is_def_eq(e1, e2); } - expr mk_local(expr const & type) { + expr instantiate_uvars_mvars(expr const & e) { + return m_ti_ptr->instantiate_uvars_mvars(e); + } + + expr mk_local(expr const & type, binder_info const & bi = binder_info()) { unsigned idx = m_next_local_idx; m_next_local_idx++; - return lean::mk_local(name(*g_prefix2, idx), type); + name n(*g_prefix, idx); + return lean::mk_local(n, n, type, bi); } bool is_internal_local(expr const & e) { if (!is_local(e)) return false; name const & n = mlocal_name(e); - return !n.is_atomic() && n.get_prefix() == *g_prefix2; + return !n.is_atomic() && n.get_prefix() == *g_prefix; } /** \brief If the constant \c e is a class, return its name */ @@ -353,10 +401,10 @@ struct cienv { void set_local_instance(unsigned i, name const & cname, expr const & e) { lean_assert(i <= m_local_instances.size()); if (i == m_local_instances.size()) { - reset_cache(); + clear_cache(); m_local_instances.push_back(mk_pair(cname, e)); } else if (e != m_local_instances[i].second) { - reset_cache(); + clear_cache(); m_local_instances[i] = mk_pair(cname, e); } else { // we don't need to reset the cache since this local instance @@ -384,7 +432,7 @@ struct cienv { if (i < m_local_instances.size()) { // new ctx has fewer local instances than previous one m_local_instances.resize(i); - reset_cache(); + clear_cache(); } } @@ -398,7 +446,7 @@ struct cienv { level mk_uvar() { unsigned idx = m_next_uvar_idx; m_next_uvar_idx++; - return mk_meta_univ(name(*g_prefix2, idx)); + return mk_meta_univ(name(*g_prefix, idx)); } // Return true iff \c l is an internal universe metavariable created by this module. @@ -406,7 +454,7 @@ struct cienv { if (!is_meta(l)) return false; name const & n = meta_id(l); - return !n.is_atomic() && n.get_prefix() == *g_prefix2; + return !n.is_atomic() && n.get_prefix() == *g_prefix; } static unsigned uvar_idx(level const & l) { @@ -427,17 +475,11 @@ struct cienv { m_state.m_uassignment.insert(uvar_idx(u), v); } - // Assign \c v to the universe metavariable \c u. - void assign(level const & u, level const & v) { - lean_assert(!is_assigned(u)); - update_assignment(u, v); - } - // Create an internal metavariable. expr mk_mvar(expr const & type) { unsigned idx = m_next_mvar_idx; m_next_mvar_idx++; - return mk_metavar(name(*g_prefix2, idx), type); + return mk_metavar(name(*g_prefix, idx), type); } // Return true iff \c e is an internal metavariable created by this module. @@ -445,7 +487,7 @@ struct cienv { if (!is_metavar(e)) return false; name const & n = mlocal_name(e); - return !n.is_atomic() && n.get_prefix() == *g_prefix2; + return !n.is_atomic() && n.get_prefix() == *g_prefix; } static unsigned mvar_idx(expr const & m) { @@ -472,454 +514,6 @@ struct cienv { update_assignment(m, v); } - bool is_def_eq(level const & l1, level const & l2) { - if (is_equivalent(l1, l2)) { - return true; - } - - if (is_uvar(l1)) { - if (auto v = get_assignment(l1)) { - return is_def_eq(*v, l2); - } else { - assign(l1, l2); - return true; - } - } - if (is_uvar(l2)) { - if (auto v = get_assignment(l2)) { - return is_def_eq(l1, *v); - } else { - assign(l2, l1); - return true; - } - } - if (is_meta(l1) || is_meta(l2)) { - // The unifier may invoke this module before universe metavariables in the class - // have been instantiated. So, we just ignore and assume they will be solved by - // the unifier. - - // See comment at g_subsingleton_hack declaration. - if (g_subsingleton_hack && (is_zero(l1) || is_zero(l2))) - return false; - return true; // we ignore - } - - level new_l1 = normalize(l1); - level new_l2 = normalize(l2); - if (l1 != new_l1 || l2 != new_l2) - return is_def_eq(new_l1, new_l2); - if (l1.kind() != l2.kind()) - return false; - - switch (l1.kind()) { - case level_kind::Max: - return - is_def_eq(max_lhs(l1), max_lhs(l2)) && - is_def_eq(max_rhs(l1), max_rhs(l2)); - case level_kind::IMax: - return - is_def_eq(imax_lhs(l1), imax_lhs(l2)) && - is_def_eq(imax_rhs(l1), imax_rhs(l2)); - case level_kind::Succ: - return is_def_eq(succ_of(l1), succ_of(l2)); - case level_kind::Param: - case level_kind::Global: - return false; - case level_kind::Zero: - case level_kind::Meta: - lean_unreachable(); - } - lean_unreachable(); - } - - bool is_def_eq(levels const & ls1, levels const & ls2) { - if (is_nil(ls1) && is_nil(ls2)) { - return true; - } else if (!is_nil(ls1) && !is_nil(ls2)) { - return - is_def_eq(head(ls1), head(ls2)) && - is_def_eq(tail(ls1), tail(ls2)); - } else { - return false; - } - } - - /** \brief Given \c e of the form ?m t_1 ... t_n, where - ?m is an assigned mvar, substitute \c ?m with its assignment. */ - expr subst_mvar(expr const & e) { - buffer args; - expr const & m = get_app_args(e, args); - lean_assert(is_mvar(m)); - expr const * v = get_assignment(m); - lean_assert(v); - return apply_beta(*v, args.size(), args.data()); - } - - bool has_assigned_uvar(level const & l) const { - if (!has_meta(l)) - return false; - if (m_state.m_uassignment.empty()) - return false; - bool found = false; - for_each(l, [&](level const & l) { - if (!has_meta(l)) - return false; // stop search - if (found) - return false; // stop search - if (is_uvar(l) && is_assigned(l)) { - found = true; - return false; // stop search - } - return true; // continue search - }); - return found; - } - - bool has_assigned_uvar(levels const & ls) const { - for (level const & l : ls) { - if (has_assigned_uvar(l)) - return true; - } - return false; - } - - bool has_assigned_uvar_mvar(expr const & e) const { - if (!has_expr_metavar(e) && !has_univ_metavar(e)) - return false; - if (m_state.m_eassignment.empty() && m_state.m_uassignment.empty()) - return false; - bool found = false; - for_each(e, [&](expr const & e, unsigned) { - if (!has_expr_metavar(e) && !has_univ_metavar(e)) - return false; // stop search - if (found) - return false; // stop search - if ((is_mvar(e) && is_assigned(e)) || - (is_constant(e) && has_assigned_uvar(const_levels(e))) || - (is_sort(e) && has_assigned_uvar(sort_level(e)))) { - found = true; - return false; // stop search - } - return true; // continue search - }); - return found; - } - - level instantiate_uvars(level const & l) { - if (!has_assigned_uvar(l)) - return l; - return replace(l, [&](level const & l) { - if (!has_meta(l)) { - return some_level(l); - } else if (is_uvar(l)) { - if (auto v1 = get_assignment(l)) { - level v2 = instantiate_uvars(*v1); - if (*v1 != v2) { - update_assignment(l, v2); - return some_level(v2); - } else { - return some_level(*v1); - } - } - } - return none_level(); - }); - } - - struct instantiate_uvars_mvars_fn : public replace_visitor { - cienv & m_owner; - - level visit_level(level const & l) { - return m_owner.instantiate_uvars(l); - } - - levels visit_levels(levels const & ls) { - return map_reuse(ls, - [&](level const & l) { return visit_level(l); }, - [](level const & l1, level const & l2) { return is_eqp(l1, l2); }); - } - - virtual expr visit_sort(expr const & s) { - return update_sort(s, visit_level(sort_level(s))); - } - - virtual expr visit_constant(expr const & c) { - return update_constant(c, visit_levels(const_levels(c))); - } - - virtual expr visit_local(expr const & e) { - return update_mlocal(e, visit(mlocal_type(e))); - } - - virtual expr visit_meta(expr const & m) { - if (is_mvar(m)) { - if (auto v1 = m_owner.get_assignment(m)) { - if (!has_expr_metavar(*v1)) { - return *v1; - } else { - expr v2 = m_owner.instantiate_uvars_mvars(*v1); - if (v2 != *v1) - m_owner.update_assignment(m, v2); - return v2; - } - } else { - return m; - } - } else { - return m; - } - } - - virtual expr visit_app(expr const & e) { - buffer args; - expr const & f = get_app_rev_args(e, args); - if (is_mvar(f)) { - if (auto v = m_owner.get_assignment(f)) { - expr new_app = apply_beta(*v, args.size(), args.data()); - if (has_expr_metavar(new_app)) - return visit(new_app); - else - return new_app; - } - } - expr new_f = visit(f); - buffer new_args; - bool modified = !is_eqp(new_f, f); - for (expr const & arg : args) { - expr new_arg = visit(arg); - if (!is_eqp(arg, new_arg)) - modified = true; - new_args.push_back(new_arg); - } - if (!modified) - return e; - else - return mk_rev_app(new_f, new_args, e.get_tag()); - } - - virtual expr visit_macro(expr const & e) { - lean_assert(is_macro(e)); - buffer new_args; - for (unsigned i = 0; i < macro_num_args(e); i++) - new_args.push_back(visit(macro_arg(e, i))); - return update_macro(e, new_args.size(), new_args.data()); - } - - virtual expr visit(expr const & e) { - if (!has_expr_metavar(e) && !has_univ_metavar(e)) - return e; - else - return replace_visitor::visit(e); - } - - public: - instantiate_uvars_mvars_fn(cienv & o):m_owner(o) {} - - expr operator()(expr const & e) { return visit(e); } - }; - - expr instantiate_uvars_mvars(expr const & e) { - if (!has_assigned_uvar_mvar(e)) - return e; - else - return instantiate_uvars_mvars_fn(*this)(e); - } - - /** \brief Given \c ma of the form ?m t_1 ... t_n, (try to) assign - ?m to (an abstraction of) v. Return true if success and false otherwise. */ - bool assign_mvar(expr const & ma, expr const & v) { - buffer args; - expr const & m = get_app_args(ma, args); - buffer locals; - for (expr const & arg : args) { - if (!is_internal_local(arg)) - break; // is not local - if (std::any_of(locals.begin(), locals.end(), [&](expr const & local) { - return mlocal_name(local) == mlocal_name(arg); })) - break; // duplicate local - locals.push_back(arg); - } - lean_assert(is_mvar(m)); - expr new_v = instantiate_uvars_mvars(v); - - // We must check - // 1. Any local constant occurring in new_v occurs in locals - // 2. m does not occur in new_v - bool ok = true; - for_each(new_v, [&](expr const & e, unsigned) { - if (!ok) - return false; // stop search - if (is_internal_local(e)) { - if (std::all_of(locals.begin(), locals.end(), [&](expr const & a) { - return mlocal_name(a) != mlocal_name(e); })) { - ok = false; // failed 1 - return false; - } - } else if (is_mvar(e)) { - if (m == e) { - ok = false; // failed 2 - return false; - } - return false; - } - return true; - }); - if (!ok) - return false; - if (args.empty()) { - // easy case - assign(m, new_v); - return true; - } else if (args.size() == locals.size()) { - assign(m, Fun(locals, new_v)); - return true; - } else { - // This case is imprecise since it is not a higher order pattern. - // That the term \c ma is of the form (?m t_1 ... t_n) and the t_i's are not pairwise - // distinct local constants. - expr m_type = mlocal_type(m); - for (unsigned i = 0; i < args.size(); i++) { - m_type = whnf(m_type); - if (!is_pi(m_type)) - return false; - lean_assert(i <= locals.size()); - if (i == locals.size()) - locals.push_back(mk_local(binding_domain(m_type))); - lean_assert(i < locals.size()); - m_type = instantiate(binding_body(m_type), locals[i]); - } - lean_assert(locals.size() == args.size()); - assign(m, Fun(locals, new_v)); - return true; - } - } - - bool is_def_eq_binding(expr e1, expr e2) { - lean_assert(e1.kind() == e2.kind()); - lean_assert(is_binding(e1)); - expr_kind k = e1.kind(); - buffer subst; - do { - optional var_e1_type; - if (binding_domain(e1) != binding_domain(e2)) { - var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data()); - expr var_e2_type = instantiate_rev(binding_domain(e2), subst.size(), subst.data()); - if (!is_def_eq_core(var_e2_type, *var_e1_type)) - return false; - } - if (!closed(binding_body(e1)) || !closed(binding_body(e2))) { - // local is used inside t or s - if (!var_e1_type) - var_e1_type = instantiate_rev(binding_domain(e1), subst.size(), subst.data()); - subst.push_back(mk_local(*var_e1_type)); - } else { - expr const & dont_care = mk_Prop(); - subst.push_back(dont_care); - } - e1 = binding_body(e1); - e2 = binding_body(e2); - } while (e1.kind() == k && e2.kind() == k); - return is_def_eq_core(instantiate_rev(e1, subst.size(), subst.data()), - instantiate_rev(e2, subst.size(), subst.data())); - } - - bool is_def_eq_app(expr const & e1, expr const & e2) { - lean_assert(is_app(e1) && is_app(e2)); - buffer args1, args2; - expr const & f1 = get_app_args(e1, args1); - expr const & f2 = get_app_args(e2, args2); - if (args1.size() != args2.size() || !is_def_eq_core(f1, f2)) - return false; - for (unsigned i = 0; i < args1.size(); i++) { - if (!is_def_eq_core(args1[i], args2[i])) - return false; - } - return true; - } - - bool is_def_eq_eta(expr const & e1, expr const & e2) { - expr new_e1 = try_eta(e1); - expr new_e2 = try_eta(e2); - if (e1 != new_e1 || e2 != new_e2) - return is_def_eq_core(new_e1, new_e2); - return false; - } - - bool is_def_eq_proof_irrel(expr const & e1, expr const & e2) { - if (!m_env.prop_proof_irrel()) - return false; - expr e1_type = infer_type(e1); - expr e2_type = infer_type(e2); - return is_prop(e1_type) && is_def_eq_core(e1_type, e2_type); - } - - bool is_def_eq_core(expr const & e1, expr const & e2) { - check_system("is_def_eq"); - if (e1 == e2) - return true; - expr const & f1 = get_app_fn(e1); - if (is_mvar(f1)) { - if (is_assigned(f1)) { - return is_def_eq_core(subst_mvar(e1), e2); - } else { - return assign_mvar(e1, e2); - } - } - expr const & f2 = get_app_fn(e2); - if (is_mvar(f2)) { - if (is_assigned(f2)) { - return is_def_eq_core(e1, subst_mvar(e2)); - } else { - return assign_mvar(e2, e1); - } - } - expr e1_n = whnf(e1); - expr e2_n = whnf(e2); - if (e1 != e1_n || e2 != e2_n) - return is_def_eq_core(e1_n, e2_n); - if (e1.kind() == e2.kind()) { - switch (e1.kind()) { - case expr_kind::Lambda: - case expr_kind::Pi: - if (is_def_eq_binding(e1, e2)) - return true; - break; - case expr_kind::Sort: - if (is_def_eq(sort_level(e1), sort_level(e2))) - return true; - break; - case expr_kind::Meta: - case expr_kind::Var: - lean_unreachable(); // LCOV_EXCL_LINE - case expr_kind::Local: - case expr_kind::Macro: - break; - case expr_kind::Constant: - if (const_name(e1) == const_name(e2) && - is_def_eq(const_levels(e1), const_levels(e2))) - return true; - break; - case expr_kind::App: - if (is_def_eq_app(e1, e2)) - return true; - break; - } - } - if (is_def_eq_eta(e1, e2)) - return true; - return is_def_eq_proof_irrel(e1, e2); - } - - bool is_def_eq(expr const & e1, expr const & e2) { - state saved_state = m_state; - if (!is_def_eq_core(e1, e2)) { - m_state = saved_state; - return false; - } else { - return true; - } - } - io_state_stream diagnostic() { io_state ios(*g_ios); ios.set_options(m_options); @@ -1196,19 +790,19 @@ struct cienv { MK_THREAD_LOCAL_GET_DEF(cienv, get_cienv); -static void reset_cache_and_ctx() { - get_cienv().reset_cache_and_ctx(); +static void clear_cache_and_ctx() { + get_cienv().clear_cache_and_ctx(); } -ci_type_inference_factory_scope::ci_type_inference_factory_scope(ci_type_inference_factory & factory): - m_old(g_factory) { - g_factory = &factory; - reset_cache_and_ctx(); +ci_local_metavar_types_scope::ci_local_metavar_types_scope(ci_local_metavar_types & t): + m_old(g_lm_types) { + g_lm_types = &t; + clear_cache_and_ctx(); } -ci_type_inference_factory_scope::~ci_type_inference_factory_scope() { - reset_cache_and_ctx(); - g_factory = m_old; +ci_local_metavar_types_scope::~ci_local_metavar_types_scope() { + clear_cache_and_ctx(); + g_lm_types = m_old; } static optional mk_class_instance(environment const & env, io_state const & ios, list const & ctx, expr const & e, pos_info_provider const * pip, @@ -1347,8 +941,7 @@ optional mk_subsingleton_instance(type_checker & tc, io_state const & ios, } void initialize_class_instance_resolution() { - g_prefix1 = new name(name::mk_internal_unique_name()); - g_prefix2 = new name(name::mk_internal_unique_name()); + g_prefix = new name(name::mk_internal_unique_name()); g_class_trace_instances = new name{"class", "trace_instances"}; g_class_instance_max_depth = new name{"class", "instance_max_depth"}; g_class_trans_instances = new name{"class", "trans_instances"}; @@ -1362,13 +955,12 @@ void initialize_class_instance_resolution() { register_bool_option(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES, "(class) use automatically derived instances from the transitive closure of " "the structure instance graph"); - g_default_factory = new ci_type_inference_factory(); + g_lm_types = new default_ci_local_metavar_types(); } void finalize_class_instance_resolution() { - delete g_default_factory; - delete g_prefix1; - delete g_prefix2; + delete g_lm_types; + delete g_prefix; delete g_class_trace_instances; delete g_class_instance_max_depth; delete g_class_trans_instances; diff --git a/src/library/class_instance_resolution.h b/src/library/class_instance_resolution.h index 9691e8a9a..50ee72a35 100644 --- a/src/library/class_instance_resolution.h +++ b/src/library/class_instance_resolution.h @@ -12,24 +12,27 @@ Author: Leonardo de Moura #include "library/local_context.h" namespace lean { -class ci_type_inference { +/** Auxiliary object used to customize type class resolution. + It allows us to specify how the types of local constants and metavariables are retrieved. + + \remark We need this object because modules such as blast store + the types of some local constants (e.g., hypotheses) in a + different data-structure. + */ +class ci_local_metavar_types { public: - virtual ~ci_type_inference() {} - virtual expr whnf(expr const & e) = 0; - virtual expr infer_type(expr const & e) = 0; + virtual ~ci_local_metavar_types() {} + virtual expr infer_local(expr const & e) = 0; + virtual expr infer_metavar(expr const & e) = 0; }; -class ci_type_inference_factory { +/** \brief Auxiliary object for changing the thread local storage that stores the auxiliary object + ci_local_metavar_types used by type class resolution. */ +class ci_local_metavar_types_scope { + ci_local_metavar_types * m_old; public: - virtual ~ci_type_inference_factory(); - virtual std::shared_ptr operator()(environment const & env) const; -}; - -class ci_type_inference_factory_scope { - ci_type_inference_factory * m_old; -public: - ci_type_inference_factory_scope(ci_type_inference_factory & factory); - ~ci_type_inference_factory_scope(); + ci_local_metavar_types_scope(ci_local_metavar_types & t); + ~ci_local_metavar_types_scope(); }; optional mk_class_instance(environment const & env, io_state const & ios, list const & ctx, expr const & e, pos_info_provider const * pip = nullptr); diff --git a/src/library/init_module.cpp b/src/library/init_module.cpp index ed0abd3df..781532ca3 100644 --- a/src/library/init_module.cpp +++ b/src/library/init_module.cpp @@ -44,6 +44,7 @@ Author: Leonardo de Moura #include "library/meng_paulson.h" #include "library/norm_num.h" #include "library/class_instance_resolution.h" +#include "library/type_inference.h" namespace lean { void initialize_library_module() { @@ -87,9 +88,11 @@ void initialize_library_module() { initialize_meng_paulson(); initialize_norm_num(); initialize_class_instance_resolution(); + initialize_type_inference(); } void finalize_library_module() { + finalize_type_inference(); finalize_class_instance_resolution(); finalize_norm_num(); finalize_meng_paulson();