fix(library/definitional/equations): allow a function to be the result of a match-with term or recursive definition
This commit is contained in:
parent
61cdec73f6
commit
4edd7b9099
2 changed files with 76 additions and 4 deletions
|
@ -351,6 +351,7 @@ class equation_compiler_fn {
|
|||
}
|
||||
|
||||
[[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
|
||||
[[ noreturn ]] static void throw_error(sstream const & ss, expr const & src) { throw_generic_exception(ss, src); }
|
||||
[[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
|
||||
[[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); }
|
||||
[[ noreturn ]] void throw_error(expr const & src, sstream const & ss) const { throw_generic_exception(ss, src); }
|
||||
|
@ -493,19 +494,66 @@ class equation_compiler_fn {
|
|||
}
|
||||
}
|
||||
|
||||
// Store in \c arities the number of arguments of each function being defined.
|
||||
// This procedure also makes sure that two different equations for the same function
|
||||
// contain the same number of arguments in the left-hand-side.
|
||||
// Remark: after executing this procedure the arity of m_fns[i] is stored in arities[i]
|
||||
// if there is at least one equation for m_fns[i].
|
||||
void initialize_arities(expr const & eqns, buffer<optional<unsigned>> & arities) {
|
||||
lean_assert(arities.empty());
|
||||
buffer<expr> eqs;
|
||||
to_equations(eqns, eqs);
|
||||
lean_assert(!eqs.empty());
|
||||
arities.resize(m_fns.size());
|
||||
for (expr eq : eqs) {
|
||||
if (is_lambda_equation(eq)) {
|
||||
for (expr const & fn : m_fns)
|
||||
eq = instantiate(binding_body(eq), fn);
|
||||
while (is_lambda(eq))
|
||||
eq = binding_body(eq);
|
||||
lean_assert(is_equation(eq));
|
||||
expr const & lhs = equation_lhs(eq);
|
||||
buffer<expr> lhs_args;
|
||||
expr const & lhs_fn = get_app_args(lhs, lhs_args);
|
||||
if (!is_local(lhs_fn))
|
||||
throw_error(sstream() << "invalid recursive equation, "
|
||||
<< "left-hand-side is not one of the functions being defined", eq);
|
||||
unsigned i = 0;
|
||||
for (; i < m_fns.size(); i++) {
|
||||
if (lhs_fn == m_fns[i]) {
|
||||
if (arities[i] && *arities[i] != lhs_args.size())
|
||||
throw_error(sstream() << "invalid recursive equation for '" << lhs_fn << "' "
|
||||
<< "left-hand-side of different equations have different number of arguments", eq);
|
||||
arities[i] = lhs_args.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the variable stack for each function that needs
|
||||
// to be compiled.
|
||||
// This method assumes m_fns has been already initialized.
|
||||
// This method also initialized the buffer prg, but the eqns
|
||||
// field of each program is not initialized by it.
|
||||
void initialize_var_stack(buffer<program> & prgs) {
|
||||
//
|
||||
// See initialize_arities for an explanation for \c arities.
|
||||
void initialize_var_stack(buffer<program> & prgs, buffer<optional<unsigned>> const & arities) {
|
||||
lean_assert(!m_fns.empty());
|
||||
lean_assert(prgs.empty());
|
||||
for (expr const & fn : m_fns) {
|
||||
for (unsigned i = 0; i < m_fns.size(); i++) {
|
||||
expr const & fn = m_fns[i];
|
||||
buffer<expr> args;
|
||||
expr r_type = to_telescope(mlocal_type(fn), args);
|
||||
expr r_type = to_telescope(mlocal_type(fn), args);
|
||||
for (expr & arg : args)
|
||||
arg = update_mlocal(arg, whnf(mlocal_type(arg)));
|
||||
if (arities[i]) {
|
||||
unsigned arity = *arities[i];
|
||||
if (args.size() > arity) {
|
||||
r_type = Pi(args.size() - arity, args.data() + arity, r_type);
|
||||
args.shrink(arity);
|
||||
}
|
||||
}
|
||||
list<expr> ctx = to_list(args);
|
||||
list<optional<name>> vstack = map2<optional<name>>(ctx, [](expr const & e) {
|
||||
return optional<name>(mlocal_name(e));
|
||||
|
@ -579,8 +627,10 @@ class equation_compiler_fn {
|
|||
// Create initial program state for each function being defined.
|
||||
void initialize(expr const & eqns, buffer<program> & prg) {
|
||||
lean_assert(is_equations(eqns));
|
||||
buffer<optional<unsigned>> arities;
|
||||
initialize_fns(eqns);
|
||||
initialize_var_stack(prg);
|
||||
initialize_arities(eqns, arities);
|
||||
initialize_var_stack(prg, arities);
|
||||
buffer<expr> eqs;
|
||||
to_equations(eqns, eqs);
|
||||
buffer<buffer<eqn>> res_eqns;
|
||||
|
|
22
tests/lean/run/match_fun.lean
Normal file
22
tests/lean/run/match_fun.lean
Normal file
|
@ -0,0 +1,22 @@
|
|||
open bool nat
|
||||
|
||||
definition foo (b : bool) : nat → nat :=
|
||||
match b with
|
||||
| tt := λ x : nat, zero
|
||||
| ff := λ y : nat, (succ zero)
|
||||
end
|
||||
|
||||
example : foo tt 1 = zero := rfl
|
||||
example : foo ff 1 = 1 := rfl
|
||||
|
||||
|
||||
definition zero_fn := λ x : nat, zero
|
||||
|
||||
definition foo2 : bool → nat → nat
|
||||
| foo2 tt := succ
|
||||
| foo2 ff := zero_fn
|
||||
|
||||
example : foo2 tt 1 = 2 := rfl
|
||||
example : foo2 tt 2 = 3 := rfl
|
||||
example : foo2 ff 1 = 0 := rfl
|
||||
example : foo2 ff 2 = 0 := rfl
|
Loading…
Reference in a new issue