feat(library/class_instance_resolution): new type class resolution procedure
This commit is contained in:
parent
c69bbd4eb7
commit
f5819fab60
2 changed files with 242 additions and 44 deletions
|
@ -13,6 +13,7 @@ Author: Leonardo de Moura
|
|||
#include "library/normalize.h"
|
||||
#include "library/reducible.h"
|
||||
#include "library/class.h"
|
||||
#include "library/io_state_stream.h"
|
||||
#include "library/replace_visitor.h"
|
||||
#include "library/class_instance_resolution.h"
|
||||
|
||||
|
@ -92,17 +93,19 @@ struct cienv {
|
|||
typedef rb_map<unsigned, level, unsigned_cmp> uassignment;
|
||||
typedef rb_map<unsigned, expr, unsigned_cmp> eassignment;
|
||||
|
||||
environment m_env;
|
||||
ci_type_inference_ptr m_tc_ptr;
|
||||
expr_struct_map<expr> m_cache;
|
||||
name_generator m_ngen;
|
||||
name_predicate m_not_reducible_pred;
|
||||
environment m_env;
|
||||
pos_info_provider const * m_pip;
|
||||
optional<pos_info> m_pos;
|
||||
ci_type_inference_ptr m_tc_ptr;
|
||||
expr_struct_map<expr> m_cache;
|
||||
name_generator m_ngen;
|
||||
name_predicate m_not_reducible_pred;
|
||||
|
||||
list<expr> m_ctx;
|
||||
buffer<expr> m_local_instances;
|
||||
list<expr> m_ctx;
|
||||
buffer<pair<name, expr>> m_local_instances;
|
||||
|
||||
unsigned m_next_uvar;
|
||||
unsigned m_next_mvar;
|
||||
unsigned m_next_uvar;
|
||||
unsigned m_next_mvar;
|
||||
|
||||
struct state {
|
||||
list<expr> m_stack; // stack of meta-variables that need to be synthesized;
|
||||
|
@ -110,24 +113,27 @@ struct cienv {
|
|||
eassignment m_eassignment;
|
||||
};
|
||||
|
||||
state m_state; // active state
|
||||
state m_state; // active state
|
||||
|
||||
struct choice {
|
||||
list<expr> m_local_instances;
|
||||
list<name> m_trans_instances;
|
||||
list<name> m_instances;
|
||||
state m_state;
|
||||
list<expr> m_local_instances;
|
||||
list<name> m_trans_instances;
|
||||
list<name> m_instances;
|
||||
state m_state;
|
||||
};
|
||||
|
||||
list<choice> m_choices;
|
||||
std::vector<choice> m_choices;
|
||||
expr m_main_mvar;
|
||||
|
||||
bool m_multiple_instances;
|
||||
bool m_multiple_instances;
|
||||
|
||||
bool m_displayed_trace_header;
|
||||
|
||||
// configuration
|
||||
bool m_unique_instances;
|
||||
unsigned m_max_depth;
|
||||
bool m_trans_instances;
|
||||
bool m_trace_instances;
|
||||
bool m_unique_instances;
|
||||
unsigned m_max_depth;
|
||||
bool m_trans_instances;
|
||||
bool m_trace_instances;
|
||||
|
||||
cienv(bool multiple_instances = false):
|
||||
m_ngen(*g_prefix2),
|
||||
|
@ -304,14 +310,14 @@ struct cienv {
|
|||
|
||||
|
||||
// Auxiliary method for set_ctx
|
||||
void set_local_instance(unsigned i, expr const & e) {
|
||||
void set_local_instance(unsigned i, name const & cname, expr const & e) {
|
||||
lean_assert(i <= m_local_instances.size());
|
||||
if (i == m_local_instances.size()) {
|
||||
reset_cache();
|
||||
m_local_instances.push_back(e);
|
||||
} else if (e != m_local_instances[i]) {
|
||||
m_local_instances.push_back(mk_pair(cname, e));
|
||||
} else if (e != m_local_instances[i].second) {
|
||||
reset_cache();
|
||||
m_local_instances[i] = e;
|
||||
m_local_instances[i] = mk_pair(cname, e);
|
||||
} else {
|
||||
// we don't need to reset the cache since this local instance
|
||||
// is equal to the one used in a previous call
|
||||
|
@ -330,13 +336,19 @@ struct cienv {
|
|||
// Remark: we use infer_type(e) instead of mlocal_type because we want to allow
|
||||
// customers (e.g., blast) of this class to store the type of local constants
|
||||
// in a different place.
|
||||
if (is_class(infer_type(e))) {
|
||||
set_local_instance(i, e);
|
||||
if (auto cname = is_class(infer_type(e))) {
|
||||
set_local_instance(i, *cname, e);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void set_pos_info(pos_info_provider const * pip, expr const & type) {
|
||||
m_pip = pip;
|
||||
if (m_pip)
|
||||
m_pos = m_pip->get_pos_info(type);
|
||||
}
|
||||
|
||||
// Create an internal universal metavariable
|
||||
level mk_uvar() {
|
||||
unsigned idx = m_next_uvar;
|
||||
|
@ -823,27 +835,207 @@ struct cienv {
|
|||
}
|
||||
}
|
||||
|
||||
expr init_search(expr const & type) {
|
||||
m_state = state();
|
||||
expr m = mk_mvar(type);
|
||||
m_state.m_stack = cons(m, m_state.m_stack);
|
||||
return m;
|
||||
void trace(expr const & mvar, expr const & r) {
|
||||
if (!m_trace_instances)
|
||||
return;
|
||||
auto out = diagnostic(m_env, *g_ios);
|
||||
unsigned depth = m_choices.size();
|
||||
if (!m_displayed_trace_header && depth == 1) {
|
||||
if (m_pip) {
|
||||
if (auto fname = m_pip->get_file_name()) {
|
||||
out << fname << ":";
|
||||
}
|
||||
if (m_pos)
|
||||
out << m_pos->first << ":" << m_pos->second << ":";
|
||||
}
|
||||
out << " class-instance resolution trace" << endl;
|
||||
m_displayed_trace_header = true;
|
||||
}
|
||||
for (unsigned i = 0; i < depth; i++)
|
||||
out << " ";
|
||||
if (depth > 1)
|
||||
out << "[" << depth << "] ";
|
||||
out << mvar << " : " << mlocal_type(mvar) << " := " << r << endl;
|
||||
}
|
||||
|
||||
bool try_instance(expr const & mvar, expr const & inst, expr const & inst_type) {
|
||||
try {
|
||||
buffer<expr> locals;
|
||||
expr mvar_type = mlocal_type(mvar);
|
||||
while (true) {
|
||||
mvar_type = whnf(mvar_type);
|
||||
if (!is_pi(mvar_type))
|
||||
break;
|
||||
expr local = mk_local(binding_domain(mvar_type));
|
||||
locals.push_back(local);
|
||||
mvar_type = instantiate(binding_body(mvar_type), local);
|
||||
}
|
||||
expr type = inst_type;
|
||||
expr r = inst;
|
||||
buffer<expr> new_inst_mvars;
|
||||
while (true) {
|
||||
type = whnf(type);
|
||||
if (!is_pi(type))
|
||||
break;
|
||||
expr new_mvar = mk_mvar(binding_domain(type));
|
||||
if (binding_info(type).is_inst_implicit()) {
|
||||
new_inst_mvars.push_back(new_mvar);
|
||||
}
|
||||
r = mk_app(r, new_mvar);
|
||||
type = instantiate(binding_body(type), new_mvar);
|
||||
}
|
||||
trace(mvar, r);
|
||||
if (!is_def_eq(mvar_type, type))
|
||||
return false;
|
||||
r = Fun(locals, r);
|
||||
assign(mvar, r);
|
||||
// copy new_inst_mvars to stack
|
||||
unsigned i = new_inst_mvars.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
m_state.m_stack = cons(new_inst_mvars[i], m_state.m_stack);
|
||||
}
|
||||
return true;
|
||||
} catch (exception &) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool try_instance(expr const & mvar, name const & inst_name) {
|
||||
if (auto decl = m_env.find(inst_name)) {
|
||||
buffer<level> ls_buffer;
|
||||
unsigned num_univ_ps = decl->get_num_univ_params();
|
||||
for (unsigned i = 0; i < num_univ_ps; i++)
|
||||
ls_buffer.push_back(mk_uvar());
|
||||
levels ls = to_list(ls_buffer.begin(), ls_buffer.end());
|
||||
expr inst_cnst = mk_constant(inst_name, ls);
|
||||
expr inst_type = instantiate_type_univ_params(*decl, ls);
|
||||
return try_instance(mvar, inst_cnst, inst_type);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
list<expr> get_local_instances(name const & cname) {
|
||||
buffer<expr> selected;
|
||||
for (pair<name, expr> const & p : m_local_instances) {
|
||||
if (p.first == cname)
|
||||
selected.push_back(p.second);
|
||||
}
|
||||
return to_list(selected);
|
||||
}
|
||||
|
||||
bool is_done() const {
|
||||
return empty(m_state.m_stack);
|
||||
}
|
||||
|
||||
expr const & next_mvar() const {
|
||||
lean_assert(!is_done());
|
||||
return head(m_state.m_stack);
|
||||
}
|
||||
|
||||
bool mk_choice_point(expr const & mvar) {
|
||||
lean_assert(is_mvar(mvar));
|
||||
expr mvar_type = instantiate_uvars_mvars(mlocal_type(mvar));
|
||||
if (has_expr_metavar(mvar_type))
|
||||
return false;
|
||||
auto cname = is_class(mvar_type);
|
||||
if (!cname)
|
||||
return false;
|
||||
choice r;
|
||||
r.m_local_instances = get_local_instances(*cname);
|
||||
if (m_trans_instances && m_choices.empty()) {
|
||||
// we only use transitive instances in the top-level
|
||||
r.m_trans_instances = get_class_derived_trans_instances(m_env, *cname);
|
||||
}
|
||||
r.m_instances = get_class_instances(m_env, *cname);
|
||||
if (empty(r.m_local_instances) && empty(r.m_trans_instances) && empty(r.m_instances))
|
||||
return false;
|
||||
r.m_state = m_state;
|
||||
m_choices.push_back(r);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool process_next_alt_core(list<expr> & insts) {
|
||||
while (!empty(insts)) {
|
||||
expr inst = head(insts);
|
||||
insts = tail(insts);
|
||||
expr inst_type = infer_type(inst);
|
||||
if (try_instance(next_mvar(), inst, inst_type))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool process_next_alt_core(list<name> & inst_names) {
|
||||
while (!empty(inst_names)) {
|
||||
name inst_name = head(inst_names);
|
||||
inst_names = tail(inst_names);
|
||||
if (try_instance(next_mvar(), inst_name))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool process_next_alt() {
|
||||
lean_assert(!m_choices.empty());
|
||||
choice & c = m_choices.back();
|
||||
if (process_next_alt_core(c.m_local_instances))
|
||||
return true;
|
||||
if (process_next_alt_core(c.m_trans_instances))
|
||||
return true;
|
||||
if (process_next_alt_core(c.m_instances))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool process_next_mvar() {
|
||||
lean_assert(!is_done());
|
||||
expr mvar = next_mvar();
|
||||
if (!mk_choice_point(mvar))
|
||||
return false;
|
||||
return process_next_alt();
|
||||
}
|
||||
|
||||
bool backtrack() {
|
||||
lean_assert(!m_choices.empty());
|
||||
while (true) {
|
||||
m_choices.pop_back();
|
||||
if (m_choices.empty())
|
||||
return false;
|
||||
m_state = m_choices.back().m_state;
|
||||
if (process_next_alt())
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
optional<expr> search() {
|
||||
// TODO(Leo):
|
||||
return none_expr();
|
||||
while (!is_done()) {
|
||||
if (!process_next_mvar()) {
|
||||
if (!backtrack())
|
||||
return none_expr();
|
||||
}
|
||||
}
|
||||
return some_expr(instantiate_uvars_mvars(m_main_mvar));
|
||||
}
|
||||
|
||||
optional<expr> operator()(environment const & env, options const & o, list<expr> const & ctx, expr const & type) {
|
||||
void init_search(expr const & type) {
|
||||
m_state = state();
|
||||
m_main_mvar = mk_mvar(type);
|
||||
m_state.m_stack = cons(m_main_mvar, m_state.m_stack);
|
||||
m_displayed_trace_header = false;
|
||||
}
|
||||
|
||||
optional<expr> operator()(environment const & env, options const & o, pos_info_provider const * pip, list<expr> const & ctx, expr const & type) {
|
||||
set_env(env);
|
||||
set_options(o);
|
||||
set_ctx(ctx);
|
||||
set_pos_info(pip, type);
|
||||
|
||||
if (auto r = check_cache(type))
|
||||
return r;
|
||||
|
||||
expr m = init_search(type);
|
||||
init_search(type);
|
||||
|
||||
if (auto r = search()) {
|
||||
cache_result(type, *r);
|
||||
|
@ -856,9 +1048,15 @@ struct cienv {
|
|||
optional<expr> next() {
|
||||
if (!m_multiple_instances)
|
||||
return none_expr();
|
||||
|
||||
// TODO(Leo): backtrack and search
|
||||
return none_expr();
|
||||
if (m_choices.empty())
|
||||
return none_expr();
|
||||
m_state = m_choices.back().m_state;
|
||||
if (process_next_alt())
|
||||
return search();
|
||||
else if (backtrack())
|
||||
return search();
|
||||
else
|
||||
return none_expr();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -879,13 +1077,13 @@ ci_type_inference_factory_scope::~ci_type_inference_factory_scope() {
|
|||
g_factory = m_old;
|
||||
}
|
||||
|
||||
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e) {
|
||||
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider * pip) {
|
||||
flet<io_state*> set_ios(g_ios, const_cast<io_state*>(&ios));
|
||||
return get_cienv()(env, ios.get_options(), ctx, e);
|
||||
return get_cienv()(env, ios.get_options(), pip, ctx, e);
|
||||
}
|
||||
|
||||
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e) {
|
||||
return mk_class_instance(env, get_dummy_ios(), ctx, e);
|
||||
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e, pos_info_provider * pip) {
|
||||
return mk_class_instance(env, get_dummy_ios(), ctx, e, pip);
|
||||
}
|
||||
|
||||
void initialize_class_instance_resolution() {
|
||||
|
|
|
@ -30,8 +30,8 @@ public:
|
|||
~ci_type_inference_factory_scope();
|
||||
};
|
||||
|
||||
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e);
|
||||
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e);
|
||||
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider const * pip = nullptr);
|
||||
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e, pos_info_provider const * pip = nullptr);
|
||||
void initialize_class_instance_resolution();
|
||||
void finalize_class_instance_resolution();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue