From f5819fab6029fdbeb5606237f249434f7ffd758b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 17 Oct 2015 11:24:19 -0700 Subject: [PATCH] feat(library/class_instance_resolution): new type class resolution procedure --- src/library/class_instance_resolution.cpp | 282 ++++++++++++++++++---- src/library/class_instance_resolution.h | 4 +- 2 files changed, 242 insertions(+), 44 deletions(-) diff --git a/src/library/class_instance_resolution.cpp b/src/library/class_instance_resolution.cpp index 8a3988603..39d1f7636 100644 --- a/src/library/class_instance_resolution.cpp +++ b/src/library/class_instance_resolution.cpp @@ -13,6 +13,7 @@ Author: Leonardo de Moura #include "library/normalize.h" #include "library/reducible.h" #include "library/class.h" +#include "library/io_state_stream.h" #include "library/replace_visitor.h" #include "library/class_instance_resolution.h" @@ -92,17 +93,19 @@ struct cienv { typedef rb_map uassignment; typedef rb_map eassignment; - environment m_env; - ci_type_inference_ptr m_tc_ptr; - expr_struct_map m_cache; - name_generator m_ngen; - name_predicate m_not_reducible_pred; + environment m_env; + pos_info_provider const * m_pip; + optional m_pos; + ci_type_inference_ptr m_tc_ptr; + expr_struct_map m_cache; + name_generator m_ngen; + name_predicate m_not_reducible_pred; - list m_ctx; - buffer m_local_instances; + list m_ctx; + buffer> m_local_instances; - unsigned m_next_uvar; - unsigned m_next_mvar; + unsigned m_next_uvar; + unsigned m_next_mvar; struct state { list m_stack; // stack of meta-variables that need to be synthesized; @@ -110,24 +113,27 @@ struct cienv { eassignment m_eassignment; }; - state m_state; // active state + state m_state; // active state struct choice { - list m_local_instances; - list m_trans_instances; - list m_instances; - state m_state; + list m_local_instances; + list m_trans_instances; + list m_instances; + state m_state; }; - list m_choices; + std::vector m_choices; + expr m_main_mvar; - bool m_multiple_instances; + bool m_multiple_instances; + + bool m_displayed_trace_header; // configuration - bool m_unique_instances; - unsigned m_max_depth; - bool m_trans_instances; - bool m_trace_instances; + bool m_unique_instances; + unsigned m_max_depth; + bool m_trans_instances; + bool m_trace_instances; cienv(bool multiple_instances = false): m_ngen(*g_prefix2), @@ -304,14 +310,14 @@ struct cienv { // Auxiliary method for set_ctx - void set_local_instance(unsigned i, expr const & e) { + void set_local_instance(unsigned i, name const & cname, expr const & e) { lean_assert(i <= m_local_instances.size()); if (i == m_local_instances.size()) { reset_cache(); - m_local_instances.push_back(e); - } else if (e != m_local_instances[i]) { + m_local_instances.push_back(mk_pair(cname, e)); + } else if (e != m_local_instances[i].second) { reset_cache(); - m_local_instances[i] = e; + m_local_instances[i] = mk_pair(cname, e); } else { // we don't need to reset the cache since this local instance // is equal to the one used in a previous call @@ -330,13 +336,19 @@ struct cienv { // Remark: we use infer_type(e) instead of mlocal_type because we want to allow // customers (e.g., blast) of this class to store the type of local constants // in a different place. - if (is_class(infer_type(e))) { - set_local_instance(i, e); + if (auto cname = is_class(infer_type(e))) { + set_local_instance(i, *cname, e); i++; } } } + void set_pos_info(pos_info_provider const * pip, expr const & type) { + m_pip = pip; + if (m_pip) + m_pos = m_pip->get_pos_info(type); + } + // Create an internal universal metavariable level mk_uvar() { unsigned idx = m_next_uvar; @@ -823,27 +835,207 @@ struct cienv { } } - expr init_search(expr const & type) { - m_state = state(); - expr m = mk_mvar(type); - m_state.m_stack = cons(m, m_state.m_stack); - return m; + void trace(expr const & mvar, expr const & r) { + if (!m_trace_instances) + return; + auto out = diagnostic(m_env, *g_ios); + unsigned depth = m_choices.size(); + if (!m_displayed_trace_header && depth == 1) { + if (m_pip) { + if (auto fname = m_pip->get_file_name()) { + out << fname << ":"; + } + if (m_pos) + out << m_pos->first << ":" << m_pos->second << ":"; + } + out << " class-instance resolution trace" << endl; + m_displayed_trace_header = true; + } + for (unsigned i = 0; i < depth; i++) + out << " "; + if (depth > 1) + out << "[" << depth << "] "; + out << mvar << " : " << mlocal_type(mvar) << " := " << r << endl; + } + + bool try_instance(expr const & mvar, expr const & inst, expr const & inst_type) { + try { + buffer locals; + expr mvar_type = mlocal_type(mvar); + while (true) { + mvar_type = whnf(mvar_type); + if (!is_pi(mvar_type)) + break; + expr local = mk_local(binding_domain(mvar_type)); + locals.push_back(local); + mvar_type = instantiate(binding_body(mvar_type), local); + } + expr type = inst_type; + expr r = inst; + buffer new_inst_mvars; + while (true) { + type = whnf(type); + if (!is_pi(type)) + break; + expr new_mvar = mk_mvar(binding_domain(type)); + if (binding_info(type).is_inst_implicit()) { + new_inst_mvars.push_back(new_mvar); + } + r = mk_app(r, new_mvar); + type = instantiate(binding_body(type), new_mvar); + } + trace(mvar, r); + if (!is_def_eq(mvar_type, type)) + return false; + r = Fun(locals, r); + assign(mvar, r); + // copy new_inst_mvars to stack + unsigned i = new_inst_mvars.size(); + while (i > 0) { + --i; + m_state.m_stack = cons(new_inst_mvars[i], m_state.m_stack); + } + return true; + } catch (exception &) { + return false; + } + } + + bool try_instance(expr const & mvar, name const & inst_name) { + if (auto decl = m_env.find(inst_name)) { + buffer ls_buffer; + unsigned num_univ_ps = decl->get_num_univ_params(); + for (unsigned i = 0; i < num_univ_ps; i++) + ls_buffer.push_back(mk_uvar()); + levels ls = to_list(ls_buffer.begin(), ls_buffer.end()); + expr inst_cnst = mk_constant(inst_name, ls); + expr inst_type = instantiate_type_univ_params(*decl, ls); + return try_instance(mvar, inst_cnst, inst_type); + } else { + return false; + } + } + + list get_local_instances(name const & cname) { + buffer selected; + for (pair const & p : m_local_instances) { + if (p.first == cname) + selected.push_back(p.second); + } + return to_list(selected); + } + + bool is_done() const { + return empty(m_state.m_stack); + } + + expr const & next_mvar() const { + lean_assert(!is_done()); + return head(m_state.m_stack); + } + + bool mk_choice_point(expr const & mvar) { + lean_assert(is_mvar(mvar)); + expr mvar_type = instantiate_uvars_mvars(mlocal_type(mvar)); + if (has_expr_metavar(mvar_type)) + return false; + auto cname = is_class(mvar_type); + if (!cname) + return false; + choice r; + r.m_local_instances = get_local_instances(*cname); + if (m_trans_instances && m_choices.empty()) { + // we only use transitive instances in the top-level + r.m_trans_instances = get_class_derived_trans_instances(m_env, *cname); + } + r.m_instances = get_class_instances(m_env, *cname); + if (empty(r.m_local_instances) && empty(r.m_trans_instances) && empty(r.m_instances)) + return false; + r.m_state = m_state; + m_choices.push_back(r); + return true; + } + + bool process_next_alt_core(list & insts) { + while (!empty(insts)) { + expr inst = head(insts); + insts = tail(insts); + expr inst_type = infer_type(inst); + if (try_instance(next_mvar(), inst, inst_type)) + return true; + } + return false; + } + + bool process_next_alt_core(list & inst_names) { + while (!empty(inst_names)) { + name inst_name = head(inst_names); + inst_names = tail(inst_names); + if (try_instance(next_mvar(), inst_name)) + return true; + } + return false; + } + + bool process_next_alt() { + lean_assert(!m_choices.empty()); + choice & c = m_choices.back(); + if (process_next_alt_core(c.m_local_instances)) + return true; + if (process_next_alt_core(c.m_trans_instances)) + return true; + if (process_next_alt_core(c.m_instances)) + return true; + return false; + } + + bool process_next_mvar() { + lean_assert(!is_done()); + expr mvar = next_mvar(); + if (!mk_choice_point(mvar)) + return false; + return process_next_alt(); + } + + bool backtrack() { + lean_assert(!m_choices.empty()); + while (true) { + m_choices.pop_back(); + if (m_choices.empty()) + return false; + m_state = m_choices.back().m_state; + if (process_next_alt()) + return true; + } } optional search() { - // TODO(Leo): - return none_expr(); + while (!is_done()) { + if (!process_next_mvar()) { + if (!backtrack()) + return none_expr(); + } + } + return some_expr(instantiate_uvars_mvars(m_main_mvar)); } - optional operator()(environment const & env, options const & o, list const & ctx, expr const & type) { + void init_search(expr const & type) { + m_state = state(); + m_main_mvar = mk_mvar(type); + m_state.m_stack = cons(m_main_mvar, m_state.m_stack); + m_displayed_trace_header = false; + } + + optional operator()(environment const & env, options const & o, pos_info_provider const * pip, list const & ctx, expr const & type) { set_env(env); set_options(o); set_ctx(ctx); + set_pos_info(pip, type); if (auto r = check_cache(type)) return r; - expr m = init_search(type); + init_search(type); if (auto r = search()) { cache_result(type, *r); @@ -856,9 +1048,15 @@ struct cienv { optional next() { if (!m_multiple_instances) return none_expr(); - - // TODO(Leo): backtrack and search - return none_expr(); + if (m_choices.empty()) + return none_expr(); + m_state = m_choices.back().m_state; + if (process_next_alt()) + return search(); + else if (backtrack()) + return search(); + else + return none_expr(); } }; @@ -879,13 +1077,13 @@ ci_type_inference_factory_scope::~ci_type_inference_factory_scope() { g_factory = m_old; } -optional mk_class_instance(environment const & env, io_state const & ios, list const & ctx, expr const & e) { +optional mk_class_instance(environment const & env, io_state const & ios, list const & ctx, expr const & e, pos_info_provider * pip) { flet set_ios(g_ios, const_cast(&ios)); - return get_cienv()(env, ios.get_options(), ctx, e); + return get_cienv()(env, ios.get_options(), pip, ctx, e); } -optional mk_class_instance(environment const & env, list const & ctx, expr const & e) { - return mk_class_instance(env, get_dummy_ios(), ctx, e); +optional mk_class_instance(environment const & env, list const & ctx, expr const & e, pos_info_provider * pip) { + return mk_class_instance(env, get_dummy_ios(), ctx, e, pip); } void initialize_class_instance_resolution() { diff --git a/src/library/class_instance_resolution.h b/src/library/class_instance_resolution.h index 5b25120a7..a6b1c75bf 100644 --- a/src/library/class_instance_resolution.h +++ b/src/library/class_instance_resolution.h @@ -30,8 +30,8 @@ public: ~ci_type_inference_factory_scope(); }; -optional mk_class_instance(environment const & env, io_state const & ios, list const & ctx, expr const & e); -optional mk_class_instance(environment const & env, list const & ctx, expr const & e); +optional mk_class_instance(environment const & env, io_state const & ios, list const & ctx, expr const & e, pos_info_provider const * pip = nullptr); +optional mk_class_instance(environment const & env, list const & ctx, expr const & e, pos_info_provider const * pip = nullptr); void initialize_class_instance_resolution(); void finalize_class_instance_resolution(); }