diff --git a/src/library/class_instance_resolution.cpp b/src/library/class_instance_resolution.cpp index 3f985745c..fc0d3dd6c 100644 --- a/src/library/class_instance_resolution.cpp +++ b/src/library/class_instance_resolution.cpp @@ -160,6 +160,67 @@ struct cienv { virtual void pop() { m_cienv.m_state = m_stack.back(); m_stack.pop_back(); } virtual void commit() { m_stack.pop_back(); } + static bool has_meta_arg(expr e) { + while (is_app(e)) { + if (is_meta(app_arg(e))) + return true; + e = app_fn(e); + } + return false; + } + + /** IF \c e is of the form (f ... (?m t_1 ... t_n) ...) where ?m is an unassigned + metavariable whose type is a type class, and (?m t_1 ... t_n) must be synthesized + by type class resolution, then we return ?m. + Otherwise, we return none */ + optional find_unsynth_metavar(expr const & e) { + if (!has_meta_arg(e)) + return none_expr(); + buffer args; + expr const & fn = get_app_args(e, args); + expr type = m_cienv.infer_type(fn); + unsigned i = 0; + while (i < args.size()) { + type = m_cienv.whnf(type); + if (!is_pi(type)) + return none_expr(); + expr const & arg = args[i]; + if (binding_info(type).is_inst_implicit() && is_meta(arg)) { + expr const & m = get_app_fn(arg); + if (is_mvar(m)) { + expr m_type = instantiate_uvars_mvars(infer_metavar(m)); + if (!has_expr_metavar_relaxed(m_type)) { + return some_expr(m); + } + } + } + type = instantiate(binding_body(type), arg); + i++; + } + return none_expr(); + } + + bool mk_instance(expr const & m) { + lean_assert(m); + std::cout << "\n\nFOUND CANDIDATE: " << m << "\n\n"; + // TODO(Leo) + return false; + } + + virtual bool on_is_def_eq_failure(expr const & e1, expr const & e2) { + if (is_app(e1) && is_app(e2)) { + if (auto m1 = find_unsynth_metavar(e1)) { + if (mk_instance(*m1)) + return true; + } + if (auto m2 = find_unsynth_metavar(e2)) { + if (mk_instance(*m2)) + return true; + } + } + return false; + } + virtual bool ignore_universe_def_eq(level const & l1, level const & l2) const { if (is_meta(l1) || is_meta(l2)) { // The unifier may invoke this module before universe metavariables in the class diff --git a/src/library/type_inference.cpp b/src/library/type_inference.cpp index 0bb5270ee..a714ffdd8 100644 --- a/src/library/type_inference.cpp +++ b/src/library/type_inference.cpp @@ -498,14 +498,18 @@ bool type_inference::is_def_eq_binding(expr e1, expr e2) { instantiate_rev(e2, subst.size(), subst.data())); } -bool type_inference::is_def_eq_args(expr t, expr s) { - while (is_app(t) && is_app(s)) { - if (!is_def_eq_core(app_arg(t), app_arg(s))) +bool type_inference::is_def_eq_args(expr const & e1, expr const & e2) { + lean_assert(is_app(e1) && is_app(e2)); + buffer args1, args2; + get_app_args(e1, args1); + get_app_args(e2, args2); + if (args1.size() != args2.size()) + return false; + for (unsigned i = 0; i < args1.size(); i++) { + if (!is_def_eq_core(args1[i], args2[i])) return false; - t = app_fn(t); - s = app_fn(s); } - return !is_app(t) && !is_app(s); + return true; } bool type_inference::is_def_eq_eta(expr const & e1, expr const & e2) { @@ -745,7 +749,10 @@ bool type_inference::is_def_eq_core(expr const & t, expr const & s) { if (is_def_eq_proof_irrel(t_n, s_n)) return true; - return false; + if (on_is_def_eq_failure(t_n, s_n)) + return is_def_eq_core(t_n, s_n); + else + return false; } bool type_inference::is_def_eq(expr const & e1, expr const & e2) { diff --git a/src/library/type_inference.h b/src/library/type_inference.h index ee9de4069..b62df50da 100644 --- a/src/library/type_inference.h +++ b/src/library/type_inference.h @@ -42,7 +42,7 @@ class type_inference { lbool quick_is_def_eq(expr const & e1, expr const & e2); bool is_def_eq_core(expr const & e1, expr const & e2); - bool is_def_eq_args(expr t, expr s); + bool is_def_eq_args(expr const & e1, expr const & e2); bool is_def_eq_binding(expr e1, expr e2); bool is_def_eq_eta(expr const & e1, expr const & e2); bool is_def_eq_proof_irrel(expr const & e1, expr const & e2); @@ -141,6 +141,11 @@ public: /** \brief Keep the changes since last push */ virtual void commit() = 0; + /** \brief This method is invoked before failure. + The "customer" may try to assign unassigned mvars in the given expression. + The result is true to indicate that some metavariable has been assigned. */ + virtual bool on_is_def_eq_failure(expr const &, expr const &) { return false; } + bool is_assigned(level const & u) const { return get_assignment(u) != nullptr; } bool is_assigned(expr const & m) const { return get_assignment(m) != nullptr; }