feat(frontends/lean/calc_proof_elaborator): add '{...}' if needed in calc proof steps

This is part of #268
This commit is contained in:
Leonardo de Moura 2014-10-31 00:55:19 -07:00
parent 42dba5cc98
commit 17df85f592
2 changed files with 72 additions and 25 deletions

View file

@ -15,8 +15,38 @@ Author: Leonardo de Moura
#include "frontends/lean/calc.h"
namespace lean {
static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen,
expr const & e, expr const & e_type, tag g) {
static optional<pair<expr, expr>> mk_op(environment const & env, local_context & ctx, name_generator & ngen, type_checker_ptr & tc,
name const & op, unsigned nunivs, unsigned nargs, std::initializer_list<expr> const & explicit_args,
constraint_seq & cs, tag g) {
levels lvls;
for (unsigned i = 0; i < nunivs; i++)
lvls = levels(mk_meta_univ(ngen.next()), lvls);
expr c = mk_constant(op, lvls);
expr op_type = instantiate_type_univ_params(env.get(op), lvls);
buffer<expr> args;
for (unsigned i = 0; i < nargs; i++) {
if (!is_pi(op_type))
return optional<pair<expr, expr>>();
expr arg = ctx.mk_meta(ngen, some_expr(binding_domain(op_type)), g);
args.push_back(arg);
op_type = instantiate(binding_body(op_type), arg);
}
expr r = mk_app(c, args, g);
for (expr const & explicit_arg : explicit_args) {
if (!is_pi(op_type))
return optional<pair<expr, expr>>();
r = mk_app(r, explicit_arg);
expr type = tc->infer(explicit_arg, cs);
justification j = mk_app_justification(r, explicit_arg, binding_domain(op_type), type);
if (!tc->is_def_eq(binding_domain(op_type), type, j, cs))
return optional<pair<expr, expr>>();
op_type = instantiate(binding_body(op_type), explicit_arg);
}
return some(mk_pair(r, op_type));
}
static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen, type_checker_ptr & tc,
expr const & e, expr const & e_type, constraint_seq & cs, tag g) {
buffer<expr> args;
expr const & op = get_app_args(e_type, args);
if (is_constant(op) && args.size() >= 2) {
@ -26,21 +56,33 @@ static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_
unsigned sz = args.size();
expr lhs = args[sz-2];
expr rhs = args[sz-1];
levels lvls;
for (unsigned i = 0; i < nunivs; i++)
lvls = levels(mk_meta_univ(ngen.next()), lvls);
expr symm_op = mk_constant(symm, lvls);
buffer<expr> inv_args;
for (unsigned i = 0; i < nargs - 3; i++)
inv_args.push_back(ctx.mk_meta(ngen, none_expr(), g));
inv_args.push_back(lhs);
inv_args.push_back(rhs);
inv_args.push_back(e);
expr new_e = mk_app(symm_op, inv_args);
args[sz-2] = rhs;
args[sz-1] = lhs;
expr new_e_type = mk_app(op, args);
return some(mk_pair(new_e, new_e_type));
return mk_op(env, ctx, ngen, tc, symm, nunivs, nargs-3, {lhs, rhs, e}, cs, g);
}
}
return optional<pair<expr, expr>>();
}
static optional<pair<expr, expr>> apply_subst(environment const & env, local_context & ctx, name_generator & ngen,
type_checker_ptr & tc, expr const & e, expr const & e_type,
expr const & pred, constraint_seq & cs, tag g) {
buffer<expr> pred_args;
get_app_args(pred, pred_args);
unsigned npargs = pred_args.size();
if (npargs < 2)
return optional<pair<expr, expr>>();
buffer<expr> args;
expr const & op = get_app_args(e_type, args);
if (is_constant(op) && args.size() >= 2) {
if (auto subst_it = get_calc_subst_info(env, const_name(op))) {
name subst; unsigned subst_nargs; unsigned subst_univs;
std::tie(subst, subst_nargs, subst_univs) = *subst_it;
if (auto refl_it = get_calc_refl_info(env, const_name(op))) {
name refl; unsigned refl_nargs; unsigned refl_univs;
std::tie(refl, refl_nargs, refl_univs) = *refl_it;
if (auto refl_pair = mk_op(env, ctx, ngen, tc, refl, refl_univs, refl_nargs-1, { pred_args[npargs-2] }, cs, g)) {
return mk_op(env, ctx, ngen, tc, subst, subst_univs, subst_nargs-2, {e, refl_pair->first}, cs, g);
}
}
}
}
return optional<pair<expr, expr>>();
@ -83,9 +125,8 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c
e_type = tc->whnf(instantiate(binding_body(e_type), imp_arg), new_cs);
}
auto try_alternative = [&](expr const & e, expr const & e_type) {
auto try_alternative = [&](expr const & e, expr const & e_type, constraint_seq fcs) {
justification new_j = mk_type_mismatch_jst(e, e_type, meta_type);
constraint_seq fcs = new_cs;
if (!tc->is_def_eq(e_type, meta_type, new_j, fcs))
throw unifier_exception(new_j, s);
buffer<constraint> cs_buffer;
@ -114,13 +155,19 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c
std::unique_ptr<exception> saved_ex;
try {
return try_alternative(e, e_type);
return try_alternative(e, e_type, new_cs);
} catch (exception & ex) {
saved_ex.reset(ex.clone());
}
if (auto p = apply_symmetry(env, ctx, ngen, e, e_type, g)) {
try { return try_alternative(p->first, p->second); } catch (exception &) {}
constraint_seq symm_cs = new_cs;
if (auto symm = apply_symmetry(env, ctx, ngen, tc, e, e_type, symm_cs, g)) {
try { return try_alternative(symm->first, symm->second, symm_cs); } catch (exception &) {}
}
constraint_seq subst_cs = new_cs;
if (auto subst = apply_subst(env, ctx, ngen, tc, e, e_type, meta_type, subst_cs, g)) {
try { return try_alternative(subst->first, subst->second, subst_cs); } catch (exception&) {}
}
saved_ex->rethrow();

View file

@ -1,16 +1,16 @@
import logic algebra.category.basic
open eq eq.ops category functor natural_transformation
variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D}
variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D}
protected definition compose2 (η : G ⟹ H) (θ : F ⟹ G) : F ⟹ H :=
natural_transformation.mk
(λ a, η a ∘ θ a)
(λ a b f, calc
H f ∘ (η a ∘ θ a) = (H f ∘ η a) ∘ θ a : assoc
... = (η b ∘ G f) ∘ θ a : {naturality η f}
... = (η b ∘ G f) ∘ θ a : naturality η f
... = η b ∘ (G f ∘ θ a) : assoc
... = η b ∘ (θ b ∘ F f) : {naturality θ f}
... = η b ∘ (θ b ∘ F f) : naturality θ f
... = (η b ∘ θ b) ∘ F f : assoc)
theorem tst (a b c : num) (H₁ : ∀ x, b = x) (H₂ : c = b) : a = c :=