diff --git a/src/library/elaborator/elaborator.cpp b/src/library/elaborator/elaborator.cpp index e74c1e48d..b1da9d9c9 100644 --- a/src/library/elaborator/elaborator.cpp +++ b/src/library/elaborator/elaborator.cpp @@ -82,13 +82,13 @@ class elaborator::imp { struct ho_match_case_split : public case_split { unification_constraint m_constraint; unsigned m_idx; // current alternative - std::vector m_states; // set of alternatives + std::vector m_states; // alternatives + std::vector m_assumptions; // assumption for each alternative - ho_match_case_split(unification_constraint const & cnstr, unsigned num_states, state const * states, state const & prev_state): + ho_match_case_split(unification_constraint const & cnstr, state const & prev_state): case_split(prev_state), m_constraint(cnstr), - m_idx(0), - m_states(states, states + num_states) { + m_idx(0) { } virtual ~ho_match_case_split() {} @@ -96,6 +96,11 @@ class elaborator::imp { virtual bool next(imp & owner) { return owner.next_ho_case(*this); } + + void push_back(state const & s, trace const & tr) { + m_states.push_back(s); + m_assumptions.push_back(tr); + } }; struct synthesizer_case_split : public case_split { @@ -259,7 +264,36 @@ class elaborator::imp { } /** - \brief Auxiliary method for pushing a new constraint to the constraint queue. + \brief Return (f x_{num_vars - 1} ... x_0) + */ + expr mk_app_vars(expr const & f, unsigned num_vars) { + buffer args; + args.push_back(f); + unsigned i = num_vars; + while (i > 0) { + --i; + args.push_back(mk_var(i)); + } + return mk_app(args.size(), args.data()); + } + + /** + \brief Auxiliary method for pushing a new constraint to the given constraint queue. + If \c is_eq is true, then a equality constraint is created, otherwise a convertability constraint is created. + */ + void push_new_constraint(cnstr_queue & q, bool is_eq, context const & new_ctx, expr const & new_a, expr const & new_b, trace const & new_tr) { + if (is_eq) + q.push_front(mk_eq_constraint(new_ctx, new_a, new_b, new_tr)); + else + q.push_front(mk_convertible_constraint(new_ctx, new_a, new_b, new_tr)); + } + + void push_new_eq_constraint(cnstr_queue & q, context const & new_ctx, expr const & new_a, expr const & new_b, trace const & new_tr) { + push_new_constraint(q, true, new_ctx, new_a, new_b, new_tr); + } + + /** + \brief Auxiliary method for pushing a new constraint to the current constraint queue. The new constraint is based on the constraint \c c. The constraint \c c may be a equality or convertability constraint. The update is justified by \c new_tr. */ @@ -273,7 +307,7 @@ class elaborator::imp { } /** - \brief Auxiliary method for pushing a new constraint to the constraint queue. + \brief Auxiliary method for pushing a new constraint to the current constraint queue. The new constraint is based on the constraint \c c. The constraint \c c may be a equality or convertability constraint. The flag \c is_lhs says if the left-hand-side or right-hand-side are being updated with \c new_a. The update is justified by \c new_tr. @@ -608,6 +642,98 @@ class elaborator::imp { } } + /** + \brief Process a constraint ctx |- a = b where \c a is of the form (?m ...). + We perform a "case split" using "projection" or "imitation". See Huet&Lang's paper on higher order matching + for further details. + */ + bool process_meta_app(expr const & a, expr const & b, bool is_lhs, unification_constraint const & c) { + if (is_meta_app(a) && !has_local_context(arg(a, 0)) && !is_meta_app(b)) { + context const & ctx = get_context(c); + metavar_env & menv = m_state.m_menv; + expr f_a = arg(a, 0); + lean_assert(is_metavar(f_a)); + unsigned num_a = num_args(a); + buffer arg_types; + buffer ucs; + for (unsigned i = 1; i < num_a; i++) { + arg_types.push_back(m_type_inferer(arg(a, i), ctx, &menv, ucs)); + for (auto uc : ucs) + push_front(uc); + } + std::unique_ptr new_cs(new ho_match_case_split(c, m_state)); + // Add projections + for (unsigned i = 1; i < num_a; i++) { + // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), x_i + state new_state(m_state); + trace new_assumption = mk_assumption(); + expr proj = mk_lambda(arg_types, mk_var(num_a - i - 1)); + expr new_a = arg(a, i); + expr new_b = b; + if (is_lhs) + swap(new_a, new_b); + push_new_constraint(new_state.m_queue, is_eq(c), ctx, new_a, new_b, new_assumption); + push_new_eq_constraint(new_state.m_queue, ctx, f_a, proj, new_assumption); + new_cs->push_back(new_state, new_assumption); + } + // Add imitation + state new_state(m_state); + trace new_assumption = mk_assumption(); + expr imitation; + if (is_app(b)) { + // Imitation for applications + expr f_b = arg(b, 0); + unsigned num_b = num_args(b); + // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), f_b (h_1 x_1 ... x_{num_a-1}) ... (h_{num_b-1} x_1 ... x_{num_a-1}) + // New constraints (h_i a_1 ... a_{num_a-1}) == arg(b, i) + buffer imitation_args; // arguments for the imitation + imitation_args.push_back(f_b); + for (unsigned i = 1; i < num_b; i++) { + expr h_i = new_state.m_menv.mk_metavar(ctx); + imitation_args.push_back(mk_app_vars(h_i, num_a - 1)); + push_new_eq_constraint(new_state.m_queue, ctx, update_app(a, 0, h_i), arg(b, i), new_assumption); + } + imitation = mk_lambda(arg_types, mk_app(imitation_args.size(), imitation_args.data())); + } else if (is_eq(b)) { + // Imitation for equality + // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), (h_1 x_1 ... x_{num_a-1}) = (h_2 x_1 ... x_{num_a-1}) + // New constraints (h_1 a_1 ... a_{num_a-1}) == eq_lhs(b) + // (h_2 a_1 ... a_{num_a-1}) == eq_rhs(b) + expr h_1 = new_state.m_menv.mk_metavar(ctx); + expr h_2 = new_state.m_menv.mk_metavar(ctx); + push_new_eq_constraint(new_state.m_queue, ctx, update_app(a, 0, h_1), eq_lhs(b), new_assumption); + push_new_eq_constraint(new_state.m_queue, ctx, update_app(a, 0, h_2), eq_rhs(b), new_assumption); + imitation = mk_lambda(arg_types, mk_eq(mk_app_vars(h_1, num_a - 1), mk_app_vars(h_2, num_a - 1))); + } else if (is_abstraction(b)) { + // Imitation for lambdas and Pis + // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), + // fun (x_b : (?h_1 x_1 ... x_{num_a-1})), (?h_2 x_1 ... x_{num_a-1} x_b) + // New constraints (h_1 a_1 ... a_{num_a-1}) == abst_domain(b) + // (h_2 a_1 ... a_{num_a-1} x_b) == abst_body(b) + expr h_1 = new_state.m_menv.mk_metavar(ctx); + expr h_2 = new_state.m_menv.mk_metavar(ctx); + push_new_eq_constraint(new_state.m_queue, ctx, update_app(a, 0, h_1), abst_domain(b), new_assumption); + push_new_eq_constraint(new_state.m_queue, extend(ctx, abst_name(b), abst_domain(b)), + mk_app(update_app(a, 0, h_2), Var(0)), abst_body(b), new_assumption); + imitation = mk_lambda(arg_types, update_abstraction(b, mk_app_vars(h_1, num_a - 1), mk_app_vars(h_2, num_a))); + } else { + // "Dumb imitation" aka the constant function + // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), b + imitation = mk_lambda(arg_types, lift_free_vars(b, 0, num_a - 1)); + } + lean_assert(imitation); + push_new_eq_constraint(new_state.m_queue, ctx, f_a, imitation, new_assumption); + new_cs->push_back(new_state, new_assumption); + bool r = new_cs->next(*this); + lean_assert(r); + m_case_splits.push_back(std::move(new_cs)); + reset_quota(); + return r; + } else { + return false; + } + } + bool process_eq_convertible(context const & ctx, expr const & a, expr const & b, unification_constraint const & c) { bool eq = is_eq(c); if (a == b) { @@ -701,6 +827,12 @@ class elaborator::imp { return true; } + if (m_quota < 0) { + // process expensive cases + if (process_meta_app(a, b, true, c) || process_meta_app(b, a, false, c)) + return true; + } + std::cout << "Postponed: "; display(std::cout, c); push_back(c); @@ -759,28 +891,18 @@ class elaborator::imp { } } - bool next_ho_case(ho_match_case_split &) { -#if 0 - unification_constraint & cnstr = s.m_constraint; - context const & ctx = get_context(cnstr); - expr const & a = eq_lhs(cnstr); - expr const & b = eq_rhs(cnstr); - lean_assert(is_meta_app(a)); - lean_assert(!has_local_context(arg(a, 0))); - lean_assert(!is_meta_app(b)); - expr f_a = arg(a, 0); - lean_assert(is_metavar(f_a)); - unsigned num_a = num_args(a); - - - - // unification_constraints_wrapper ucw; - buffer arg_types; - for (unsigned i = 1; i < num_a; i++) { - arg_types.push_back(m_type_inferer(arg(a, i), ctx, &s, &ucw)); + bool next_ho_case(ho_match_case_split & s) { + unsigned idx = s.m_idx; + unsigned sz = s.m_states.size(); + if (idx < sz) { + s.m_idx++; + s.m_curr_assumption = s.m_assumptions[sz - idx - 1]; + m_state = s.m_states[sz - idx - 1]; + return true; + } else { + m_conflict = trace(new unification_failure_by_cases_trace(s.m_constraint, s.m_failed_traces.size(), s.m_failed_traces.data())); + return false; } -#endif - return true; } bool next_plugin_case(plugin_case_split & s) { @@ -831,7 +953,7 @@ public: while (true) { check_interrupted(m_interrupted); cnstr_queue & q = m_state.m_queue; - if (q.empty() || m_quota < -static_cast(q.size())) { + if (q.empty() || m_quota < - static_cast(q.size()) - 10) { name m = find_unassigned_metavar(); std::cout << "Queue is empty\n"; display(std::cout); if (m) { @@ -843,7 +965,7 @@ public: } } else { unification_constraint c = q.front(); - std::cout << "Processing, quota: " << m_quota << " "; display(std::cout, c); + std::cout << "Processing, quota: " << m_quota << ", depth: " << m_case_splits.size() << " "; display(std::cout, c); q.pop_front(); if (!process(c)) { resolve_conflict(); diff --git a/src/tests/library/elaborator/elaborator.cpp b/src/tests/library/elaborator/elaborator.cpp index 4ad8658f5..f1658c27c 100644 --- a/src/tests/library/elaborator/elaborator.cpp +++ b/src/tests/library/elaborator/elaborator.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "kernel/environment.h" #include "kernel/type_checker.h" #include "kernel/abstract.h" +#include "library/reduce.h" #include "library/arith/arith.h" #include "library/all/all.h" #include "library/elaborator/elaborator.h" @@ -258,6 +259,7 @@ static void tst6() { ucs.push_back(mk_eq_constraint(context(), expected, given, trace())); elaborator elb(env, menv, ucs.size(), ucs.data()); substitution s = elb.next(); + std::cout << beta_reduce(instantiate_metavars(V, s)) << "\n"; } int main() {