diff --git a/src/kernel/unification_constraint.cpp b/src/kernel/unification_constraint.cpp index fc84547c2..d4b24ca5d 100644 --- a/src/kernel/unification_constraint.cpp +++ b/src/kernel/unification_constraint.cpp @@ -62,6 +62,10 @@ format unification_constraint_eq::pp(formatter const & fmt, options const & opts return add_justification(fmt, opts, body, m_justification, p, include_justification, menv); } +unification_constraint unification_constraint_eq::updt_justification(justification const & j) const { + return mk_eq_constraint(m_ctx, m_lhs, m_rhs, j); +} + unification_constraint_convertible::unification_constraint_convertible(context const & c, expr const & from, expr const & to, justification const & j): unification_constraint_cell(unification_constraint_kind::Convertible, c, j), m_from(from), @@ -80,6 +84,10 @@ format unification_constraint_convertible::pp(formatter const & fmt, options con return add_justification(fmt, opts, body, m_justification, p, include_justification, menv); } +unification_constraint unification_constraint_convertible::updt_justification(justification const & j) const { + return mk_convertible_constraint(m_ctx, m_from, m_to, j); +} + unification_constraint_max::unification_constraint_max(context const & c, expr const & lhs1, expr const & lhs2, expr const & rhs, justification const & j): unification_constraint_cell(unification_constraint_kind::Max, c, j), m_lhs1(lhs1), @@ -101,6 +109,10 @@ format unification_constraint_max::pp(formatter const & fmt, options const & opt return add_justification(fmt, opts, body, m_justification, p, include_justification, menv); } +unification_constraint unification_constraint_max::updt_justification(justification const & j) const { + return mk_max_constraint(m_ctx, m_lhs1, m_lhs2, m_rhs, j); +} + unification_constraint_choice::unification_constraint_choice(context const & c, expr const & mvar, unsigned num, expr const * choices, justification const & j): unification_constraint_cell(unification_constraint_kind::Choice, c, j), m_mvar(mvar), @@ -127,6 +139,10 @@ format unification_constraint_choice::pp(formatter const & fmt, options const & return add_justification(fmt, opts, body, m_justification, p, include_justification, menv); } +unification_constraint unification_constraint_choice::updt_justification(justification const & j) const { + return mk_choice_constraint(m_ctx, m_mvar, m_choices.size(), m_choices.data(), j); +} + format unification_constraint::pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool include_justification, optional const & menv) const { return m_ptr->pp(fmt, opts, p, include_justification, menv); diff --git a/src/kernel/unification_constraint.h b/src/kernel/unification_constraint.h index ca7b8d4a7..14d3ab9a0 100644 --- a/src/kernel/unification_constraint.h +++ b/src/kernel/unification_constraint.h @@ -52,6 +52,8 @@ public: virtual format pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool include_justification, optional const & menv) const = 0; void set_justification(justification const & j) { lean_assert(!m_justification); m_justification = j; } + /** \brief Return a new constraint equal to this one, but with the new justification */ + virtual unification_constraint updt_justification(justification const & j) const = 0; }; class unification_constraint_eq; @@ -83,6 +85,8 @@ public: context const & get_context() const { return m_ptr->get_context(); } + virtual unification_constraint updt_justification(justification const & j) const { lean_assert(m_ptr); return m_ptr->updt_justification(j); } + friend unification_constraint mk_eq_constraint(context const & c, expr const & lhs, expr const & rhs, justification const & j); friend unification_constraint mk_convertible_constraint(context const & c, expr const & from, expr const & to, justification const & j); friend unification_constraint mk_max_constraint(context const & c, expr const & lhs1, expr const & lhs2, expr const & rhs, justification const & j); @@ -107,6 +111,7 @@ public: expr const & get_rhs() const { return m_rhs; } virtual format pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool include_justification, optional const & menv) const; + virtual unification_constraint updt_justification(justification const & j) const; }; /** @@ -124,6 +129,7 @@ public: expr const & get_to() const { return m_to; } virtual format pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool include_justification, optional const & menv) const; + virtual unification_constraint updt_justification(justification const & j) const; }; /** @@ -141,6 +147,7 @@ public: expr const & get_rhs() const { return m_rhs; } virtual format pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool include_justification, optional const & menv) const; + virtual unification_constraint updt_justification(justification const & j) const; }; /** @@ -159,6 +166,7 @@ public: std::vector::const_iterator end_choices() const { return m_choices.end(); } virtual format pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool include_justification, optional const & menv) const; + virtual unification_constraint updt_justification(justification const & j) const; }; unification_constraint mk_eq_constraint(context const & c, expr const & lhs, expr const & rhs, justification const & j); diff --git a/src/library/elaborator/elaborator_justification.cpp b/src/library/elaborator/elaborator_justification.cpp index 62617f72b..e1e701045 100644 --- a/src/library/elaborator/elaborator_justification.cpp +++ b/src/library/elaborator/elaborator_justification.cpp @@ -121,18 +121,18 @@ bool is_derived_constraint(unification_constraint const & uc) { return j && dynamic_cast(j.raw()); } -unification_constraint const & get_non_derived_constraint(unification_constraint const & uc) { - auto jcell = uc.get_justification().raw(); +unification_constraint get_non_derived_constraint(unification_constraint const & uc) { + auto j = uc.get_justification(); + auto jcell = j.raw(); if (auto pcell = dynamic_cast(jcell)) { return get_non_derived_constraint(pcell->get_constraint()); } else { - return uc; + return uc.updt_justification(remove_detail(j)); } } justification remove_detail(justification const & j) { auto jcell = j.raw(); - if (auto fc_cell = dynamic_cast(jcell)) { auto uc = fc_cell->get_constraint(); if (is_derived_constraint(uc)) { @@ -145,12 +145,13 @@ justification remove_detail(justification const & j) { new_js.push_back(remove_detail(j)); return justification(new unification_failure_by_cases_justification(uc, new_js.size(), new_js.data(), fc_cell->get_menv())); } - return j; } else if (auto f_cell = dynamic_cast(jcell)) { unification_constraint const & new_uc = get_non_derived_constraint(f_cell->get_constraint()); return justification(new unification_failure_justification(new_uc, f_cell->get_menv())); } else if (auto p_cell = dynamic_cast(jcell)) { return remove_detail(p_cell->get_constraint().get_justification()); + } else if (auto t_cell = dynamic_cast(jcell)) { + return remove_detail(t_cell->get_justification()); } else { return j; } diff --git a/src/library/elaborator/elaborator_justification.h b/src/library/elaborator/elaborator_justification.h index 35cc55d0f..31609bbe0 100644 --- a/src/library/elaborator/elaborator_justification.h +++ b/src/library/elaborator/elaborator_justification.h @@ -164,6 +164,7 @@ public: virtual ~typeof_mvar_justification(); virtual format pp_header(formatter const &, options const &, optional const & menv) const; virtual void get_children(buffer & r) const; + justification const & get_justification() const { return m_justification; } }; /** diff --git a/tests/lean/induction2.lean b/tests/lean/induction2.lean new file mode 100644 index 000000000..2663f77f3 --- /dev/null +++ b/tests/lean/induction2.lean @@ -0,0 +1,21 @@ +import macros -- loads the take, assume, obtain macros + +using Nat -- using the Nat namespace (it allows us to suppress the Nat:: prefix) + +axiom Induction : ∀ P : Nat → Bool, P 0 ⇒ (∀ n, P n ⇒ P (n + 1)) ⇒ ∀ n, P n. + +-- induction on n + +theorem Comm1 : ∀ n m, n + m = m + n +:= Induction + ◂ _ -- I use a placeholder because I do not want to write the P + ◂ (take m, -- Base case + calc 0 + m = m : add::zerol m + ... = m + 0 : symm (add::zeror m)) + ◂ (take n, -- Inductive case + assume iH, + take m, + calc n + 1 + m = (n + m) + 1 : add::succl n m + ... = (m + n) + 1 : { iH } -- Error is here + ... = m + (n + 1) : symm (add::succr m n)) + diff --git a/tests/lean/induction2.lean.expected.out b/tests/lean/induction2.lean.expected.out new file mode 100644 index 000000000..a028dad8a --- /dev/null +++ b/tests/lean/induction2.lean.expected.out @@ -0,0 +1,104 @@ + Set: pp::colors + Set: pp::unicode + Imported 'macros' + Using: Nat + Assumed: Induction +Failed to solve + ⊢ (?M::10 ≈ @mp) ⊕ (?M::10 ≈ eq::@mp) ⊕ (?M::10 ≈ forall::@elim) + (line: 11: pos: 5) Overloading at + (forall::@elim | eq::@mp | @mp) _ _ Induction _ + Failed to solve + ⊢ (ℕ → Bool) → Bool ≺ Bool + (line: 11: pos: 5) Type of argument 3 must be convertible to the expected type in the application of + @mp + with arguments: + ?M::7 + λ P : ℕ → Bool, P 0 ⇒ (∀ n : ℕ, P n ⇒ P (n + 1)) ⇒ (∀ n : ℕ, P n) + Induction + ?M::9 + Failed to solve + ⊢ ∀ P : ℕ → Bool, P 0 ⇒ (∀ n : ℕ, P n ⇒ P (n + 1)) ⇒ (∀ n : ℕ, P n) ≺ ?M::7 == ?M::8 + (line: 11: pos: 5) Type of argument 3 must be convertible to the expected type in the application of + eq::@mp + with arguments: + ?M::7 + ?M::8 + Induction + ?M::9 + Failed to solve + ⊢ (?M::17 ≈ @mp) ⊕ (?M::17 ≈ eq::@mp) ⊕ (?M::17 ≈ forall::@elim) + (line: 12: pos: 6) Overloading at + (forall::@elim | eq::@mp | @mp) + _ + _ + ((forall::@elim | eq::@mp | @mp) _ _ Induction _) + (forall::intro (λ m : _, Nat::add::zerol m ⋈ symm (Nat::add::zeror m))) + Failed to solve + ⊢ (?M::34 ≈ @mp) ⊕ (?M::34 ≈ eq::@mp) ⊕ (?M::34 ≈ forall::@elim) + (line: 15: pos: 5) Overloading at + let κ::1 := (forall::@elim | eq::@mp | @mp) + _ + _ + ((forall::@elim | eq::@mp | @mp) _ _ Induction _) + (forall::intro (λ m : _, Nat::add::zerol m ⋈ symm (Nat::add::zeror m))), + κ::2 := λ n : _, + discharge + (λ iH : _, + forall::intro + (λ m : _, + Nat::add::succl n m ⋈ subst (refl (n + m + 1)) iH ⋈ + symm (Nat::add::succr m n))) + in (forall::@elim | eq::@mp | @mp) _ _ κ::1 (forall::intro κ::2) + Failed to solve + ⊢ ∀ n : ℕ, ?M::9 n ≺ ∀ n m : ℕ, n + m = m + n + (line: 15: pos: 5) Type of definition 'Comm1' must be convertible to expected type. + Failed to solve + ⊢ (∀ n : ℕ, ?M::9 n ⇒ ?M::9 (n + 1)) ⇒ (∀ n : ℕ, ?M::9 n) ≺ ?M::3 == ?M::4 + (line: 15: pos: 5) Type of argument 3 must be convertible to the expected type in the application of + eq::@mp + with arguments: + ?M::3 + ?M::4 + Induction ◂ ?M::9 ◂ forall::intro (λ m : ℕ, Nat::add::zerol m ⋈ symm (Nat::add::zeror m)) + forall::intro + (λ n : ℕ, + discharge + (λ iH : ?M::20, + forall::intro + (λ m : ℕ, + Nat::add::succl n m ⋈ subst (refl (n + m + 1)) iH ⋈ + symm (Nat::add::succr m n)))) + Failed to solve + ⊢ Bool ≺ ?M::3 → Bool + (line: 15: pos: 5) Type of argument 3 must be convertible to the expected type in the application of + forall::@elim + with arguments: + ?M::3 + ∀ n : ℕ, ?M::9 n + Induction ◂ ?M::9 ◂ forall::intro (λ m : ℕ, Nat::add::zerol m ⋈ symm (Nat::add::zeror m)) + forall::intro + (λ n : ℕ, + discharge + (λ iH : ?M::20, + forall::intro + (λ m : ℕ, + Nat::add::succl n m ⋈ subst (refl (n + m + 1)) iH ⋈ + symm (Nat::add::succr m n)))) + Failed to solve + ⊢ ?M::9 0 ⇒ (∀ n : ℕ, ?M::9 n ⇒ ?M::9 (n + 1)) ⇒ (∀ n : ℕ, ?M::9 n) ≺ ?M::5 == ?M::6 + (line: 12: pos: 6) Type of argument 3 must be convertible to the expected type in the application of + eq::@mp + with arguments: + ?M::5 + ?M::6 + Induction ◂ ?M::9 + forall::intro (λ m : ℕ, Nat::add::zerol m ⋈ symm (Nat::add::zeror m)) + Failed to solve + ⊢ Bool ≺ ?M::5 → Bool + (line: 12: pos: 6) Type of argument 3 must be convertible to the expected type in the application of + forall::@elim + with arguments: + ?M::5 + (∀ n : ℕ, ?M::9 n ⇒ ?M::9 (n + 1)) ⇒ (∀ n : ℕ, ?M::9 n) + Induction ◂ ?M::9 + forall::intro (λ m : ℕ, Nat::add::zerol m ⋈ symm (Nat::add::zeror m))