fix(library/definitional/equations): fixes #541

This commit allows recursive applications to have less or more arguments
than the equation left-hand-side.
We add two tests
   - 541a.lean  recursive call with more arguments
   - 542b.lean  recursive call with less arguments
This commit is contained in:
Leonardo de Moura 2015-05-10 20:37:44 -07:00
parent 8f711873a0
commit c5fb3ec6d0
3 changed files with 147 additions and 42 deletions

View file

@ -15,6 +15,7 @@ Author: Leonardo de Moura
#include "kernel/error_msgs.h"
#include "kernel/for_each_fn.h"
#include "kernel/find_fn.h"
#include "kernel/replace_fn.h"
#include "library/generic_exception.h"
#include "library/kernel_serializer.h"
#include "library/io_state_stream.h"
@ -1298,6 +1299,9 @@ class equation_compiler_fn {
return is_constant(fn) && std::find(m_below_cnsts.begin(), m_below_cnsts.end(), fn) != m_below_cnsts.end();
}
/** \brief Return the number of arguments in the left-hand-side of program prg_idx */
unsigned get_lhs_size(unsigned prg_idx) const { return length(m_prgs[prg_idx].m_context); }
/** \brief Retrieve \c a from the below dictionary \c d. \c d is a term made of products, and C's from (m_Cs_locals).
\c b is the below constant that was used to create the below dictionary \c d.
*/
@ -1344,13 +1348,19 @@ class equation_compiler_fn {
abst_below_args.append(m_Cs_locals);
for (unsigned i = m_nparams + m_Cs_locals.size(); i < below_args.size(); i++)
abst_below_args.push_back(below_args[i]);
expr abst_below = mk_app(below_cnst, abst_below_args);
expr below_dict = normalize(m_main.m_tc, abst_below);
expr rec_arg = normalize(m_main.m_tc, args[m_rec_arg_pos[prg_idx]]);
expr abst_below = mk_app(below_cnst, abst_below_args);
expr below_dict = normalize(m_main.m_tc, abst_below);
expr rec_arg = normalize(m_main.m_tc, args[m_rec_arg_pos[prg_idx]]);
unsigned lhs_size = get_lhs_size(prg_idx);
if (optional<expr> b = to_below(below_dict, rec_arg, below)) {
expr r = *b;
for (unsigned rest_pos : m_rest_pos[prg_idx])
r = mk_app(r, args[rest_pos], g);
for (unsigned rest_pos : m_rest_pos[prg_idx]) {
if (rest_pos < args.size())
r = mk_app(r, args[rest_pos], g);
}
for (unsigned i = lhs_size; i < args.size(); i++) {
r = mk_app(r, args[i], g);
}
return r;
} else {
m_main.throw_error(sstream() << "failed to compile recursive equations using "
@ -1411,6 +1421,35 @@ class equation_compiler_fn {
}
};
// Fix the i-th argument in the Pi-type t
expr fix_fn_type(expr const & t, unsigned i, expr const & p) {
if (!is_pi(t)) {
throw_error(sstream() << "invalid recursive equation, failed to move parameter '" << p << "'");
} else if (i == 0) {
return instantiate(binding_body(t), p);
} else {
expr local = mk_local(mk_fresh_name(), binding_name(t), binding_domain(t), binding_info(t));
expr body = fix_fn_type(instantiate(binding_body(t), local), i-1, p);
return Pi(local, body);
}
}
// For each function application (fn ...) in e, replace it with (new_fn ...) and remove the i-th
// argument.
expr fix_rec_arg(expr const & fn, expr const & new_fn, unsigned i, expr const & e) {
return ::lean::replace(e, [&](expr const & e) {
if (is_app(e) && get_app_fn(e) == fn) {
buffer<expr> args;
get_app_args(e, args);
if (i < args.size())
args.erase(i);
return some_expr(mk_app(new_fn, args));
} else {
return none_expr();
}
});
}
// Move inductive datatype parameters occuring in prg to m_additional_context
pair<program, unsigned> move_params(program const & prg, unsigned arg_pos) {
expr const & a_type = mlocal_type(get_ith(prg.m_context, arg_pos));
@ -1428,34 +1467,40 @@ class equation_compiler_fn {
buffer<eqn> new_eqns;
to_buffer(prg.m_var_stack, new_var_stack);
to_buffer(prg.m_eqns, new_eqns);
expr new_fn = prg.m_fn;
for (expr const & param : params) {
if (!contains_local(param, m_global_context)) {
m_additional_context.push_back(param);
new_context = remove(new_context, param);
unsigned i = 0;
for (; i < new_var_stack.size(); i++) {
if (*new_var_stack[i] == mlocal_name(param))
break;
}
lean_assert(i < new_var_stack.size());
lean_assert(i != arg_pos);
if (i < arg_pos)
arg_pos--;
new_var_stack.erase(i);
for (eqn & e : new_eqns) {
expr const & p = get_ith(e.m_patterns, i);
if (!is_local(p)) {
throw_error(sstream() << "invalid recursive equations, "
<< "trying to pattern match inductive datatype parameter '" << p << "'");
} else {
list<expr> new_local_ctx = remove(e.m_local_context, p);
list<expr> new_patterns = ::lean::remove(e.m_patterns, p);
e = replace(eqn(e, new_local_ctx, new_patterns), p, param);
}
if (contains_local(param, m_global_context))
continue; // parameter doesn't need to be moved
m_additional_context.push_back(param);
new_context = remove(new_context, param);
unsigned i = 0;
for (; i < new_var_stack.size(); i++) {
if (*new_var_stack[i] == mlocal_name(param))
break;
}
lean_assert(i < new_var_stack.size());
lean_assert(i != arg_pos);
expr new_fn_type = fix_fn_type(mlocal_type(new_fn), i, param);
expr new_new_fn = update_mlocal(new_fn, new_fn_type);
if (i < arg_pos)
arg_pos--;
new_var_stack.erase(i);
for (eqn & e : new_eqns) {
expr const & p = get_ith(e.m_patterns, i);
if (!is_local(p)) {
throw_error(sstream() << "invalid recursive equations, "
<< "trying to pattern match inductive datatype parameter '" << p << "'");
} else {
list<expr> new_local_ctx = remove(e.m_local_context, p);
list<expr> new_patterns = ::lean::remove(e.m_patterns, p);
expr new_rhs = fix_rec_arg(new_fn, new_new_fn, i, e.m_rhs);
e = replace(eqn(new_local_ctx, new_patterns, new_rhs), p, param);
}
}
new_fn = new_new_fn;
}
return mk_pair(program(prg, new_context, to_list(new_var_stack), to_list(new_eqns)), arg_pos);
return mk_pair(program(new_fn, new_context, to_list(new_var_stack), to_list(new_eqns), prg.m_type), arg_pos);
}
}
@ -1471,8 +1516,7 @@ class equation_compiler_fn {
lean_assert(check_program(prgs[0]));
}
expr compile_brec_on_core(buffer<program> const & prgs,
buffer<unsigned> const & arg_pos, buffer<unsigned> const & rec_arg_pos) {
expr compile_brec_on_core(buffer<program> const & prgs, buffer<unsigned> const & arg_pos) {
// Return the recursive argument of the i-th program
auto get_rec_arg = [&](unsigned i) -> expr {
program const & pi = prgs[i];
@ -1669,7 +1713,7 @@ class equation_compiler_fn {
brec_on = mk_app(brec_on, F);
i++;
}
expr r = elim_rec_apps_fn(*this, prgs, nparams, below_cnsts, Cs_locals, rec_arg_pos, rest_arg_pos)(brec_on);
expr r = elim_rec_apps_fn(*this, prgs, nparams, below_cnsts, Cs_locals, arg_pos, rest_arg_pos)(brec_on);
// add remaining arguments
r = mk_app(r, rest0);
@ -1681,17 +1725,12 @@ class equation_compiler_fn {
expr compile_brec_on(buffer<program> & prgs) {
lean_assert(!prgs.empty());
buffer<unsigned> rec_arg_pos;
if (!find_rec_args(prgs, rec_arg_pos)) {
buffer<unsigned> arg_pos;
if (!find_rec_args(prgs, arg_pos)) {
throw_error(sstream() << "invalid recursive equations, "
<< "failed to find recursive arguments that are structurally smaller "
<< "(possible solution: use well-founded recursion)");
}
// Remark: move_params updates argument positions.
// Thus, we copy rec_arg_pos to arg_pos.
// We use rec_arg_pos when invoking elim_rec_apps_fn
buffer<unsigned> arg_pos;
arg_pos.append(rec_arg_pos);
move_params(prgs, arg_pos);
buffer<expr> rs;
for (unsigned i = 0; i < prgs.size(); i++) {
@ -1702,11 +1741,9 @@ class equation_compiler_fn {
// This is wasteful. We should rewrite this.
std::swap(prgs[0], prgs[i]);
std::swap(arg_pos[0], arg_pos[i]);
std::swap(rec_arg_pos[0], rec_arg_pos[i]);
rs.push_back(compile_brec_on_core(prgs, arg_pos, rec_arg_pos));
rs.push_back(compile_brec_on_core(prgs, arg_pos));
std::swap(prgs[0], prgs[i]);
std::swap(arg_pos[0], arg_pos[i]);
std::swap(rec_arg_pos[0], rec_arg_pos[i]);
}
if (rs.size() > 1)

25
tests/lean/run/541a.lean Normal file
View file

@ -0,0 +1,25 @@
import data.list data.nat
open nat list eq.ops
theorem nat.le_of_eq {x y : } (H : x = y) : x ≤ y := H ▸ !le.refl
section
variable {Q : Type}
definition f : list Q → -- default if l is empty, else max l
| [] := 0
| (h :: t) := f t + 1
theorem f_foo : ∀{l : list Q}, ∀{q : Q}, q ∈ l → f l ≥ 1
| [] := take q, assume Hq, absurd Hq !not_mem_nil
| [h] := take q, assume Hq, nat.le_of_eq !rfl
| (h :: (h' :: t)) := take q, assume Hq,
have Hor : q = h q ∈ (h' :: t), from iff.mp !mem_cons_iff Hq,
have H : f (h' :: t) ≥ 1, from f_foo (mem_cons h' t),
have H1 : 1 + 1 ≤ f (h' :: t) + 1, from nat.add_le_add_right H 1,
calc
f (h :: h' :: t) = f (h' :: t) + 1 : rfl
... ≥ 1 + 1 : H1
... = 1 : sorry
end

43
tests/lean/run/541b.lean Normal file
View file

@ -0,0 +1,43 @@
import data.list
inductive typ : Type :=
| nat : typ
| arr : typ → typ → typ
inductive const : Type :=
| z | s
inductive exp : Type :=
| var : nat → exp
| cnst : const → exp
| lam : nat → typ → exp → exp
| ap : exp → exp → exp
open exp
inductive is_val : exp → Prop :=
| vcnst : Π c, is_val (cnst c)
| vlam : Π x t e, is_val (lam x t e)
| vcnst_ap : Π {e} c, is_val e → is_val (ap (cnst c) e)
inductive step : exp → exp → Prop :=
infix `➤`:50 := step
| stepl : Π {e1 e1'} e2, e1 ➤ e1' → ap e1 e2 ➤ ap e1' e2
| stepr : Π {e1 e2 e2'}, is_val e1 → e2 ➤ e2' → ap e1 e2 ➤ ap e1 e2'
| subst : Π {x e1 e1' e2} t, is_val e2 → ap (lam x t e1) e2 ➤ e1'
infix `➤`:50 := step
open is_val
open step
theorem nostep : ∀ {e} e', is_val e → e ➤ e' → false
| nostep e' (vcnst c) Hsteps := by cases Hsteps
| nostep e' (vlam x t e) Hsteps := by cases Hsteps
| nostep (ap e' e) (@vcnst_ap e c Hval) (stepl e Hbad) :=
have Hvalc : is_val (cnst c), from vcnst c,
have IH : not (cnst c ➤ e'), from nostep e' Hvalc,
absurd Hbad IH
| nostep (ap (cnst c) e') (@vcnst_ap e c Hvale) (stepr Hvalc Hbad) :=
have IH : not (e ➤ e'), from nostep e' Hvale,
absurd Hbad IH