diff --git a/src/frontends/lean/calc_proof_elaborator.cpp b/src/frontends/lean/calc_proof_elaborator.cpp index f05b6d770..015ba9917 100644 --- a/src/frontends/lean/calc_proof_elaborator.cpp +++ b/src/frontends/lean/calc_proof_elaborator.cpp @@ -15,8 +15,38 @@ Author: Leonardo de Moura #include "frontends/lean/calc.h" namespace lean { -static optional> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen, - expr const & e, expr const & e_type, tag g) { +static optional> mk_op(environment const & env, local_context & ctx, name_generator & ngen, type_checker_ptr & tc, + name const & op, unsigned nunivs, unsigned nargs, std::initializer_list const & explicit_args, + constraint_seq & cs, tag g) { + levels lvls; + for (unsigned i = 0; i < nunivs; i++) + lvls = levels(mk_meta_univ(ngen.next()), lvls); + expr c = mk_constant(op, lvls); + expr op_type = instantiate_type_univ_params(env.get(op), lvls); + buffer args; + for (unsigned i = 0; i < nargs; i++) { + if (!is_pi(op_type)) + return optional>(); + expr arg = ctx.mk_meta(ngen, some_expr(binding_domain(op_type)), g); + args.push_back(arg); + op_type = instantiate(binding_body(op_type), arg); + } + expr r = mk_app(c, args, g); + for (expr const & explicit_arg : explicit_args) { + if (!is_pi(op_type)) + return optional>(); + r = mk_app(r, explicit_arg); + expr type = tc->infer(explicit_arg, cs); + justification j = mk_app_justification(r, explicit_arg, binding_domain(op_type), type); + if (!tc->is_def_eq(binding_domain(op_type), type, j, cs)) + return optional>(); + op_type = instantiate(binding_body(op_type), explicit_arg); + } + return some(mk_pair(r, op_type)); +} + +static optional> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen, type_checker_ptr & tc, + expr const & e, expr const & e_type, constraint_seq & cs, tag g) { buffer args; expr const & op = get_app_args(e_type, args); if (is_constant(op) && args.size() >= 2) { @@ -26,21 +56,33 @@ static optional> apply_symmetry(environment const & env, local_ unsigned sz = args.size(); expr lhs = args[sz-2]; expr rhs = args[sz-1]; - levels lvls; - for (unsigned i = 0; i < nunivs; i++) - lvls = levels(mk_meta_univ(ngen.next()), lvls); - expr symm_op = mk_constant(symm, lvls); - buffer inv_args; - for (unsigned i = 0; i < nargs - 3; i++) - inv_args.push_back(ctx.mk_meta(ngen, none_expr(), g)); - inv_args.push_back(lhs); - inv_args.push_back(rhs); - inv_args.push_back(e); - expr new_e = mk_app(symm_op, inv_args); - args[sz-2] = rhs; - args[sz-1] = lhs; - expr new_e_type = mk_app(op, args); - return some(mk_pair(new_e, new_e_type)); + return mk_op(env, ctx, ngen, tc, symm, nunivs, nargs-3, {lhs, rhs, e}, cs, g); + } + } + return optional>(); +} + +static optional> apply_subst(environment const & env, local_context & ctx, name_generator & ngen, + type_checker_ptr & tc, expr const & e, expr const & e_type, + expr const & pred, constraint_seq & cs, tag g) { + buffer pred_args; + get_app_args(pred, pred_args); + unsigned npargs = pred_args.size(); + if (npargs < 2) + return optional>(); + buffer args; + expr const & op = get_app_args(e_type, args); + if (is_constant(op) && args.size() >= 2) { + if (auto subst_it = get_calc_subst_info(env, const_name(op))) { + name subst; unsigned subst_nargs; unsigned subst_univs; + std::tie(subst, subst_nargs, subst_univs) = *subst_it; + if (auto refl_it = get_calc_refl_info(env, const_name(op))) { + name refl; unsigned refl_nargs; unsigned refl_univs; + std::tie(refl, refl_nargs, refl_univs) = *refl_it; + if (auto refl_pair = mk_op(env, ctx, ngen, tc, refl, refl_univs, refl_nargs-1, { pred_args[npargs-2] }, cs, g)) { + return mk_op(env, ctx, ngen, tc, subst, subst_univs, subst_nargs-2, {e, refl_pair->first}, cs, g); + } + } } } return optional>(); @@ -83,9 +125,8 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c e_type = tc->whnf(instantiate(binding_body(e_type), imp_arg), new_cs); } - auto try_alternative = [&](expr const & e, expr const & e_type) { + auto try_alternative = [&](expr const & e, expr const & e_type, constraint_seq fcs) { justification new_j = mk_type_mismatch_jst(e, e_type, meta_type); - constraint_seq fcs = new_cs; if (!tc->is_def_eq(e_type, meta_type, new_j, fcs)) throw unifier_exception(new_j, s); buffer cs_buffer; @@ -114,13 +155,19 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c std::unique_ptr saved_ex; try { - return try_alternative(e, e_type); + return try_alternative(e, e_type, new_cs); } catch (exception & ex) { saved_ex.reset(ex.clone()); } - if (auto p = apply_symmetry(env, ctx, ngen, e, e_type, g)) { - try { return try_alternative(p->first, p->second); } catch (exception &) {} + constraint_seq symm_cs = new_cs; + if (auto symm = apply_symmetry(env, ctx, ngen, tc, e, e_type, symm_cs, g)) { + try { return try_alternative(symm->first, symm->second, symm_cs); } catch (exception &) {} + } + + constraint_seq subst_cs = new_cs; + if (auto subst = apply_subst(env, ctx, ngen, tc, e, e_type, meta_type, subst_cs, g)) { + try { return try_alternative(subst->first, subst->second, subst_cs); } catch (exception&) {} } saved_ex->rethrow(); diff --git a/tests/lean/run/imp_bang.lean b/tests/lean/run/imp_bang.lean index 2ce8529fb..074558d96 100644 --- a/tests/lean/run/imp_bang.lean +++ b/tests/lean/run/imp_bang.lean @@ -1,16 +1,16 @@ import logic algebra.category.basic open eq eq.ops category functor natural_transformation -variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D} +variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D} protected definition compose2 (η : G ⟹ H) (θ : F ⟹ G) : F ⟹ H := natural_transformation.mk (λ a, η a ∘ θ a) (λ a b f, calc H f ∘ (η a ∘ θ a) = (H f ∘ η a) ∘ θ a : assoc - ... = (η b ∘ G f) ∘ θ a : {naturality η f} + ... = (η b ∘ G f) ∘ θ a : naturality η f ... = η b ∘ (G f ∘ θ a) : assoc - ... = η b ∘ (θ b ∘ F f) : {naturality θ f} + ... = η b ∘ (θ b ∘ F f) : naturality θ f ... = (η b ∘ θ b) ∘ F f : assoc) theorem tst (a b c : num) (H₁ : ∀ x, b = x) (H₂ : c = b) : a = c :=