feat(frontends/lean): try to inject symmetry (if needed) in calc proofs, add calc_symm command for configuring the symmetry theorem for a given operator

This is part of #268
This commit is contained in:
Leonardo de Moura 2014-10-30 23:24:09 -07:00
parent c5a62f8abb
commit 591e566472
9 changed files with 129 additions and 34 deletions

View file

@ -207,7 +207,7 @@ rec_on p idp
calc_subst transport calc_subst transport
calc_trans concat calc_trans concat
calc_refl idpath calc_refl idpath
calc_symm inverse
-- More theorems for moving things around in equations -- More theorems for moving things around in equations
-- --------------------------------------------------- -- ---------------------------------------------------

View file

@ -63,6 +63,7 @@ end heq
calc_trans heq.trans calc_trans heq.trans
calc_trans heq.trans_left calc_trans heq.trans_left
calc_trans heq.trans_right calc_trans heq.trans_right
calc_symm heq.symm
theorem cast_heq {A B : Type} (H : A = B) (a : A) : cast H a == a := theorem cast_heq {A B : Type} (H : A = B) (a : A) : cast H a == a :=
have H₁ : ∀ (H : A = A) (a : A), cast H a == a, from have H₁ : ∀ (H : A = A) (a : A), cast H a == a, from

View file

@ -49,6 +49,7 @@ end eq
calc_subst eq.subst calc_subst eq.subst
calc_refl eq.refl calc_refl eq.refl
calc_trans eq.trans calc_trans eq.trans
calc_symm eq.symm
open eq.ops open eq.ops

View file

@ -12,7 +12,8 @@
"hiding" "exposing" "parameter" "parameters" "begin" "proof" "qed" "conjecture" "constant" "constants" "hiding" "exposing" "parameter" "parameters" "begin" "proof" "qed" "conjecture" "constant" "constants"
"hypothesis" "lemma" "corollary" "variable" "variables" "print" "theorem" "hypothesis" "lemma" "corollary" "variable" "variables" "print" "theorem"
"context" "open" "as" "export" "axiom" "inductive" "with" "structure" "universe" "universes" "context" "open" "as" "export" "axiom" "inductive" "with" "structure" "universe" "universes"
"alias" "help" "environment" "options" "precedence" "reserve" "postfix" "prefix" "calc_trans" "calc_subst" "calc_refl" "alias" "help" "environment" "options" "precedence" "reserve" "postfix" "prefix"
"calc_trans" "calc_subst" "calc_refl" "calc_symm"
"infix" "infixl" "infixr" "notation" "eval" "check" "exit" "coercion" "end" "infix" "infixl" "infixr" "notation" "eval" "check" "exit" "coercion" "end"
"using" "namespace" "instance" "class" "section" "using" "namespace" "instance" "class" "section"
"set_option" "add_rewrite" "extends" "include" "omit" "classes" "instances" "coercions" "raw") "set_option" "add_rewrite" "extends" "include" "omit" "classes" "instances" "coercions" "raw")

View file

@ -36,16 +36,21 @@ static name const & get_fn_const(expr const & e, char const * msg) {
return const_name(fn); return const_name(fn);
} }
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) { static pair<expr, unsigned> extract_arg_types_core(environment const & env, name const & f, buffer<expr> & arg_types) {
expr f_type = env.get(f).get_type(); declaration d = env.get(f);
expr f_type = d.get_type();
while (is_pi(f_type)) { while (is_pi(f_type)) {
arg_types.push_back(binding_domain(f_type)); arg_types.push_back(binding_domain(f_type));
f_type = binding_body(f_type); f_type = binding_body(f_type);
} }
return f_type; return mk_pair(f_type, length(d.get_univ_params()));
} }
enum class calc_cmd { Subst, Trans, Refl }; static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
return extract_arg_types_core(env, f, arg_types).first;
}
enum class calc_cmd { Subst, Trans, Refl, Symm };
struct calc_entry { struct calc_entry {
calc_cmd m_cmd; calc_cmd m_cmd;
@ -57,13 +62,14 @@ struct calc_entry {
struct calc_state { struct calc_state {
typedef name_map<pair<name, unsigned>> refl_table; typedef name_map<pair<name, unsigned>> refl_table;
typedef name_map<pair<name, unsigned>> subst_table; typedef name_map<pair<name, unsigned>> subst_table;
typedef name_map<std::tuple<name, unsigned, unsigned>> symm_table;
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table; typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
trans_table m_trans_table; trans_table m_trans_table;
refl_table m_refl_table; refl_table m_refl_table;
subst_table m_subst_table; subst_table m_subst_table;
symm_table m_symm_table;
calc_state() {} calc_state() {}
void add_calc_subst(environment const & env, name const & subst) { void add_calc_subst(environment const & env, name const & subst) {
buffer<expr> arg_types; buffer<expr> arg_types;
expr r_type = extract_arg_types(env, subst, arg_types); expr r_type = extract_arg_types(env, subst, arg_types);
@ -95,6 +101,18 @@ struct calc_state {
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application"); name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs)); m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
} }
void add_calc_symm(environment const & env, name const & symm) {
buffer<expr> arg_types;
auto p = extract_arg_types_core(env, symm, arg_types);
expr r_type = p.first;
unsigned nunivs = p.second;
unsigned nargs = arg_types.size();
if (nargs < 3)
throw exception("invalid calc symmetry rule, it must have at least 3 arguments");
name const & rop = get_fn_const(r_type, "invalid calc symmetry rule, result type must be an operator application");
m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs));
}
}; };
static name * g_calc_name = nullptr; static name * g_calc_name = nullptr;
@ -108,6 +126,7 @@ struct calc_config {
case calc_cmd::Refl: s.add_calc_refl(env, e.m_name); break; case calc_cmd::Refl: s.add_calc_refl(env, e.m_name); break;
case calc_cmd::Subst: s.add_calc_subst(env, e.m_name); break; case calc_cmd::Subst: s.add_calc_subst(env, e.m_name); break;
case calc_cmd::Trans: s.add_calc_trans(env, e.m_name); break; case calc_cmd::Trans: s.add_calc_trans(env, e.m_name); break;
case calc_cmd::Symm: s.add_calc_symm(env, e.m_name); break;
} }
} }
static name const & get_class_name() { static name const & get_class_name() {
@ -149,10 +168,25 @@ environment calc_trans_cmd(parser & p) {
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Trans, id)); return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Trans, id));
} }
environment calc_symm_cmd(parser & p) {
name id = p.check_constant_next("invalid 'calc_symm' command, constant expected");
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Symm, id));
}
void register_calc_cmds(cmd_table & r) { void register_calc_cmds(cmd_table & r) {
add_cmd(r, cmd_info("calc_subst", "set the substitution rule that is used by the calculational proof '{...}' notation", calc_subst_cmd)); add_cmd(r, cmd_info("calc_subst", "set the substitution rule that is used by the calculational proof '{...}' notation", calc_subst_cmd));
add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd)); add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd));
add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd)); add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd));
add_cmd(r, cmd_info("calc_symm", "set the symmetry rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_symm_cmd));
}
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & rop) {
auto const & s = calc_ext::get_state(env);
if (auto it = s.m_symm_table.find(rop)) {
return optional<std::tuple<name, unsigned, unsigned>>(*it);
} else {
return optional<std::tuple<name, unsigned, unsigned>>();
}
} }
static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); } static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); }

View file

@ -11,6 +11,10 @@ class parser;
void register_calc_cmds(cmd_table & r); void register_calc_cmds(cmd_table & r);
expr parse_calc(parser & p); expr parse_calc(parser & p);
bool is_calc_annotation(expr const & e); bool is_calc_annotation(expr const & e);
/** \brief Given an operator name \c op, return the symmetry rule associated with, number of arguments, and universe parameters.
Return none if the operator does not have a symmetry rule associated with it.
*/
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op);
void initialize_calc(); void initialize_calc();
void finalize_calc(); void finalize_calc();
} }

View file

@ -12,8 +12,40 @@ Author: Leonardo de Moura
#include "frontends/lean/util.h" #include "frontends/lean/util.h"
#include "frontends/lean/local_context.h" #include "frontends/lean/local_context.h"
#include "frontends/lean/info_manager.h" #include "frontends/lean/info_manager.h"
#include "frontends/lean/calc.h"
namespace lean { namespace lean {
static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_context & ctx, name_generator & ngen,
expr const & e, expr const & e_type, tag g) {
buffer<expr> args;
expr const & op = get_app_args(e_type, args);
if (is_constant(op) && args.size() >= 2) {
if (auto t = get_calc_symm_info(env, const_name(op))) {
name symm; unsigned nargs; unsigned nunivs;
std::tie(symm, nargs, nunivs) = *t;
unsigned sz = args.size();
expr lhs = args[sz-2];
expr rhs = args[sz-1];
levels lvls;
for (unsigned i = 0; i < nunivs; i++)
lvls = levels(mk_meta_univ(ngen.next()), lvls);
expr symm_op = mk_constant(symm, lvls);
buffer<expr> inv_args;
for (unsigned i = 0; i < nargs - 3; i++)
inv_args.push_back(ctx.mk_meta(ngen, none_expr(), g));
inv_args.push_back(lhs);
inv_args.push_back(rhs);
inv_args.push_back(e);
expr new_e = mk_app(symm_op, inv_args);
args[sz-2] = rhs;
args[sz-1] = lhs;
expr new_e_type = mk_app(op, args);
return some(mk_pair(new_e, new_e_type));
}
}
return optional<pair<expr, expr>>();
}
/** \brief Create a "choice" constraint that postpones the resolution of a calc proof step. /** \brief Create a "choice" constraint that postpones the resolution of a calc proof step.
By delaying it, we can perform quick fixes such as: By delaying it, we can perform quick fixes such as:
@ -36,6 +68,7 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c
expr e_type = tc->infer(e, new_cs); expr e_type = tc->infer(e, new_cs);
e_type = tc->whnf(e_type, new_cs); e_type = tc->whnf(e_type, new_cs);
tag g = e.get_tag(); tag g = e.get_tag();
// add '!' is needed
while (is_pi(e_type)) { while (is_pi(e_type)) {
binder_info bi = binding_info(e_type); binder_info bi = binding_info(e_type);
if (!bi.is_implicit() && !bi.is_inst_implicit()) { if (!bi.is_implicit() && !bi.is_inst_implicit()) {
@ -50,31 +83,48 @@ constraint mk_calc_proof_cnstr(environment const & env, local_context const & _c
e_type = tc->whnf(instantiate(binding_body(e_type), imp_arg), new_cs); e_type = tc->whnf(instantiate(binding_body(e_type), imp_arg), new_cs);
} }
justification new_j = mk_type_mismatch_jst(e, e_type, meta_type); auto try_alternative = [&](expr const & e, expr const & e_type) {
if (!tc->is_def_eq(e_type, meta_type, new_j, new_cs)) justification new_j = mk_type_mismatch_jst(e, e_type, meta_type);
throw unifier_exception(new_j, s); constraint_seq fcs = new_cs;
buffer<constraint> cs_buffer; if (!tc->is_def_eq(e_type, meta_type, new_j, fcs))
new_cs.linearize(cs_buffer); throw unifier_exception(new_j, s);
metavar_closure cls(meta); buffer<constraint> cs_buffer;
cls.add(meta_type); fcs.linearize(cs_buffer);
cls.mk_constraints(s, j, relax, cs_buffer); metavar_closure cls(meta);
cs_buffer.push_back(mk_eq_cnstr(meta, e, j, relax)); cls.add(meta_type);
cls.mk_constraints(s, j, relax, cs_buffer);
cs_buffer.push_back(mk_eq_cnstr(meta, e, j, relax));
unifier_config new_cfg(cfg); unifier_config new_cfg(cfg);
new_cfg.m_discard = false; new_cfg.m_discard = false;
unify_result_seq seq = unify(env, cs_buffer.size(), cs_buffer.data(), ngen, substitution(), new_cfg); unify_result_seq seq = unify(env, cs_buffer.size(), cs_buffer.data(), ngen, substitution(), new_cfg);
auto p = seq.pull(); auto p = seq.pull();
lean_assert(p); lean_assert(p);
substitution new_s = p->first.first; substitution new_s = p->first.first;
constraints postponed = map(p->first.second, constraints postponed = map(p->first.second,
[&](constraint const & c) { [&](constraint const & c) {
// we erase internal justifications // we erase internal justifications
return update_justification(c, j); return update_justification(c, j);
}); });
if (im) if (im)
im->instantiate(new_s); im->instantiate(new_s);
constraints r = cls.mk_constraints(new_s, j, relax); constraints r = cls.mk_constraints(new_s, j, relax);
return append(r, postponed); return append(r, postponed);
};
std::unique_ptr<exception> saved_ex;
try {
return try_alternative(e, e_type);
} catch (exception & ex) {
saved_ex.reset(ex.clone());
}
if (auto p = apply_symmetry(env, ctx, ngen, e, e_type, g)) {
try { return try_alternative(p->first, p->second); } catch (exception &) {}
}
saved_ex->rethrow();
lean_unreachable();
}; };
bool owner = false; bool owner = false;
return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::Epilogue), owner, j, relax); return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::Epilogue), owner, j, relax);

View file

@ -85,7 +85,7 @@ void init_token_table(token_table & t) {
"evaluate", "check", "eval", "[priority", "print", "end", "namespace", "section", "import", "evaluate", "check", "eval", "[priority", "print", "end", "namespace", "section", "import",
"inductive", "record", "renaming", "extends", "structure", "module", "universe", "universes", "inductive", "record", "renaming", "extends", "structure", "module", "universe", "universes",
"precedence", "reserve", "infixl", "infixr", "infix", "postfix", "prefix", "notation", "context", "precedence", "reserve", "infixl", "infixr", "infix", "postfix", "prefix", "notation", "context",
"exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "tactic_hint", "exit", "set_option", "open", "export", "calc_subst", "calc_refl", "calc_trans", "calc_symm", "tactic_hint",
"add_begin_end_tactic", "set_begin_end_tactic", "instance", "class", "add_begin_end_tactic", "set_begin_end_tactic", "instance", "class",
"include", "omit", "#erase_cache", "#projections", nullptr}; "include", "omit", "#erase_cache", "#projections", nullptr};

View file

@ -1,4 +1,4 @@
import algebra.category.basic import logic algebra.category.basic
open eq eq.ops category functor natural_transformation open eq eq.ops category functor natural_transformation
variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D} variables {obC obD : Type} {C : category obC} {D : category obD} {F G H : C ⇒ D}
@ -9,6 +9,10 @@ natural_transformation.mk
(λ a b f, calc (λ a b f, calc
H f ∘ (η a ∘ θ a) = (H f ∘ η a) ∘ θ a : assoc H f ∘ (η a ∘ θ a) = (H f ∘ η a) ∘ θ a : assoc
... = (η b ∘ G f) ∘ θ a : {naturality η f} ... = (η b ∘ G f) ∘ θ a : {naturality η f}
... = η b ∘ (G f ∘ θ a) : symm !assoc ... = η b ∘ (G f ∘ θ a) : assoc
... = η b ∘ (θ b ∘ F f) : {naturality θ f} ... = η b ∘ (θ b ∘ F f) : {naturality θ f}
... = (η b ∘ θ b) ∘ F f : assoc) ... = (η b ∘ θ b) ∘ F f : assoc)
theorem tst (a b c : num) (H₁ : ∀ x, b = x) (H₂ : c = b) : a = c :=
calc a = b : H₁
... = c : H₂