feat(library/definitional/equations): add brec_on based compilation

This commit is contained in:
Leonardo de Moura 2015-01-03 22:23:37 -08:00
parent 42354cd4fa
commit bdfa919098

View file

@ -19,6 +19,7 @@ Author: Leonardo de Moura
#include "library/io_state_stream.h"
#include "library/annotation.h"
#include "library/util.h"
#include "library/locals.h"
#include "library/tactic/inversion_tactic.h"
namespace lean {
@ -338,23 +339,16 @@ class equation_compiler_fn {
}
};
// For debugging purposes
template<typename T>
static bool contains_local(T const & locals, expr const & l) {
return std::any_of(locals.begin(), locals.end(),
[&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); });
}
#ifdef LEAN_DEBUG
// For debugging purposes: checks whether all local constants occurring in \c e
// are in local_ctx or m_global_context
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
for_each(e, [&](expr const & e, unsigned) {
if (is_local(e) &&
!(contains_local(local_context, e) ||
contains_local(context, e) ||
contains_local(m_global_context, e) ||
contains_local(m_fns, e))) {
!(contains_local(e, local_context) ||
contains_local(e, context) ||
contains_local(e, m_global_context) ||
contains_local(e, m_fns))) {
lean_unreachable();
return false;
}
@ -617,10 +611,6 @@ class equation_compiler_fn {
return r;
}
static bool contains(list<expr> const & local_ctx, expr const & p) {
return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); });
}
expr compile_skip(program const & prg) {
lean_assert(!head(prg.m_var_stack));
auto new_stack = tail(prg.m_var_stack);
@ -646,7 +636,7 @@ class equation_compiler_fn {
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
} else {
lean_assert(is_local(p));
if (contains(e.m_local_context, p)) {
if (contains_local(p, e.m_local_context)) {
list<expr> new_local_ctx = e.m_local_context;
new_local_ctx = remove(new_local_ctx, p);
new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); });
@ -796,7 +786,7 @@ class equation_compiler_fn {
}
}
expr compile(program const & p) {
expr compile_pat_match(program const & p) {
buffer<expr> vars;
to_buffer(p.m_context, vars);
if (!is_proof_irrelevant()) {
@ -812,11 +802,7 @@ class equation_compiler_fn {
/** \brief Return true iff \c e is one of the functions being defined */
bool is_fn(expr const & e) const {
return
is_local(e) &&
std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) {
return mlocal_name(fn) == mlocal_name(e);
});
return is_local(e) && contains_local(e, m_fns);
}
/** \brief Return true iff the equations are recursive. */
@ -833,18 +819,39 @@ class equation_compiler_fn {
return false;
}
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/
bool is_inductive(expr const & t) const {
expr const & fn = get_app_fn(t);
return
is_constant(fn) &&
env().find(name{const_name(fn), "brec_on"});
/** \brief Return true if all locals are distinct local constants. */
static bool all_distinct_locals(unsigned num, expr const * locals) {
for (unsigned i = 0; i < num; i++) {
if (!is_local(locals[i]))
return false;
if (contains_local(locals[i], locals, locals + i))
return false;
}
return true;
}
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition,
and all indices of t are in ctx. */
bool is_rec_inductive(list<expr> const & ctx, expr const & t) const {
expr const & I = get_app_fn(t);
if (is_constant(I) && env().find(name{const_name(I), "brec_on"})) {
unsigned nindices = *inductive::get_num_indices(env(), const_name(I));
if (nindices > 0) {
buffer<expr> args;
get_app_args(I, args);
return
all_distinct_locals(nindices, args.end() - nindices) &&
std::all_of(args.end() - nindices, args.end(), [&](expr const & idx) { return contains_local(idx, ctx); });
} else {
return true;
}
} else {
return false;
}
}
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
bool is_compatible_inductive(expr const & t1, expr const & t2) {
lean_assert(is_inductive(t1));
lean_assert(is_inductive(t2));
buffer<expr> args1, args2;
name const & I1 = const_name(get_app_args(t1, args1));
name const & I2 = const_name(get_app_args(t2, args2));
@ -986,7 +993,7 @@ class equation_compiler_fn {
expr const & v = p.get_var(*n);
expr const & t = mlocal_type(v);
if (// argument must be an inductive datatype
is_inductive(t) &&
is_rec_inductive(p.m_context, t) &&
// argument must be an inductive datatype different from the ones in arg_types
std::all_of(arg_types.begin(), arg_types.end(),
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
@ -1011,23 +1018,149 @@ class equation_compiler_fn {
return find_rec_args(prgs, 0, arg_pos, arg_types);
}
void apply_brec_on(buffer<program> & prgs) {
expr compile_brec_on(buffer<program> & prgs) {
lean_assert(!prgs.empty());
buffer<unsigned> arg_pos;
if (!find_rec_args(prgs, arg_pos)) {
throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller "
<< "(possible solution: use well-founded recursion)");
}
// out() << "Found recursive arguments: ";
// for (unsigned p : arg_pos) out() << " " << p; out() << "\n";
// TODO(Leo):
// Return the recursive argument of the i-th program
auto get_rec_arg = [&](unsigned i) -> expr {
program const & pi = prgs[i];
return get_ith(pi.m_context, arg_pos[i]);
};
// Return the type of the recursive argument of the i-th program
auto get_rec_type = [&](unsigned i) -> expr {
return mlocal_type(get_rec_arg(i));
};
// Return the program associated with the inductive datatype named I_name.
// Return none if there isn't one.
auto get_prg_for = [&](name const & I_name) -> optional<unsigned> {
for (unsigned i = 0; i < prgs.size(); i++) {
expr const & t = get_rec_type(i);
if (const_name(get_app_fn(t)) == I_name)
return optional<unsigned>(i);
}
return optional<unsigned>();
};
expr const & a0_type = get_rec_type(0);
lean_assert(is_rec_inductive(prgs[0].m_context, a0_type));
buffer<expr> a0_type_args;
expr const & I0 = get_app_args(a0_type, a0_type_args);
inductive::inductive_decls t = *inductive::is_inductive_decl(env(), const_name(I0));
unsigned nparams = std::get<1>(t);
list<inductive::inductive_decl> decls = std::get<2>(t);
// Distribute parameters of the ith program intro three groups:
// indices, major premise (arg), and remaining arguments (rest)
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
program const & p = prgs[i];
arg = get_rec_arg(i);
list<expr> const & ctx = p.m_context;
buffer<expr> arg_args;
get_app_args(mlocal_type(arg), arg_args);
lean_assert(nparams <= arg_args.size());
indices.append(arg_args.size() - nparams, arg_args.data() + nparams);
for (expr const & l : ctx) {
if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices))
rest.push_back(l);
}
};
// Compute the resulting universe level for brec_on
auto get_brec_on_result_level = [&]() -> level {
buffer<expr> indices, rest;
expr arg;
distribute_context(0, indices, arg, rest);
expr r_type = Pi(indices, prgs[0].m_type);
return sort_level(m_tc.ensure_type(r_type).first);
};
level rlvl = get_brec_on_result_level();
levels brec_on_lvls = cons(rlvl, const_levels(I0));
expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
buffer<expr> params;
// add parameters
for (unsigned i = 0; i < nparams; i++) {
params.push_back(a0_type_args[i]);
brec_on = mk_app(brec_on, a0_type_args[i]);
}
buffer<expr> Cs; // brec_on "motives"
buffer<buffer<expr>> C_args_buffer;
for (inductive::inductive_decl const & decl : decls) {
name const & I_name = inductive::inductive_decl_name(decl);
expr C;
C_args_buffer.push_back(buffer<expr>());
buffer<expr> & C_args = C_args_buffer.back();
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
buffer<expr> indices, rest; expr arg;
distribute_context(*p_idx, indices, arg, rest);
expr type = Pi(rest, prgs[*p_idx].m_type);
C_args.append(indices);
C_args.push_back(arg);
C = Fun(C_args, type);
} else {
expr type = whnf(infer_type(brec_on));
expr d = binding_domain(type);
expr unit = mk_constant("unit", rlvl);
to_telescope_ext(d, C_args);
C = Fun(C_args, unit);
}
brec_on = mk_app(brec_on, C);
Cs.push_back(C);
}
// add functionals
unsigned i = 0;
for (inductive::inductive_decl const & decl : decls) {
name const & I_name = inductive::inductive_decl_name(decl);
expr below = mk_constant(name{I_name, "below"}, brec_on_lvls);
below = mk_app(mk_app(below, params), Cs);
expr F;
buffer<expr> & C_args = C_args_buffer[i];
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
program & prg_i = prgs[*p_idx];
buffer<expr> indices, rest; expr arg;
distribute_context(*p_idx, indices, arg, rest);
below = mk_app(mk_app(below, indices), arg);
expr b = mk_local(mk_fresh_name(), "b", below, binder_info());
buffer<expr> new_ctx;
new_ctx.append(indices);
new_ctx.push_back(arg);
new_ctx.push_back(b);
new_ctx.append(rest);
prg_i.m_context = to_list(new_ctx);
// TODO(Leo): replace recursive calls with "b" applications
F = compile_pat_match(prg_i);
} else {
expr star = mk_constant(name{"unit", "star"}, rlvl);
buffer<expr> F_args;
F_args.append(C_args);
below = mk_app(below, F_args);
F_args.push_back(mk_local(mk_fresh_name(), "b", below, binder_info()));
F = Fun(F_args, star);
}
brec_on = mk_app(brec_on, F);
i++;
}
// out() << "brec_on: " << brec_on << "\n";
return brec_on;
}
void apply_wf(buffer<program> & /* prgs */) {
expr compile_wf(buffer<program> & /* prgs */) {
// TODO(Leo)
return expr();
}
public:
equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax):
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) {
@ -1041,21 +1174,14 @@ public:
initialize(eqns, prgs);
if (is_recursive(prgs)) {
if (is_wf_equations(eqns)) {
apply_wf(prgs);
return compile_wf(prgs);
} else {
apply_brec_on(prgs);
return compile_brec_on(prgs);
}
} else {
lean_assert(prgs.size() == 1);
return compile_pat_match(prgs[0]);
}
buffer<expr> rs;
for (program const & p : prgs) {
expr r = compile(p);
// out() << r << "\nType: " << m_tc.infer(r).first << "\n";
rs.push_back(r);
}
if (closed(rs[0]))
return rs[0];
else
return m_meta;
}
};