diff --git a/library/hott/path.lean b/library/hott/path.lean index 4e992eab6..e52f27ed2 100644 --- a/library/hott/path.lean +++ b/library/hott/path.lean @@ -207,7 +207,7 @@ rec_on p idp calc_subst transport calc_trans concat calc_refl idpath - +calc_symm inverse -- More theorems for moving things around in equations -- --------------------------------------------------- diff --git a/library/logic/cast.lean b/library/logic/cast.lean index c461149b2..8909be365 100644 --- a/library/logic/cast.lean +++ b/library/logic/cast.lean @@ -63,6 +63,7 @@ end heq calc_trans heq.trans calc_trans heq.trans_left calc_trans heq.trans_right +calc_symm heq.symm theorem cast_heq {A B : Type} (H : A = B) (a : A) : cast H a == a := have H₁ : ∀ (H : A = A) (a : A), cast H a == a, from diff --git a/library/logic/eq.lean b/library/logic/eq.lean index 614c4780f..33f7b41f9 100644 --- a/library/logic/eq.lean +++ b/library/logic/eq.lean @@ -49,6 +49,7 @@ end eq calc_subst eq.subst calc_refl eq.refl calc_trans eq.trans +calc_symm eq.symm open eq.ops diff --git a/src/emacs/lean-syntax.el b/src/emacs/lean-syntax.el index c68956e11..f24b0b71f 100644 --- a/src/emacs/lean-syntax.el +++ b/src/emacs/lean-syntax.el @@ -12,7 +12,8 @@ "hiding" "exposing" "parameter" "parameters" "begin" "proof" "qed" "conjecture" "constant" "constants" "hypothesis" "lemma" "corollary" "variable" "variables" "print" "theorem" "context" "open" "as" "export" "axiom" "inductive" "with" "structure" "universe" "universes" - "alias" "help" "environment" "options" "precedence" "reserve" "postfix" "prefix" "calc_trans" "calc_subst" "calc_refl" + "alias" "help" "environment" "options" "precedence" "reserve" "postfix" "prefix" + "calc_trans" "calc_subst" "calc_refl" "calc_symm" "infix" "infixl" "infixr" "notation" "eval" "check" "exit" "coercion" "end" "using" "namespace" "instance" "class" "section" "set_option" "add_rewrite" "extends" "include" "omit" "classes" "instances" "coercions" "raw") diff --git a/src/frontends/lean/calc.cpp b/src/frontends/lean/calc.cpp index c90781419..d617a06e6 100644 --- a/src/frontends/lean/calc.cpp +++ b/src/frontends/lean/calc.cpp @@ -36,16 +36,21 @@ static name const & get_fn_const(expr const & e, char const * msg) { return const_name(fn); } -static expr extract_arg_types(environment const & env, name const & f, buffer & arg_types) { - expr f_type = env.get(f).get_type(); +static pair extract_arg_types_core(environment const & env, name const & f, buffer & arg_types) { + declaration d = env.get(f); + expr f_type = d.get_type(); while (is_pi(f_type)) { arg_types.push_back(binding_domain(f_type)); f_type = binding_body(f_type); } - return f_type; + return mk_pair(f_type, length(d.get_univ_params())); } -enum class calc_cmd { Subst, Trans, Refl }; +static expr extract_arg_types(environment const & env, name const & f, buffer & arg_types) { + return extract_arg_types_core(env, f, arg_types).first; +} + +enum class calc_cmd { Subst, Trans, Refl, Symm }; struct calc_entry { calc_cmd m_cmd; @@ -57,13 +62,14 @@ struct calc_entry { struct calc_state { typedef name_map> refl_table; typedef name_map> subst_table; + typedef name_map> symm_table; typedef rb_map, name_pair_quick_cmp> trans_table; trans_table m_trans_table; refl_table m_refl_table; subst_table m_subst_table; + symm_table m_symm_table; calc_state() {} - void add_calc_subst(environment const & env, name const & subst) { buffer arg_types; expr r_type = extract_arg_types(env, subst, arg_types); @@ -95,6 +101,18 @@ struct calc_state { name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application"); m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs)); } + + void add_calc_symm(environment const & env, name const & symm) { + buffer arg_types; + auto p = extract_arg_types_core(env, symm, arg_types); + expr r_type = p.first; + unsigned nunivs = p.second; + unsigned nargs = arg_types.size(); + if (nargs < 3) + throw exception("invalid calc symmetry rule, it must have at least 3 arguments"); + name const & rop = get_fn_const(r_type, "invalid calc symmetry rule, result type must be an operator application"); + m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs)); + } }; static name * g_calc_name = nullptr; @@ -108,6 +126,7 @@ struct calc_config { case calc_cmd::Refl: s.add_calc_refl(env, e.m_name); break; case calc_cmd::Subst: s.add_calc_subst(env, e.m_name); break; case calc_cmd::Trans: s.add_calc_trans(env, e.m_name); break; + case calc_cmd::Symm: s.add_calc_symm(env, e.m_name); break; } } static name const & get_class_name() { @@ -149,10 +168,25 @@ environment calc_trans_cmd(parser & p) { return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Trans, id)); } +environment calc_symm_cmd(parser & p) { + name id = p.check_constant_next("invalid 'calc_symm' command, constant expected"); + return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Symm, id)); +} + void register_calc_cmds(cmd_table & r) { add_cmd(r, cmd_info("calc_subst", "set the substitution rule that is used by the calculational proof '{...}' notation", calc_subst_cmd)); add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd)); add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd)); + add_cmd(r, cmd_info("calc_symm", "set the symmetry rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_symm_cmd)); +} + +optional> get_calc_symm_info(environment const & env, name const & rop) { + auto const & s = calc_ext::get_state(env); + if (auto it = s.m_symm_table.find(rop)) { + return optional>(*it); + } else { + return optional>(); + } } static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); } diff --git a/src/frontends/lean/calc.h b/src/frontends/lean/calc.h index 8d47b04d2..073a8125f 100644 --- a/src/frontends/lean/calc.h +++ b/src/frontends/lean/calc.h @@ -11,6 +11,10 @@ class parser; void register_calc_cmds(cmd_table & r); expr parse_calc(parser & p); bool is_calc_annotation(expr const & e); +/** \brief Given an operator name \c op, return the symmetry rule associated with, number of arguments, and universe parameters. + Return none if the operator does not have a symmetry rule associated with it. +*/ +optional> get_calc_symm_info(environment const & env, name const & op); void initialize_calc(); void finalize_calc(); } diff --git a/src/frontends/lean/calc_proof_elaborator.cpp b/src/frontends/lean/calc_proof_elaborator.cpp index cd1eb919f..f05b6d770 100644 --- a/src/frontends/lean/calc_proof_elaborator.cpp +++ b/src/frontends/lean/calc_proof_elaborator.cpp @@ -12,8 +12,40 @@ Author: Leonardo de Moura #include "frontends/lean/util.h" #include "frontends/lean/local_context.h" #include "frontends/lean/info_manager.h" +#include "frontends/lean/calc.h" namespace lean { +static optional> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen, + expr const & e, expr const & e_type, tag g) { + buffer args; + expr const & op = get_app_args(e_type, args); + if (is_constant(op) && args.size() >= 2) { + if (auto t = get_calc_symm_info(env, const_name(op))) { + name symm; unsigned nargs; unsigned nunivs; + std::tie(symm, nargs, nunivs) = *t; + unsigned sz = args.size(); + expr lhs = args[sz-2]; + expr rhs = args[sz-1]; + levels lvls; + for (unsigned i = 0; i < nunivs; i++) + lvls = levels(mk_meta_univ(ngen.next()), lvls); + expr symm_op = mk_constant(symm, lvls); + buffer inv_args; + for (unsigned i = 0; i < nargs - 3; i++) + inv_args.push_back(ctx.mk_meta(ngen, none_expr(), g)); + inv_args.push_back(lhs); + inv_args.push_back(rhs); + inv_args.push_back(e); + expr new_e = mk_app(symm_op, inv_args); + args[sz-2] = rhs; + args[sz-1] = lhs; + expr new_e_type = mk_app(op, args); + return some(mk_pair(new_e, new_e_type)); + } + } + return optional>(); +} + /** \brief Create a "choice" constraint that postpones the resolution of a calc proof step. By delaying it, we can perform quick fixes such as: @@ -36,6 +68,7 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c expr e_type = tc->infer(e, new_cs); e_type = tc->whnf(e_type, new_cs); tag g = e.get_tag(); + // add '!' is needed while (is_pi(e_type)) { binder_info bi = binding_info(e_type); if (!bi.is_implicit() && !bi.is_inst_implicit()) { @@ -50,31 +83,48 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c e_type = tc->whnf(instantiate(binding_body(e_type), imp_arg), new_cs); } - justification new_j = mk_type_mismatch_jst(e, e_type, meta_type); - if (!tc->is_def_eq(e_type, meta_type, new_j, new_cs)) - throw unifier_exception(new_j, s); - buffer cs_buffer; - new_cs.linearize(cs_buffer); - metavar_closure cls(meta); - cls.add(meta_type); - cls.mk_constraints(s, j, relax, cs_buffer); - cs_buffer.push_back(mk_eq_cnstr(meta, e, j, relax)); + auto try_alternative = [&](expr const & e, expr const & e_type) { + justification new_j = mk_type_mismatch_jst(e, e_type, meta_type); + constraint_seq fcs = new_cs; + if (!tc->is_def_eq(e_type, meta_type, new_j, fcs)) + throw unifier_exception(new_j, s); + buffer cs_buffer; + fcs.linearize(cs_buffer); + metavar_closure cls(meta); + cls.add(meta_type); + cls.mk_constraints(s, j, relax, cs_buffer); + cs_buffer.push_back(mk_eq_cnstr(meta, e, j, relax)); - unifier_config new_cfg(cfg); - new_cfg.m_discard = false; - unify_result_seq seq = unify(env, cs_buffer.size(), cs_buffer.data(), ngen, substitution(), new_cfg); - auto p = seq.pull(); - lean_assert(p); - substitution new_s = p->first.first; - constraints postponed = map(p->first.second, - [&](constraint const & c) { - // we erase internal justifications - return update_justification(c, j); - }); - if (im) - im->instantiate(new_s); - constraints r = cls.mk_constraints(new_s, j, relax); - return append(r, postponed); + unifier_config new_cfg(cfg); + new_cfg.m_discard = false; + unify_result_seq seq = unify(env, cs_buffer.size(), cs_buffer.data(), ngen, substitution(), new_cfg); + auto p = seq.pull(); + lean_assert(p); + substitution new_s = p->first.first; + constraints postponed = map(p->first.second, + [&](constraint const & c) { + // we erase internal justifications + return update_justification(c, j); + }); + if (im) + im->instantiate(new_s); + constraints r = cls.mk_constraints(new_s, j, relax); + return append(r, postponed); + }; + + std::unique_ptr saved_ex; + try { + return try_alternative(e, e_type); + } catch (exception & ex) { + saved_ex.reset(ex.clone()); + } + + if (auto p = apply_symmetry(env, ctx, ngen, e, e_type, g)) { + try { return try_alternative(p->first, p->second); } catch (exception &) {} + } + + saved_ex->rethrow(); + lean_unreachable(); }; bool owner = false; return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::Epilogue), owner, j, relax); diff --git a/src/frontends/lean/token_table.cpp b/src/frontends/lean/token_table.cpp index f5dfb6d75..439c3e8cb 100644 --- a/src/frontends/lean/token_table.cpp +++ b/src/frontends/lean/token_table.cpp @@ -85,7 +85,7 @@ void init_token_table(token_table & t) { "evaluate", "check", "eval", "[priority", "print", "end", "namespace", "section", "import", "inductive", "record", "renaming", "extends", "structure", "module", "universe", "universes", "precedence", "reserve", "infixl", "infixr", "infix", "postfix", "prefix", "notation", "context", - "exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "tactic_hint", + "exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "calc_symm", "tactic_hint", "add_begin_end_tactic", "set_begin_end_tactic", "instance", "class", "include", "omit", "#erase_cache", "#projections", nullptr}; diff --git a/tests/lean/run/imp_bang.lean b/tests/lean/run/imp_bang.lean index dc621aed8..2ce8529fb 100644 --- a/tests/lean/run/imp_bang.lean +++ b/tests/lean/run/imp_bang.lean @@ -1,4 +1,4 @@ -import algebra.category.basic +import logic algebra.category.basic open eq eq.ops category functor natural_transformation variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D} @@ -9,6 +9,10 @@ natural_transformation.mk (λ a b f, calc H f ∘ (η a ∘ θ a) = (H f ∘ η a) ∘ θ a : assoc ... = (η b ∘ G f) ∘ θ a : {naturality η f} - ... = η b ∘ (G f ∘ θ a) : symm !assoc + ... = η b ∘ (G f ∘ θ a) : assoc ... = η b ∘ (θ b ∘ F f) : {naturality θ f} ... = (η b ∘ θ b) ∘ F f : assoc) + +theorem tst (a b c : num) (H₁ : ∀ x, b = x) (H₂ : c = b) : a = c := +calc a = b : H₁ + ... = c : H₂