diff --git a/src/library/type_inferer.cpp b/src/library/type_inferer.cpp index 1cb361d2d..fbe0b4048 100644 --- a/src/library/type_inferer.cpp +++ b/src/library/type_inferer.cpp @@ -43,6 +43,10 @@ class type_inferer::imp { return ::lean::lift_free_vars(e, d, m_menv.to_some_menv()); } + expr lower_free_vars(expr const & e, unsigned s, unsigned n) { + return ::lean::lower_free_vars(e, s, n, m_menv.to_some_menv()); + } + expr instantiate(expr const & e, unsigned n, expr const * s) { return ::lean::instantiate(e, n, s, m_menv.to_some_menv()); } @@ -70,16 +74,28 @@ class type_inferer::imp { throw type_expected_exception(m_env, ctx, s); } + /** + \brief Given \c t (a Pi term), this method returns the body (aka range) + of the function space for the element e in the domain of the Pi. + */ + expr get_pi_body(expr const & t, expr const & e) { + lean_assert(is_pi(t)); + if (is_arrow(t)) + return lower_free_vars(abst_body(t), 1, 1); + else + return instantiate(abst_body(t), 1, &e); + } + expr get_range(expr t, expr const & e, context const & ctx) { - unsigned num = num_args(e) - 1; - while (num > 0) { - --num; + unsigned num = num_args(e); + for (unsigned i = 1; i < num; i++) { + expr const & a = arg(e, i); if (is_pi(t)) { - t = abst_body(t); + t = get_pi_body(t, a); } else { t = normalize(t, ctx, false); if (is_pi(t)) { - t = abst_body(t); + t = get_pi_body(t, a); } else if (has_metavar(t) && m_menv && m_uc) { // Create two fresh variables A and B, // and assign r == (Pi(x : A), B) @@ -88,21 +104,18 @@ class type_inferer::imp { expr p = mk_pi(g_x_name, A, B); justification jst = mk_function_expected_justification(ctx, e); m_uc->push_back(mk_eq_constraint(ctx, t, p, jst)); - t = abst_body(p); + t = get_pi_body(p, a); } else { t = normalize(t, ctx, true); if (is_pi(t)) { - t = abst_body(t); + t = get_pi_body(t, a); } else { throw function_expected_exception(m_env, ctx, e); } } } } - if (closed(t)) - return t; - else - return instantiate(t, num_args(e)-1, &arg(e, 1)); + return t; } expr infer_type(expr const & e, context const & ctx) { diff --git a/tests/lean/type_inf_bug1.lean b/tests/lean/type_inf_bug1.lean new file mode 100644 index 000000000..5256aad8f --- /dev/null +++ b/tests/lean/type_inf_bug1.lean @@ -0,0 +1,37 @@ +SetOption pp::colors false + +Definition TypeM := (Type M) +Definition TypeU := (Type U) +Variable MyCastEq {A : TypeU} {A' : TypeU} (H : A == A') (x : A) : x == cast H x + +Check fun (A A': TypeM) + (a : A) + (b : A') + (L2 : A' == A), + let b' : A := cast L2 b, + L3 : b == b' := MyCastEq L2 b + in L3 + +Check fun (A A': TypeM) + (B : A -> TypeM) + (B' : A' -> TypeM) + (f : Pi x : A, B x) + (g : Pi x : A', B' x) + (a : A) + (b : A') + (H1 : (Pi x : A, B x) == (Pi x : A', B' x)) + (H2 : f == g) + (H3 : a == b), + let L1 : A == A' := DomInj H1, + L2 : A' == A := Symm L1, + b' : A := cast L2 b, + L3 : b == b' := MyCastEq L2 b, + L4 : a == b' := TransExt H3 L3, + L5 : f a == f b' := Congr2 f L4, + S1 : (Pi x : A', B' x) == (Pi x : A, B x) := Symm H1, + g' : (Pi x : A, B x) := cast S1 g, + L6 : g == g' := MyCastEq S1 g, + L7 : f == g' := TransExt H2 L6, + L8 : f b' == g' b' := Congr1 b' L7, + L9 : f a == g' b' := TransExt L5 L8 + in L9 diff --git a/tests/lean/type_inf_bug1.lean.expected.out b/tests/lean/type_inf_bug1.lean.expected.out new file mode 100644 index 000000000..623613bbe --- /dev/null +++ b/tests/lean/type_inf_bug1.lean.expected.out @@ -0,0 +1,40 @@ + Set: pp::colors + Set: pp::unicode + Set: pp::colors + Defined: TypeM + Defined: TypeU + Assumed: MyCastEq +λ (A A' : TypeM) (a : A) (b : A') (L2 : A' == A), let b' : A := cast L2 b, L3 : b == b' := MyCastEq L2 b in L3 : + Π (A A' : TypeM) (a : A) (b : A') (L2 : A' == A), b == cast L2 b +λ (A A' : TypeM) + (B : A → TypeM) + (B' : A' → TypeM) + (f : Π x : A, B x) + (g : Π x : A', B' x) + (a : A) + (b : A') + (H1 : (Π x : A, B x) == (Π x : A', B' x)) + (H2 : f == g) + (H3 : a == b), + let L1 : A == A' := DomInj H1, + L2 : A' == A := Symm L1, + b' : A := cast L2 b, + L3 : b == b' := MyCastEq L2 b, + L4 : a == b' := TransExt H3 L3, + L5 : f a == f b' := Congr2 f L4, + S1 : (Π x : A', B' x) == (Π x : A, B x) := Symm H1, + g' : Π x : A, B x := cast S1 g, + L6 : g == g' := MyCastEq S1 g, + L7 : f == g' := TransExt H2 L6, + L8 : f b' == g' b' := Congr1 b' L7, + L9 : f a == g' b' := TransExt L5 L8 + in L9 : + Π (A A' : TypeM) + (B : A → TypeM) + (B' : A' → TypeM) + (f : Π x : A, B x) + (g : Π x : A', B' x) + (a : A) + (b : A') + (H1 : (Π x : A, B x) == (Π x : A', B' x)), + f == g → a == b → f a == Cast (Π x : A', B' x) (Π x : A, B x) (Symm H1) g (Cast A' A (Symm (DomInj H1)) b)