feat(library/class): transitive instances

see issue #666
This commit is contained in:
Leonardo de Moura 2015-06-21 16:02:24 -07:00
parent 54128eb45f
commit 859ef441a0
4 changed files with 165 additions and 54 deletions

View file

@ -821,7 +821,7 @@ struct structure_cmd_fn {
m_env = add_coercion(m_env, coercion_name, m_p.ios());
if (m_modifiers.is_class() && is_class(m_env, parent_name)) {
// if both are classes, then we also mark coercion_name as an instance
m_env = add_instance(m_env, coercion_name);
m_env = add_trans_instance(m_env, coercion_name);
}
}
}

View file

@ -12,32 +12,38 @@ Author: Leonardo de Moura
#include "library/kernel_serializer.h"
#include "library/reducible.h"
#include "library/aliases.h"
#include "library/tc_multigraph.h"
#include "library/protected.h"
#ifndef LEAN_INSTANCE_DEFAULT_PRIORITY
#define LEAN_INSTANCE_DEFAULT_PRIORITY 1000
#endif
namespace lean {
enum class class_entry_kind { ClassCmd, InstanceCmd, MultiCmd };
enum class class_entry_kind { Class, Multi, Instance, TransInstance, DerivedTransInstance };
struct class_entry {
class_entry_kind m_cmd_kind;
class_entry_kind m_kind;
name m_class;
name m_instance; // only relevant if m_cmd_kind == InstanceCmd
unsigned m_priority; // only relevant if m_cmd_kind == InstanceCmd
class_entry():m_cmd_kind(class_entry_kind::ClassCmd), m_priority(0) {}
explicit class_entry(name const & c):m_cmd_kind(class_entry_kind::ClassCmd), m_class(c), m_priority(0) {}
class_entry(name const & c, name const & i, unsigned p):
m_cmd_kind(class_entry_kind::InstanceCmd), m_class(c), m_instance(i), m_priority(p) {}
name m_instance; // only relevant if m_kind == Instance
unsigned m_priority; // only relevant if m_kind == Instance
class_entry():m_kind(class_entry_kind::Class), m_priority(0) {}
explicit class_entry(name const & c):m_kind(class_entry_kind::Class), m_class(c), m_priority(0) {}
class_entry(class_entry_kind k, name const & c, name const & i, unsigned p):
m_kind(k), m_class(c), m_instance(i), m_priority(p) {}
class_entry(name const & c, bool):
m_cmd_kind(class_entry_kind::MultiCmd), m_class(c) {}
m_kind(class_entry_kind::Multi), m_class(c) {}
};
struct class_state {
typedef name_map<list<name>> class_instances;
typedef name_map<unsigned> instance_priorities;
class_instances m_instances;
instance_priorities m_priorities;
name_set m_multiple; // set of classes that allow multiple solutions/instances
class_instances m_instances;
class_instances m_derived_trans_instances;
instance_priorities m_priorities;
name_set m_multiple; // set of classes that allow multiple solutions/instances
tc_multigraph m_mgraph;
class_state():m_mgraph("transitive instance") {}
unsigned get_priority(name const & i) const {
if (auto it = m_priorities.find(i))
@ -80,6 +86,22 @@ struct class_state {
m_priorities.insert(i, p);
}
void add_trans_instance(environment const & env, name const & c, name const & i, unsigned p) {
add_instance(c, i, p);
m_mgraph.add1(env, i);
}
void add_derived_trans_instance(environment const & env, name const & c, name const & i) {
auto it = m_derived_trans_instances.find(c);
if (!it) {
m_derived_trans_instances.insert(c, to_list(i));
} else {
auto lst = filter(*it, [&](name const & i1) { return i1 != i; });
m_derived_trans_instances.insert(c, cons(i, lst));
}
m_mgraph.add1(env, i);
}
void add_multiple(name const & c) {
add_class(c);
m_multiple.insert(c);
@ -92,16 +114,22 @@ static std::string * g_key = nullptr;
struct class_config {
typedef class_state state;
typedef class_entry entry;
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
switch (e.m_cmd_kind) {
case class_entry_kind::ClassCmd:
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
switch (e.m_kind) {
case class_entry_kind::Class:
s.add_class(e.m_class);
break;
case class_entry_kind::InstanceCmd:
case class_entry_kind::Multi:
s.add_multiple(e.m_class);
break;
case class_entry_kind::Instance:
s.add_instance(e.m_class, e.m_instance, e.m_priority);
break;
case class_entry_kind::MultiCmd:
s.add_multiple(e.m_class);
case class_entry_kind::TransInstance:
s.add_trans_instance(env, e.m_class, e.m_instance, e.m_priority);
break;
case class_entry_kind::DerivedTransInstance:
s.add_derived_trans_instance(env, e.m_class, e.m_instance);
break;
}
}
@ -112,36 +140,50 @@ struct class_config {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << static_cast<char>(e.m_cmd_kind);
switch (e.m_cmd_kind) {
case class_entry_kind::ClassCmd: case class_entry_kind::MultiCmd:
s << static_cast<char>(e.m_kind);
switch (e.m_kind) {
case class_entry_kind::Class:
case class_entry_kind::Multi:
s << e.m_class;
break;
case class_entry_kind::InstanceCmd:
case class_entry_kind::Instance:
case class_entry_kind::TransInstance:
s << e.m_class << e.m_instance << e.m_priority;
break;
case class_entry_kind::DerivedTransInstance:
s << e.m_class << e.m_instance;
break;
}
}
static entry read_entry(deserializer & d) {
entry e; char k;
d >> k;
e.m_cmd_kind = static_cast<class_entry_kind>(k);
switch (e.m_cmd_kind) {
case class_entry_kind::ClassCmd: case class_entry_kind::MultiCmd:
e.m_kind = static_cast<class_entry_kind>(k);
switch (e.m_kind) {
case class_entry_kind::Class:
case class_entry_kind::Multi:
d >> e.m_class;
break;
case class_entry_kind::InstanceCmd:
case class_entry_kind::Instance:
case class_entry_kind::TransInstance:
d >> e.m_class >> e.m_instance >> e.m_priority;
break;
case class_entry_kind::DerivedTransInstance:
d >> e.m_class >> e.m_instance;
break;
}
return e;
}
static optional<unsigned> get_fingerprint(entry const & e) {
switch (e.m_cmd_kind) {
case class_entry_kind::ClassCmd: case class_entry_kind::MultiCmd:
switch (e.m_kind) {
case class_entry_kind::Class:
case class_entry_kind::Multi:
return some(e.m_class.hash());
case class_entry_kind::InstanceCmd:
case class_entry_kind::Instance:
case class_entry_kind::TransInstance:
return some(hash(hash(e.m_class.hash(), e.m_instance.hash()), e.m_priority));
case class_entry_kind::DerivedTransInstance:
return some(hash(e.m_class.hash(), e.m_instance.hash()));
}
lean_unreachable();
}
@ -194,7 +236,7 @@ type_checker_ptr mk_class_type_checker(environment const & env, name_generator &
}
static name * g_tmp_prefix = nullptr;
environment add_instance(environment const & env, name const & n, unsigned priority, bool persistent) {
static environment add_instance_core(environment const & env, class_entry_kind k, name const & n, unsigned priority, bool persistent) {
declaration d = env.get(n);
expr type = d.get_type();
name_generator ngen(*g_tmp_prefix);
@ -207,13 +249,35 @@ environment add_instance(environment const & env, name const & n, unsigned prior
}
name c = get_class_name(env, get_app_fn(type));
check_is_class(env, c);
return class_ext::add_entry(env, get_dummy_ios(), class_entry(c, n, priority), persistent);
return class_ext::add_entry(env, get_dummy_ios(), class_entry(k, c, n, priority), persistent);
}
environment add_instance(environment const & env, name const & n, unsigned priority, bool persistent) {
return add_instance_core(env, class_entry_kind::Instance, n, priority, persistent);
}
environment add_instance(environment const & env, name const & n, bool persistent) {
return add_instance(env, n, LEAN_INSTANCE_DEFAULT_PRIORITY, persistent);
}
environment add_trans_instance(environment const & env, name const & n, unsigned priority, bool persistent) {
class_state const & s = class_ext::get_state(env);
tc_multigraph g = s.m_mgraph;
pair<environment, list<name>> new_env_insts = g.add(env, n);
environment new_env = new_env_insts.first;
new_env = add_instance_core(new_env, class_entry_kind::TransInstance, n, priority, persistent);
for (name const & tn : new_env_insts.second) {
new_env = add_instance_core(new_env, class_entry_kind::DerivedTransInstance, tn, 0, persistent);
new_env = set_reducible(new_env, tn, reducible_status::Reducible, persistent);
new_env = add_protected(new_env, tn);
}
return new_env;
}
environment add_trans_instance(environment const & env, name const & n, bool persistent) {
return add_trans_instance(env, n, LEAN_INSTANCE_DEFAULT_PRIORITY, persistent);
}
environment mark_multiple_instances(environment const & env, name const & n, bool persistent) {
check_class(env, n);
return class_ext::add_entry(env, get_dummy_ios(), class_entry(n, true), persistent);
@ -234,6 +298,11 @@ list<name> get_class_instances(environment const & env, name const & c) {
return ptr_to_list(s.m_instances.find(c));
}
list<name> get_class_derived_trans_instances(environment const & env, name const & c) {
class_state const & s = class_ext::get_state(env);
return ptr_to_list(s.m_derived_trans_instances.find(c));
}
/** \brief If the constant \c e is a class, return its name */
static optional<name> constant_is_ext_class(environment const & env, expr const & e) {
name const & cls_name = const_name(e);

View file

@ -15,12 +15,18 @@ environment add_class(environment const & env, name const & n, bool persistent =
environment add_instance(environment const & env, name const & n, bool persistent = true);
/** \brief Add a new 'class instance' to the environment. */
environment add_instance(environment const & env, name const & n, unsigned priority, bool persistent);
/** \brief Add a new 'class transitive instance' to the environment with default priority. */
environment add_trans_instance(environment const & env, name const & n, bool persistent = true);
/** \brief Add a new 'class transitive instance' to the environment. */
environment add_trans_instance(environment const & env, name const & n, unsigned priority, bool persistent);
/** \brief Return true iff \c c was declared with \c add_class. */
bool is_class(environment const & env, name const & c);
/** \brief Return true iff \c i was declared with \c add_instance. */
bool is_instance(environment const & env, name const & i);
/** \brief Return the instances of the given class. */
list<name> get_class_instances(environment const & env, name const & c);
/** \brief Return instances from the transitive closure graph of instances added using add_trans_instance */
list<name> get_class_derived_trans_instances(environment const & env, name const & c);
/** \brief Return the classes in the given environment. */
void get_classes(environment const & env, buffer<name> & classes);
name get_class_name(environment const & env, expr const & e);

View file

@ -40,11 +40,16 @@ Author: Leonardo de Moura
#define LEAN_DEFAULT_CLASS_CONSERVATIVE true
#endif
#ifndef LEAN_DEFAULT_CLASS_TRANS_INSTANCES
#define LEAN_DEFAULT_CLASS_TRANS_INSTANCES true
#endif
namespace lean {
static name * g_class_unique_class_instances = nullptr;
static name * g_class_trace_instances = nullptr;
static name * g_class_instance_max_depth = nullptr;
static name * g_class_conservative = nullptr;
static name * g_class_trans_instances = nullptr;
[[ noreturn ]] void throw_class_exception(char const * msg, expr const & m) { throw_generic_exception(msg, m); }
[[ noreturn ]] void throw_class_exception(expr const & m, pp_fn const & fn) { throw_generic_exception(m, fn); }
@ -54,6 +59,7 @@ void initialize_class_instance_elaborator() {
g_class_trace_instances = new name{"class", "trace_instances"};
g_class_instance_max_depth = new name{"class", "instance_max_depth"};
g_class_conservative = new name{"class", "conservative"};
g_class_trans_instances = new name{"class", "trans_instances"};
register_bool_option(*g_class_unique_class_instances, LEAN_DEFAULT_CLASS_UNIQUE_CLASS_INSTANCES,
"(class) generate an error if there is more than one solution "
@ -67,6 +73,9 @@ void initialize_class_instance_elaborator() {
register_bool_option(*g_class_conservative, LEAN_DEFAULT_CLASS_CONSERVATIVE,
"(class) use conservative unification (only unfold reducible definitions, and avoid delta-delta case splits)");
register_bool_option(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES,
"(class) use automatically derived instances from the transitive closure of the structure instance graph");
}
void finalize_class_instance_elaborator() {
@ -74,6 +83,7 @@ void finalize_class_instance_elaborator() {
delete g_class_trace_instances;
delete g_class_instance_max_depth;
delete g_class_conservative;
delete g_class_trans_instances;
}
bool get_class_unique_class_instances(options const & o) {
@ -92,6 +102,10 @@ bool get_class_conservative(options const & o) {
return o.get_bool(*g_class_conservative, LEAN_DEFAULT_CLASS_CONSERVATIVE);
}
bool get_class_trans_instances(options const & o) {
return o.get_bool(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES);
}
/** \brief Context for handling class-instance metavariable choice constraint */
struct class_instance_context {
io_state m_ios;
@ -102,6 +116,7 @@ struct class_instance_context {
bool m_trace_instances;
bool m_conservative;
unsigned m_max_depth;
bool m_trans_instances;
char const * m_fname;
optional<pos_info> m_pos;
class_instance_context(environment const & env, io_state const & ios,
@ -113,6 +128,7 @@ struct class_instance_context {
m_trace_instances = get_class_trace_instances(ios.get_options());
m_max_depth = get_class_instance_max_depth(ios.get_options());
m_conservative = get_class_conservative(ios.get_options());
m_trans_instances = get_class_trans_instances(ios.get_options());
m_tc = mk_class_type_checker(env, m_ngen.mk_child(), m_conservative);
options opts = m_ios.get_options();
opts = opts.update_if_undef(get_pp_purify_metavars_name(), false);
@ -134,10 +150,11 @@ struct class_instance_context {
optional<pos_info> const & get_pos() const { return m_pos; }
char const * get_file_name() const { return m_fname; }
unsigned get_max_depth() const { return m_max_depth; }
bool use_trans_instances() const { return m_trans_instances; }
};
pair<expr, constraint> mk_class_instance_elaborator(std::shared_ptr<class_instance_context> const & C, local_context const & ctx,
optional<expr> const & type, tag g, unsigned depth);
static pair<expr, constraint> mk_class_instance_elaborator(std::shared_ptr<class_instance_context> const & C, local_context const & ctx,
optional<expr> const & type, tag g, unsigned depth, bool use_globals);
/** \brief Choice function \c fn for synthesizing class instances.
@ -160,6 +177,7 @@ struct class_instance_elaborator : public choice_iterator {
list<expr> m_local_instances;
// global declaration names that are class instances.
// This information is retrieved using #get_class_instances.
list<name> m_trans_instances;
list<name> m_instances;
justification m_jst;
unsigned m_depth;
@ -167,10 +185,10 @@ struct class_instance_elaborator : public choice_iterator {
class_instance_elaborator(std::shared_ptr<class_instance_context> const & C, local_context const & ctx,
expr const & meta, expr const & meta_type,
list<expr> const & local_insts, list<name> const & instances,
list<expr> const & local_insts, list<name> const & trans_insts, list<name> const & instances,
justification const & j, unsigned depth):
choice_iterator(), m_C(C), m_ctx(ctx), m_meta(meta), m_meta_type(meta_type),
m_local_instances(local_insts), m_instances(instances), m_jst(j), m_depth(depth) {
m_local_instances(local_insts), m_trans_instances(trans_insts), m_instances(instances), m_jst(j), m_depth(depth) {
if (m_depth > m_C->get_max_depth()) {
throw_class_exception("maximum class-instance resolution depth has been reached "
"(the limit can be increased by setting option 'class.instance_max_depth') "
@ -205,7 +223,7 @@ struct class_instance_elaborator : public choice_iterator {
out << m_meta << " : " << t << " := " << r << endl;
}
optional<constraints> try_instance(expr const & inst, expr const & inst_type) {
optional<constraints> try_instance(expr const & inst, expr const & inst_type, bool use_globals) {
type_checker & tc = m_C->tc();
name_generator & ngen = m_C->m_ngen;
tag g = inst.get_tag();
@ -233,7 +251,7 @@ struct class_instance_elaborator : public choice_iterator {
expr arg;
if (binding_info(type).is_inst_implicit()) {
pair<expr, constraint> ac = mk_class_instance_elaborator(m_C, m_ctx, some_expr(binding_domain(type)),
g, m_depth+1);
g, m_depth+1, use_globals);
arg = ac.first;
cs.push_back(ac.second);
} else {
@ -251,7 +269,7 @@ struct class_instance_elaborator : public choice_iterator {
}
}
optional<constraints> try_instance(name const & inst) {
optional<constraints> try_instance(name const & inst, bool use_globals) {
environment const & env = m_C->env();
if (auto decl = env.find(inst)) {
name_generator & ngen = m_C->m_ngen;
@ -262,7 +280,7 @@ struct class_instance_elaborator : public choice_iterator {
levels ls = to_list(ls_buffer.begin(), ls_buffer.end());
expr inst_cnst = copy_tag(m_meta, mk_constant(inst, ls));
expr inst_type = instantiate_type_univ_params(*decl, ls);
return try_instance(inst_cnst, inst_type);
return try_instance(inst_cnst, inst_type, use_globals);
} else {
return optional<constraints>();
}
@ -274,20 +292,31 @@ struct class_instance_elaborator : public choice_iterator {
m_local_instances = tail(m_local_instances);
if (!is_local(inst))
continue;
if (auto r = try_instance(inst, mlocal_type(inst)))
bool use_globals = true;
if (auto r = try_instance(inst, mlocal_type(inst), use_globals))
return r;
}
while (!empty(m_trans_instances)) {
bool use_globals = false;
name inst = head(m_trans_instances);
m_trans_instances = tail(m_trans_instances);
if (auto cs = try_instance(inst, use_globals))
return cs;
}
while (!empty(m_instances)) {
bool use_globals = true;
name inst = head(m_instances);
m_instances = tail(m_instances);
if (auto cs = try_instance(inst))
if (auto cs = try_instance(inst, use_globals))
return cs;
}
return optional<constraints>();
}
};
constraint mk_class_instance_cnstr(std::shared_ptr<class_instance_context> const & C, local_context const & ctx, expr const & m, unsigned depth) {
// Remarks:
// - we only use get_class_instances and get_class_derived_trans_instances when use_globals is true
static constraint mk_class_instance_cnstr(std::shared_ptr<class_instance_context> const & C, local_context const & ctx, expr const & m, unsigned depth, bool use_globals) {
environment const & env = C->env();
justification j = mk_failed_to_synthesize_jst(env, m);
auto choice_fn = [=](expr const & meta, expr const & meta_type, substitution const &, name_generator const &) {
@ -297,11 +326,16 @@ constraint mk_class_instance_cnstr(std::shared_ptr<class_instance_context> const
list<expr> local_insts;
if (C->use_local_instances())
local_insts = get_local_instances(C->tc(), ctx_lst, cls_name);
list<name> insts = get_class_instances(env, cls_name);
list<name> trans_insts, insts;
if (use_globals) {
if (depth == 0 && C->use_trans_instances())
trans_insts = get_class_derived_trans_instances(env, cls_name);
insts = get_class_instances(env, cls_name);
}
if (empty(local_insts) && empty(insts))
return lazy_list<constraints>(); // nothing to be done
// we are always strict with placeholders associated with classes
return choose(std::make_shared<class_instance_elaborator>(C, ctx, meta, meta_type, local_insts, insts, j, depth));
return choose(std::make_shared<class_instance_elaborator>(C, ctx, meta, meta_type, local_insts, trans_insts, insts, j, depth));
} else {
// do nothing, type is not a class...
return lazy_list<constraints>(constraints());
@ -311,15 +345,15 @@ constraint mk_class_instance_cnstr(std::shared_ptr<class_instance_context> const
return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::Basic), owner, j);
}
pair<expr, constraint> mk_class_instance_elaborator(std::shared_ptr<class_instance_context> const & C, local_context const & ctx,
optional<expr> const & type, tag g, unsigned depth) {
static pair<expr, constraint> mk_class_instance_elaborator(std::shared_ptr<class_instance_context> const & C, local_context const & ctx,
optional<expr> const & type, tag g, unsigned depth, bool use_globals) {
expr m = ctx.mk_meta(C->m_ngen, type, g);
constraint c = mk_class_instance_cnstr(C, ctx, m, depth);
constraint c = mk_class_instance_cnstr(C, ctx, m, depth, use_globals);
return mk_pair(m, c);
}
constraint mk_class_instance_root_cnstr(std::shared_ptr<class_instance_context> const & C, local_context const & _ctx,
expr const & m, bool is_strict, unifier_config const & cfg, delay_factor const & factor) {
static constraint mk_class_instance_root_cnstr(std::shared_ptr<class_instance_context> const & C, local_context const & _ctx,
expr const & m, bool is_strict, unifier_config const & cfg, delay_factor const & factor) {
environment const & env = C->env();
justification j = mk_failed_to_synthesize_jst(env, m);
@ -336,7 +370,8 @@ constraint mk_class_instance_root_cnstr(std::shared_ptr<class_instance_context>
expr new_meta = mj.first;
justification new_j = mj.second;
unsigned depth = 0;
constraint c = mk_class_instance_cnstr(C, ctx, new_meta, depth);
bool use_globals = true;
constraint c = mk_class_instance_cnstr(C, ctx, new_meta, depth, use_globals);
unifier_config new_cfg(cfg);
new_cfg.m_discard = false;
new_cfg.m_use_exceptions = false;
@ -446,9 +481,10 @@ optional<expr> mk_class_instance(environment const & env, io_state const & ios,
auto C = std::make_shared<class_instance_context>(env, ios, prefix, use_local_instances);
if (!is_ext_class(C->tc(), type))
return none_expr();
expr meta = ctx.mk_meta(C->m_ngen, some_expr(type), type.get_tag());
unsigned depth = 0;
constraint c = mk_class_instance_cnstr(C, ctx, meta, depth);
expr meta = ctx.mk_meta(C->m_ngen, some_expr(type), type.get_tag());
unsigned depth = 0;
bool use_globals = true;
constraint c = mk_class_instance_cnstr(C, ctx, meta, depth, use_globals);
unifier_config new_cfg(cfg);
new_cfg.m_discard = true;
new_cfg.m_use_exceptions = true;