feat(library/class_instance_resolution): new type class resolution procedure

This commit is contained in:
Leonardo de Moura 2015-10-17 11:24:19 -07:00
parent c69bbd4eb7
commit f5819fab60
2 changed files with 242 additions and 44 deletions

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "library/normalize.h"
#include "library/reducible.h"
#include "library/class.h"
#include "library/io_state_stream.h"
#include "library/replace_visitor.h"
#include "library/class_instance_resolution.h"
@ -93,13 +94,15 @@ struct cienv {
typedef rb_map<unsigned, expr, unsigned_cmp> eassignment;
environment m_env;
pos_info_provider const * m_pip;
optional<pos_info> m_pos;
ci_type_inference_ptr m_tc_ptr;
expr_struct_map<expr> m_cache;
name_generator m_ngen;
name_predicate m_not_reducible_pred;
list<expr> m_ctx;
buffer<expr> m_local_instances;
buffer<pair<name, expr>> m_local_instances;
unsigned m_next_uvar;
unsigned m_next_mvar;
@ -119,10 +122,13 @@ struct cienv {
state m_state;
};
list<choice> m_choices;
std::vector<choice> m_choices;
expr m_main_mvar;
bool m_multiple_instances;
bool m_displayed_trace_header;
// configuration
bool m_unique_instances;
unsigned m_max_depth;
@ -304,14 +310,14 @@ struct cienv {
// Auxiliary method for set_ctx
void set_local_instance(unsigned i, expr const & e) {
void set_local_instance(unsigned i, name const & cname, expr const & e) {
lean_assert(i <= m_local_instances.size());
if (i == m_local_instances.size()) {
reset_cache();
m_local_instances.push_back(e);
} else if (e != m_local_instances[i]) {
m_local_instances.push_back(mk_pair(cname, e));
} else if (e != m_local_instances[i].second) {
reset_cache();
m_local_instances[i] = e;
m_local_instances[i] = mk_pair(cname, e);
} else {
// we don't need to reset the cache since this local instance
// is equal to the one used in a previous call
@ -330,13 +336,19 @@ struct cienv {
// Remark: we use infer_type(e) instead of mlocal_type because we want to allow
// customers (e.g., blast) of this class to store the type of local constants
// in a different place.
if (is_class(infer_type(e))) {
set_local_instance(i, e);
if (auto cname = is_class(infer_type(e))) {
set_local_instance(i, *cname, e);
i++;
}
}
}
void set_pos_info(pos_info_provider const * pip, expr const & type) {
m_pip = pip;
if (m_pip)
m_pos = m_pip->get_pos_info(type);
}
// Create an internal universal metavariable
level mk_uvar() {
unsigned idx = m_next_uvar;
@ -823,27 +835,207 @@ struct cienv {
}
}
expr init_search(expr const & type) {
m_state = state();
expr m = mk_mvar(type);
m_state.m_stack = cons(m, m_state.m_stack);
return m;
void trace(expr const & mvar, expr const & r) {
if (!m_trace_instances)
return;
auto out = diagnostic(m_env, *g_ios);
unsigned depth = m_choices.size();
if (!m_displayed_trace_header && depth == 1) {
if (m_pip) {
if (auto fname = m_pip->get_file_name()) {
out << fname << ":";
}
if (m_pos)
out << m_pos->first << ":" << m_pos->second << ":";
}
out << " class-instance resolution trace" << endl;
m_displayed_trace_header = true;
}
for (unsigned i = 0; i < depth; i++)
out << " ";
if (depth > 1)
out << "[" << depth << "] ";
out << mvar << " : " << mlocal_type(mvar) << " := " << r << endl;
}
bool try_instance(expr const & mvar, expr const & inst, expr const & inst_type) {
try {
buffer<expr> locals;
expr mvar_type = mlocal_type(mvar);
while (true) {
mvar_type = whnf(mvar_type);
if (!is_pi(mvar_type))
break;
expr local = mk_local(binding_domain(mvar_type));
locals.push_back(local);
mvar_type = instantiate(binding_body(mvar_type), local);
}
expr type = inst_type;
expr r = inst;
buffer<expr> new_inst_mvars;
while (true) {
type = whnf(type);
if (!is_pi(type))
break;
expr new_mvar = mk_mvar(binding_domain(type));
if (binding_info(type).is_inst_implicit()) {
new_inst_mvars.push_back(new_mvar);
}
r = mk_app(r, new_mvar);
type = instantiate(binding_body(type), new_mvar);
}
trace(mvar, r);
if (!is_def_eq(mvar_type, type))
return false;
r = Fun(locals, r);
assign(mvar, r);
// copy new_inst_mvars to stack
unsigned i = new_inst_mvars.size();
while (i > 0) {
--i;
m_state.m_stack = cons(new_inst_mvars[i], m_state.m_stack);
}
return true;
} catch (exception &) {
return false;
}
}
bool try_instance(expr const & mvar, name const & inst_name) {
if (auto decl = m_env.find(inst_name)) {
buffer<level> ls_buffer;
unsigned num_univ_ps = decl->get_num_univ_params();
for (unsigned i = 0; i < num_univ_ps; i++)
ls_buffer.push_back(mk_uvar());
levels ls = to_list(ls_buffer.begin(), ls_buffer.end());
expr inst_cnst = mk_constant(inst_name, ls);
expr inst_type = instantiate_type_univ_params(*decl, ls);
return try_instance(mvar, inst_cnst, inst_type);
} else {
return false;
}
}
list<expr> get_local_instances(name const & cname) {
buffer<expr> selected;
for (pair<name, expr> const & p : m_local_instances) {
if (p.first == cname)
selected.push_back(p.second);
}
return to_list(selected);
}
bool is_done() const {
return empty(m_state.m_stack);
}
expr const & next_mvar() const {
lean_assert(!is_done());
return head(m_state.m_stack);
}
bool mk_choice_point(expr const & mvar) {
lean_assert(is_mvar(mvar));
expr mvar_type = instantiate_uvars_mvars(mlocal_type(mvar));
if (has_expr_metavar(mvar_type))
return false;
auto cname = is_class(mvar_type);
if (!cname)
return false;
choice r;
r.m_local_instances = get_local_instances(*cname);
if (m_trans_instances && m_choices.empty()) {
// we only use transitive instances in the top-level
r.m_trans_instances = get_class_derived_trans_instances(m_env, *cname);
}
r.m_instances = get_class_instances(m_env, *cname);
if (empty(r.m_local_instances) && empty(r.m_trans_instances) && empty(r.m_instances))
return false;
r.m_state = m_state;
m_choices.push_back(r);
return true;
}
bool process_next_alt_core(list<expr> & insts) {
while (!empty(insts)) {
expr inst = head(insts);
insts = tail(insts);
expr inst_type = infer_type(inst);
if (try_instance(next_mvar(), inst, inst_type))
return true;
}
return false;
}
bool process_next_alt_core(list<name> & inst_names) {
while (!empty(inst_names)) {
name inst_name = head(inst_names);
inst_names = tail(inst_names);
if (try_instance(next_mvar(), inst_name))
return true;
}
return false;
}
bool process_next_alt() {
lean_assert(!m_choices.empty());
choice & c = m_choices.back();
if (process_next_alt_core(c.m_local_instances))
return true;
if (process_next_alt_core(c.m_trans_instances))
return true;
if (process_next_alt_core(c.m_instances))
return true;
return false;
}
bool process_next_mvar() {
lean_assert(!is_done());
expr mvar = next_mvar();
if (!mk_choice_point(mvar))
return false;
return process_next_alt();
}
bool backtrack() {
lean_assert(!m_choices.empty());
while (true) {
m_choices.pop_back();
if (m_choices.empty())
return false;
m_state = m_choices.back().m_state;
if (process_next_alt())
return true;
}
}
optional<expr> search() {
// TODO(Leo):
while (!is_done()) {
if (!process_next_mvar()) {
if (!backtrack())
return none_expr();
}
}
return some_expr(instantiate_uvars_mvars(m_main_mvar));
}
optional<expr> operator()(environment const & env, options const & o, list<expr> const & ctx, expr const & type) {
void init_search(expr const & type) {
m_state = state();
m_main_mvar = mk_mvar(type);
m_state.m_stack = cons(m_main_mvar, m_state.m_stack);
m_displayed_trace_header = false;
}
optional<expr> operator()(environment const & env, options const & o, pos_info_provider const * pip, list<expr> const & ctx, expr const & type) {
set_env(env);
set_options(o);
set_ctx(ctx);
set_pos_info(pip, type);
if (auto r = check_cache(type))
return r;
expr m = init_search(type);
init_search(type);
if (auto r = search()) {
cache_result(type, *r);
@ -856,8 +1048,14 @@ struct cienv {
optional<expr> next() {
if (!m_multiple_instances)
return none_expr();
// TODO(Leo): backtrack and search
if (m_choices.empty())
return none_expr();
m_state = m_choices.back().m_state;
if (process_next_alt())
return search();
else if (backtrack())
return search();
else
return none_expr();
}
};
@ -879,13 +1077,13 @@ ci_type_inference_factory_scope::~ci_type_inference_factory_scope() {
g_factory = m_old;
}
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e) {
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider * pip) {
flet<io_state*> set_ios(g_ios, const_cast<io_state*>(&ios));
return get_cienv()(env, ios.get_options(), ctx, e);
return get_cienv()(env, ios.get_options(), pip, ctx, e);
}
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e) {
return mk_class_instance(env, get_dummy_ios(), ctx, e);
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e, pos_info_provider * pip) {
return mk_class_instance(env, get_dummy_ios(), ctx, e, pip);
}
void initialize_class_instance_resolution() {

View file

@ -30,8 +30,8 @@ public:
~ci_type_inference_factory_scope();
};
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e);
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e);
optional<expr> mk_class_instance(environment const & env, io_state const & ios, list<expr> const & ctx, expr const & e, pos_info_provider const * pip = nullptr);
optional<expr> mk_class_instance(environment const & env, list<expr> const & ctx, expr const & e, pos_info_provider const * pip = nullptr);
void initialize_class_instance_resolution();
void finalize_class_instance_resolution();
}