diff --git a/src/library/elaborator/elaborator.cpp b/src/library/elaborator/elaborator.cpp index 7e4dcf0d0..0529bf23d 100644 --- a/src/library/elaborator/elaborator.cpp +++ b/src/library/elaborator/elaborator.cpp @@ -296,6 +296,15 @@ class elaborator::imp { push_new_constraint(q, true, new_ctx, new_a, new_b, new_tr); } + /** + \brief Auxiliary method for pushing a new constraint to the current constraint queue. + If \c is_eq is true, then a equality constraint is created, otherwise a convertability constraint is created. + */ + void push_new_constraint(bool is_eq, context const & new_ctx, expr const & new_a, expr const & new_b, trace const & new_tr) { + reset_quota(); + push_new_constraint(m_state.m_queue, is_eq, new_ctx, new_a, new_b, new_tr); + } + /** \brief Auxiliary method for pushing a new constraint to the current constraint queue. The new constraint is based on the constraint \c c. The constraint \c c may be a equality or convertability constraint. @@ -405,15 +414,23 @@ class elaborator::imp { } } else { local_entry const & me = head(metavar_lctx(a)); - if (me.is_lift() && !has_free_var(b, me.s(), me.s() + me.n())) { - // Case 3 - trace new_tr(new normalize_trace(c)); - expr new_a = pop_meta_context(a); - expr new_b = lower_free_vars(b, me.s() + me.n(), me.n()); - if (!is_lhs) - swap(new_a, new_b); - push_updated_constraint(c, new_a, new_b, new_tr); - return Processed; + if (me.is_lift()) { + if (!has_free_var(b, me.s(), me.s() + me.n())) { + // Case 3 + trace new_tr(new normalize_trace(c)); + expr new_a = pop_meta_context(a); + expr new_b = lower_free_vars(b, me.s() + me.n(), me.n()); + context new_ctx = get_context(c).remove(me.s(), me.n()); + if (!is_lhs) + swap(new_a, new_b); + push_new_constraint(is_eq(c), new_ctx, new_a, new_b, new_tr); + return Processed; + } else if (is_var(b)) { + // Failure, there is no way to unify + // ?m[lift:s:n, ...] with a variable in [s, s+n] + m_conflict = trace(new unification_failure_trace(c)); + return Failed; + } } } } @@ -691,7 +708,7 @@ class elaborator::imp { expr proj = mk_lambda(arg_types, mk_var(num_a - i - 1)); expr new_a = arg(a, i); expr new_b = b; - if (is_lhs) + if (!is_lhs) swap(new_a, new_b); push_new_constraint(new_state.m_queue, is_eq(c), ctx, new_a, new_b, new_assumption); push_new_eq_constraint(new_state.m_queue, ctx, f_a, proj, new_assumption); @@ -755,6 +772,82 @@ class elaborator::imp { } } + /** \brief Return true if \c a is of the form ?m[inst:i t, ...] */ + bool is_metavar_inst(expr const & a) const { + return is_metavar(a) && has_local_context(a) && head(metavar_lctx(a)).is_inst(); + } + + /** + \brief Process a constraint ctx |- a == b where \c a is of the form ?m[(inst:i t), ...]. + We perform a "case split", + Case 1) ?m[...] == #i and t == b + Case 2) imitate b + */ + bool process_metavar_inst(expr const & a, expr const & b, bool is_lhs, unification_constraint const & c) { + if (is_metavar_inst(a) && !is_metavar_inst(b) && !is_meta_app(b)) { + context const & ctx = get_context(c); + local_context lctx = metavar_lctx(a); + unsigned i = head(lctx).s(); + expr t = head(lctx).v(); + std::unique_ptr new_cs(new generic_case_split(c, m_state)); + { + // Case 1 + state new_state(m_state); + trace new_assumption = mk_assumption(); + // add ?m[...] == #1 + push_new_eq_constraint(new_state.m_queue, ctx, pop_meta_context(a), mk_var(i), new_assumption); + // add t == b (t << b) + expr new_a = t; + expr new_b = b; + if (!is_lhs) + swap(new_a, new_b); + push_new_constraint(new_state.m_queue, is_eq(c), ctx, new_a, new_b, new_assumption); + new_cs->push_back(new_state, new_assumption); + } + { + // Case 2 + state new_state(m_state); + trace new_assumption = mk_assumption(); + expr imitation; + if (is_app(b)) { + // Imitation for applications b == f(s_1, ..., s_k) + // mname <- f(?h_1, ..., ?h_k) + expr f_b = arg(b, 0); + unsigned num_b = num_args(b); + buffer imitation_args; + imitation_args.push_back(f_b); + for (unsigned i = 1; i < num_b; i++) + imitation_args.push_back(new_state.m_menv.mk_metavar(ctx)); + imitation = mk_app(imitation_args.size(), imitation_args.data()); + } else if (is_eq(b)) { + // Imitation for equality b == Eq(s1, s2) + // mname <- Eq(?h_1, ?h_2) + expr h_1 = new_state.m_menv.mk_metavar(ctx); + expr h_2 = new_state.m_menv.mk_metavar(ctx); + imitation = mk_eq(h_1, h_2); + } else if (is_abstraction(b)) { + // Lambdas and Pis + // Imitation for Lambdas and Pis, b == Fun(x:T) B + // mname <- Fun (x:?h_1) ?h_2 x) + expr h_1 = new_state.m_menv.mk_metavar(ctx); + expr h_2 = new_state.m_menv.mk_metavar(ctx); + imitation = update_abstraction(b, h_1, mk_app(h_2, Var(0))); + } else { + imitation = lift_free_vars(b, i, 1); + } + push_new_eq_constraint(new_state.m_queue, ctx, pop_meta_context(a), imitation, new_assumption); + new_cs->push_back(new_state, new_assumption); + } + bool r = new_cs->next(*this); + lean_assert(r); + m_case_splits.push_back(std::move(new_cs)); + reset_quota(); + return r; + } else { + return false; + } + } + /** \brief Process constraint of the form ctx |- a << ?m, where \c a is Type of Bool */ bool process_lower(expr const & a, expr const & b, unification_constraint const & c) { if (is_convertible(c) && is_metavar(b) && (a == Bool || is_type(a))) { @@ -776,24 +869,6 @@ class elaborator::imp { } } - /** - \brief Process a constraints of the form: - - true == (t1 = t2) - - true << (t1 = t2) - - \remark This method should be removed if we remove T == T ==> true normalization rule from the - kernel. - */ - bool process_true_eq(expr const & a, expr const & b, unification_constraint const & c) { - if (a == True && is_eq(b)) { - trace new_tr(new normalize_trace(c)); - push_front(mk_eq_constraint(get_context(c), eq_lhs(b), eq_rhs(b), new_tr)); - return true; - } else { - return false; - } - } - bool process_eq_convertible(context const & ctx, expr const & a, expr const & b, unification_constraint const & c) { bool eq = is_eq(c); if (a == b) { @@ -818,8 +893,7 @@ class elaborator::imp { process_simple_ho_match(ctx, b, a, false, c)) return true; - if (process_true_eq(a, b, c) || - process_true_eq(b, a, c)) + if (!eq && a == Bool && is_type(b)) return true; if (a.kind() == b.kind()) { @@ -895,6 +969,8 @@ class elaborator::imp { // process expensive cases if (process_meta_app(a, b, true, c) || process_meta_app(b, a, false, c)) return true; + if (process_metavar_inst(a, b, true, c) || process_metavar_inst(b, a, false, c)) + return true; } if (m_quota < - static_cast(m_state.m_queue.size())) { @@ -903,7 +979,7 @@ class elaborator::imp { return true; } - std::cout << "Postponed: "; display(std::cout, c); + // std::cout << "Postponed: "; display(std::cout, c); push_back(c); return true; @@ -932,7 +1008,7 @@ class elaborator::imp { while (!m_case_splits.empty()) { std::unique_ptr & d = m_case_splits.back(); - std::cout << "Assumption " << d->m_curr_assumption.pp(fmt, options(), nullptr, true) << "\n"; + // std::cout << "Assumption " << d->m_curr_assumption.pp(fmt, options(), nullptr, true) << "\n"; if (depends_on(m_conflict, d->m_curr_assumption)) { d->m_failed_traces.push_back(m_conflict); if (d->next(*this)) { @@ -1035,7 +1111,7 @@ public: } } else { unification_constraint c = q.front(); - std::cout << "Processing, quota: " << m_quota << ", depth: " << m_case_splits.size() << " "; display(std::cout, c); + // std::cout << "Processing, quota: " << m_quota << ", depth: " << m_case_splits.size() << " "; display(std::cout, c); q.pop_front(); if (!process(c)) { resolve_conflict(); diff --git a/src/tests/library/elaborator/elaborator.cpp b/src/tests/library/elaborator/elaborator.cpp index 0c50ac0e5..d95a1d1a2 100644 --- a/src/tests/library/elaborator/elaborator.cpp +++ b/src/tests/library/elaborator/elaborator.cpp @@ -280,12 +280,29 @@ static expr elaborate(expr const & e, environment const & env) { // Check elaborator success static void success(expr const & e, expr const & expected, environment const & env) { - std::cout << "\n" << e << "\n------>\n"; + std::cout << "\n" << e << "\n\n"; expr r = elaborate(e, env); - std::cout << r << "\n"; + std::cout << "\n" << e << "\n------>\n" << r << "\n"; lean_assert(r == expected); } +// Check elaborator failure +static void fails(expr const & e, environment const & env) { + try { + expr new_e = elaborate(e, env); + std::cout << "new_e: " << new_e << std::endl; + lean_unreachable(); + } catch (exception &) { + } +} + +// Check elaborator partial success (i.e., result still contain some metavariables */ +static void unsolved(expr const & e, environment const & env) { + expr r = elaborate(e, env); + std::cout << "\n" << e << "\n------>\n" << r << "\n"; + lean_assert(has_metavar(r)); +} + static void tst7() { environment env; import_all(env); @@ -362,6 +379,243 @@ static void tst9() { success(Refl(_, a), Refl(Nat, a), env); } +static void tst10() { + environment env; + import_all(env); + expr Nat = Const("N"); + env.add_var("N", Type()); + expr R = Const("R"); + env.add_var("R", Type()); + env.add_var("a", Nat); + expr a = Const("a"); + expr f = Const("f"); + env.add_var("f", Nat >> ((R >> Nat) >> R)); + expr x = Const("x"); + expr y = Const("y"); + expr z = Const("z"); + success(Fun({{x, _}, {y, _}}, f(x, y)), + Fun({{x, Nat}, {y, R >> Nat}}, f(x, y)), env); + success(Fun({{x, _}, {y, _}, {z, _}}, Eq(f(x, y), f(x, z))), + Fun({{x, Nat}, {y, R >> Nat}, {z, R >> Nat}}, Eq(f(x, y), f(x, z))), env); + expr A = Const("A"); + success(Fun({{A, Type()}, {x, _}, {y, _}, {z, _}}, Eq(f(x, y), f(x, z))), + Fun({{A, Type()}, {x, Nat}, {y, R >> Nat}, {z, R >> Nat}}, Eq(f(x, y), f(x, z))), env); +} + +static void tst11() { + environment env; + import_all(env); + expr A = Const("A"); + expr B = Const("B"); + expr a = Const("a"); + expr b = Const("b"); + expr f = Const("f"); + expr g = Const("g"); + expr Nat = Const("N"); + env.add_var("N", Type()); + env.add_var("f", Pi({{A, Type()}, {a, A}, {b, A}}, A)); + env.add_var("g", Nat >> Nat); + success(Fun({{a, _}, {b, _}}, g(f(_, a, b))), + Fun({{a, Nat}, {b, Nat}}, g(f(Nat, a, b))), env); +} + +static void tst12() { + environment env; + import_all(env); + expr lst = Const("list"); + expr nil = Const("nil"); + expr cons = Const("cons"); + expr N = Const("N"); + expr A = Const("A"); + expr f = Const("f"); + expr l = Const("l"); + expr a = Const("a"); + env.add_var("N", Type()); + env.add_var("list", Type() >> Type()); + env.add_var("nil", Pi({A, Type()}, lst(A))); + env.add_var("cons", Pi({{A, Type()}, {a, A}, {l, lst(A)}}, lst(A))); + env.add_var("f", lst(N >> N) >> Bool); + success(Fun({a, _}, f(cons(_, a, cons(_, a, nil(_))))), + Fun({a, N >> N}, f(cons(N >> N, a, cons(N >> N, a, nil(N >> N))))), env); +} + +static void tst13() { + environment env; + import_all(env); + expr B = Const("B"); + expr A = Const("A"); + expr x = Const("x"); + expr f = Const("f"); + env.add_var("f", Pi({B, Type()}, B >> B)); + success(Fun({{A, Type()}, {B, Type()}, {x, _}}, f(B, x)), + Fun({{A, Type()}, {B, Type()}, {x, B}}, f(B, x)), env); + fails(Fun({{x, _}, {A, Type()}}, f(A, x)), env); + success(Fun({{A, Type()}, {x, _}}, f(A, x)), + Fun({{A, Type()}, {x, A}}, f(A, x)), env); + success(Fun({{A, Type()}, {B, Type()}, {x, _}}, f(A, x)), + Fun({{A, Type()}, {B, Type()}, {x, A}}, f(A, x)), env); + success(Fun({{A, Type()}, {B, Type()}, {x, _}}, Eq(f(B, x), f(_, x))), + Fun({{A, Type()}, {B, Type()}, {x, B}}, Eq(f(B, x), f(B, x))), env); + success(Fun({{A, Type()}, {B, Type()}, {x, _}}, Eq(f(B, x), f(_, x))), + Fun({{A, Type()}, {B, Type()}, {x, B}}, Eq(f(B, x), f(B, x))), env); + unsolved(Fun({{A, _}, {B, _}, {x, _}}, Eq(f(B, x), f(_, x))), env); +} + +static void tst14() { + environment env; + import_all(env); + expr A = Const("A"); + expr B = Const("B"); + expr f = Const("f"); + expr g = Const("g"); + expr x = Const("x"); + expr y = Const("y"); + env.add_var("N", Type()); + env.add_var("f", Pi({A, Type()}, A >> A)); + expr N = Const("N"); + success(Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, g(_, True, False)), + Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, g(Bool, True, False)), + env); + success(Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(_, Bool, Bool)), + Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(Type(), Bool, Bool)), + env); + success(Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(_, Bool, N)), + Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(Type(), Bool, N)), + env); + success(Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, + g(_, + Fun({{x, _}, {y, _}}, Eq(f(_, x), f(_, y))), + Fun({{x, N}, {y, Bool}}, True))), + Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, + g((N >> (Bool >> Bool)), + Fun({{x, N}, {y, Bool}}, Eq(f(N, x), f(Bool, y))), + Fun({{x, N}, {y, Bool}}, True))), env); + + success(Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, + g(_, + Fun({{x, N}, {y, _}}, Eq(f(_, x), f(_, y))), + Fun({{x, _}, {y, Bool}}, True))), + Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, + g((N >> (Bool >> Bool)), + Fun({{x, N}, {y, Bool}}, Eq(f(N, x), f(Bool, y))), + Fun({{x, N}, {y, Bool}}, True))), env); +} + +static void tst15() { + environment env; + import_all(env); + expr A = Const("A"); + expr B = Const("B"); + expr C = Const("C"); + expr a = Const("a"); + expr b = Const("b"); + expr eq = Const("eq"); + env.add_var("eq", Pi({A, Type()}, A >> (A >> Bool))); + success(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, B}}, eq(_, a, b)), + Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}}, eq(B, a, b)), env); + success(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, A}}, eq(_, a, b)), + Fun({{A, Type()}, {B, Type()}, {a, A}, {b, A}}, eq(A, a, b)), env); + success(Fun({{A, Type()}, {B, Type()}, {a, A}, {b, _}}, eq(_, a, b)), + Fun({{A, Type()}, {B, Type()}, {a, A}, {b, A}}, eq(A, a, b)), env); + success(Fun({{A, Type()}, {B, Type()}, {a, B}, {b, _}}, eq(_, a, b)), + Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}}, eq(B, a, b)), env); + success(Fun({{A, Type()}, {B, Type()}, {a, B}, {b, _}, {C, Type()}}, eq(_, a, b)), + Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}, {C, Type()}}, eq(B, a, b)), env); + fails(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, _}, {C, Type()}}, eq(C, a, b)), env); + success(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, _}, {C, Type()}}, eq(B, a, b)), + Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}, {C, Type()}}, eq(B, a, b)), env); +} + +static void tst16() { + environment env; + import_all(env); + expr a = Const("a"); + expr b = Const("b"); + expr c = Const("c"); + expr H1 = Const("H1"); + expr H2 = Const("H2"); + env.add_var("a", Bool); + env.add_var("b", Bool); + env.add_var("c", Bool); + success(Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}}, + Trans(_, _, _, _, H1, H2)), + Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}}, + Trans(Bool, a, b, c, H1, H2)), + env); + expr H3 = Const("H3"); + success(Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}}, + EqTIntro(_, EqMP(_, _, Symm(_, _, _, Trans(_, _, _, _, Symm(_, _, _, H2), Symm(_, _, _, H1))), H3))), + Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}}, + EqTIntro(c, EqMP(a, c, Symm(Bool, c, a, Trans(Bool, c, b, a, Symm(Bool, b, c, H2), Symm(Bool, a, b, H1))), H3))), + env); + environment env2; + import_all(env2); + success(Fun({{a, Bool}, {b, Bool}, {c, Bool}, {H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}}, + EqTIntro(_, EqMP(_, _, Symm(_, _, _, Trans(_, _, _, _, Symm(_, _, _, H2), Symm(_, _, _, H1))), H3))), + Fun({{a, Bool}, {b, Bool}, {c, Bool}, {H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}}, + EqTIntro(c, EqMP(a, c, Symm(Bool, c, a, Trans(Bool, c, b, a, Symm(Bool, b, c, H2), Symm(Bool, a, b, H1))), H3))), + env2); + expr A = Const("A"); + success(Fun({{A, Type()}, {a, A}, {b, A}, {c, A}, {H1, Eq(a, b)}, {H2, Eq(b, c)}}, + Symm(_, _, _, Trans(_, _, _, _, Symm(_, _, _, H2), Symm(_, _, _, H1)))), + Fun({{A, Type()}, {a, A}, {b, A}, {c, A}, {H1, Eq(a, b)}, {H2, Eq(b, c)}}, + Symm(A, c, a, Trans(A, c, b, a, Symm(A, b, c, H2), Symm(A, a, b, H1)))), + env2); +} + +void tst17() { + environment env; + import_all(env); + expr A = Const("A"); + expr B = Const("B"); + expr a = Const("a"); + expr b = Const("b"); + expr eq = Const("eq"); + env.add_var("eq", Pi({A, Type(level()+1)}, A >> (A >> Bool))); + success(eq(_, Fun({{A, Type()}, {a, _}}, a), Fun({{B, Type()}, {b, B}}, b)), + eq(Pi({A, Type()}, A >> A), Fun({{A, Type()}, {a, A}}, a), Fun({{B, Type()}, {b, B}}, b)), + env); +} + +void tst18() { + environment env; + import_all(env); + expr A = Const("A"); + expr h = Const("h"); + expr f = Const("f"); + expr a = Const("a"); + env.add_var("h", Pi({A, Type()}, A) >> Bool); + success(Fun({{f, Pi({A, Type()}, _)}, {a, Bool}}, h(f)), + Fun({{f, Pi({A, Type()}, A)}, {a, Bool}}, h(f)), + env); +} + +void tst19() { + environment env; + import_all(env); + expr R = Const("R"); + expr A = Const("A"); + expr r = Const("r"); + expr eq = Const("eq"); + expr f = Const("f"); + expr g = Const("g"); + expr h = Const("h"); + expr D = Const("D"); + env.add_var("R", Type() >> Bool); + env.add_var("r", Pi({A, Type()}, R(A))); + env.add_var("h", Pi({A, Type()}, R(A)) >> Bool); + env.add_var("eq", Pi({A, Type(level()+1)}, A >> (A >> Bool))); + success(Let({{f, Fun({A, Type()}, r(_))}, + {g, Fun({A, Type()}, r(_))}, + {D, Fun({A, Type()}, eq(_, f(A), g(_)))}}, + h(f)), + Let({{f, Fun({A, Type()}, r(A))}, + {g, Fun({A, Type()}, r(A))}, + {D, Fun({A, Type()}, eq(R(A), f(A), g(A)))}}, + h(f)), + env); +} + int main() { tst1(); tst2(); @@ -371,6 +625,17 @@ int main() { tst6(); tst7(); tst8(); + tst9(); + tst10(); + tst11(); + tst12(); + tst13(); + tst14(); + tst15(); + tst16(); + tst17(); + tst18(); + tst19(); return has_violations() ? 1 : 0; }