fix(library/definitional/equations): allow a function to be the result of a match-with term or recursive definition

This commit is contained in:
Leonardo de Moura 2015-03-06 15:08:52 -08:00
parent 61cdec73f6
commit 4edd7b9099
2 changed files with 76 additions and 4 deletions

View file

@ -351,6 +351,7 @@ class equation_compiler_fn {
}
[[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
[[ noreturn ]] static void throw_error(sstream const & ss, expr const & src) { throw_generic_exception(ss, src); }
[[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
[[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); }
[[ noreturn ]] void throw_error(expr const & src, sstream const & ss) const { throw_generic_exception(ss, src); }
@ -493,19 +494,66 @@ class equation_compiler_fn {
}
}
// Store in \c arities the number of arguments of each function being defined.
// This procedure also makes sure that two different equations for the same function
// contain the same number of arguments in the left-hand-side.
// Remark: after executing this procedure the arity of m_fns[i] is stored in arities[i]
// if there is at least one equation for m_fns[i].
void initialize_arities(expr const & eqns, buffer<optional<unsigned>> & arities) {
lean_assert(arities.empty());
buffer<expr> eqs;
to_equations(eqns, eqs);
lean_assert(!eqs.empty());
arities.resize(m_fns.size());
for (expr eq : eqs) {
if (is_lambda_equation(eq)) {
for (expr const & fn : m_fns)
eq = instantiate(binding_body(eq), fn);
while (is_lambda(eq))
eq = binding_body(eq);
lean_assert(is_equation(eq));
expr const & lhs = equation_lhs(eq);
buffer<expr> lhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
if (!is_local(lhs_fn))
throw_error(sstream() << "invalid recursive equation, "
<< "left-hand-side is not one of the functions being defined", eq);
unsigned i = 0;
for (; i < m_fns.size(); i++) {
if (lhs_fn == m_fns[i]) {
if (arities[i] && *arities[i] != lhs_args.size())
throw_error(sstream() << "invalid recursive equation for '" << lhs_fn << "' "
<< "left-hand-side of different equations have different number of arguments", eq);
arities[i] = lhs_args.size();
}
}
}
}
}
// Initialize the variable stack for each function that needs
// to be compiled.
// This method assumes m_fns has been already initialized.
// This method also initialized the buffer prg, but the eqns
// field of each program is not initialized by it.
void initialize_var_stack(buffer<program> & prgs) {
//
// See initialize_arities for an explanation for \c arities.
void initialize_var_stack(buffer<program> & prgs, buffer<optional<unsigned>> const & arities) {
lean_assert(!m_fns.empty());
lean_assert(prgs.empty());
for (expr const & fn : m_fns) {
for (unsigned i = 0; i < m_fns.size(); i++) {
expr const & fn = m_fns[i];
buffer<expr> args;
expr r_type = to_telescope(mlocal_type(fn), args);
for (expr & arg : args)
arg = update_mlocal(arg, whnf(mlocal_type(arg)));
if (arities[i]) {
unsigned arity = *arities[i];
if (args.size() > arity) {
r_type = Pi(args.size() - arity, args.data() + arity, r_type);
args.shrink(arity);
}
}
list<expr> ctx = to_list(args);
list<optional<name>> vstack = map2<optional<name>>(ctx, [](expr const & e) {
return optional<name>(mlocal_name(e));
@ -579,8 +627,10 @@ class equation_compiler_fn {
// Create initial program state for each function being defined.
void initialize(expr const & eqns, buffer<program> & prg) {
lean_assert(is_equations(eqns));
buffer<optional<unsigned>> arities;
initialize_fns(eqns);
initialize_var_stack(prg);
initialize_arities(eqns, arities);
initialize_var_stack(prg, arities);
buffer<expr> eqs;
to_equations(eqns, eqs);
buffer<buffer<eqn>> res_eqns;

View file

@ -0,0 +1,22 @@
open bool nat
definition foo (b : bool) : nat → nat :=
match b with
| tt := λ x : nat, zero
| ff := λ y : nat, (succ zero)
end
example : foo tt 1 = zero := rfl
example : foo ff 1 = 1 := rfl
definition zero_fn := λ x : nat, zero
definition foo2 : bool → nat → nat
| foo2 tt := succ
| foo2 ff := zero_fn
example : foo2 tt 1 = 2 := rfl
example : foo2 tt 2 = 3 := rfl
example : foo2 ff 1 = 0 := rfl
example : foo2 ff 2 = 0 := rfl