feat(library/definitional/equations): add brec_on based compilation
This commit is contained in:
parent
42354cd4fa
commit
bdfa919098
1 changed files with 174 additions and 48 deletions
|
@ -19,6 +19,7 @@ Author: Leonardo de Moura
|
|||
#include "library/io_state_stream.h"
|
||||
#include "library/annotation.h"
|
||||
#include "library/util.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/tactic/inversion_tactic.h"
|
||||
|
||||
namespace lean {
|
||||
|
@ -338,23 +339,16 @@ class equation_compiler_fn {
|
|||
}
|
||||
};
|
||||
|
||||
// For debugging purposes
|
||||
template<typename T>
|
||||
static bool contains_local(T const & locals, expr const & l) {
|
||||
return std::any_of(locals.begin(), locals.end(),
|
||||
[&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); });
|
||||
}
|
||||
|
||||
#ifdef LEAN_DEBUG
|
||||
// For debugging purposes: checks whether all local constants occurring in \c e
|
||||
// are in local_ctx or m_global_context
|
||||
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
|
||||
for_each(e, [&](expr const & e, unsigned) {
|
||||
if (is_local(e) &&
|
||||
!(contains_local(local_context, e) ||
|
||||
contains_local(context, e) ||
|
||||
contains_local(m_global_context, e) ||
|
||||
contains_local(m_fns, e))) {
|
||||
!(contains_local(e, local_context) ||
|
||||
contains_local(e, context) ||
|
||||
contains_local(e, m_global_context) ||
|
||||
contains_local(e, m_fns))) {
|
||||
lean_unreachable();
|
||||
return false;
|
||||
}
|
||||
|
@ -617,10 +611,6 @@ class equation_compiler_fn {
|
|||
return r;
|
||||
}
|
||||
|
||||
static bool contains(list<expr> const & local_ctx, expr const & p) {
|
||||
return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); });
|
||||
}
|
||||
|
||||
expr compile_skip(program const & prg) {
|
||||
lean_assert(!head(prg.m_var_stack));
|
||||
auto new_stack = tail(prg.m_var_stack);
|
||||
|
@ -646,7 +636,7 @@ class equation_compiler_fn {
|
|||
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
|
||||
} else {
|
||||
lean_assert(is_local(p));
|
||||
if (contains(e.m_local_context, p)) {
|
||||
if (contains_local(p, e.m_local_context)) {
|
||||
list<expr> new_local_ctx = e.m_local_context;
|
||||
new_local_ctx = remove(new_local_ctx, p);
|
||||
new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); });
|
||||
|
@ -796,7 +786,7 @@ class equation_compiler_fn {
|
|||
}
|
||||
}
|
||||
|
||||
expr compile(program const & p) {
|
||||
expr compile_pat_match(program const & p) {
|
||||
buffer<expr> vars;
|
||||
to_buffer(p.m_context, vars);
|
||||
if (!is_proof_irrelevant()) {
|
||||
|
@ -812,11 +802,7 @@ class equation_compiler_fn {
|
|||
|
||||
/** \brief Return true iff \c e is one of the functions being defined */
|
||||
bool is_fn(expr const & e) const {
|
||||
return
|
||||
is_local(e) &&
|
||||
std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) {
|
||||
return mlocal_name(fn) == mlocal_name(e);
|
||||
});
|
||||
return is_local(e) && contains_local(e, m_fns);
|
||||
}
|
||||
|
||||
/** \brief Return true iff the equations are recursive. */
|
||||
|
@ -833,18 +819,39 @@ class equation_compiler_fn {
|
|||
return false;
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/
|
||||
bool is_inductive(expr const & t) const {
|
||||
expr const & fn = get_app_fn(t);
|
||||
/** \brief Return true if all locals are distinct local constants. */
|
||||
static bool all_distinct_locals(unsigned num, expr const * locals) {
|
||||
for (unsigned i = 0; i < num; i++) {
|
||||
if (!is_local(locals[i]))
|
||||
return false;
|
||||
if (contains_local(locals[i], locals, locals + i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition,
|
||||
and all indices of t are in ctx. */
|
||||
bool is_rec_inductive(list<expr> const & ctx, expr const & t) const {
|
||||
expr const & I = get_app_fn(t);
|
||||
if (is_constant(I) && env().find(name{const_name(I), "brec_on"})) {
|
||||
unsigned nindices = *inductive::get_num_indices(env(), const_name(I));
|
||||
if (nindices > 0) {
|
||||
buffer<expr> args;
|
||||
get_app_args(I, args);
|
||||
return
|
||||
is_constant(fn) &&
|
||||
env().find(name{const_name(fn), "brec_on"});
|
||||
all_distinct_locals(nindices, args.end() - nindices) &&
|
||||
std::all_of(args.end() - nindices, args.end(), [&](expr const & idx) { return contains_local(idx, ctx); });
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
|
||||
bool is_compatible_inductive(expr const & t1, expr const & t2) {
|
||||
lean_assert(is_inductive(t1));
|
||||
lean_assert(is_inductive(t2));
|
||||
buffer<expr> args1, args2;
|
||||
name const & I1 = const_name(get_app_args(t1, args1));
|
||||
name const & I2 = const_name(get_app_args(t2, args2));
|
||||
|
@ -986,7 +993,7 @@ class equation_compiler_fn {
|
|||
expr const & v = p.get_var(*n);
|
||||
expr const & t = mlocal_type(v);
|
||||
if (// argument must be an inductive datatype
|
||||
is_inductive(t) &&
|
||||
is_rec_inductive(p.m_context, t) &&
|
||||
// argument must be an inductive datatype different from the ones in arg_types
|
||||
std::all_of(arg_types.begin(), arg_types.end(),
|
||||
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
|
||||
|
@ -1011,23 +1018,149 @@ class equation_compiler_fn {
|
|||
return find_rec_args(prgs, 0, arg_pos, arg_types);
|
||||
}
|
||||
|
||||
void apply_brec_on(buffer<program> & prgs) {
|
||||
expr compile_brec_on(buffer<program> & prgs) {
|
||||
lean_assert(!prgs.empty());
|
||||
buffer<unsigned> arg_pos;
|
||||
if (!find_rec_args(prgs, arg_pos)) {
|
||||
throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller "
|
||||
<< "(possible solution: use well-founded recursion)");
|
||||
}
|
||||
// out() << "Found recursive arguments: ";
|
||||
// for (unsigned p : arg_pos) out() << " " << p; out() << "\n";
|
||||
|
||||
// TODO(Leo):
|
||||
// Return the recursive argument of the i-th program
|
||||
auto get_rec_arg = [&](unsigned i) -> expr {
|
||||
program const & pi = prgs[i];
|
||||
return get_ith(pi.m_context, arg_pos[i]);
|
||||
};
|
||||
|
||||
// Return the type of the recursive argument of the i-th program
|
||||
auto get_rec_type = [&](unsigned i) -> expr {
|
||||
return mlocal_type(get_rec_arg(i));
|
||||
};
|
||||
|
||||
// Return the program associated with the inductive datatype named I_name.
|
||||
// Return none if there isn't one.
|
||||
auto get_prg_for = [&](name const & I_name) -> optional<unsigned> {
|
||||
for (unsigned i = 0; i < prgs.size(); i++) {
|
||||
expr const & t = get_rec_type(i);
|
||||
if (const_name(get_app_fn(t)) == I_name)
|
||||
return optional<unsigned>(i);
|
||||
}
|
||||
return optional<unsigned>();
|
||||
};
|
||||
|
||||
expr const & a0_type = get_rec_type(0);
|
||||
lean_assert(is_rec_inductive(prgs[0].m_context, a0_type));
|
||||
buffer<expr> a0_type_args;
|
||||
expr const & I0 = get_app_args(a0_type, a0_type_args);
|
||||
inductive::inductive_decls t = *inductive::is_inductive_decl(env(), const_name(I0));
|
||||
unsigned nparams = std::get<1>(t);
|
||||
list<inductive::inductive_decl> decls = std::get<2>(t);
|
||||
|
||||
// Distribute parameters of the ith program intro three groups:
|
||||
// indices, major premise (arg), and remaining arguments (rest)
|
||||
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
|
||||
program const & p = prgs[i];
|
||||
arg = get_rec_arg(i);
|
||||
list<expr> const & ctx = p.m_context;
|
||||
buffer<expr> arg_args;
|
||||
get_app_args(mlocal_type(arg), arg_args);
|
||||
lean_assert(nparams <= arg_args.size());
|
||||
indices.append(arg_args.size() - nparams, arg_args.data() + nparams);
|
||||
for (expr const & l : ctx) {
|
||||
if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices))
|
||||
rest.push_back(l);
|
||||
}
|
||||
};
|
||||
|
||||
// Compute the resulting universe level for brec_on
|
||||
auto get_brec_on_result_level = [&]() -> level {
|
||||
buffer<expr> indices, rest;
|
||||
expr arg;
|
||||
distribute_context(0, indices, arg, rest);
|
||||
expr r_type = Pi(indices, prgs[0].m_type);
|
||||
return sort_level(m_tc.ensure_type(r_type).first);
|
||||
};
|
||||
|
||||
level rlvl = get_brec_on_result_level();
|
||||
levels brec_on_lvls = cons(rlvl, const_levels(I0));
|
||||
expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
|
||||
buffer<expr> params;
|
||||
// add parameters
|
||||
for (unsigned i = 0; i < nparams; i++) {
|
||||
params.push_back(a0_type_args[i]);
|
||||
brec_on = mk_app(brec_on, a0_type_args[i]);
|
||||
}
|
||||
|
||||
void apply_wf(buffer<program> & /* prgs */) {
|
||||
buffer<expr> Cs; // brec_on "motives"
|
||||
buffer<buffer<expr>> C_args_buffer;
|
||||
for (inductive::inductive_decl const & decl : decls) {
|
||||
name const & I_name = inductive::inductive_decl_name(decl);
|
||||
expr C;
|
||||
C_args_buffer.push_back(buffer<expr>());
|
||||
buffer<expr> & C_args = C_args_buffer.back();
|
||||
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
||||
buffer<expr> indices, rest; expr arg;
|
||||
distribute_context(*p_idx, indices, arg, rest);
|
||||
expr type = Pi(rest, prgs[*p_idx].m_type);
|
||||
C_args.append(indices);
|
||||
C_args.push_back(arg);
|
||||
C = Fun(C_args, type);
|
||||
} else {
|
||||
expr type = whnf(infer_type(brec_on));
|
||||
expr d = binding_domain(type);
|
||||
expr unit = mk_constant("unit", rlvl);
|
||||
to_telescope_ext(d, C_args);
|
||||
C = Fun(C_args, unit);
|
||||
}
|
||||
brec_on = mk_app(brec_on, C);
|
||||
Cs.push_back(C);
|
||||
}
|
||||
|
||||
// add functionals
|
||||
unsigned i = 0;
|
||||
for (inductive::inductive_decl const & decl : decls) {
|
||||
name const & I_name = inductive::inductive_decl_name(decl);
|
||||
expr below = mk_constant(name{I_name, "below"}, brec_on_lvls);
|
||||
below = mk_app(mk_app(below, params), Cs);
|
||||
expr F;
|
||||
buffer<expr> & C_args = C_args_buffer[i];
|
||||
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
||||
program & prg_i = prgs[*p_idx];
|
||||
buffer<expr> indices, rest; expr arg;
|
||||
distribute_context(*p_idx, indices, arg, rest);
|
||||
below = mk_app(mk_app(below, indices), arg);
|
||||
expr b = mk_local(mk_fresh_name(), "b", below, binder_info());
|
||||
buffer<expr> new_ctx;
|
||||
new_ctx.append(indices);
|
||||
new_ctx.push_back(arg);
|
||||
new_ctx.push_back(b);
|
||||
new_ctx.append(rest);
|
||||
prg_i.m_context = to_list(new_ctx);
|
||||
// TODO(Leo): replace recursive calls with "b" applications
|
||||
F = compile_pat_match(prg_i);
|
||||
} else {
|
||||
expr star = mk_constant(name{"unit", "star"}, rlvl);
|
||||
buffer<expr> F_args;
|
||||
F_args.append(C_args);
|
||||
below = mk_app(below, F_args);
|
||||
F_args.push_back(mk_local(mk_fresh_name(), "b", below, binder_info()));
|
||||
F = Fun(F_args, star);
|
||||
}
|
||||
brec_on = mk_app(brec_on, F);
|
||||
i++;
|
||||
}
|
||||
|
||||
// out() << "brec_on: " << brec_on << "\n";
|
||||
|
||||
return brec_on;
|
||||
}
|
||||
|
||||
expr compile_wf(buffer<program> & /* prgs */) {
|
||||
// TODO(Leo)
|
||||
return expr();
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax):
|
||||
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) {
|
||||
|
@ -1041,21 +1174,14 @@ public:
|
|||
initialize(eqns, prgs);
|
||||
if (is_recursive(prgs)) {
|
||||
if (is_wf_equations(eqns)) {
|
||||
apply_wf(prgs);
|
||||
return compile_wf(prgs);
|
||||
} else {
|
||||
apply_brec_on(prgs);
|
||||
return compile_brec_on(prgs);
|
||||
}
|
||||
} else {
|
||||
lean_assert(prgs.size() == 1);
|
||||
return compile_pat_match(prgs[0]);
|
||||
}
|
||||
buffer<expr> rs;
|
||||
for (program const & p : prgs) {
|
||||
expr r = compile(p);
|
||||
// out() << r << "\nType: " << m_tc.infer(r).first << "\n";
|
||||
rs.push_back(r);
|
||||
}
|
||||
if (closed(rs[0]))
|
||||
return rs[0];
|
||||
else
|
||||
return m_meta;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in a new issue