refactor(library/type_inference): move new type class resolution procedure to genere type_inference

This commit is contained in:
Leonardo de Moura 2015-10-30 16:54:42 -07:00
parent 4c573380b2
commit 56c15f4fb5
4 changed files with 937 additions and 915 deletions

View file

@ -51,9 +51,6 @@ static name * g_class_trans_instances = nullptr;
static name * g_class_force_new = nullptr;
static name * g_prefix = nullptr;
LEAN_THREAD_PTR(ci_local_metavar_types, g_lm_types);
LEAN_THREAD_PTR(io_state, g_ios);
bool get_class_trace_instances(options const & o) {
return o.get_bool(*g_class_trace_instances, LEAN_DEFAULT_CLASS_TRACE_INSTANCES);
}
@ -70,884 +67,44 @@ bool get_class_force_new(options const & o) {
return o.get_bool(*g_class_force_new, false);
}
class default_ci_local_metavar_types : public ci_local_metavar_types {
public:
virtual expr infer_local(expr const & e) { return mlocal_type(e); }
virtual expr infer_metavar(expr const & e) { return mlocal_type(e); }
};
static void finalize_lm_types(void * p) {
delete reinterpret_cast<ci_local_metavar_types*>(p);
g_lm_types = nullptr;
}
static ci_local_metavar_types & get_lm_types() {
if (!g_lm_types) {
g_lm_types = new default_ci_local_metavar_types();
register_thread_finalizer(finalize_lm_types, g_lm_types);
}
return *g_lm_types;
}
static expr ci_infer_local(expr const & e) {
return get_lm_types().infer_local(e);
}
static expr ci_infer_metavar(expr const & e) {
return get_lm_types().infer_metavar(e);
}
/** \brief The following global thread local constant is a big hack for mk_subsingleton_instance.
When g_subsingleton_hack is true, the following type-class resolution problem fails
Given (A : Type{?u}), where ?u is a universe meta-variable created by an external module,
?Inst : subsingleton.{?u} A := subsingleton_prop ?M
This case generates the unification problem
subsingleton.{?u} A =?= subsingleton.{0} ?M
which can be solved by assigning (?u := 0) and (?M := A)
when the hack is enabled, the is_def_eq method in the type class module fails at the subproblem:
?u =?= 0
That is, when the hack is on, type-class resolution cannot succeed by instantiating an external universe
meta-variable with 0.
*/
LEAN_THREAD_VALUE(bool, g_subsingleton_hack, false);
struct cienv {
typedef rb_map<unsigned, level, unsigned_cmp> uassignment;
typedef rb_map<unsigned, expr, unsigned_cmp> eassignment;
typedef std::unique_ptr<type_inference> ti_ptr;
environment m_env;
pos_info_provider const * m_pip;
typedef std::unique_ptr<default_type_inference> ti_ptr;
ti_ptr m_ti_ptr;
optional<pos_info> m_pos;
expr_struct_map<expr> m_cache;
name_predicate m_not_reducible_pred;
list<expr> m_ctx;
buffer<pair<name, expr>> m_local_instances;
unsigned m_next_local_idx;
unsigned m_next_uvar_idx;
unsigned m_next_mvar_idx;
struct stack_entry {
// We only use transitive instances when we can solve the problem in a single step.
// That is, the transitive instance does not have any instance argument, OR
// it uses local instances to fill them.
// We accomplish that by not considering global instances when solving
// transitive instance subproblems.
expr m_mvar;
unsigned m_depth;
bool m_trans_inst_subproblem;
stack_entry(expr const & m, unsigned d, bool s = false):
m_mvar(m), m_depth(d), m_trans_inst_subproblem(s) {}
};
struct state {
bool m_trans_inst_subproblem;
list<stack_entry> m_stack; // stack of meta-variables that need to be synthesized;
uassignment m_uassignment;
eassignment m_eassignment;
state():m_trans_inst_subproblem(false) {}
};
state m_state; // active state
struct choice {
list<expr> m_local_instances;
list<name> m_trans_instances;
list<name> m_instances;
state m_state;
};
std::vector<choice> m_choices;
expr m_main_mvar;
bool m_multiple_instances;
bool m_displayed_trace_header;
// configuration
options m_options; // it is used for pretty printing
unsigned m_max_depth;
bool m_trans_instances;
bool m_trace_instances;
class ti : public type_inference {
cienv & m_cienv;
std::vector<state> m_stack;
public:
ti(cienv & e):type_inference(e.m_env), m_cienv(e) {}
virtual bool is_extra_opaque(name const & n) const { return m_cienv.is_not_reducible(n); }
virtual expr mk_tmp_local(expr const & type, binder_info const & bi) { return m_cienv.mk_local(type, bi); }
virtual bool is_tmp_local(expr const & e) const { return m_cienv.is_internal_local(e); }
virtual bool is_uvar(level const & l) const { return cienv::is_uvar(l); }
virtual bool is_mvar(expr const & e) const { return m_cienv.is_mvar(e); }
virtual level const * get_assignment(level const & u) const { return m_cienv.get_assignment(u); }
virtual expr const * get_assignment(expr const & m) const { return m_cienv.get_assignment(m); }
virtual void update_assignment(level const & u, level const & v) { return m_cienv.update_assignment(u, v); }
virtual void update_assignment(expr const & m, expr const & v) { return m_cienv.update_assignment(m, v); }
virtual expr infer_local(expr const & e) const { return ci_infer_local(e); }
virtual expr infer_metavar(expr const & e) const { return ci_infer_metavar(e); }
virtual void push() { m_stack.push_back(m_cienv.m_state); }
virtual void pop() { m_cienv.m_state = m_stack.back(); m_stack.pop_back(); }
virtual void commit() { m_stack.pop_back(); }
virtual bool on_is_def_eq_failure(expr & e1, expr & e2) {
if (is_app(e1) && is_app(e2)) {
if (auto p1 = m_cienv.find_unsynth_metavar(e1)) {
if (m_cienv.mk_nested_instance(p1->first, p1->second)) {
e1 = m_cienv.instantiate_uvars_mvars(e1);
return true;
}
}
if (auto p2 = m_cienv.find_unsynth_metavar(e2)) {
if (m_cienv.mk_nested_instance(p2->first, p2->second)) {
e2 = m_cienv.instantiate_uvars_mvars(e2);
return true;
}
}
}
return false;
void reset(environment const & env, io_state const & ios, list<expr> const & ctx) {
m_ti_ptr.reset(new default_type_inference(env, ios, ctx));
}
virtual bool ignore_universe_def_eq(level const & l1, level const & l2) const {
if (is_meta(l1) || is_meta(l2)) {
// The unifier may invoke this module before universe metavariables in the class
// have been instantiated. So, we just ignore and assume they will be solved by
// the unifier.
// See comment at g_subsingleton_hack declaration.
if (g_subsingleton_hack && (is_zero(l1) || is_zero(l2)))
return false;
return true; // we ignore
} else {
return false;
}
bool compatible_env(environment const & env) {
environment const & curr_env = m_ti_ptr->env();
return env.is_descendant(curr_env) && curr_env.is_descendant(env);
}
virtual bool validate_assignment(expr const & m, buffer<expr> const & locals, expr const & v) {
// We must check
// 1. Any (internal) local constant occurring in v occurs in locals
// 2. m does not occur in v
bool ok = true;
for_each(v, [&](expr const & e, unsigned) {
if (!ok)
return false; // stop search
if (is_tmp_local(e)) {
if (std::all_of(locals.begin(), locals.end(), [&](expr const & a) {
return mlocal_name(a) != mlocal_name(e); })) {
ok = false; // failed 1
return false;
}
} else if (is_mvar(e)) {
if (m == e) {
ok = false; // failed 2
return false;
}
return false;
}
return true;
});
return ok;
}
};
cienv(bool multiple_instances = false):
m_next_local_idx(0),
m_next_uvar_idx(0),
m_next_mvar_idx(0),
m_multiple_instances(multiple_instances),
m_max_depth(LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH),
m_trans_instances(LEAN_DEFAULT_CLASS_TRANS_INSTANCES),
m_trace_instances(LEAN_DEFAULT_CLASS_TRACE_INSTANCES) {}
bool is_not_reducible(name const & n) const {
return m_not_reducible_pred(n);
}
void clear_cache() {
expr_struct_map<expr> fresh;
fresh.swap(m_cache);
if (m_ti_ptr)
void ensure_compatible(environment const & env, io_state const & ios, list<expr> const & ctx) {
if (!m_ti_ptr || !compatible_env(env) || !m_ti_ptr->compatible_local_instances(ctx))
reset(env, ios, ctx);
if (!m_ti_ptr->update_options(ios.get_options()))
m_ti_ptr->clear_cache();
}
void clear_cache_and_ctx() {
m_next_local_idx = 0;
m_next_uvar_idx = 0;
m_next_mvar_idx = 0;
m_ctx = list<expr>();
m_local_instances.clear();
clear_cache();
}
optional<expr> check_cache(expr const & type) const {
if (m_multiple_instances) {
// We do not cache results when multiple instances have to be generated.
return none_expr();
}
auto it = m_cache.find(type);
if (it != m_cache.end())
return some_expr(it->second);
else
return none_expr();
}
void cache_result(expr const & type, expr const & inst) {
if (m_multiple_instances) {
// We do not cache results when multiple instances have to be generated.
return;
}
m_cache.insert(mk_pair(type, inst));
}
void set_options(options const & o) {
m_options = o;
unsigned max_depth = get_class_instance_max_depth(o);
bool trans_instances = get_class_trans_instances(o);
bool trace_instances = get_class_trace_instances(o);
if (trace_instances) {
m_options = m_options.update_if_undef(get_pp_purify_metavars_name(), false);
m_options = m_options.update_if_undef(get_pp_implicit_name(), true);
}
if (m_max_depth != max_depth ||
m_trans_instances != trans_instances ||
m_trace_instances != trace_instances) {
clear_cache_and_ctx();
}
m_max_depth = max_depth;
m_trans_instances = trans_instances;
m_trace_instances = trace_instances;
}
void set_env(environment const & env) {
// Remark: We can implement the following potential refinement.
// If env is a descendant of m_env, and env does not add new global instances,
// then we don't need to reset the cache
if (!m_env.is_descendant(m_env) || !m_env.is_descendant(env)) {
m_env = env;
m_not_reducible_pred = mk_not_reducible_pred(m_env);
m_ti_ptr = nullptr;
clear_cache_and_ctx();
}
if (!m_ti_ptr) {
m_ti_ptr.reset(new ti(*this));
clear_cache_and_ctx();
}
}
expr whnf(expr const & e) {
lean_assert(m_ti_ptr);
return m_ti_ptr->whnf(e);
}
expr infer_type(expr const & e) {
lean_assert(m_ti_ptr);
return m_ti_ptr->infer(e);
}
bool is_def_eq(expr const & e1, expr const & e2) {
lean_assert(m_ti_ptr);
return m_ti_ptr->is_def_eq(e1, e2);
}
expr instantiate_uvars_mvars(expr const & e) {
lean_assert(m_ti_ptr);
return m_ti_ptr->instantiate_uvars_mvars(e);
}
expr mk_local(expr const & type, binder_info const & bi = binder_info()) {
unsigned idx = m_next_local_idx;
m_next_local_idx++;
name n(*g_prefix, idx);
return lean::mk_local(n, n, type, bi);
}
bool is_internal_local(expr const & e) {
if (!is_local(e))
return false;
name const & n = mlocal_name(e);
return !n.is_atomic() && n.get_prefix() == *g_prefix;
}
// Helper function for find_unsynth_metavar
static bool has_meta_arg(expr e) {
while (is_app(e)) {
if (is_meta(app_arg(e)))
return true;
e = app_fn(e);
}
return false;
}
/** IF \c e is of the form (f ... (?m t_1 ... t_n) ...) where ?m is an unassigned
metavariable whose type is a type class, and (?m t_1 ... t_n) must be synthesized
by type class resolution, then we return ?m.
Otherwise, we return none */
optional<pair<expr, expr>> find_unsynth_metavar(expr const & e) {
if (!has_meta_arg(e))
return optional<pair<expr, expr>>();
buffer<expr> args;
expr const & fn = get_app_args(e, args);
expr type = infer_type(fn);
unsigned i = 0;
while (i < args.size()) {
type = whnf(type);
if (!is_pi(type))
return optional<pair<expr, expr>>();
expr const & arg = args[i];
if (binding_info(type).is_inst_implicit() && is_meta(arg)) {
expr const & m = get_app_fn(arg);
if (is_mvar(m)) {
expr m_type = instantiate_uvars_mvars(infer_type(m));
if (!has_expr_metavar_relaxed(m_type)) {
return some(mk_pair(m, m_type));
}
}
}
type = instantiate(binding_body(type), arg);
i++;
}
return optional<pair<expr, expr>>();
}
/** \brief If the constant \c e is a class, return its name */
optional<name> constant_is_class(expr const & e) {
name const & cls_name = const_name(e);
if (lean::is_class(m_env, cls_name)) {
return optional<name>(cls_name);
} else {
return optional<name>();
}
}
optional<name> is_full_class(expr type) {
type = whnf(type);
if (is_pi(type)) {
return is_full_class(instantiate(binding_body(type), mk_local(binding_domain(type))));
} else {
expr f = get_app_fn(type);
if (!is_constant(f))
return optional<name>();
return constant_is_class(f);
}
}
/** \brief Partial/Quick test for is_class. Result
l_true: \c type is a class, and the name of the class is stored in \c result.
l_false: \c type is not a class.
l_undef: procedure did not establish whether \c type is a class or not.
*/
lbool is_quick_class(expr const & type, name & result) {
expr const * it = &type;
while (true) {
switch (it->kind()) {
case expr_kind::Var: case expr_kind::Sort: case expr_kind::Local:
case expr_kind::Meta: case expr_kind::Lambda:
return l_false;
case expr_kind::Macro:
return l_undef;
case expr_kind::Constant:
if (auto r = constant_is_class(*it)) {
result = *r;
return l_true;
} else if (is_not_reducible(const_name(*it))) {
return l_false;
} else {
return l_undef;
}
case expr_kind::App: {
expr const & f = get_app_fn(*it);
if (is_constant(f)) {
if (auto r = constant_is_class(f)) {
result = *r;
return l_true;
} else if (is_not_reducible(const_name(f))) {
return l_false;
} else {
return l_undef;
}
} else if (is_lambda(f) || is_macro(f)) {
return l_undef;
} else {
return l_false;
}
}
case expr_kind::Pi:
it = &binding_body(*it);
break;
}
}
}
/** \brief Return true iff \c type is a class or Pi that produces a class. */
optional<name> is_class(expr const & type) {
name result;
switch (is_quick_class(type, result)) {
case l_true: return optional<name>(result);
case l_false: return optional<name>();
case l_undef: break;
}
return is_full_class(type);
}
// Auxiliary method for set_ctx
void set_local_instance(unsigned i, name const & cname, expr const & e) {
lean_assert(i <= m_local_instances.size());
if (i == m_local_instances.size()) {
clear_cache();
m_local_instances.push_back(mk_pair(cname, e));
} else if (e != m_local_instances[i].second) {
clear_cache();
m_local_instances[i] = mk_pair(cname, e);
} else {
// we don't need to reset the cache since this local instance
// is equal to the one used in a previous call
}
}
void set_ctx(list<expr> const & ctx) {
if (is_eqp(m_ctx, ctx)) {
// we can keep the cache because the local context
// is still pointing to the same object.
return;
}
m_ctx = ctx;
unsigned i = 0;
for (expr const & e : ctx) {
// Remark: we use infer_type(e) instead of mlocal_type because we want to allow
// customers (e.g., blast) of this class to store the type of local constants
// in a different place.
if (auto cname = is_class(infer_type(e))) {
set_local_instance(i, *cname, e);
i++;
}
}
if (i < m_local_instances.size()) {
// new ctx has fewer local instances than previous one
m_local_instances.resize(i);
clear_cache();
}
}
void set_pos_info(pos_info_provider const * pip, expr const & pos_ref) {
m_pip = pip;
if (m_pip)
m_pos = m_pip->get_pos_info(pos_ref);
}
// Create an internal universal metavariable
level mk_uvar() {
unsigned idx = m_next_uvar_idx;
m_next_uvar_idx++;
return mk_meta_univ(name(*g_prefix, idx));
}
// Return true iff \c l is an internal universe metavariable created by this module.
static bool is_uvar(level const & l) {
if (!is_meta(l))
return false;
name const & n = meta_id(l);
return !n.is_atomic() && n.get_prefix() == *g_prefix;
}
static unsigned uvar_idx(level const & l) {
lean_assert(is_uvar(l));
return meta_id(l).get_numeral();
}
level const * get_assignment(level const & u) const {
return m_state.m_uassignment.find(uvar_idx(u));
}
bool is_assigned(level const & u) const {
return get_assignment(u) != nullptr;
}
// Assign \c v to the universe metavariable \c u.
void update_assignment(level const & u, level const & v) {
m_state.m_uassignment.insert(uvar_idx(u), v);
}
// Create an internal metavariable.
expr mk_mvar(expr const & type) {
unsigned idx = m_next_mvar_idx;
m_next_mvar_idx++;
return mk_metavar(name(*g_prefix, idx), type);
}
// Return true iff \c e is an internal metavariable created by this module.
static bool is_mvar(expr const & e) {
if (!is_metavar(e))
return false;
name const & n = mlocal_name(e);
return !n.is_atomic() && n.get_prefix() == *g_prefix;
}
static unsigned mvar_idx(expr const & m) {
lean_assert(is_mvar(m));
return mlocal_name(m).get_numeral();
}
expr const * get_assignment(expr const & m) const {
return m_state.m_eassignment.find(mvar_idx(m));
}
bool is_assigned(expr const & m) const {
return get_assignment(m) != nullptr;
}
void update_assignment(expr const & m, expr const & v) {
m_state.m_eassignment.insert(mvar_idx(m), v);
lean_assert(is_assigned(m));
}
// Assign \c v to the metavariable \c m.
void assign(expr const & m, expr const & v) {
lean_assert(!is_assigned(m));
update_assignment(m, v);
}
io_state_stream diagnostic() {
io_state ios(*g_ios);
ios.set_options(m_options);
return lean::diagnostic(m_env, ios);
}
void trace(unsigned depth, expr const & mvar, expr const & mvar_type, expr const & r) {
lean_assert(m_trace_instances);
auto out = diagnostic();
if (!m_displayed_trace_header && m_choices.size() == 1) {
if (m_pip) {
if (auto fname = m_pip->get_file_name()) {
out << fname << ":";
}
if (m_pos)
out << m_pos->first << ":" << m_pos->second << ":";
}
out << " class-instance resolution trace" << endl;
m_displayed_trace_header = true;
}
for (unsigned i = 0; i < depth; i++)
out << " ";
if (depth > 0)
out << "[" << depth << "] ";
out << mvar << " : " << instantiate_uvars_mvars(mvar_type) << " := " << r << endl;
}
// Try to synthesize e.m_mvar using instance inst : inst_type.
// trans_inst is true if inst is a transitive instance.
bool try_instance(stack_entry const & e, expr const & inst, expr const & inst_type, bool trans_inst) {
try {
buffer<expr> locals;
expr const & mvar = e.m_mvar;
expr mvar_type = mlocal_type(mvar);
while (true) {
mvar_type = whnf(mvar_type);
if (!is_pi(mvar_type))
break;
expr local = mk_local(binding_domain(mvar_type));
locals.push_back(local);
mvar_type = instantiate(binding_body(mvar_type), local);
}
expr type = inst_type;
expr r = inst;
buffer<expr> new_inst_mvars;
while (true) {
type = whnf(type);
if (!is_pi(type))
break;
expr new_mvar = mk_mvar(Pi(locals, binding_domain(type)));
if (binding_info(type).is_inst_implicit()) {
new_inst_mvars.push_back(new_mvar);
}
expr new_arg = mk_app(new_mvar, locals);
r = mk_app(r, new_arg);
type = instantiate(binding_body(type), new_arg);
}
if (m_trace_instances) {
trace(e.m_depth, mk_app(mvar, locals), mvar_type, r);
}
if (!is_def_eq(mvar_type, type)) {
return false;
}
r = Fun(locals, r);
if (is_assigned(mvar)) {
// Remark: if the metavariable is already assigned, we should check whether
// the previous assignment (obtained by solving unification constraints) and the
// synthesized one are definitionally equal. We don't do that for performance reasons.
// Moreover, the is_def_eq defined here is not complete (e.g., it only unfolds reducible constants).
update_assignment(mvar, r);
} else {
assign(mvar, r);
}
// copy new_inst_mvars to stack
unsigned i = new_inst_mvars.size();
while (i > 0) {
--i;
m_state.m_stack = cons(stack_entry(new_inst_mvars[i], e.m_depth+1, trans_inst), m_state.m_stack);
}
return true;
} catch (exception &) {
return false;
}
}
bool try_instance(stack_entry const & e, name const & inst_name, bool trans_inst) {
if (auto decl = m_env.find(inst_name)) {
buffer<level> ls_buffer;
unsigned num_univ_ps = decl->get_num_univ_params();
for (unsigned i = 0; i < num_univ_ps; i++)
ls_buffer.push_back(mk_uvar());
levels ls = to_list(ls_buffer.begin(), ls_buffer.end());
expr inst_cnst = mk_constant(inst_name, ls);
expr inst_type = instantiate_type_univ_params(*decl, ls);
return try_instance(e, inst_cnst, inst_type, trans_inst);
} else {
return false;
}
}
list<expr> get_local_instances(name const & cname) {
buffer<expr> selected;
for (pair<name, expr> const & p : m_local_instances) {
if (p.first == cname)
selected.push_back(p.second);
}
return to_list(selected);
}
bool is_done() const {
return empty(m_state.m_stack);
}
bool mk_choice_point(expr const & mvar) {
lean_assert(is_mvar(mvar));
if (m_choices.size() > m_max_depth) {
throw_class_exception("maximum class-instance resolution depth has been reached "
"(the limit can be increased by setting option 'class.instance_max_depth') "
"(the class-instance resolution trace can be visualized by setting option 'class.trace_instances')",
mlocal_type(m_main_mvar));
}
// Remark: we initially tried to reject branches where mvar_type contained unassigned metavariables.
// The idea was to make the procedure easier to understand.
// However, it turns out this is too restrictive. The group_theory folder contains the following instance.
// nsubg_setoid : Π {A : Type} [s : group A] (N : set A) [is_nsubg : @is_normal_subgroup A s N], setoid A
// When it is used, it creates a subproblem for
// is_nsubg : @is_normal_subgroup A s ?N
// where ?N is not known. Actually, we can only find the value for ?N by constructing the instance is_nsubg.
expr mvar_type = instantiate_uvars_mvars(mlocal_type(mvar));
bool toplevel_choice = m_choices.empty();
m_choices.push_back(choice());
choice & r = m_choices.back();
auto cname = is_class(mvar_type);
if (!cname)
return false;
r.m_local_instances = get_local_instances(*cname);
if (m_trans_instances && toplevel_choice) {
// we only use transitive instances in the top-level
r.m_trans_instances = get_class_derived_trans_instances(m_env, *cname);
}
r.m_instances = get_class_instances(m_env, *cname);
if (empty(r.m_local_instances) && empty(r.m_trans_instances) && empty(r.m_instances))
return false;
r.m_state = m_state;
return true;
}
bool process_next_alt_core(stack_entry const & e, list<expr> & insts) {
while (!empty(insts)) {
expr inst = head(insts);
insts = tail(insts);
expr inst_type = infer_type(inst);
bool trans_inst = false;
if (try_instance(e, inst, inst_type, trans_inst))
return true;
}
return false;
}
bool process_next_alt_core(stack_entry const & e, list<name> & inst_names, bool trans_inst) {
while (!empty(inst_names)) {
name inst_name = head(inst_names);
inst_names = tail(inst_names);
if (try_instance(e, inst_name, trans_inst))
return true;
}
return false;
}
bool process_next_alt(stack_entry const & e) {
lean_assert(!m_choices.empty());
choice & c = m_choices.back();
if (process_next_alt_core(e, c.m_local_instances))
return true;
if (!e.m_trans_inst_subproblem) {
if (process_next_alt_core(e, c.m_trans_instances, true))
return true;
if (process_next_alt_core(e, c.m_instances, false))
return true;
}
return false;
}
bool process_next_mvar() {
lean_assert(!is_done());
stack_entry e = head(m_state.m_stack);
if (!mk_choice_point(e.m_mvar))
return false;
m_state.m_stack = tail(m_state.m_stack);
return process_next_alt(e);
}
bool backtrack() {
if (m_choices.empty())
return false;
while (true) {
m_choices.pop_back();
if (m_choices.empty())
return false;
m_state = m_choices.back().m_state;
stack_entry e = head(m_state.m_stack);
m_state.m_stack = tail(m_state.m_stack);
if (process_next_alt(e))
return true;
}
}
optional<expr> search() {
while (!is_done()) {
if (!process_next_mvar()) {
if (!backtrack())
return none_expr();
}
}
return some_expr(instantiate_uvars_mvars(m_main_mvar));
}
optional<expr> next_solution() {
if (m_choices.empty())
return none_expr();
m_state = m_choices.back().m_state;
stack_entry e = head(m_state.m_stack);
m_state.m_stack = tail(m_state.m_stack);
if (process_next_alt(e))
return search();
else if (backtrack())
return search();
else
return none_expr();
}
void init_search(expr const & type) {
m_state = state();
m_main_mvar = mk_mvar(type);
m_state.m_stack = cons(stack_entry(m_main_mvar, 0), m_state.m_stack);
m_choices.clear();
}
optional<expr> ensure_no_meta(optional<expr> r) {
while (true) {
if (!r)
return none_expr();
if (!has_expr_metavar_relaxed(*r)) {
cache_result(mlocal_type(m_main_mvar), *r);
return r;
}
r = next_solution();
}
}
optional<expr> mk_instance_core(expr const & type) {
if (auto r = check_cache(type)) {
if (m_trace_instances) {
auto out = diagnostic();
out << "cached instance for " << type << "\n" << *r << "\n";
}
return r;
}
init_search(type);
auto r = search();
return ensure_no_meta(r);
}
optional<expr> operator()(environment const & env, options const & o, pos_info_provider const * pip, list<expr> const & ctx, expr const & type,
optional<expr> operator()(environment const & env, io_state const & ios,
pos_info_provider const * pip, list<expr> const & ctx, expr const & type,
expr const & pos_ref) {
set_env(env);
set_options(o);
set_ctx(ctx);
set_pos_info(pip, pos_ref);
m_displayed_trace_header = false;
return mk_instance_core(type);
}
optional<expr> next() {
if (!m_multiple_instances)
return none_expr();
auto r = next_solution();
return ensure_no_meta(r);
}
/** \brief Create a nested type class instance of the given type
\remark This method is used to resolve nested type class resolution problems. */
optional<expr> mk_nested_instance(expr const & type) {
std::vector<choice> choices;
m_choices.swap(choices); // save choice stack
flet<state> save_state(m_state, state());
flet<expr> save_main_mvar(m_main_mvar, expr());
auto r = mk_instance_core(type);
m_choices.swap(choices); // restore choice stack
return r;
}
/** \brief Create a nested type class instance of the given type, and assign it to metavariable \c m.
Return true iff the instance was successfully created.
\remark This method is used to resolve nested type class resolution problems. */
bool mk_nested_instance(expr const & m, expr const & m_type) {
lean_assert(is_mvar(m));
if (auto r = mk_nested_instance(m_type)) {
update_assignment(m, *r);
return true;
} else {
return false;
}
ensure_compatible(env, ios, ctx);
type_inference::scope_pos_info scope(*m_ti_ptr, pip, pos_ref);
return m_ti_ptr->mk_class_instance(type);
}
};
MK_THREAD_LOCAL_GET_DEF(cienv, get_cienv);
static void clear_cache_and_ctx() {
get_cienv().clear_cache_and_ctx();
static optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx,
expr const & e, pos_info_provider const * pip, expr const & pos_ref) {
return get_cienv()(env, ios, pip, ctx, e, pos_ref);
}
ci_local_metavar_types_scope::ci_local_metavar_types_scope(ci_local_metavar_types & t):
m_old(g_lm_types) {
g_lm_types = &t;
clear_cache_and_ctx();
}
ci_local_metavar_types_scope::~ci_local_metavar_types_scope() {
clear_cache_and_ctx();
g_lm_types = m_old;
}
static optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider const * pip,
expr const & pos_ref) {
flet<io_state*> set_ios(g_ios, const_cast<io_state*>(&ios));
return get_cienv()(env, ios.get_options(), pip, ctx, e, pos_ref);
}
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider const * pip) {
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx,
expr const & e, pos_info_provider const * pip) {
return mk_class_instance(env, ios, ctx, e, pip, e);
}
@ -958,7 +115,8 @@ optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx
// Auxiliary class for generating a lazy-stream of instances.
class class_multi_instance_iterator : public choice_iterator {
io_state m_ios;
cienv m_cienv;
default_type_inference m_ti;
type_inference::scope_pos_info m_scope_pos_info;
expr m_new_meta;
justification m_new_j;
optional<expr> m_first;
@ -969,11 +127,11 @@ public:
bool is_strict):
choice_iterator(!is_strict),
m_ios(ios),
m_cienv(true),
m_ti(env, ios, ctx, true),
m_scope_pos_info(m_ti, pip, pos_ref),
m_new_meta(new_meta),
m_new_j(new_j) {
flet<io_state*> set_ios(g_ios, const_cast<io_state*>(&m_ios));
m_first = m_cienv(env, ios.get_options(), pip, ctx, e, pos_ref);
m_first = m_ti.mk_class_instance(e);
}
virtual ~class_multi_instance_iterator() {}
@ -984,8 +142,7 @@ public:
r = m_first;
m_first = none_expr();
} else {
flet<io_state*> set_ios(g_ios, const_cast<io_state*>(&m_ios));
r = m_cienv.next();
r = m_ti.next_class_instance();
}
if (r) {
constraint c = mk_eq_cnstr(m_new_meta, *r, m_new_j);
@ -1000,17 +157,17 @@ static constraint mk_class_instance_root_cnstr(environment const & env, io_state
bool use_local_instances, pos_info_provider const * pip) {
justification j = mk_failed_to_synthesize_jst(env, m);
auto choice_fn = [=](expr const & meta, expr const & meta_type, substitution const & s, name_generator &&) {
local_context ctx;
if (use_local_instances)
ctx = _ctx.instantiate(substitution(s));
cienv & cenv = get_cienv();
cenv.set_env(env);
auto cls_name = cenv.is_class(meta_type);
cenv.ensure_compatible(env, ios, ctx.get_data());
auto cls_name = cenv.m_ti_ptr->is_class(meta_type);
if (!cls_name) {
// do nothing, since type is not a class.
return lazy_list<constraints>(constraints());
}
bool multiple_insts = try_multiple_instances(env, *cls_name);
local_context ctx;
if (use_local_instances)
ctx = _ctx.instantiate(substitution(s));
pair<expr, justification> mj = update_meta(meta, s);
expr new_meta = mj.first;
justification new_j = mj.second;
@ -1066,14 +223,16 @@ optional<expr> mk_hset_instance(type_checker & tc, io_state const & ios, list<ex
}
optional<expr> mk_subsingleton_instance(type_checker & tc, io_state const & ios, list<expr> const & ctx, expr const & type) {
flet<bool> set(g_subsingleton_hack, true);
cienv & cenv = get_cienv();
cenv.ensure_compatible(tc.env(), ios, ctx);
flet<bool> set(cenv.m_ti_ptr->get_ignore_if_zero(), true);
level lvl = sort_level(tc.ensure_type(type).first);
expr subsingleton;
if (is_standard(tc.env()))
subsingleton = mk_app(mk_constant(get_subsingleton_name(), {lvl}), type);
else
subsingleton = tc.whnf(mk_app(mk_constant(get_is_trunc_is_hprop_name(), {lvl}), type)).first;
return mk_class_instance(tc.env(), ios, ctx, subsingleton);
return cenv.m_ti_ptr->mk_class_instance(subsingleton);
}
void initialize_class_instance_resolution() {

View file

@ -9,32 +9,10 @@ Author: Leonardo de Moura
#include "kernel/environment.h"
#include "kernel/pos_info_provider.h"
#include "library/io_state.h"
#include "library/type_inference.h"
#include "library/local_context.h"
namespace lean {
/** Auxiliary object used to customize type class resolution.
It allows us to specify how the types of local constants and metavariables are retrieved.
\remark We need this object because modules such as blast store
the types of some local constants (e.g., hypotheses) in a
different data-structure.
*/
class ci_local_metavar_types {
public:
virtual ~ci_local_metavar_types() {}
virtual expr infer_local(expr const & e) = 0;
virtual expr infer_metavar(expr const & e) = 0;
};
/** \brief Auxiliary object for changing the thread local storage that stores the auxiliary object
ci_local_metavar_types used by type class resolution. */
class ci_local_metavar_types_scope {
ci_local_metavar_types * m_old;
public:
ci_local_metavar_types_scope(ci_local_metavar_types & t);
~ci_local_metavar_types_scope();
};
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider const * pip = nullptr);
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e, pos_info_provider const * pip = nullptr);

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <vector>
#include <algorithm>
#include "util/interrupt.h"
#include "kernel/instantiate.h"
@ -13,6 +14,10 @@ Author: Leonardo de Moura
#include "library/normalize.h"
#include "library/replace_visitor.h"
#include "library/type_inference.h"
#include "library/pp_options.h"
#include "library/reducible.h"
#include "library/generic_exception.h"
#include "library/class.h"
namespace lean {
static name * g_prefix = nullptr;
@ -45,18 +50,41 @@ struct type_inference::ext_ctx : public extension_context {
}
};
type_inference::type_inference(environment const & env):
// TODO(Leo): move this methods to this module
bool get_class_trace_instances(options const & o);
unsigned get_class_instance_max_depth(options const & o);
bool get_class_trans_instances(options const & o);
type_inference::type_inference(environment const & env, io_state const & ios, bool multiple_instances):
m_env(env),
m_ios(ios),
m_ngen(*g_prefix),
m_ext_ctx(new ext_ctx(*this)),
m_proj_info(get_projection_info_map(env)) {
m_pip = nullptr;
m_ci_multiple_instances = multiple_instances;
m_ignore_external_mvars = false;
m_check_types = true;
// TODO(Leo): use compilation options for setting config
m_ci_max_depth = 32;
m_ci_trans_instances = true;
m_ci_trace_instances = false;
update_options(ios.get_options());
}
type_inference::~type_inference() {
}
void type_inference::set_context(list<expr> const & ctx) {
clear_cache();
m_ci_local_instances.clear();
for (expr const & e : ctx) {
if (auto cname = is_class(infer(e))) {
m_ci_local_instances.push_back(mk_pair(*cname, e));
}
}
}
bool type_inference::is_opaque(declaration const & d) const {
if (d.is_theorem())
return true;
@ -1031,6 +1059,670 @@ expr type_inference::infer(expr const & e) {
}
void type_inference::clear_cache() {
m_ci_cache.clear();
}
/** \brief If the constant \c e is a class, return its name */
optional<name> type_inference::constant_is_class(expr const & e) {
name const & cls_name = const_name(e);
if (lean::is_class(m_env, cls_name)) {
return optional<name>(cls_name);
} else {
return optional<name>();
}
}
optional<name> type_inference::is_full_class(expr type) {
type = whnf(type);
if (is_pi(type)) {
return is_full_class(instantiate(binding_body(type), mk_tmp_local(binding_domain(type))));
} else {
expr f = get_app_fn(type);
if (!is_constant(f))
return optional<name>();
return constant_is_class(f);
}
}
/** \brief Partial/Quick test for is_class. Result
l_true: \c type is a class, and the name of the class is stored in \c result.
l_false: \c type is not a class.
l_undef: procedure did not establish whether \c type is a class or not.
*/
lbool type_inference::is_quick_class(expr const & type, name & result) {
expr const * it = &type;
while (true) {
switch (it->kind()) {
case expr_kind::Var: case expr_kind::Sort: case expr_kind::Local:
case expr_kind::Meta: case expr_kind::Lambda:
return l_false;
case expr_kind::Macro:
return l_undef;
case expr_kind::Constant:
if (auto r = constant_is_class(*it)) {
result = *r;
return l_true;
} else if (is_extra_opaque(const_name(*it))) {
return l_false;
} else {
return l_undef;
}
case expr_kind::App: {
expr const & f = get_app_fn(*it);
if (is_constant(f)) {
if (auto r = constant_is_class(f)) {
result = *r;
return l_true;
} else if (is_extra_opaque(const_name(f))) {
return l_false;
} else {
return l_undef;
}
} else if (is_lambda(f) || is_macro(f)) {
return l_undef;
} else {
return l_false;
}
}
case expr_kind::Pi:
it = &binding_body(*it);
break;
}
}
}
/** \brief Return true iff \c type is a class or Pi that produces a class. */
optional<name> type_inference::is_class(expr const & type) {
name result;
switch (is_quick_class(type, result)) {
case l_true: return optional<name>(result);
case l_false: return optional<name>();
case l_undef: break;
}
return is_full_class(type);
}
bool type_inference::compatible_local_instances(list<expr> const & ctx) {
unsigned i = 0;
for (expr const & e : ctx) {
// Remark: we use infer_type(e) instead of mlocal_type because we want to allow
// customers (e.g., blast) of this class to store the type of local constants
// in a different place.
if (auto cname = is_class(infer(e))) {
if (i == m_ci_local_instances.size())
return false; // ctx has more local instances than m_ci_local_instances
if (e != m_ci_local_instances[i].second)
return false; // local instance in ctx is not compatible with one at m_ci_local_instances
i++;
}
}
return i == m_ci_local_instances.size();
}
// Helper function for find_unsynth_metavar
static bool has_meta_arg(expr e) {
while (is_app(e)) {
if (is_meta(app_arg(e)))
return true;
e = app_fn(e);
}
return false;
}
/** IF \c e is of the form (f ... (?m t_1 ... t_n) ...) where ?m is an unassigned
metavariable whose type is a type class, and (?m t_1 ... t_n) must be synthesized
by type class resolution, then we return ?m.
Otherwise, we return none */
optional<pair<expr, expr>> type_inference::find_unsynth_metavar(expr const & e) {
if (!has_meta_arg(e))
return optional<pair<expr, expr>>();
buffer<expr> args;
expr const & fn = get_app_args(e, args);
expr type = infer(fn);
unsigned i = 0;
while (i < args.size()) {
type = whnf(type);
if (!is_pi(type))
return optional<pair<expr, expr>>();
expr const & arg = args[i];
if (binding_info(type).is_inst_implicit() && is_meta(arg)) {
expr const & m = get_app_fn(arg);
if (is_mvar(m)) {
expr m_type = instantiate_uvars_mvars(infer(m));
if (!has_expr_metavar_relaxed(m_type)) {
return some(mk_pair(m, m_type));
}
}
}
type = instantiate(binding_body(type), arg);
i++;
}
return optional<pair<expr, expr>>();
}
bool type_inference::on_is_def_eq_failure(expr & e1, expr & e2) {
if (is_app(e1) && is_app(e2)) {
if (auto p1 = find_unsynth_metavar(e1)) {
if (mk_nested_instance(p1->first, p1->second)) {
e1 = instantiate_uvars_mvars(e1);
return true;
}
}
if (auto p2 = find_unsynth_metavar(e2)) {
if (mk_nested_instance(p2->first, p2->second)) {
e2 = instantiate_uvars_mvars(e2);
return true;
}
}
}
return false;
}
bool type_inference::validate_assignment(expr const & m, buffer<expr> const & locals, expr const & v) {
// We must check
// 1. Any (internal) local constant occurring in v occurs in locals
// 2. m does not occur in v
bool ok = true;
for_each(v, [&](expr const & e, unsigned) {
if (!ok)
return false; // stop search
if (is_tmp_local(e)) {
if (std::all_of(locals.begin(), locals.end(), [&](expr const & a) {
return mlocal_name(a) != mlocal_name(e); })) {
ok = false; // failed 1
return false;
}
} else if (is_mvar(e)) {
if (m == e) {
ok = false; // failed 2
return false;
}
return false;
}
return true;
});
return ok;
}
bool type_inference::update_options(options const & opts) {
options o = opts;
unsigned max_depth = get_class_instance_max_depth(o);
bool trans_instances = get_class_trans_instances(o);
bool trace_instances = get_class_trace_instances(o);
if (trace_instances) {
o = o.update_if_undef(get_pp_purify_metavars_name(), false);
o = o.update_if_undef(get_pp_implicit_name(), true);
}
bool r = true;
if (m_ci_max_depth != max_depth ||
m_ci_trans_instances != trans_instances ||
m_ci_trace_instances != trace_instances) {
r = false;
}
m_ci_max_depth = max_depth;
m_ci_trans_instances = trans_instances;
m_ci_trace_instances = trace_instances;
m_ios.set_options(o);
return r;
}
[[ noreturn ]] static void throw_class_exception(char const * msg, expr const & m) { throw_generic_exception(msg, m); }
[[ noreturn ]] static void throw_class_exception(expr const & m, pp_fn const & fn) { throw_generic_exception(m, fn); }
io_state_stream type_inference::diagnostic() {
return lean::diagnostic(m_env, m_ios);
}
void type_inference::trace(unsigned depth, expr const & mvar, expr const & mvar_type, expr const & r) {
lean_assert(m_ci_trace_instances);
auto out = diagnostic();
if (!m_ci_displayed_trace_header && m_ci_choices.size() == m_ci_choices_ini_sz + 1) {
if (m_pip) {
if (auto fname = m_pip->get_file_name()) {
out << fname << ":";
}
if (m_ci_pos)
out << m_ci_pos->first << ":" << m_ci_pos->second << ":";
}
out << " class-instance resolution trace" << endl;
m_ci_displayed_trace_header = true;
}
for (unsigned i = 0; i < depth; i++)
out << " ";
if (depth > 0)
out << "[" << depth << "] ";
out << mvar << " : " << instantiate_uvars_mvars(mvar_type) << " := " << r << endl;
}
// Try to synthesize e.m_mvar using instance inst : inst_type.
// trans_inst is true if inst is a transitive instance.
bool type_inference::try_instance(ci_stack_entry const & e, expr const & inst, expr const & inst_type, bool trans_inst) {
try {
buffer<expr> locals;
expr const & mvar = e.m_mvar;
expr mvar_type = mlocal_type(mvar);
while (true) {
mvar_type = whnf(mvar_type);
if (!is_pi(mvar_type))
break;
expr local = mk_tmp_local(binding_domain(mvar_type));
locals.push_back(local);
mvar_type = instantiate(binding_body(mvar_type), local);
}
expr type = inst_type;
expr r = inst;
buffer<expr> new_inst_mvars;
while (true) {
type = whnf(type);
if (!is_pi(type))
break;
expr new_mvar = mk_mvar(Pi(locals, binding_domain(type)));
if (binding_info(type).is_inst_implicit()) {
new_inst_mvars.push_back(new_mvar);
}
expr new_arg = mk_app(new_mvar, locals);
r = mk_app(r, new_arg);
type = instantiate(binding_body(type), new_arg);
}
if (m_ci_trace_instances) {
trace(e.m_depth, mk_app(mvar, locals), mvar_type, r);
}
if (!is_def_eq(mvar_type, type)) {
return false;
}
r = Fun(locals, r);
// Remark: if the metavariable is already assigned, we should check whether
// the previous assignment (obtained by solving unification constraints) and the
// synthesized one are definitionally equal. We don't do that for performance reasons.
// Moreover, the is_def_eq defined here is not complete (e.g., it only unfolds reducible constants).
update_assignment(mvar, r);
// copy new_inst_mvars to stack
unsigned i = new_inst_mvars.size();
while (i > 0) {
--i;
m_ci_state.m_stack = cons(ci_stack_entry(new_inst_mvars[i], e.m_depth+1, trans_inst), m_ci_state.m_stack);
}
return true;
} catch (exception &) {
return false;
}
}
bool type_inference::try_instance(ci_stack_entry const & e, name const & inst_name, bool trans_inst) {
if (auto decl = m_env.find(inst_name)) {
buffer<level> ls_buffer;
unsigned num_univ_ps = decl->get_num_univ_params();
for (unsigned i = 0; i < num_univ_ps; i++)
ls_buffer.push_back(mk_uvar());
levels ls = to_list(ls_buffer.begin(), ls_buffer.end());
expr inst_cnst = mk_constant(inst_name, ls);
expr inst_type = instantiate_type_univ_params(*decl, ls);
return try_instance(e, inst_cnst, inst_type, trans_inst);
} else {
return false;
}
}
list<expr> type_inference::get_local_instances(name const & cname) {
buffer<expr> selected;
for (pair<name, expr> const & p : m_ci_local_instances) {
if (p.first == cname)
selected.push_back(p.second);
}
return to_list(selected);
}
bool type_inference::is_ci_done() const {
return empty(m_ci_state.m_stack);
}
bool type_inference::mk_choice_point(expr const & mvar) {
lean_assert(is_mvar(mvar));
if (m_ci_choices.size() > m_ci_choices_ini_sz + m_ci_max_depth) {
throw_class_exception("maximum class-instance resolution depth has been reached "
"(the limit can be increased by setting option 'class.instance_max_depth') "
"(the class-instance resolution trace can be visualized "
"by setting option 'class.trace_instances')",
infer(m_ci_main_mvar));
}
// Remark: we initially tried to reject branches where mvar_type contained unassigned metavariables.
// The idea was to make the procedure easier to understand.
// However, it turns out this is too restrictive. The group_theory folder contains the following instance.
// nsubg_setoid : Π {A : Type} [s : group A] (N : set A) [is_nsubg : @is_normal_subgroup A s N], setoid A
// When it is used, it creates a subproblem for
// is_nsubg : @is_normal_subgroup A s ?N
// where ?N is not known. Actually, we can only find the value for ?N by constructing the instance is_nsubg.
expr mvar_type = instantiate_uvars_mvars(mlocal_type(mvar));
bool toplevel_choice = m_ci_choices.empty();
m_ci_choices.push_back(ci_choice());
push();
ci_choice & r = m_ci_choices.back();
auto cname = is_class(mvar_type);
if (!cname)
return false;
r.m_local_instances = get_local_instances(*cname);
if (m_ci_trans_instances && toplevel_choice) {
// we only use transitive instances in the top-level
r.m_trans_instances = get_class_derived_trans_instances(m_env, *cname);
}
r.m_instances = get_class_instances(m_env, *cname);
if (empty(r.m_local_instances) && empty(r.m_trans_instances) && empty(r.m_instances))
return false;
r.m_state = m_ci_state;
return true;
}
bool type_inference::process_next_alt_core(ci_stack_entry const & e, list<expr> & insts) {
while (!empty(insts)) {
expr inst = head(insts);
insts = tail(insts);
expr inst_type = infer(inst);
bool trans_inst = false;
if (try_instance(e, inst, inst_type, trans_inst))
return true;
}
return false;
}
bool type_inference::process_next_alt_core(ci_stack_entry const & e, list<name> & inst_names, bool trans_inst) {
while (!empty(inst_names)) {
name inst_name = head(inst_names);
inst_names = tail(inst_names);
if (try_instance(e, inst_name, trans_inst))
return true;
}
return false;
}
bool type_inference::process_next_alt(ci_stack_entry const & e) {
lean_assert(m_ci_choices.size() > m_ci_choices_ini_sz);
lean_assert(!m_ci_choices.empty());
std::vector<ci_choice> & cs = m_ci_choices;
list<expr> locals = cs.back().m_local_instances;
if (process_next_alt_core(e, locals)) {
cs.back().m_local_instances = locals;
return true;
}
cs.back().m_local_instances = list<expr>();
if (!e.m_trans_inst_subproblem) {
list<name> trans_insts = cs.back().m_trans_instances;
if (process_next_alt_core(e, trans_insts, true)) {
cs.back().m_trans_instances = trans_insts;
return true;
}
cs.back().m_trans_instances = list<name>();
list<name> insts = cs.back().m_instances;
if (process_next_alt_core(e, insts, false)) {
cs.back().m_instances = insts;
return true;
}
cs.back().m_instances = list<name>();
}
return false;
}
bool type_inference::process_next_mvar() {
lean_assert(!is_ci_done());
ci_stack_entry e = head(m_ci_state.m_stack);
if (!mk_choice_point(e.m_mvar))
return false;
m_ci_state.m_stack = tail(m_ci_state.m_stack);
return process_next_alt(e);
}
bool type_inference::backtrack() {
if (m_ci_choices.size() == m_ci_choices_ini_sz)
return false;
lean_assert(!m_ci_choices.empty());
while (true) {
m_ci_choices.pop_back();
pop();
if (m_ci_choices.size() == m_ci_choices_ini_sz)
return false;
m_ci_state = m_ci_choices.back().m_state;
ci_stack_entry e = head(m_ci_state.m_stack);
m_ci_state.m_stack = tail(m_ci_state.m_stack);
if (process_next_alt(e))
return true;
}
}
optional<expr> type_inference::search() {
while (!is_ci_done()) {
if (!process_next_mvar()) {
if (!backtrack())
return none_expr();
}
}
return some_expr(instantiate_uvars_mvars(m_ci_main_mvar));
}
optional<expr> type_inference::next_solution() {
if (m_ci_choices.size() == m_ci_choices_ini_sz)
return none_expr();
pop(); push(); // restore assignment
m_ci_state = m_ci_choices.back().m_state;
ci_stack_entry e = head(m_ci_state.m_stack);
m_ci_state.m_stack = tail(m_ci_state.m_stack);
if (process_next_alt(e))
return search();
else if (backtrack())
return search();
else
return none_expr();
}
void type_inference::init_search(expr const & type) {
m_ci_state = ci_state();
m_ci_main_mvar = mk_mvar(type);
m_ci_state.m_stack = to_list(ci_stack_entry(m_ci_main_mvar, 0));
m_ci_choices_ini_sz = m_ci_choices.size();
}
optional<expr> type_inference::check_ci_cache(expr const & type) const {
if (m_ci_multiple_instances) {
// We do not cache results when multiple instances have to be generated.
return none_expr();
}
auto it = m_ci_cache.find(type);
if (it != m_ci_cache.end())
return some_expr(it->second);
else
return none_expr();
}
void type_inference::cache_ci_result(expr const & type, expr const & inst) {
if (m_ci_multiple_instances) {
// We do not cache results when multiple instances have to be generated.
return;
}
m_ci_cache.insert(mk_pair(type, inst));
}
optional<expr> type_inference::ensure_no_meta(optional<expr> r) {
while (true) {
if (!r)
return none_expr();
if (!has_expr_metavar_relaxed(*r)) {
cache_ci_result(mlocal_type(m_ci_main_mvar), *r);
return r;
}
r = next_solution();
}
}
optional<expr> type_inference::mk_class_instance_core(expr const & type) {
if (auto r = check_ci_cache(type)) {
if (m_ci_trace_instances) {
diagnostic() << "cached instance for " << type << "\n" << *r << "\n";
}
return r;
}
init_search(type);
auto r = search();
return ensure_no_meta(r);
}
void type_inference::restore_choices(unsigned old_sz) {
lean_assert(old_sz <= m_ci_choices.size());
while (m_ci_choices.size() > old_sz) {
m_ci_choices.pop_back();
pop();
}
lean_assert(m_ci_choices.size() == old_sz);
}
optional<expr> type_inference::mk_class_instance(expr const & type) {
m_ci_choices.clear();
ci_choices_scope scope(*this);
m_ci_displayed_trace_header = false;
auto r = mk_class_instance_core(type);
if (r)
scope.commit();
return r;
}
optional<expr> type_inference::next_class_instance() {
if (!m_ci_multiple_instances)
return none_expr();
auto r = next_solution();
return ensure_no_meta(r);
}
/** \brief Create a nested type class instance of the given type
\remark This method is used to resolve nested type class resolution problems. */
optional<expr> type_inference::mk_nested_instance(expr const & type) {
ci_choices_scope scope(*this);
flet<unsigned> save_choice_sz(m_ci_choices_ini_sz, m_ci_choices_ini_sz);
flet<ci_state> save_state(m_ci_state, ci_state());
flet<expr> save_main_mvar(m_ci_main_mvar, expr());
unsigned old_choices_sz = m_ci_choices.size();
auto r = mk_class_instance_core(type);
if (r)
scope.commit();
m_ci_choices.resize(old_choices_sz); // cut search
return r;
}
/** \brief Create a nested type class instance of the given type, and assign it to metavariable \c m.
Return true iff the instance was successfully created.
\remark This method is used to resolve nested type class resolution problems. */
bool type_inference::mk_nested_instance(expr const & m, expr const & m_type) {
lean_assert(is_mvar(m));
if (auto r = mk_nested_instance(m_type)) {
update_assignment(m, *r);
return true;
} else {
return false;
}
}
type_inference::scope_pos_info::scope_pos_info(type_inference & o, pos_info_provider const * pip, expr const & pos_ref):
m_owner(o),
m_old_pip(m_owner.m_pip),
m_old_pos(m_owner.m_ci_pos) {
m_owner.m_pip = pip;
if (pip)
m_owner.m_ci_pos = pip->get_pos_info(pos_ref);
}
type_inference::scope_pos_info::~scope_pos_info() {
m_owner.m_pip = m_old_pip;
m_owner.m_ci_pos = m_old_pos;
}
default_type_inference::default_type_inference(environment const & env, io_state const & ios,
list<expr> const & ctx, bool multiple_instances):
type_inference(env, ios, multiple_instances),
m_not_reducible_pred(mk_not_reducible_pred(env)) {
m_ignore_if_zero = false;
m_next_local_idx = 0;
m_next_uvar_idx = 0;
m_next_mvar_idx = 0;
set_context(ctx);
}
default_type_inference::~default_type_inference() {}
expr default_type_inference::mk_tmp_local(expr const & type, binder_info const & bi) {
unsigned idx = m_next_local_idx;
m_next_local_idx++;
name n(*g_prefix, idx);
return lean::mk_local(n, n, type, bi);
}
bool default_type_inference::is_tmp_local(expr const & e) const {
if (!is_local(e))
return false;
name const & n = mlocal_name(e);
return !n.is_atomic() && n.get_prefix() == *g_prefix;
}
bool default_type_inference::is_uvar(level const & l) const {
if (!is_meta(l))
return false;
name const & n = meta_id(l);
return !n.is_atomic() && n.get_prefix() == *g_prefix;
}
bool default_type_inference::is_mvar(expr const & e) const {
if (!is_metavar(e))
return false;
name const & n = mlocal_name(e);
return !n.is_atomic() && n.get_prefix() == *g_prefix;
}
unsigned default_type_inference::uvar_idx(level const & l) const {
lean_assert(is_uvar(l));
return meta_id(l).get_numeral();
}
unsigned default_type_inference::mvar_idx(expr const & m) const {
lean_assert(is_mvar(m));
return mlocal_name(m).get_numeral();
}
level const * default_type_inference::get_assignment(level const & u) const {
return m_assignment.m_uassignment.find(uvar_idx(u));
}
expr const * default_type_inference::get_assignment(expr const & m) const {
return m_assignment.m_eassignment.find(mvar_idx(m));
}
void default_type_inference::update_assignment(level const & u, level const & v) {
m_assignment.m_uassignment.insert(uvar_idx(u), v);
}
void default_type_inference::update_assignment(expr const & m, expr const & v) {
m_assignment.m_eassignment.insert(mvar_idx(m), v);
}
level default_type_inference::mk_uvar() {
unsigned idx = m_next_uvar_idx;
m_next_uvar_idx++;
return mk_meta_univ(name(*g_prefix, idx));
}
expr default_type_inference::mk_mvar(expr const & type) {
unsigned idx = m_next_mvar_idx;
m_next_mvar_idx++;
return mk_metavar(name(*g_prefix, idx), type);
}
bool default_type_inference::ignore_universe_def_eq(level const & l1, level const & l2) const {
if (is_meta(l1) || is_meta(l2)) {
// The unifier may invoke this module before universe metavariables in the class
// have been instantiated. So, we just ignore and assume they will be solved by
// the unifier.
// See comment at m_ignore_if_zero declaration.
if (m_ignore_if_zero && (is_zero(l1) || is_zero(l2)))
return false;
return true; // we ignore
} else {
return false;
}
}
void initialize_type_inference() {

View file

@ -8,6 +8,8 @@ Author: Leonardo de Moura
#include <memory>
#include <vector>
#include "kernel/environment.h"
#include "library/io_state.h"
#include "library/io_state_stream.h"
#include "library/projection.h"
namespace lean {
@ -18,12 +20,15 @@ namespace lean {
are supported.
This is a generic class containing several virtual methods that must
be implemeneted by "customers" (e.g., type class resolution procedure, blast tactic).
be implemeneted by "customers" (e.g., blast tactic).
This class also implements type class resolution
*/
class type_inference {
struct ext_ctx;
friend struct ext_ctx;
environment m_env;
io_state m_ios;
name_generator m_ngen;
std::unique_ptr<ext_ctx> m_ext_ctx;
// postponed universe constraints
@ -117,10 +122,94 @@ class type_inference {
void commit() { m_postponed_sz = m_owner.m_postponed.size(); m_owner.commit(); m_keep = true; }
};
// Data-structures for type class resolution
struct ci_stack_entry {
// We only use transitive instances when we can solve the problem in a single step.
// That is, the transitive instance does not have any instance argument, OR
// it uses local instances to fill them.
// We accomplish that by not considering global instances when solving
// transitive instance subproblems.
expr m_mvar;
unsigned m_depth;
bool m_trans_inst_subproblem;
ci_stack_entry(expr const & m, unsigned d, bool s = false):
m_mvar(m), m_depth(d), m_trans_inst_subproblem(s) {}
};
struct ci_state {
bool m_trans_inst_subproblem;
list<ci_stack_entry> m_stack; // stack of meta-variables that need to be synthesized;
};
struct ci_choice {
list<expr> m_local_instances;
list<name> m_trans_instances;
list<name> m_instances;
ci_state m_state;
};
struct ci_choices_scope {
type_inference & m_owner;
unsigned m_ci_choices_sz;
bool m_keep{false};
ci_choices_scope(type_inference & o):m_owner(o), m_ci_choices_sz(o.m_ci_choices.size()) {}
~ci_choices_scope() { if (!m_keep) m_owner.restore_choices(m_ci_choices_sz); }
void commit() { m_keep = true; }
};
pos_info_provider const * m_pip;
std::vector<pair<name, expr>> m_ci_local_instances;
expr_struct_map<expr> m_ci_cache;
bool m_ci_multiple_instances;
expr m_ci_main_mvar;
ci_state m_ci_state; // active state
std::vector<ci_choice> m_ci_choices;
unsigned m_ci_choices_ini_sz;
bool m_ci_displayed_trace_header;
optional<pos_info> m_ci_pos;
// configuration options
unsigned m_ci_max_depth;
bool m_ci_trans_instances;
bool m_ci_trace_instances;
io_state_stream diagnostic();
optional<name> constant_is_class(expr const & e);
optional<name> is_full_class(expr type);
lbool is_quick_class(expr const & type, name & result);
optional<pair<expr, expr>> find_unsynth_metavar(expr const & e);
void trace(unsigned depth, expr const & mvar, expr const & mvar_type, expr const & r);
bool try_instance(ci_stack_entry const & e, expr const & inst, expr const & inst_type, bool trans_inst);
bool try_instance(ci_stack_entry const & e, name const & inst_name, bool trans_inst);
list<expr> get_local_instances(name const & cname);
bool is_ci_done() const;
bool mk_choice_point(expr const & mvar);
bool process_next_alt_core(ci_stack_entry const & e, list<expr> & insts);
bool process_next_alt_core(ci_stack_entry const & e, list<name> & inst_names, bool trans_inst);
bool process_next_alt(ci_stack_entry const & e);
bool process_next_mvar();
bool backtrack();
optional<expr> search();
optional<expr> next_solution();
void init_search(expr const & type);
void restore_choices(unsigned old_sz);
optional<expr> ensure_no_meta(optional<expr> r);
optional<expr> mk_nested_instance(expr const & type);
bool mk_nested_instance(expr const & m, expr const & m_type);
optional<expr> mk_class_instance_core(expr const & type);
optional<expr> check_ci_cache(expr const & type) const;
void cache_ci_result(expr const & type, expr const & inst);
public:
type_inference(environment const & env);
type_inference(environment const & env, io_state const & ios, bool multiple_instances = false);
virtual ~type_inference();
void set_context(list<expr> const & ctx);
environment const & env() const { return m_env; }
/** \brief Opaque constants are never unfolded by this procedure.
The is_def_eq method will lazily unfold non-opaque constants.
@ -169,8 +258,13 @@ public:
should return true if m can be assigned to an abstraction of \c v.
\remark This method should check at least if m does not occur in v,
and if all tmp locals in v are in locals. */
virtual bool validate_assignment(expr const & m, buffer<expr> const & locals, expr const & v) = 0;
and if all tmp locals in v are in locals.
The default implementation checks the following things:
1. Any (internal) local constant occurring in v occurs in locals
2. m does not occur in v
*/
virtual bool validate_assignment(expr const & m, buffer<expr> const & locals, expr const & v);
/** \brief Return the type of a local constant (local or not).
\remark This method allows the customer to store the type of local constants
@ -180,7 +274,11 @@ public:
/** \brief Return the type of a meta-variable (even if it is not a unification one) */
virtual expr infer_metavar(expr const & e) const = 0;
/** \brief Save the current assignment */
virtual level mk_uvar() = 0;
virtual expr mk_mvar(expr const &) = 0;
/** \brief Save the current assignment and metavariable declarations */
virtual void push() = 0;
/** \brief Retore assignment (inverse for push) */
virtual void pop() = 0;
@ -189,8 +287,11 @@ public:
/** \brief This method is invoked before failure.
The "customer" may try to assign unassigned mvars in the given expression.
The result is true to indicate that some metavariable has been assigned. */
virtual bool on_is_def_eq_failure(expr &, expr &) { return false; }
The result is true to indicate that some metavariable has been assigned.
The default implementation tries to invoke type class resolution to
assign unassigned metavariables in the given terms. */
virtual bool on_is_def_eq_failure(expr &, expr &);
bool is_assigned(level const & u) const { return get_assignment(u) != nullptr; }
bool is_assigned(expr const & m) const { return get_assignment(m) != nullptr; }
@ -222,8 +323,100 @@ public:
*/
bool is_def_eq(expr const & e1, expr const & e2);
/** \brief If \c type is a type class, return its name */
optional<name> is_class(expr const & type);
/** \brief Try to synthesize an instance of the type class \c type */
optional<expr> mk_class_instance(expr const & type);
optional<expr> next_class_instance();
/** \brief Clear internal caches used to speedup computation */
void clear_cache();
/** \brief Update configuration options.
Return true iff the new options do not change the behavior of the object.
\remark We assume pretty-printing options are irrelevant. */
bool update_options(options const & opts);
/** \brief Return true if the local instances at \c ctx are compatible with the ones
in the type inference object. This method is used to decide whether a type inference
object can be reused by the elaborator. */
bool compatible_local_instances(list<expr> const & ctx);
/** \brief Auxiliary object used to set position information for the type class resolution trace. */
class scope_pos_info {
type_inference & m_owner;
pos_info_provider const * m_old_pip;
optional<pos_info> m_old_pos;
public:
scope_pos_info(type_inference & o, pos_info_provider const * pip, expr const & pos_ref);
~scope_pos_info();
};
};
/** \brief Default implementation for the generic type_inference class.
It implements a simple meta-variable assignment.
We use this class to implement the interface with the elaborator. */
class default_type_inference : public type_inference {
typedef rb_map<unsigned, level, unsigned_cmp> uassignment;
typedef rb_map<unsigned, expr, unsigned_cmp> eassignment;
name_predicate m_not_reducible_pred;
struct assignment {
uassignment m_uassignment;
eassignment m_eassignment;
};
assignment m_assignment;
std::vector<assignment> m_trail;
unsigned m_next_local_idx;
unsigned m_next_uvar_idx;
unsigned m_next_mvar_idx;
/** \brief When m_ignore_if_zero is true, the following type-class resolution problem fails
Given (A : Type{?u}), where ?u is a universe meta-variable created by an external module,
?Inst : subsingleton.{?u} A := subsingleton_prop ?M
This case generates the unification problem
subsingleton.{?u} A =?= subsingleton.{0} ?M
which can be solved by assigning (?u := 0) and (?M := A)
when the hack is enabled, the is_def_eq method in the type class module fails at the subproblem:
?u =?= 0
That is, when the hack is on, type-class resolution cannot succeed by instantiating an external universe
meta-variable with 0.
*/
bool m_ignore_if_zero;
unsigned uvar_idx(level const & l) const;
unsigned mvar_idx(expr const & m) const;
public:
default_type_inference(environment const & env, io_state const & ios,
list<expr> const & ctx = list<expr>(), bool multiple_instances = false);
virtual ~default_type_inference();
virtual bool is_extra_opaque(name const & n) const { return m_not_reducible_pred(n); }
virtual bool ignore_universe_def_eq(level const & l1, level const & l2) const;
virtual expr mk_tmp_local(expr const & type, binder_info const & bi);
virtual bool is_tmp_local(expr const & e) const;
virtual bool is_uvar(level const & l) const;
virtual bool is_mvar(expr const & e) const;
virtual level const * get_assignment(level const & u) const;
virtual expr const * get_assignment(expr const & m) const;
virtual void update_assignment(level const & u, level const & v);
virtual void update_assignment(expr const & m, expr const & v);
virtual level mk_uvar();
virtual expr mk_mvar(expr const &);
virtual expr infer_local(expr const & e) const { return mlocal_type(e); }
virtual expr infer_metavar(expr const & e) const { return mlocal_type(e); }
virtual void push() { m_trail.push_back(m_assignment); }
virtual void pop() { lean_assert(!m_trail.empty()); m_assignment = m_trail.back(); m_trail.pop_back(); }
virtual void commit() { lean_assert(!m_trail.empty()); m_trail.pop_back(); }
bool & get_ignore_if_zero() { return m_ignore_if_zero; }
};
void initialize_type_inference();