diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index 7522b5205..660833144 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -821,7 +821,7 @@ struct structure_cmd_fn { m_env = add_coercion(m_env, coercion_name, m_p.ios()); if (m_modifiers.is_class() && is_class(m_env, parent_name)) { // if both are classes, then we also mark coercion_name as an instance - m_env = add_instance(m_env, coercion_name); + m_env = add_trans_instance(m_env, coercion_name); } } } diff --git a/src/library/class.cpp b/src/library/class.cpp index e51056163..014caea1a 100644 --- a/src/library/class.cpp +++ b/src/library/class.cpp @@ -12,32 +12,38 @@ Author: Leonardo de Moura #include "library/kernel_serializer.h" #include "library/reducible.h" #include "library/aliases.h" +#include "library/tc_multigraph.h" +#include "library/protected.h" #ifndef LEAN_INSTANCE_DEFAULT_PRIORITY #define LEAN_INSTANCE_DEFAULT_PRIORITY 1000 #endif namespace lean { -enum class class_entry_kind { ClassCmd, InstanceCmd, MultiCmd }; +enum class class_entry_kind { Class, Multi, Instance, TransInstance, DerivedTransInstance }; struct class_entry { - class_entry_kind m_cmd_kind; + class_entry_kind m_kind; name m_class; - name m_instance; // only relevant if m_cmd_kind == InstanceCmd - unsigned m_priority; // only relevant if m_cmd_kind == InstanceCmd - class_entry():m_cmd_kind(class_entry_kind::ClassCmd), m_priority(0) {} - explicit class_entry(name const & c):m_cmd_kind(class_entry_kind::ClassCmd), m_class(c), m_priority(0) {} - class_entry(name const & c, name const & i, unsigned p): - m_cmd_kind(class_entry_kind::InstanceCmd), m_class(c), m_instance(i), m_priority(p) {} + name m_instance; // only relevant if m_kind == Instance + unsigned m_priority; // only relevant if m_kind == Instance + class_entry():m_kind(class_entry_kind::Class), m_priority(0) {} + explicit class_entry(name const & c):m_kind(class_entry_kind::Class), m_class(c), m_priority(0) {} + class_entry(class_entry_kind k, name const & c, name const & i, unsigned p): + m_kind(k), m_class(c), m_instance(i), m_priority(p) {} class_entry(name const & c, bool): - m_cmd_kind(class_entry_kind::MultiCmd), m_class(c) {} + m_kind(class_entry_kind::Multi), m_class(c) {} }; struct class_state { typedef name_map> class_instances; typedef name_map instance_priorities; - class_instances m_instances; - instance_priorities m_priorities; - name_set m_multiple; // set of classes that allow multiple solutions/instances + class_instances m_instances; + class_instances m_derived_trans_instances; + instance_priorities m_priorities; + name_set m_multiple; // set of classes that allow multiple solutions/instances + tc_multigraph m_mgraph; + + class_state():m_mgraph("transitive instance") {} unsigned get_priority(name const & i) const { if (auto it = m_priorities.find(i)) @@ -80,6 +86,22 @@ struct class_state { m_priorities.insert(i, p); } + void add_trans_instance(environment const & env, name const & c, name const & i, unsigned p) { + add_instance(c, i, p); + m_mgraph.add1(env, i); + } + + void add_derived_trans_instance(environment const & env, name const & c, name const & i) { + auto it = m_derived_trans_instances.find(c); + if (!it) { + m_derived_trans_instances.insert(c, to_list(i)); + } else { + auto lst = filter(*it, [&](name const & i1) { return i1 != i; }); + m_derived_trans_instances.insert(c, cons(i, lst)); + } + m_mgraph.add1(env, i); + } + void add_multiple(name const & c) { add_class(c); m_multiple.insert(c); @@ -92,16 +114,22 @@ static std::string * g_key = nullptr; struct class_config { typedef class_state state; typedef class_entry entry; - static void add_entry(environment const &, io_state const &, state & s, entry const & e) { - switch (e.m_cmd_kind) { - case class_entry_kind::ClassCmd: + static void add_entry(environment const & env, io_state const &, state & s, entry const & e) { + switch (e.m_kind) { + case class_entry_kind::Class: s.add_class(e.m_class); break; - case class_entry_kind::InstanceCmd: + case class_entry_kind::Multi: + s.add_multiple(e.m_class); + break; + case class_entry_kind::Instance: s.add_instance(e.m_class, e.m_instance, e.m_priority); break; - case class_entry_kind::MultiCmd: - s.add_multiple(e.m_class); + case class_entry_kind::TransInstance: + s.add_trans_instance(env, e.m_class, e.m_instance, e.m_priority); + break; + case class_entry_kind::DerivedTransInstance: + s.add_derived_trans_instance(env, e.m_class, e.m_instance); break; } } @@ -112,36 +140,50 @@ struct class_config { return *g_key; } static void write_entry(serializer & s, entry const & e) { - s << static_cast(e.m_cmd_kind); - switch (e.m_cmd_kind) { - case class_entry_kind::ClassCmd: case class_entry_kind::MultiCmd: + s << static_cast(e.m_kind); + switch (e.m_kind) { + case class_entry_kind::Class: + case class_entry_kind::Multi: s << e.m_class; break; - case class_entry_kind::InstanceCmd: + case class_entry_kind::Instance: + case class_entry_kind::TransInstance: s << e.m_class << e.m_instance << e.m_priority; break; + case class_entry_kind::DerivedTransInstance: + s << e.m_class << e.m_instance; + break; } } static entry read_entry(deserializer & d) { entry e; char k; d >> k; - e.m_cmd_kind = static_cast(k); - switch (e.m_cmd_kind) { - case class_entry_kind::ClassCmd: case class_entry_kind::MultiCmd: + e.m_kind = static_cast(k); + switch (e.m_kind) { + case class_entry_kind::Class: + case class_entry_kind::Multi: d >> e.m_class; break; - case class_entry_kind::InstanceCmd: + case class_entry_kind::Instance: + case class_entry_kind::TransInstance: d >> e.m_class >> e.m_instance >> e.m_priority; break; + case class_entry_kind::DerivedTransInstance: + d >> e.m_class >> e.m_instance; + break; } return e; } static optional get_fingerprint(entry const & e) { - switch (e.m_cmd_kind) { - case class_entry_kind::ClassCmd: case class_entry_kind::MultiCmd: + switch (e.m_kind) { + case class_entry_kind::Class: + case class_entry_kind::Multi: return some(e.m_class.hash()); - case class_entry_kind::InstanceCmd: + case class_entry_kind::Instance: + case class_entry_kind::TransInstance: return some(hash(hash(e.m_class.hash(), e.m_instance.hash()), e.m_priority)); + case class_entry_kind::DerivedTransInstance: + return some(hash(e.m_class.hash(), e.m_instance.hash())); } lean_unreachable(); } @@ -194,7 +236,7 @@ type_checker_ptr mk_class_type_checker(environment const & env, name_generator & } static name * g_tmp_prefix = nullptr; -environment add_instance(environment const & env, name const & n, unsigned priority, bool persistent) { +static environment add_instance_core(environment const & env, class_entry_kind k, name const & n, unsigned priority, bool persistent) { declaration d = env.get(n); expr type = d.get_type(); name_generator ngen(*g_tmp_prefix); @@ -207,13 +249,35 @@ environment add_instance(environment const & env, name const & n, unsigned prior } name c = get_class_name(env, get_app_fn(type)); check_is_class(env, c); - return class_ext::add_entry(env, get_dummy_ios(), class_entry(c, n, priority), persistent); + return class_ext::add_entry(env, get_dummy_ios(), class_entry(k, c, n, priority), persistent); +} + +environment add_instance(environment const & env, name const & n, unsigned priority, bool persistent) { + return add_instance_core(env, class_entry_kind::Instance, n, priority, persistent); } environment add_instance(environment const & env, name const & n, bool persistent) { return add_instance(env, n, LEAN_INSTANCE_DEFAULT_PRIORITY, persistent); } +environment add_trans_instance(environment const & env, name const & n, unsigned priority, bool persistent) { + class_state const & s = class_ext::get_state(env); + tc_multigraph g = s.m_mgraph; + pair> new_env_insts = g.add(env, n); + environment new_env = new_env_insts.first; + new_env = add_instance_core(new_env, class_entry_kind::TransInstance, n, priority, persistent); + for (name const & tn : new_env_insts.second) { + new_env = add_instance_core(new_env, class_entry_kind::DerivedTransInstance, tn, 0, persistent); + new_env = set_reducible(new_env, tn, reducible_status::Reducible, persistent); + new_env = add_protected(new_env, tn); + } + return new_env; +} + +environment add_trans_instance(environment const & env, name const & n, bool persistent) { + return add_trans_instance(env, n, LEAN_INSTANCE_DEFAULT_PRIORITY, persistent); +} + environment mark_multiple_instances(environment const & env, name const & n, bool persistent) { check_class(env, n); return class_ext::add_entry(env, get_dummy_ios(), class_entry(n, true), persistent); @@ -234,6 +298,11 @@ list get_class_instances(environment const & env, name const & c) { return ptr_to_list(s.m_instances.find(c)); } +list get_class_derived_trans_instances(environment const & env, name const & c) { + class_state const & s = class_ext::get_state(env); + return ptr_to_list(s.m_derived_trans_instances.find(c)); +} + /** \brief If the constant \c e is a class, return its name */ static optional constant_is_ext_class(environment const & env, expr const & e) { name const & cls_name = const_name(e); diff --git a/src/library/class.h b/src/library/class.h index fb67ae289..a66804e26 100644 --- a/src/library/class.h +++ b/src/library/class.h @@ -15,12 +15,18 @@ environment add_class(environment const & env, name const & n, bool persistent = environment add_instance(environment const & env, name const & n, bool persistent = true); /** \brief Add a new 'class instance' to the environment. */ environment add_instance(environment const & env, name const & n, unsigned priority, bool persistent); +/** \brief Add a new 'class transitive instance' to the environment with default priority. */ +environment add_trans_instance(environment const & env, name const & n, bool persistent = true); +/** \brief Add a new 'class transitive instance' to the environment. */ +environment add_trans_instance(environment const & env, name const & n, unsigned priority, bool persistent); /** \brief Return true iff \c c was declared with \c add_class. */ bool is_class(environment const & env, name const & c); /** \brief Return true iff \c i was declared with \c add_instance. */ bool is_instance(environment const & env, name const & i); /** \brief Return the instances of the given class. */ list get_class_instances(environment const & env, name const & c); +/** \brief Return instances from the transitive closure graph of instances added using add_trans_instance */ +list get_class_derived_trans_instances(environment const & env, name const & c); /** \brief Return the classes in the given environment. */ void get_classes(environment const & env, buffer & classes); name get_class_name(environment const & env, expr const & e); diff --git a/src/library/class_instance_synth.cpp b/src/library/class_instance_synth.cpp index f89bc7472..03475b754 100644 --- a/src/library/class_instance_synth.cpp +++ b/src/library/class_instance_synth.cpp @@ -40,11 +40,16 @@ Author: Leonardo de Moura #define LEAN_DEFAULT_CLASS_CONSERVATIVE true #endif +#ifndef LEAN_DEFAULT_CLASS_TRANS_INSTANCES +#define LEAN_DEFAULT_CLASS_TRANS_INSTANCES true +#endif + namespace lean { static name * g_class_unique_class_instances = nullptr; static name * g_class_trace_instances = nullptr; static name * g_class_instance_max_depth = nullptr; static name * g_class_conservative = nullptr; +static name * g_class_trans_instances = nullptr; [[ noreturn ]] void throw_class_exception(char const * msg, expr const & m) { throw_generic_exception(msg, m); } [[ noreturn ]] void throw_class_exception(expr const & m, pp_fn const & fn) { throw_generic_exception(m, fn); } @@ -54,6 +59,7 @@ void initialize_class_instance_elaborator() { g_class_trace_instances = new name{"class", "trace_instances"}; g_class_instance_max_depth = new name{"class", "instance_max_depth"}; g_class_conservative = new name{"class", "conservative"}; + g_class_trans_instances = new name{"class", "trans_instances"}; register_bool_option(*g_class_unique_class_instances, LEAN_DEFAULT_CLASS_UNIQUE_CLASS_INSTANCES, "(class) generate an error if there is more than one solution " @@ -67,6 +73,9 @@ void initialize_class_instance_elaborator() { register_bool_option(*g_class_conservative, LEAN_DEFAULT_CLASS_CONSERVATIVE, "(class) use conservative unification (only unfold reducible definitions, and avoid delta-delta case splits)"); + + register_bool_option(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES, + "(class) use automatically derived instances from the transitive closure of the structure instance graph"); } void finalize_class_instance_elaborator() { @@ -74,6 +83,7 @@ void finalize_class_instance_elaborator() { delete g_class_trace_instances; delete g_class_instance_max_depth; delete g_class_conservative; + delete g_class_trans_instances; } bool get_class_unique_class_instances(options const & o) { @@ -92,6 +102,10 @@ bool get_class_conservative(options const & o) { return o.get_bool(*g_class_conservative, LEAN_DEFAULT_CLASS_CONSERVATIVE); } +bool get_class_trans_instances(options const & o) { + return o.get_bool(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES); +} + /** \brief Context for handling class-instance metavariable choice constraint */ struct class_instance_context { io_state m_ios; @@ -102,6 +116,7 @@ struct class_instance_context { bool m_trace_instances; bool m_conservative; unsigned m_max_depth; + bool m_trans_instances; char const * m_fname; optional m_pos; class_instance_context(environment const & env, io_state const & ios, @@ -113,6 +128,7 @@ struct class_instance_context { m_trace_instances = get_class_trace_instances(ios.get_options()); m_max_depth = get_class_instance_max_depth(ios.get_options()); m_conservative = get_class_conservative(ios.get_options()); + m_trans_instances = get_class_trans_instances(ios.get_options()); m_tc = mk_class_type_checker(env, m_ngen.mk_child(), m_conservative); options opts = m_ios.get_options(); opts = opts.update_if_undef(get_pp_purify_metavars_name(), false); @@ -134,10 +150,11 @@ struct class_instance_context { optional const & get_pos() const { return m_pos; } char const * get_file_name() const { return m_fname; } unsigned get_max_depth() const { return m_max_depth; } + bool use_trans_instances() const { return m_trans_instances; } }; -pair mk_class_instance_elaborator(std::shared_ptr const & C, local_context const & ctx, - optional const & type, tag g, unsigned depth); +static pair mk_class_instance_elaborator(std::shared_ptr const & C, local_context const & ctx, + optional const & type, tag g, unsigned depth, bool use_globals); /** \brief Choice function \c fn for synthesizing class instances. @@ -160,6 +177,7 @@ struct class_instance_elaborator : public choice_iterator { list m_local_instances; // global declaration names that are class instances. // This information is retrieved using #get_class_instances. + list m_trans_instances; list m_instances; justification m_jst; unsigned m_depth; @@ -167,10 +185,10 @@ struct class_instance_elaborator : public choice_iterator { class_instance_elaborator(std::shared_ptr const & C, local_context const & ctx, expr const & meta, expr const & meta_type, - list const & local_insts, list const & instances, + list const & local_insts, list const & trans_insts, list const & instances, justification const & j, unsigned depth): choice_iterator(), m_C(C), m_ctx(ctx), m_meta(meta), m_meta_type(meta_type), - m_local_instances(local_insts), m_instances(instances), m_jst(j), m_depth(depth) { + m_local_instances(local_insts), m_trans_instances(trans_insts), m_instances(instances), m_jst(j), m_depth(depth) { if (m_depth > m_C->get_max_depth()) { throw_class_exception("maximum class-instance resolution depth has been reached " "(the limit can be increased by setting option 'class.instance_max_depth') " @@ -205,7 +223,7 @@ struct class_instance_elaborator : public choice_iterator { out << m_meta << " : " << t << " := " << r << endl; } - optional try_instance(expr const & inst, expr const & inst_type) { + optional try_instance(expr const & inst, expr const & inst_type, bool use_globals) { type_checker & tc = m_C->tc(); name_generator & ngen = m_C->m_ngen; tag g = inst.get_tag(); @@ -233,7 +251,7 @@ struct class_instance_elaborator : public choice_iterator { expr arg; if (binding_info(type).is_inst_implicit()) { pair ac = mk_class_instance_elaborator(m_C, m_ctx, some_expr(binding_domain(type)), - g, m_depth+1); + g, m_depth+1, use_globals); arg = ac.first; cs.push_back(ac.second); } else { @@ -251,7 +269,7 @@ struct class_instance_elaborator : public choice_iterator { } } - optional try_instance(name const & inst) { + optional try_instance(name const & inst, bool use_globals) { environment const & env = m_C->env(); if (auto decl = env.find(inst)) { name_generator & ngen = m_C->m_ngen; @@ -262,7 +280,7 @@ struct class_instance_elaborator : public choice_iterator { levels ls = to_list(ls_buffer.begin(), ls_buffer.end()); expr inst_cnst = copy_tag(m_meta, mk_constant(inst, ls)); expr inst_type = instantiate_type_univ_params(*decl, ls); - return try_instance(inst_cnst, inst_type); + return try_instance(inst_cnst, inst_type, use_globals); } else { return optional(); } @@ -274,20 +292,31 @@ struct class_instance_elaborator : public choice_iterator { m_local_instances = tail(m_local_instances); if (!is_local(inst)) continue; - if (auto r = try_instance(inst, mlocal_type(inst))) + bool use_globals = true; + if (auto r = try_instance(inst, mlocal_type(inst), use_globals)) return r; } + while (!empty(m_trans_instances)) { + bool use_globals = false; + name inst = head(m_trans_instances); + m_trans_instances = tail(m_trans_instances); + if (auto cs = try_instance(inst, use_globals)) + return cs; + } while (!empty(m_instances)) { + bool use_globals = true; name inst = head(m_instances); m_instances = tail(m_instances); - if (auto cs = try_instance(inst)) + if (auto cs = try_instance(inst, use_globals)) return cs; } return optional(); } }; -constraint mk_class_instance_cnstr(std::shared_ptr const & C, local_context const & ctx, expr const & m, unsigned depth) { +// Remarks: +// - we only use get_class_instances and get_class_derived_trans_instances when use_globals is true +static constraint mk_class_instance_cnstr(std::shared_ptr const & C, local_context const & ctx, expr const & m, unsigned depth, bool use_globals) { environment const & env = C->env(); justification j = mk_failed_to_synthesize_jst(env, m); auto choice_fn = [=](expr const & meta, expr const & meta_type, substitution const &, name_generator const &) { @@ -297,11 +326,16 @@ constraint mk_class_instance_cnstr(std::shared_ptr const list local_insts; if (C->use_local_instances()) local_insts = get_local_instances(C->tc(), ctx_lst, cls_name); - list insts = get_class_instances(env, cls_name); + list trans_insts, insts; + if (use_globals) { + if (depth == 0 && C->use_trans_instances()) + trans_insts = get_class_derived_trans_instances(env, cls_name); + insts = get_class_instances(env, cls_name); + } if (empty(local_insts) && empty(insts)) return lazy_list(); // nothing to be done // we are always strict with placeholders associated with classes - return choose(std::make_shared(C, ctx, meta, meta_type, local_insts, insts, j, depth)); + return choose(std::make_shared(C, ctx, meta, meta_type, local_insts, trans_insts, insts, j, depth)); } else { // do nothing, type is not a class... return lazy_list(constraints()); @@ -311,15 +345,15 @@ constraint mk_class_instance_cnstr(std::shared_ptr const return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::Basic), owner, j); } -pair mk_class_instance_elaborator(std::shared_ptr const & C, local_context const & ctx, - optional const & type, tag g, unsigned depth) { +static pair mk_class_instance_elaborator(std::shared_ptr const & C, local_context const & ctx, + optional const & type, tag g, unsigned depth, bool use_globals) { expr m = ctx.mk_meta(C->m_ngen, type, g); - constraint c = mk_class_instance_cnstr(C, ctx, m, depth); + constraint c = mk_class_instance_cnstr(C, ctx, m, depth, use_globals); return mk_pair(m, c); } -constraint mk_class_instance_root_cnstr(std::shared_ptr const & C, local_context const & _ctx, - expr const & m, bool is_strict, unifier_config const & cfg, delay_factor const & factor) { +static constraint mk_class_instance_root_cnstr(std::shared_ptr const & C, local_context const & _ctx, + expr const & m, bool is_strict, unifier_config const & cfg, delay_factor const & factor) { environment const & env = C->env(); justification j = mk_failed_to_synthesize_jst(env, m); @@ -336,7 +370,8 @@ constraint mk_class_instance_root_cnstr(std::shared_ptr expr new_meta = mj.first; justification new_j = mj.second; unsigned depth = 0; - constraint c = mk_class_instance_cnstr(C, ctx, new_meta, depth); + bool use_globals = true; + constraint c = mk_class_instance_cnstr(C, ctx, new_meta, depth, use_globals); unifier_config new_cfg(cfg); new_cfg.m_discard = false; new_cfg.m_use_exceptions = false; @@ -446,9 +481,10 @@ optional mk_class_instance(environment const & env, io_state const & ios, auto C = std::make_shared(env, ios, prefix, use_local_instances); if (!is_ext_class(C->tc(), type)) return none_expr(); - expr meta = ctx.mk_meta(C->m_ngen, some_expr(type), type.get_tag()); - unsigned depth = 0; - constraint c = mk_class_instance_cnstr(C, ctx, meta, depth); + expr meta = ctx.mk_meta(C->m_ngen, some_expr(type), type.get_tag()); + unsigned depth = 0; + bool use_globals = true; + constraint c = mk_class_instance_cnstr(C, ctx, meta, depth, use_globals); unifier_config new_cfg(cfg); new_cfg.m_discard = true; new_cfg.m_use_exceptions = true;