feat(library/definitional/equations): brec_on compilation
This commit is contained in:
parent
98a856373d
commit
faf78ce3e6
2 changed files with 238 additions and 21 deletions
|
@ -20,6 +20,7 @@ Author: Leonardo de Moura
|
|||
#include "library/annotation.h"
|
||||
#include "library/util.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/normalize.h"
|
||||
#include "library/tactic/inversion_tactic.h"
|
||||
|
||||
namespace lean {
|
||||
|
@ -1016,7 +1017,136 @@ class equation_compiler_fn {
|
|||
return find_rec_args(prgs, 0, arg_pos, arg_types);
|
||||
}
|
||||
|
||||
expr compile_brec_on(buffer<program> & prgs) {
|
||||
// Auxiliary function object used to eliminate recursive applications using "below" applications
|
||||
struct elim_rec_apps_fn {
|
||||
equation_compiler_fn & m_main;
|
||||
buffer<program> const & m_prgs;
|
||||
unsigned m_nparams;
|
||||
buffer<expr> const & m_below_cnsts; // below constants
|
||||
buffer<expr> const & m_Cs_locals; // auxiliary local constants representing the "motives"
|
||||
buffer<unsigned> const & m_rec_arg_pos; // position of recursive argument for each program
|
||||
buffer<buffer<unsigned>> const & m_rest_pos; // position of remaining arguments for each program
|
||||
|
||||
elim_rec_apps_fn(equation_compiler_fn & m, buffer<program> const & prgs, unsigned nparams,
|
||||
buffer<expr> const & below_cnsts, buffer<expr> const & Cs_locals, buffer<unsigned> const & rec_arg_pos,
|
||||
buffer<buffer<unsigned>> const & rest_pos):
|
||||
m_main(m), m_prgs(prgs), m_nparams(nparams), m_below_cnsts(below_cnsts), m_Cs_locals(Cs_locals),
|
||||
m_rec_arg_pos(rec_arg_pos), m_rest_pos(rest_pos) {}
|
||||
|
||||
bool is_below_type(expr const & t) const {
|
||||
expr const & fn = get_app_fn(t);
|
||||
return is_constant(fn) && std::find(m_below_cnsts.begin(), m_below_cnsts.end(), fn) != m_below_cnsts.end();
|
||||
}
|
||||
|
||||
/** \brief Retrieve \c a from the below dictionary \c d. \c d is a term made of products, and C's from (m_Cs_locals).
|
||||
\c b is the below constant that was used to create the below dictionary \c d.
|
||||
*/
|
||||
optional<expr> to_below(expr const & d, expr const & a, expr const & b) {
|
||||
expr const & fn = get_app_fn(d);
|
||||
if (is_constant(fn) && const_name(fn) == "prod") {
|
||||
if (auto r = to_below(app_arg(app_fn(d)), a, mk_pr1(m_main.m_tc, b)))
|
||||
return r;
|
||||
else if (auto r = to_below(app_arg(d), a, mk_pr2(m_main.m_tc, b)))
|
||||
return r;
|
||||
else
|
||||
return none_expr();
|
||||
} else if (is_constant(fn) && const_name(fn) == "and") {
|
||||
// For ibelow, we use "and" instead of products
|
||||
if (auto r = to_below(app_arg(app_fn(d)), a, mk_and_elim_left(m_main.m_tc, b)))
|
||||
return r;
|
||||
else if (auto r = to_below(app_arg(d), a, mk_and_elim_right(m_main.m_tc, b)))
|
||||
return r;
|
||||
else
|
||||
return none_expr();
|
||||
} else if (is_local(fn)) {
|
||||
for (expr const & C : m_Cs_locals) {
|
||||
if (mlocal_name(C) == mlocal_name(fn) && app_arg(d) == a)
|
||||
return some_expr(b);
|
||||
}
|
||||
return none_expr();
|
||||
} else if (is_pi(d)) {
|
||||
// TODO(Leo)
|
||||
return none_expr();
|
||||
} else {
|
||||
return none_expr();
|
||||
}
|
||||
}
|
||||
|
||||
expr elim(unsigned prg_idx, buffer<expr> const & args, expr const & below) {
|
||||
// Replace motives with abstract ones. We use the abstract motives (m_Cs_locals) as "markers"
|
||||
buffer<expr> below_args;
|
||||
expr const & below_cnst = get_app_args(mlocal_type(below), below_args);
|
||||
buffer<expr> abst_below_args;
|
||||
abst_below_args.append(m_nparams, below_args.data());
|
||||
abst_below_args.append(m_Cs_locals);
|
||||
for (unsigned i = m_nparams + m_Cs_locals.size(); i < below_args.size(); i++)
|
||||
abst_below_args.push_back(below_args[i]);
|
||||
expr abst_below = mk_app(below_cnst, abst_below_args);
|
||||
expr below_dict = normalize(m_main.m_tc, abst_below);
|
||||
expr rec_arg = normalize(m_main.m_tc, args[m_rec_arg_pos[prg_idx]]);
|
||||
if (optional<expr> b = to_below(below_dict, rec_arg, below)) {
|
||||
expr r = *b;
|
||||
for (unsigned rest_pos : m_rest_pos[prg_idx])
|
||||
r = mk_app(r, args[rest_pos]);
|
||||
return r;
|
||||
} else {
|
||||
m_main.throw_error(sstream() << "failed to compile recursive equations using brec_on approach (possible solution: use well-founded recursion)");
|
||||
}
|
||||
}
|
||||
|
||||
/** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */
|
||||
expr elim(expr const & e, optional<expr> const & b) {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Meta:
|
||||
case expr_kind::Local: case expr_kind::Constant:
|
||||
case expr_kind::Sort:
|
||||
return e;
|
||||
case expr_kind::Macro: {
|
||||
buffer<expr> new_args;
|
||||
for (unsigned i = 0; i < macro_num_args(e); i++)
|
||||
new_args.push_back(elim(macro_arg(e, i), b));
|
||||
return update_macro(e, new_args.size(), new_args.data());
|
||||
}
|
||||
case expr_kind::App: {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
expr new_fn = elim(fn, b);
|
||||
buffer<expr> new_args;
|
||||
for (expr const & arg : args)
|
||||
new_args.push_back(elim(arg, b));
|
||||
if (is_local(fn) && b) {
|
||||
for (unsigned j = 0; j < m_prgs.size(); j++) {
|
||||
if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) {
|
||||
return elim(j, new_args, *b);
|
||||
}
|
||||
}
|
||||
}
|
||||
return mk_app(new_fn, new_args);
|
||||
}
|
||||
case expr_kind::Lambda: {
|
||||
expr local = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
|
||||
expr body = instantiate(binding_body(e), local);
|
||||
expr new_body;
|
||||
if (is_below_type(binding_domain(e)))
|
||||
new_body = elim(body, some_expr(local));
|
||||
else
|
||||
new_body = elim(body, b);
|
||||
return Fun(local, new_body);
|
||||
}
|
||||
case expr_kind::Pi: {
|
||||
expr local = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
|
||||
expr new_body = elim(instantiate(binding_body(e), local), b);
|
||||
return Pi(local, new_body);
|
||||
}}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
expr operator()(expr const & e) {
|
||||
return elim(e, none_expr());
|
||||
}
|
||||
};
|
||||
|
||||
expr compile_brec_on(buffer<program> const & prgs) {
|
||||
lean_assert(!prgs.empty());
|
||||
buffer<unsigned> arg_pos;
|
||||
if (!find_rec_args(prgs, arg_pos)) {
|
||||
|
@ -1054,9 +1184,15 @@ class equation_compiler_fn {
|
|||
unsigned nparams = std::get<1>(t);
|
||||
list<inductive::inductive_decl> decls = std::get<2>(t);
|
||||
|
||||
// TODO(Leo): move parameters to global context.
|
||||
// we should also check if the user tried to perform pattern matching on parameters
|
||||
|
||||
// Distribute parameters of the ith program intro three groups:
|
||||
// indices, major premise (arg), and remaining arguments (rest)
|
||||
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
|
||||
// We store the position of the rest arguments in the buffer rest_pos.
|
||||
// The buffer rest_pos is used to replace the recursive applications with below applications.
|
||||
auto distribute_context_core = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest,
|
||||
buffer<unsigned> & indices_pos, buffer<unsigned> & rest_pos) {
|
||||
program const & p = prgs[i];
|
||||
arg = get_rec_arg(i);
|
||||
list<expr> const & ctx = p.m_context;
|
||||
|
@ -1064,24 +1200,62 @@ class equation_compiler_fn {
|
|||
get_app_args(mlocal_type(arg), arg_args);
|
||||
lean_assert(nparams <= arg_args.size());
|
||||
indices.append(arg_args.size() - nparams, arg_args.data() + nparams);
|
||||
unsigned j = 0;
|
||||
for (expr const & l : ctx) {
|
||||
if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices))
|
||||
if (mlocal_name(l) == mlocal_name(arg)) {
|
||||
// do nothing
|
||||
} else if (contains_local(l, indices)) {
|
||||
indices_pos.push_back(j);
|
||||
} else {
|
||||
rest.push_back(l);
|
||||
rest_pos.push_back(j);
|
||||
}
|
||||
j++;
|
||||
}
|
||||
};
|
||||
|
||||
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
|
||||
buffer<unsigned> indices_pos, rest_pos;
|
||||
distribute_context_core(i, indices, arg, rest, indices_pos, rest_pos);
|
||||
};
|
||||
|
||||
// Compute the resulting universe level for brec_on
|
||||
auto get_brec_on_result_level = [&]() -> level {
|
||||
buffer<expr> indices, rest;
|
||||
expr arg;
|
||||
buffer<expr> indices, rest; expr arg;
|
||||
distribute_context(0, indices, arg, rest);
|
||||
expr r_type = Pi(indices, prgs[0].m_type);
|
||||
expr r_type = Pi(rest, prgs[0].m_type);
|
||||
return sort_level(m_tc.ensure_type(r_type).first);
|
||||
};
|
||||
|
||||
level rlvl = get_brec_on_result_level();
|
||||
levels brec_on_lvls = cons(rlvl, const_levels(I0));
|
||||
expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
|
||||
bool reflexive = env().prop_proof_irrel() && is_reflexive_datatype(m_tc, const_name(I0));
|
||||
bool use_ibelow = reflexive && is_zero(rlvl);
|
||||
if (reflexive) {
|
||||
if (!is_zero(rlvl) && !is_not_zero(rlvl))
|
||||
throw_error(sstream() << "invalid recursive equations, when trying to recurse over reflexive inductive datatype, "
|
||||
<< "the universe level of the resultant universe must be zero OR not zero for every level assignment");
|
||||
if (!is_zero(rlvl)) {
|
||||
// For reflexive type, the type of brec_on and ibelow perform a +1 on the motive universe.
|
||||
// Example: for a reflexive formula type, we have:
|
||||
// formula.below.{l_1} : Π {C : formula → Type.{l_1+1}}, formula → Type.{max (l_1+1) 1}
|
||||
if (auto dlvl = dec_level(rlvl)) {
|
||||
rlvl = *dlvl;
|
||||
} else {
|
||||
throw_error(sstream() << "invalid recursive equations, when trying to recurse over reflexive inductive datatype, "
|
||||
<< "the universe level of the resultant universe must be zero OR not zero for every level assignment, "
|
||||
<< "the compiler managed to establish that the resultant universe level L is never zero, but fail to comput L-1");
|
||||
}
|
||||
}
|
||||
}
|
||||
levels brec_on_lvls;
|
||||
expr brec_on;
|
||||
if (use_ibelow) {
|
||||
brec_on_lvls = const_levels(I0);
|
||||
brec_on = mk_constant(name{const_name(I0), "binduction_on"}, brec_on_lvls);
|
||||
} else {
|
||||
brec_on_lvls = cons(rlvl, const_levels(I0));
|
||||
brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
|
||||
}
|
||||
buffer<expr> params;
|
||||
// add parameters
|
||||
for (unsigned i = 0; i < nparams; i++) {
|
||||
|
@ -1090,12 +1264,19 @@ class equation_compiler_fn {
|
|||
}
|
||||
|
||||
buffer<expr> Cs; // brec_on "motives"
|
||||
// The following loop fills Cs_locals with auxiliary local constants that will be used to
|
||||
// convert recursive applications into "below applications".
|
||||
// These constants are essentially abstracting Cs.
|
||||
buffer<expr> Cs_locals;
|
||||
buffer<buffer<expr>> C_args_buffer;
|
||||
for (inductive::inductive_decl const & decl : decls) {
|
||||
name const & I_name = inductive::inductive_decl_name(decl);
|
||||
expr C;
|
||||
C_args_buffer.push_back(buffer<expr>());
|
||||
buffer<expr> & C_args = C_args_buffer.back();
|
||||
expr C_type = whnf(infer_type(brec_on));
|
||||
expr C_local = mk_local(mk_fresh_name(), "C", C_type, binder_info());
|
||||
Cs_locals.push_back(C_local);
|
||||
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
||||
buffer<expr> indices, rest; expr arg;
|
||||
distribute_context(*p_idx, indices, arg, rest);
|
||||
|
@ -1104,8 +1285,7 @@ class equation_compiler_fn {
|
|||
C_args.push_back(arg);
|
||||
C = Fun(C_args, type);
|
||||
} else {
|
||||
expr type = whnf(infer_type(brec_on));
|
||||
expr d = binding_domain(type);
|
||||
expr d = binding_domain(C_type);
|
||||
expr unit = mk_constant("unit", rlvl);
|
||||
to_telescope_ext(d, C_args);
|
||||
C = Fun(C_args, unit);
|
||||
|
@ -1114,18 +1294,32 @@ class equation_compiler_fn {
|
|||
Cs.push_back(C);
|
||||
}
|
||||
|
||||
// add indices and major
|
||||
buffer<expr> indices0, rest0; expr arg0;
|
||||
distribute_context(0, indices0, arg0, rest0);
|
||||
brec_on = mk_app(mk_app(brec_on, indices0), arg0);
|
||||
|
||||
// add functionals
|
||||
unsigned i = 0;
|
||||
buffer<expr> below_cnsts;
|
||||
buffer<buffer<unsigned>> rest_arg_pos;
|
||||
for (inductive::inductive_decl const & decl : decls) {
|
||||
name const & I_name = inductive::inductive_decl_name(decl);
|
||||
expr below = mk_constant(name{I_name, "below"}, brec_on_lvls);
|
||||
below = mk_app(mk_app(below, params), Cs);
|
||||
expr below_cnst;
|
||||
if (use_ibelow)
|
||||
below_cnst = mk_constant(name{I_name, "ibelow"}, brec_on_lvls);
|
||||
else
|
||||
below_cnst = mk_constant(name{I_name, "below"}, brec_on_lvls);
|
||||
below_cnsts.push_back(below_cnst);
|
||||
expr below = mk_app(mk_app(below_cnst, params), Cs);
|
||||
expr F;
|
||||
buffer<expr> & C_args = C_args_buffer[i];
|
||||
rest_arg_pos.push_back(buffer<unsigned>());
|
||||
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
||||
program & prg_i = prgs[*p_idx];
|
||||
buffer<expr> indices, rest; expr arg;
|
||||
distribute_context(*p_idx, indices, arg, rest);
|
||||
program const & prg_i = prgs[*p_idx];
|
||||
buffer<expr> indices, rest; expr arg; buffer<unsigned> indices_pos;
|
||||
buffer<unsigned> & rest_pos = rest_arg_pos.back();
|
||||
distribute_context_core(*p_idx, indices, arg, rest, indices_pos, rest_pos);
|
||||
below = mk_app(mk_app(below, indices), arg);
|
||||
expr b = mk_local(mk_fresh_name(), "b", below, binder_info());
|
||||
buffer<expr> new_ctx;
|
||||
|
@ -1133,9 +1327,7 @@ class equation_compiler_fn {
|
|||
new_ctx.push_back(arg);
|
||||
new_ctx.push_back(b);
|
||||
new_ctx.append(rest);
|
||||
prg_i.m_context = to_list(new_ctx);
|
||||
// TODO(Leo): replace recursive calls with "b" applications
|
||||
F = compile_pat_match(prg_i);
|
||||
F = compile_pat_match(program(prg_i, to_list(new_ctx)));
|
||||
} else {
|
||||
expr star = mk_constant(name{"unit", "star"}, rlvl);
|
||||
buffer<expr> F_args;
|
||||
|
@ -1147,10 +1339,14 @@ class equation_compiler_fn {
|
|||
brec_on = mk_app(brec_on, F);
|
||||
i++;
|
||||
}
|
||||
expr r = elim_rec_apps_fn(*this, prgs, nparams, below_cnsts, Cs_locals, arg_pos, rest_arg_pos)(brec_on);
|
||||
// add remaining arguments
|
||||
r = mk_app(r, rest0);
|
||||
|
||||
// out() << "brec_on: " << brec_on << "\n";
|
||||
|
||||
return brec_on;
|
||||
buffer<expr> ctx0_buffer;
|
||||
to_buffer(prgs[0].m_context, ctx0_buffer);
|
||||
r = Fun(ctx0_buffer, r);
|
||||
return r;
|
||||
}
|
||||
|
||||
expr compile_wf(buffer<program> & /* prgs */) {
|
||||
|
|
21
tests/lean/run/eq4.lean
Normal file
21
tests/lean/run/eq4.lean
Normal file
|
@ -0,0 +1,21 @@
|
|||
open nat
|
||||
|
||||
definition half : nat → nat,
|
||||
half 0 := 0,
|
||||
half 1 := 0,
|
||||
half (x+2) := half x + 1
|
||||
|
||||
theorem half0 : half 0 = 0 :=
|
||||
rfl
|
||||
|
||||
theorem half1 : half 1 = 0 :=
|
||||
rfl
|
||||
|
||||
theorem half_succ_succ (a : nat) : half (a + 2) = half a + 1 :=
|
||||
rfl
|
||||
|
||||
example : half 5 = 2 :=
|
||||
rfl
|
||||
|
||||
example : half 8 = 4 :=
|
||||
rfl
|
Loading…
Reference in a new issue