feat(library/definitional/equations): brec_on compilation

This commit is contained in:
Leonardo de Moura 2015-01-04 17:45:13 -08:00
parent 98a856373d
commit faf78ce3e6
2 changed files with 238 additions and 21 deletions

View file

@ -20,6 +20,7 @@ Author: Leonardo de Moura
#include "library/annotation.h"
#include "library/util.h"
#include "library/locals.h"
#include "library/normalize.h"
#include "library/tactic/inversion_tactic.h"
namespace lean {
@ -1016,7 +1017,136 @@ class equation_compiler_fn {
return find_rec_args(prgs, 0, arg_pos, arg_types);
}
expr compile_brec_on(buffer<program> & prgs) {
// Auxiliary function object used to eliminate recursive applications using "below" applications
struct elim_rec_apps_fn {
equation_compiler_fn & m_main;
buffer<program> const & m_prgs;
unsigned m_nparams;
buffer<expr> const & m_below_cnsts; // below constants
buffer<expr> const & m_Cs_locals; // auxiliary local constants representing the "motives"
buffer<unsigned> const & m_rec_arg_pos; // position of recursive argument for each program
buffer<buffer<unsigned>> const & m_rest_pos; // position of remaining arguments for each program
elim_rec_apps_fn(equation_compiler_fn & m, buffer<program> const & prgs, unsigned nparams,
buffer<expr> const & below_cnsts, buffer<expr> const & Cs_locals, buffer<unsigned> const & rec_arg_pos,
buffer<buffer<unsigned>> const & rest_pos):
m_main(m), m_prgs(prgs), m_nparams(nparams), m_below_cnsts(below_cnsts), m_Cs_locals(Cs_locals),
m_rec_arg_pos(rec_arg_pos), m_rest_pos(rest_pos) {}
bool is_below_type(expr const & t) const {
expr const & fn = get_app_fn(t);
return is_constant(fn) && std::find(m_below_cnsts.begin(), m_below_cnsts.end(), fn) != m_below_cnsts.end();
}
/** \brief Retrieve \c a from the below dictionary \c d. \c d is a term made of products, and C's from (m_Cs_locals).
\c b is the below constant that was used to create the below dictionary \c d.
*/
optional<expr> to_below(expr const & d, expr const & a, expr const & b) {
expr const & fn = get_app_fn(d);
if (is_constant(fn) && const_name(fn) == "prod") {
if (auto r = to_below(app_arg(app_fn(d)), a, mk_pr1(m_main.m_tc, b)))
return r;
else if (auto r = to_below(app_arg(d), a, mk_pr2(m_main.m_tc, b)))
return r;
else
return none_expr();
} else if (is_constant(fn) && const_name(fn) == "and") {
// For ibelow, we use "and" instead of products
if (auto r = to_below(app_arg(app_fn(d)), a, mk_and_elim_left(m_main.m_tc, b)))
return r;
else if (auto r = to_below(app_arg(d), a, mk_and_elim_right(m_main.m_tc, b)))
return r;
else
return none_expr();
} else if (is_local(fn)) {
for (expr const & C : m_Cs_locals) {
if (mlocal_name(C) == mlocal_name(fn) && app_arg(d) == a)
return some_expr(b);
}
return none_expr();
} else if (is_pi(d)) {
// TODO(Leo)
return none_expr();
} else {
return none_expr();
}
}
expr elim(unsigned prg_idx, buffer<expr> const & args, expr const & below) {
// Replace motives with abstract ones. We use the abstract motives (m_Cs_locals) as "markers"
buffer<expr> below_args;
expr const & below_cnst = get_app_args(mlocal_type(below), below_args);
buffer<expr> abst_below_args;
abst_below_args.append(m_nparams, below_args.data());
abst_below_args.append(m_Cs_locals);
for (unsigned i = m_nparams + m_Cs_locals.size(); i < below_args.size(); i++)
abst_below_args.push_back(below_args[i]);
expr abst_below = mk_app(below_cnst, abst_below_args);
expr below_dict = normalize(m_main.m_tc, abst_below);
expr rec_arg = normalize(m_main.m_tc, args[m_rec_arg_pos[prg_idx]]);
if (optional<expr> b = to_below(below_dict, rec_arg, below)) {
expr r = *b;
for (unsigned rest_pos : m_rest_pos[prg_idx])
r = mk_app(r, args[rest_pos]);
return r;
} else {
m_main.throw_error(sstream() << "failed to compile recursive equations using brec_on approach (possible solution: use well-founded recursion)");
}
}
/** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */
expr elim(expr const & e, optional<expr> const & b) {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Meta:
case expr_kind::Local: case expr_kind::Constant:
case expr_kind::Sort:
return e;
case expr_kind::Macro: {
buffer<expr> new_args;
for (unsigned i = 0; i < macro_num_args(e); i++)
new_args.push_back(elim(macro_arg(e, i), b));
return update_macro(e, new_args.size(), new_args.data());
}
case expr_kind::App: {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
expr new_fn = elim(fn, b);
buffer<expr> new_args;
for (expr const & arg : args)
new_args.push_back(elim(arg, b));
if (is_local(fn) && b) {
for (unsigned j = 0; j < m_prgs.size(); j++) {
if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) {
return elim(j, new_args, *b);
}
}
}
return mk_app(new_fn, new_args);
}
case expr_kind::Lambda: {
expr local = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
expr body = instantiate(binding_body(e), local);
expr new_body;
if (is_below_type(binding_domain(e)))
new_body = elim(body, some_expr(local));
else
new_body = elim(body, b);
return Fun(local, new_body);
}
case expr_kind::Pi: {
expr local = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
expr new_body = elim(instantiate(binding_body(e), local), b);
return Pi(local, new_body);
}}
lean_unreachable();
}
expr operator()(expr const & e) {
return elim(e, none_expr());
}
};
expr compile_brec_on(buffer<program> const & prgs) {
lean_assert(!prgs.empty());
buffer<unsigned> arg_pos;
if (!find_rec_args(prgs, arg_pos)) {
@ -1054,9 +1184,15 @@ class equation_compiler_fn {
unsigned nparams = std::get<1>(t);
list<inductive::inductive_decl> decls = std::get<2>(t);
// TODO(Leo): move parameters to global context.
// we should also check if the user tried to perform pattern matching on parameters
// Distribute parameters of the ith program intro three groups:
// indices, major premise (arg), and remaining arguments (rest)
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
// We store the position of the rest arguments in the buffer rest_pos.
// The buffer rest_pos is used to replace the recursive applications with below applications.
auto distribute_context_core = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest,
buffer<unsigned> & indices_pos, buffer<unsigned> & rest_pos) {
program const & p = prgs[i];
arg = get_rec_arg(i);
list<expr> const & ctx = p.m_context;
@ -1064,24 +1200,62 @@ class equation_compiler_fn {
get_app_args(mlocal_type(arg), arg_args);
lean_assert(nparams <= arg_args.size());
indices.append(arg_args.size() - nparams, arg_args.data() + nparams);
unsigned j = 0;
for (expr const & l : ctx) {
if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices))
if (mlocal_name(l) == mlocal_name(arg)) {
// do nothing
} else if (contains_local(l, indices)) {
indices_pos.push_back(j);
} else {
rest.push_back(l);
rest_pos.push_back(j);
}
j++;
}
};
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
buffer<unsigned> indices_pos, rest_pos;
distribute_context_core(i, indices, arg, rest, indices_pos, rest_pos);
};
// Compute the resulting universe level for brec_on
auto get_brec_on_result_level = [&]() -> level {
buffer<expr> indices, rest;
expr arg;
buffer<expr> indices, rest; expr arg;
distribute_context(0, indices, arg, rest);
expr r_type = Pi(indices, prgs[0].m_type);
expr r_type = Pi(rest, prgs[0].m_type);
return sort_level(m_tc.ensure_type(r_type).first);
};
level rlvl = get_brec_on_result_level();
levels brec_on_lvls = cons(rlvl, const_levels(I0));
expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
bool reflexive = env().prop_proof_irrel() && is_reflexive_datatype(m_tc, const_name(I0));
bool use_ibelow = reflexive && is_zero(rlvl);
if (reflexive) {
if (!is_zero(rlvl) && !is_not_zero(rlvl))
throw_error(sstream() << "invalid recursive equations, when trying to recurse over reflexive inductive datatype, "
<< "the universe level of the resultant universe must be zero OR not zero for every level assignment");
if (!is_zero(rlvl)) {
// For reflexive type, the type of brec_on and ibelow perform a +1 on the motive universe.
// Example: for a reflexive formula type, we have:
// formula.below.{l_1} : Π {C : formula → Type.{l_1+1}}, formula → Type.{max (l_1+1) 1}
if (auto dlvl = dec_level(rlvl)) {
rlvl = *dlvl;
} else {
throw_error(sstream() << "invalid recursive equations, when trying to recurse over reflexive inductive datatype, "
<< "the universe level of the resultant universe must be zero OR not zero for every level assignment, "
<< "the compiler managed to establish that the resultant universe level L is never zero, but fail to comput L-1");
}
}
}
levels brec_on_lvls;
expr brec_on;
if (use_ibelow) {
brec_on_lvls = const_levels(I0);
brec_on = mk_constant(name{const_name(I0), "binduction_on"}, brec_on_lvls);
} else {
brec_on_lvls = cons(rlvl, const_levels(I0));
brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
}
buffer<expr> params;
// add parameters
for (unsigned i = 0; i < nparams; i++) {
@ -1090,12 +1264,19 @@ class equation_compiler_fn {
}
buffer<expr> Cs; // brec_on "motives"
// The following loop fills Cs_locals with auxiliary local constants that will be used to
// convert recursive applications into "below applications".
// These constants are essentially abstracting Cs.
buffer<expr> Cs_locals;
buffer<buffer<expr>> C_args_buffer;
for (inductive::inductive_decl const & decl : decls) {
name const & I_name = inductive::inductive_decl_name(decl);
expr C;
C_args_buffer.push_back(buffer<expr>());
buffer<expr> & C_args = C_args_buffer.back();
expr C_type = whnf(infer_type(brec_on));
expr C_local = mk_local(mk_fresh_name(), "C", C_type, binder_info());
Cs_locals.push_back(C_local);
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
buffer<expr> indices, rest; expr arg;
distribute_context(*p_idx, indices, arg, rest);
@ -1104,8 +1285,7 @@ class equation_compiler_fn {
C_args.push_back(arg);
C = Fun(C_args, type);
} else {
expr type = whnf(infer_type(brec_on));
expr d = binding_domain(type);
expr d = binding_domain(C_type);
expr unit = mk_constant("unit", rlvl);
to_telescope_ext(d, C_args);
C = Fun(C_args, unit);
@ -1114,18 +1294,32 @@ class equation_compiler_fn {
Cs.push_back(C);
}
// add indices and major
buffer<expr> indices0, rest0; expr arg0;
distribute_context(0, indices0, arg0, rest0);
brec_on = mk_app(mk_app(brec_on, indices0), arg0);
// add functionals
unsigned i = 0;
buffer<expr> below_cnsts;
buffer<buffer<unsigned>> rest_arg_pos;
for (inductive::inductive_decl const & decl : decls) {
name const & I_name = inductive::inductive_decl_name(decl);
expr below = mk_constant(name{I_name, "below"}, brec_on_lvls);
below = mk_app(mk_app(below, params), Cs);
expr below_cnst;
if (use_ibelow)
below_cnst = mk_constant(name{I_name, "ibelow"}, brec_on_lvls);
else
below_cnst = mk_constant(name{I_name, "below"}, brec_on_lvls);
below_cnsts.push_back(below_cnst);
expr below = mk_app(mk_app(below_cnst, params), Cs);
expr F;
buffer<expr> & C_args = C_args_buffer[i];
rest_arg_pos.push_back(buffer<unsigned>());
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
program & prg_i = prgs[*p_idx];
buffer<expr> indices, rest; expr arg;
distribute_context(*p_idx, indices, arg, rest);
program const & prg_i = prgs[*p_idx];
buffer<expr> indices, rest; expr arg; buffer<unsigned> indices_pos;
buffer<unsigned> & rest_pos = rest_arg_pos.back();
distribute_context_core(*p_idx, indices, arg, rest, indices_pos, rest_pos);
below = mk_app(mk_app(below, indices), arg);
expr b = mk_local(mk_fresh_name(), "b", below, binder_info());
buffer<expr> new_ctx;
@ -1133,9 +1327,7 @@ class equation_compiler_fn {
new_ctx.push_back(arg);
new_ctx.push_back(b);
new_ctx.append(rest);
prg_i.m_context = to_list(new_ctx);
// TODO(Leo): replace recursive calls with "b" applications
F = compile_pat_match(prg_i);
F = compile_pat_match(program(prg_i, to_list(new_ctx)));
} else {
expr star = mk_constant(name{"unit", "star"}, rlvl);
buffer<expr> F_args;
@ -1147,10 +1339,14 @@ class equation_compiler_fn {
brec_on = mk_app(brec_on, F);
i++;
}
expr r = elim_rec_apps_fn(*this, prgs, nparams, below_cnsts, Cs_locals, arg_pos, rest_arg_pos)(brec_on);
// add remaining arguments
r = mk_app(r, rest0);
// out() << "brec_on: " << brec_on << "\n";
return brec_on;
buffer<expr> ctx0_buffer;
to_buffer(prgs[0].m_context, ctx0_buffer);
r = Fun(ctx0_buffer, r);
return r;
}
expr compile_wf(buffer<program> & /* prgs */) {

21
tests/lean/run/eq4.lean Normal file
View file

@ -0,0 +1,21 @@
open nat
definition half : nat → nat,
half 0 := 0,
half 1 := 0,
half (x+2) := half x + 1
theorem half0 : half 0 = 0 :=
rfl
theorem half1 : half 1 = 0 :=
rfl
theorem half_succ_succ (a : nat) : half (a + 2) = half a + 1 :=
rfl
example : half 5 = 2 :=
rfl
example : half 8 = 4 :=
rfl