feat(frontends/lean): try to inject symmetry (if needed) in calc proofs, add calc_symm command for configuring the symmetry theorem for a given operator
This is part of #268
This commit is contained in:
parent
c5a62f8abb
commit
591e566472
9 changed files with 129 additions and 34 deletions
|
@ -207,7 +207,7 @@ rec_on p idp
|
|||
calc_subst transport
|
||||
calc_trans concat
|
||||
calc_refl idpath
|
||||
|
||||
calc_symm inverse
|
||||
|
||||
-- More theorems for moving things around in equations
|
||||
-- ---------------------------------------------------
|
||||
|
|
|
@ -63,6 +63,7 @@ end heq
|
|||
calc_trans heq.trans
|
||||
calc_trans heq.trans_left
|
||||
calc_trans heq.trans_right
|
||||
calc_symm heq.symm
|
||||
|
||||
theorem cast_heq {A B : Type} (H : A = B) (a : A) : cast H a == a :=
|
||||
have H₁ : ∀ (H : A = A) (a : A), cast H a == a, from
|
||||
|
|
|
@ -49,6 +49,7 @@ end eq
|
|||
calc_subst eq.subst
|
||||
calc_refl eq.refl
|
||||
calc_trans eq.trans
|
||||
calc_symm eq.symm
|
||||
|
||||
open eq.ops
|
||||
|
||||
|
|
|
@ -12,7 +12,8 @@
|
|||
"hiding" "exposing" "parameter" "parameters" "begin" "proof" "qed" "conjecture" "constant" "constants"
|
||||
"hypothesis" "lemma" "corollary" "variable" "variables" "print" "theorem"
|
||||
"context" "open" "as" "export" "axiom" "inductive" "with" "structure" "universe" "universes"
|
||||
"alias" "help" "environment" "options" "precedence" "reserve" "postfix" "prefix" "calc_trans" "calc_subst" "calc_refl"
|
||||
"alias" "help" "environment" "options" "precedence" "reserve" "postfix" "prefix"
|
||||
"calc_trans" "calc_subst" "calc_refl" "calc_symm"
|
||||
"infix" "infixl" "infixr" "notation" "eval" "check" "exit" "coercion" "end"
|
||||
"using" "namespace" "instance" "class" "section"
|
||||
"set_option" "add_rewrite" "extends" "include" "omit" "classes" "instances" "coercions" "raw")
|
||||
|
|
|
@ -36,16 +36,21 @@ static name const & get_fn_const(expr const & e, char const * msg) {
|
|||
return const_name(fn);
|
||||
}
|
||||
|
||||
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
|
||||
expr f_type = env.get(f).get_type();
|
||||
static pair<expr, unsigned> extract_arg_types_core(environment const & env, name const & f, buffer<expr> & arg_types) {
|
||||
declaration d = env.get(f);
|
||||
expr f_type = d.get_type();
|
||||
while (is_pi(f_type)) {
|
||||
arg_types.push_back(binding_domain(f_type));
|
||||
f_type = binding_body(f_type);
|
||||
}
|
||||
return f_type;
|
||||
return mk_pair(f_type, length(d.get_univ_params()));
|
||||
}
|
||||
|
||||
enum class calc_cmd { Subst, Trans, Refl };
|
||||
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
|
||||
return extract_arg_types_core(env, f, arg_types).first;
|
||||
}
|
||||
|
||||
enum class calc_cmd { Subst, Trans, Refl, Symm };
|
||||
|
||||
struct calc_entry {
|
||||
calc_cmd m_cmd;
|
||||
|
@ -57,13 +62,14 @@ struct calc_entry {
|
|||
struct calc_state {
|
||||
typedef name_map<pair<name, unsigned>> refl_table;
|
||||
typedef name_map<pair<name, unsigned>> subst_table;
|
||||
typedef name_map<std::tuple<name, unsigned, unsigned>> symm_table;
|
||||
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
|
||||
trans_table m_trans_table;
|
||||
refl_table m_refl_table;
|
||||
subst_table m_subst_table;
|
||||
symm_table m_symm_table;
|
||||
calc_state() {}
|
||||
|
||||
|
||||
void add_calc_subst(environment const & env, name const & subst) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, subst, arg_types);
|
||||
|
@ -95,6 +101,18 @@ struct calc_state {
|
|||
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
|
||||
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
|
||||
}
|
||||
|
||||
void add_calc_symm(environment const & env, name const & symm) {
|
||||
buffer<expr> arg_types;
|
||||
auto p = extract_arg_types_core(env, symm, arg_types);
|
||||
expr r_type = p.first;
|
||||
unsigned nunivs = p.second;
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 3)
|
||||
throw exception("invalid calc symmetry rule, it must have at least 3 arguments");
|
||||
name const & rop = get_fn_const(r_type, "invalid calc symmetry rule, result type must be an operator application");
|
||||
m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs));
|
||||
}
|
||||
};
|
||||
|
||||
static name * g_calc_name = nullptr;
|
||||
|
@ -108,6 +126,7 @@ struct calc_config {
|
|||
case calc_cmd::Refl: s.add_calc_refl(env, e.m_name); break;
|
||||
case calc_cmd::Subst: s.add_calc_subst(env, e.m_name); break;
|
||||
case calc_cmd::Trans: s.add_calc_trans(env, e.m_name); break;
|
||||
case calc_cmd::Symm: s.add_calc_symm(env, e.m_name); break;
|
||||
}
|
||||
}
|
||||
static name const & get_class_name() {
|
||||
|
@ -149,10 +168,25 @@ environment calc_trans_cmd(parser & p) {
|
|||
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Trans, id));
|
||||
}
|
||||
|
||||
environment calc_symm_cmd(parser & p) {
|
||||
name id = p.check_constant_next("invalid 'calc_symm' command, constant expected");
|
||||
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Symm, id));
|
||||
}
|
||||
|
||||
void register_calc_cmds(cmd_table & r) {
|
||||
add_cmd(r, cmd_info("calc_subst", "set the substitution rule that is used by the calculational proof '{...}' notation", calc_subst_cmd));
|
||||
add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd));
|
||||
add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd));
|
||||
add_cmd(r, cmd_info("calc_symm", "set the symmetry rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_symm_cmd));
|
||||
}
|
||||
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & rop) {
|
||||
auto const & s = calc_ext::get_state(env);
|
||||
if (auto it = s.m_symm_table.find(rop)) {
|
||||
return optional<std::tuple<name, unsigned, unsigned>>(*it);
|
||||
} else {
|
||||
return optional<std::tuple<name, unsigned, unsigned>>();
|
||||
}
|
||||
}
|
||||
|
||||
static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); }
|
||||
|
|
|
@ -11,6 +11,10 @@ class parser;
|
|||
void register_calc_cmds(cmd_table & r);
|
||||
expr parse_calc(parser & p);
|
||||
bool is_calc_annotation(expr const & e);
|
||||
/** \brief Given an operator name \c op, return the symmetry rule associated with, number of arguments, and universe parameters.
|
||||
Return none if the operator does not have a symmetry rule associated with it.
|
||||
*/
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op);
|
||||
void initialize_calc();
|
||||
void finalize_calc();
|
||||
}
|
||||
|
|
|
@ -12,8 +12,40 @@ Author: Leonardo de Moura
|
|||
#include "frontends/lean/util.h"
|
||||
#include "frontends/lean/local_context.h"
|
||||
#include "frontends/lean/info_manager.h"
|
||||
#include "frontends/lean/calc.h"
|
||||
|
||||
namespace lean {
|
||||
static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen,
|
||||
expr const & e, expr const & e_type, tag g) {
|
||||
buffer<expr> args;
|
||||
expr const & op = get_app_args(e_type, args);
|
||||
if (is_constant(op) && args.size() >= 2) {
|
||||
if (auto t = get_calc_symm_info(env, const_name(op))) {
|
||||
name symm; unsigned nargs; unsigned nunivs;
|
||||
std::tie(symm, nargs, nunivs) = *t;
|
||||
unsigned sz = args.size();
|
||||
expr lhs = args[sz-2];
|
||||
expr rhs = args[sz-1];
|
||||
levels lvls;
|
||||
for (unsigned i = 0; i < nunivs; i++)
|
||||
lvls = levels(mk_meta_univ(ngen.next()), lvls);
|
||||
expr symm_op = mk_constant(symm, lvls);
|
||||
buffer<expr> inv_args;
|
||||
for (unsigned i = 0; i < nargs - 3; i++)
|
||||
inv_args.push_back(ctx.mk_meta(ngen, none_expr(), g));
|
||||
inv_args.push_back(lhs);
|
||||
inv_args.push_back(rhs);
|
||||
inv_args.push_back(e);
|
||||
expr new_e = mk_app(symm_op, inv_args);
|
||||
args[sz-2] = rhs;
|
||||
args[sz-1] = lhs;
|
||||
expr new_e_type = mk_app(op, args);
|
||||
return some(mk_pair(new_e, new_e_type));
|
||||
}
|
||||
}
|
||||
return optional<pair<expr, expr>>();
|
||||
}
|
||||
|
||||
/** \brief Create a "choice" constraint that postpones the resolution of a calc proof step.
|
||||
|
||||
By delaying it, we can perform quick fixes such as:
|
||||
|
@ -36,6 +68,7 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c
|
|||
expr e_type = tc->infer(e, new_cs);
|
||||
e_type = tc->whnf(e_type, new_cs);
|
||||
tag g = e.get_tag();
|
||||
// add '!' is needed
|
||||
while (is_pi(e_type)) {
|
||||
binder_info bi = binding_info(e_type);
|
||||
if (!bi.is_implicit() && !bi.is_inst_implicit()) {
|
||||
|
@ -50,31 +83,48 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c
|
|||
e_type = tc->whnf(instantiate(binding_body(e_type), imp_arg), new_cs);
|
||||
}
|
||||
|
||||
justification new_j = mk_type_mismatch_jst(e, e_type, meta_type);
|
||||
if (!tc->is_def_eq(e_type, meta_type, new_j, new_cs))
|
||||
throw unifier_exception(new_j, s);
|
||||
buffer<constraint> cs_buffer;
|
||||
new_cs.linearize(cs_buffer);
|
||||
metavar_closure cls(meta);
|
||||
cls.add(meta_type);
|
||||
cls.mk_constraints(s, j, relax, cs_buffer);
|
||||
cs_buffer.push_back(mk_eq_cnstr(meta, e, j, relax));
|
||||
auto try_alternative = [&](expr const & e, expr const & e_type) {
|
||||
justification new_j = mk_type_mismatch_jst(e, e_type, meta_type);
|
||||
constraint_seq fcs = new_cs;
|
||||
if (!tc->is_def_eq(e_type, meta_type, new_j, fcs))
|
||||
throw unifier_exception(new_j, s);
|
||||
buffer<constraint> cs_buffer;
|
||||
fcs.linearize(cs_buffer);
|
||||
metavar_closure cls(meta);
|
||||
cls.add(meta_type);
|
||||
cls.mk_constraints(s, j, relax, cs_buffer);
|
||||
cs_buffer.push_back(mk_eq_cnstr(meta, e, j, relax));
|
||||
|
||||
unifier_config new_cfg(cfg);
|
||||
new_cfg.m_discard = false;
|
||||
unify_result_seq seq = unify(env, cs_buffer.size(), cs_buffer.data(), ngen, substitution(), new_cfg);
|
||||
auto p = seq.pull();
|
||||
lean_assert(p);
|
||||
substitution new_s = p->first.first;
|
||||
constraints postponed = map(p->first.second,
|
||||
[&](constraint const & c) {
|
||||
// we erase internal justifications
|
||||
return update_justification(c, j);
|
||||
});
|
||||
if (im)
|
||||
im->instantiate(new_s);
|
||||
constraints r = cls.mk_constraints(new_s, j, relax);
|
||||
return append(r, postponed);
|
||||
unifier_config new_cfg(cfg);
|
||||
new_cfg.m_discard = false;
|
||||
unify_result_seq seq = unify(env, cs_buffer.size(), cs_buffer.data(), ngen, substitution(), new_cfg);
|
||||
auto p = seq.pull();
|
||||
lean_assert(p);
|
||||
substitution new_s = p->first.first;
|
||||
constraints postponed = map(p->first.second,
|
||||
[&](constraint const & c) {
|
||||
// we erase internal justifications
|
||||
return update_justification(c, j);
|
||||
});
|
||||
if (im)
|
||||
im->instantiate(new_s);
|
||||
constraints r = cls.mk_constraints(new_s, j, relax);
|
||||
return append(r, postponed);
|
||||
};
|
||||
|
||||
std::unique_ptr<exception> saved_ex;
|
||||
try {
|
||||
return try_alternative(e, e_type);
|
||||
} catch (exception & ex) {
|
||||
saved_ex.reset(ex.clone());
|
||||
}
|
||||
|
||||
if (auto p = apply_symmetry(env, ctx, ngen, e, e_type, g)) {
|
||||
try { return try_alternative(p->first, p->second); } catch (exception &) {}
|
||||
}
|
||||
|
||||
saved_ex->rethrow();
|
||||
lean_unreachable();
|
||||
};
|
||||
bool owner = false;
|
||||
return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::Epilogue), owner, j, relax);
|
||||
|
|
|
@ -85,7 +85,7 @@ void init_token_table(token_table & t) {
|
|||
"evaluate", "check", "eval", "[priority", "print", "end", "namespace", "section", "import",
|
||||
"inductive", "record", "renaming", "extends", "structure", "module", "universe", "universes",
|
||||
"precedence", "reserve", "infixl", "infixr", "infix", "postfix", "prefix", "notation", "context",
|
||||
"exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "tactic_hint",
|
||||
"exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "calc_symm", "tactic_hint",
|
||||
"add_begin_end_tactic", "set_begin_end_tactic", "instance", "class",
|
||||
"include", "omit", "#erase_cache", "#projections", nullptr};
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import algebra.category.basic
|
||||
import logic algebra.category.basic
|
||||
open eq eq.ops category functor natural_transformation
|
||||
|
||||
variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D}
|
||||
|
@ -9,6 +9,10 @@ natural_transformation.mk
|
|||
(λ a b f, calc
|
||||
H f ∘ (η a ∘ θ a) = (H f ∘ η a) ∘ θ a : assoc
|
||||
... = (η b ∘ G f) ∘ θ a : {naturality η f}
|
||||
... = η b ∘ (G f ∘ θ a) : symm !assoc
|
||||
... = η b ∘ (G f ∘ θ a) : assoc
|
||||
... = η b ∘ (θ b ∘ F f) : {naturality θ f}
|
||||
... = (η b ∘ θ b) ∘ F f : assoc)
|
||||
|
||||
theorem tst (a b c : num) (H₁ : ∀ x, b = x) (H₂ : c = b) : a = c :=
|
||||
calc a = b : H₁
|
||||
... = c : H₂
|
||||
|
|
Loading…
Reference in a new issue