feat(library/definitional/equations): add brec_on based compilation
This commit is contained in:
parent
42354cd4fa
commit
bdfa919098
1 changed files with 174 additions and 48 deletions
|
@ -19,6 +19,7 @@ Author: Leonardo de Moura
|
||||||
#include "library/io_state_stream.h"
|
#include "library/io_state_stream.h"
|
||||||
#include "library/annotation.h"
|
#include "library/annotation.h"
|
||||||
#include "library/util.h"
|
#include "library/util.h"
|
||||||
|
#include "library/locals.h"
|
||||||
#include "library/tactic/inversion_tactic.h"
|
#include "library/tactic/inversion_tactic.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
|
@ -338,23 +339,16 @@ class equation_compiler_fn {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// For debugging purposes
|
|
||||||
template<typename T>
|
|
||||||
static bool contains_local(T const & locals, expr const & l) {
|
|
||||||
return std::any_of(locals.begin(), locals.end(),
|
|
||||||
[&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); });
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef LEAN_DEBUG
|
#ifdef LEAN_DEBUG
|
||||||
// For debugging purposes: checks whether all local constants occurring in \c e
|
// For debugging purposes: checks whether all local constants occurring in \c e
|
||||||
// are in local_ctx or m_global_context
|
// are in local_ctx or m_global_context
|
||||||
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
|
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
|
||||||
for_each(e, [&](expr const & e, unsigned) {
|
for_each(e, [&](expr const & e, unsigned) {
|
||||||
if (is_local(e) &&
|
if (is_local(e) &&
|
||||||
!(contains_local(local_context, e) ||
|
!(contains_local(e, local_context) ||
|
||||||
contains_local(context, e) ||
|
contains_local(e, context) ||
|
||||||
contains_local(m_global_context, e) ||
|
contains_local(e, m_global_context) ||
|
||||||
contains_local(m_fns, e))) {
|
contains_local(e, m_fns))) {
|
||||||
lean_unreachable();
|
lean_unreachable();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -617,10 +611,6 @@ class equation_compiler_fn {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool contains(list<expr> const & local_ctx, expr const & p) {
|
|
||||||
return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); });
|
|
||||||
}
|
|
||||||
|
|
||||||
expr compile_skip(program const & prg) {
|
expr compile_skip(program const & prg) {
|
||||||
lean_assert(!head(prg.m_var_stack));
|
lean_assert(!head(prg.m_var_stack));
|
||||||
auto new_stack = tail(prg.m_var_stack);
|
auto new_stack = tail(prg.m_var_stack);
|
||||||
|
@ -646,7 +636,7 @@ class equation_compiler_fn {
|
||||||
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
|
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
|
||||||
} else {
|
} else {
|
||||||
lean_assert(is_local(p));
|
lean_assert(is_local(p));
|
||||||
if (contains(e.m_local_context, p)) {
|
if (contains_local(p, e.m_local_context)) {
|
||||||
list<expr> new_local_ctx = e.m_local_context;
|
list<expr> new_local_ctx = e.m_local_context;
|
||||||
new_local_ctx = remove(new_local_ctx, p);
|
new_local_ctx = remove(new_local_ctx, p);
|
||||||
new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); });
|
new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); });
|
||||||
|
@ -796,7 +786,7 @@ class equation_compiler_fn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
expr compile(program const & p) {
|
expr compile_pat_match(program const & p) {
|
||||||
buffer<expr> vars;
|
buffer<expr> vars;
|
||||||
to_buffer(p.m_context, vars);
|
to_buffer(p.m_context, vars);
|
||||||
if (!is_proof_irrelevant()) {
|
if (!is_proof_irrelevant()) {
|
||||||
|
@ -812,11 +802,7 @@ class equation_compiler_fn {
|
||||||
|
|
||||||
/** \brief Return true iff \c e is one of the functions being defined */
|
/** \brief Return true iff \c e is one of the functions being defined */
|
||||||
bool is_fn(expr const & e) const {
|
bool is_fn(expr const & e) const {
|
||||||
return
|
return is_local(e) && contains_local(e, m_fns);
|
||||||
is_local(e) &&
|
|
||||||
std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) {
|
|
||||||
return mlocal_name(fn) == mlocal_name(e);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \brief Return true iff the equations are recursive. */
|
/** \brief Return true iff the equations are recursive. */
|
||||||
|
@ -833,18 +819,39 @@ class equation_compiler_fn {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/
|
/** \brief Return true if all locals are distinct local constants. */
|
||||||
bool is_inductive(expr const & t) const {
|
static bool all_distinct_locals(unsigned num, expr const * locals) {
|
||||||
expr const & fn = get_app_fn(t);
|
for (unsigned i = 0; i < num; i++) {
|
||||||
|
if (!is_local(locals[i]))
|
||||||
|
return false;
|
||||||
|
if (contains_local(locals[i], locals, locals + i))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition,
|
||||||
|
and all indices of t are in ctx. */
|
||||||
|
bool is_rec_inductive(list<expr> const & ctx, expr const & t) const {
|
||||||
|
expr const & I = get_app_fn(t);
|
||||||
|
if (is_constant(I) && env().find(name{const_name(I), "brec_on"})) {
|
||||||
|
unsigned nindices = *inductive::get_num_indices(env(), const_name(I));
|
||||||
|
if (nindices > 0) {
|
||||||
|
buffer<expr> args;
|
||||||
|
get_app_args(I, args);
|
||||||
return
|
return
|
||||||
is_constant(fn) &&
|
all_distinct_locals(nindices, args.end() - nindices) &&
|
||||||
env().find(name{const_name(fn), "brec_on"});
|
std::all_of(args.end() - nindices, args.end(), [&](expr const & idx) { return contains_local(idx, ctx); });
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
|
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
|
||||||
bool is_compatible_inductive(expr const & t1, expr const & t2) {
|
bool is_compatible_inductive(expr const & t1, expr const & t2) {
|
||||||
lean_assert(is_inductive(t1));
|
|
||||||
lean_assert(is_inductive(t2));
|
|
||||||
buffer<expr> args1, args2;
|
buffer<expr> args1, args2;
|
||||||
name const & I1 = const_name(get_app_args(t1, args1));
|
name const & I1 = const_name(get_app_args(t1, args1));
|
||||||
name const & I2 = const_name(get_app_args(t2, args2));
|
name const & I2 = const_name(get_app_args(t2, args2));
|
||||||
|
@ -986,7 +993,7 @@ class equation_compiler_fn {
|
||||||
expr const & v = p.get_var(*n);
|
expr const & v = p.get_var(*n);
|
||||||
expr const & t = mlocal_type(v);
|
expr const & t = mlocal_type(v);
|
||||||
if (// argument must be an inductive datatype
|
if (// argument must be an inductive datatype
|
||||||
is_inductive(t) &&
|
is_rec_inductive(p.m_context, t) &&
|
||||||
// argument must be an inductive datatype different from the ones in arg_types
|
// argument must be an inductive datatype different from the ones in arg_types
|
||||||
std::all_of(arg_types.begin(), arg_types.end(),
|
std::all_of(arg_types.begin(), arg_types.end(),
|
||||||
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
|
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
|
||||||
|
@ -1011,23 +1018,149 @@ class equation_compiler_fn {
|
||||||
return find_rec_args(prgs, 0, arg_pos, arg_types);
|
return find_rec_args(prgs, 0, arg_pos, arg_types);
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply_brec_on(buffer<program> & prgs) {
|
expr compile_brec_on(buffer<program> & prgs) {
|
||||||
lean_assert(!prgs.empty());
|
lean_assert(!prgs.empty());
|
||||||
buffer<unsigned> arg_pos;
|
buffer<unsigned> arg_pos;
|
||||||
if (!find_rec_args(prgs, arg_pos)) {
|
if (!find_rec_args(prgs, arg_pos)) {
|
||||||
throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller "
|
throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller "
|
||||||
<< "(possible solution: use well-founded recursion)");
|
<< "(possible solution: use well-founded recursion)");
|
||||||
}
|
}
|
||||||
// out() << "Found recursive arguments: ";
|
|
||||||
// for (unsigned p : arg_pos) out() << " " << p; out() << "\n";
|
|
||||||
|
|
||||||
// TODO(Leo):
|
// Return the recursive argument of the i-th program
|
||||||
|
auto get_rec_arg = [&](unsigned i) -> expr {
|
||||||
|
program const & pi = prgs[i];
|
||||||
|
return get_ith(pi.m_context, arg_pos[i]);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return the type of the recursive argument of the i-th program
|
||||||
|
auto get_rec_type = [&](unsigned i) -> expr {
|
||||||
|
return mlocal_type(get_rec_arg(i));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return the program associated with the inductive datatype named I_name.
|
||||||
|
// Return none if there isn't one.
|
||||||
|
auto get_prg_for = [&](name const & I_name) -> optional<unsigned> {
|
||||||
|
for (unsigned i = 0; i < prgs.size(); i++) {
|
||||||
|
expr const & t = get_rec_type(i);
|
||||||
|
if (const_name(get_app_fn(t)) == I_name)
|
||||||
|
return optional<unsigned>(i);
|
||||||
|
}
|
||||||
|
return optional<unsigned>();
|
||||||
|
};
|
||||||
|
|
||||||
|
expr const & a0_type = get_rec_type(0);
|
||||||
|
lean_assert(is_rec_inductive(prgs[0].m_context, a0_type));
|
||||||
|
buffer<expr> a0_type_args;
|
||||||
|
expr const & I0 = get_app_args(a0_type, a0_type_args);
|
||||||
|
inductive::inductive_decls t = *inductive::is_inductive_decl(env(), const_name(I0));
|
||||||
|
unsigned nparams = std::get<1>(t);
|
||||||
|
list<inductive::inductive_decl> decls = std::get<2>(t);
|
||||||
|
|
||||||
|
// Distribute parameters of the ith program intro three groups:
|
||||||
|
// indices, major premise (arg), and remaining arguments (rest)
|
||||||
|
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
|
||||||
|
program const & p = prgs[i];
|
||||||
|
arg = get_rec_arg(i);
|
||||||
|
list<expr> const & ctx = p.m_context;
|
||||||
|
buffer<expr> arg_args;
|
||||||
|
get_app_args(mlocal_type(arg), arg_args);
|
||||||
|
lean_assert(nparams <= arg_args.size());
|
||||||
|
indices.append(arg_args.size() - nparams, arg_args.data() + nparams);
|
||||||
|
for (expr const & l : ctx) {
|
||||||
|
if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices))
|
||||||
|
rest.push_back(l);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute the resulting universe level for brec_on
|
||||||
|
auto get_brec_on_result_level = [&]() -> level {
|
||||||
|
buffer<expr> indices, rest;
|
||||||
|
expr arg;
|
||||||
|
distribute_context(0, indices, arg, rest);
|
||||||
|
expr r_type = Pi(indices, prgs[0].m_type);
|
||||||
|
return sort_level(m_tc.ensure_type(r_type).first);
|
||||||
|
};
|
||||||
|
|
||||||
|
level rlvl = get_brec_on_result_level();
|
||||||
|
levels brec_on_lvls = cons(rlvl, const_levels(I0));
|
||||||
|
expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
|
||||||
|
buffer<expr> params;
|
||||||
|
// add parameters
|
||||||
|
for (unsigned i = 0; i < nparams; i++) {
|
||||||
|
params.push_back(a0_type_args[i]);
|
||||||
|
brec_on = mk_app(brec_on, a0_type_args[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply_wf(buffer<program> & /* prgs */) {
|
buffer<expr> Cs; // brec_on "motives"
|
||||||
|
buffer<buffer<expr>> C_args_buffer;
|
||||||
|
for (inductive::inductive_decl const & decl : decls) {
|
||||||
|
name const & I_name = inductive::inductive_decl_name(decl);
|
||||||
|
expr C;
|
||||||
|
C_args_buffer.push_back(buffer<expr>());
|
||||||
|
buffer<expr> & C_args = C_args_buffer.back();
|
||||||
|
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
||||||
|
buffer<expr> indices, rest; expr arg;
|
||||||
|
distribute_context(*p_idx, indices, arg, rest);
|
||||||
|
expr type = Pi(rest, prgs[*p_idx].m_type);
|
||||||
|
C_args.append(indices);
|
||||||
|
C_args.push_back(arg);
|
||||||
|
C = Fun(C_args, type);
|
||||||
|
} else {
|
||||||
|
expr type = whnf(infer_type(brec_on));
|
||||||
|
expr d = binding_domain(type);
|
||||||
|
expr unit = mk_constant("unit", rlvl);
|
||||||
|
to_telescope_ext(d, C_args);
|
||||||
|
C = Fun(C_args, unit);
|
||||||
|
}
|
||||||
|
brec_on = mk_app(brec_on, C);
|
||||||
|
Cs.push_back(C);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add functionals
|
||||||
|
unsigned i = 0;
|
||||||
|
for (inductive::inductive_decl const & decl : decls) {
|
||||||
|
name const & I_name = inductive::inductive_decl_name(decl);
|
||||||
|
expr below = mk_constant(name{I_name, "below"}, brec_on_lvls);
|
||||||
|
below = mk_app(mk_app(below, params), Cs);
|
||||||
|
expr F;
|
||||||
|
buffer<expr> & C_args = C_args_buffer[i];
|
||||||
|
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
||||||
|
program & prg_i = prgs[*p_idx];
|
||||||
|
buffer<expr> indices, rest; expr arg;
|
||||||
|
distribute_context(*p_idx, indices, arg, rest);
|
||||||
|
below = mk_app(mk_app(below, indices), arg);
|
||||||
|
expr b = mk_local(mk_fresh_name(), "b", below, binder_info());
|
||||||
|
buffer<expr> new_ctx;
|
||||||
|
new_ctx.append(indices);
|
||||||
|
new_ctx.push_back(arg);
|
||||||
|
new_ctx.push_back(b);
|
||||||
|
new_ctx.append(rest);
|
||||||
|
prg_i.m_context = to_list(new_ctx);
|
||||||
|
// TODO(Leo): replace recursive calls with "b" applications
|
||||||
|
F = compile_pat_match(prg_i);
|
||||||
|
} else {
|
||||||
|
expr star = mk_constant(name{"unit", "star"}, rlvl);
|
||||||
|
buffer<expr> F_args;
|
||||||
|
F_args.append(C_args);
|
||||||
|
below = mk_app(below, F_args);
|
||||||
|
F_args.push_back(mk_local(mk_fresh_name(), "b", below, binder_info()));
|
||||||
|
F = Fun(F_args, star);
|
||||||
|
}
|
||||||
|
brec_on = mk_app(brec_on, F);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// out() << "brec_on: " << brec_on << "\n";
|
||||||
|
|
||||||
|
return brec_on;
|
||||||
|
}
|
||||||
|
|
||||||
|
expr compile_wf(buffer<program> & /* prgs */) {
|
||||||
// TODO(Leo)
|
// TODO(Leo)
|
||||||
|
return expr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax):
|
equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax):
|
||||||
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) {
|
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) {
|
||||||
|
@ -1041,21 +1174,14 @@ public:
|
||||||
initialize(eqns, prgs);
|
initialize(eqns, prgs);
|
||||||
if (is_recursive(prgs)) {
|
if (is_recursive(prgs)) {
|
||||||
if (is_wf_equations(eqns)) {
|
if (is_wf_equations(eqns)) {
|
||||||
apply_wf(prgs);
|
return compile_wf(prgs);
|
||||||
} else {
|
} else {
|
||||||
apply_brec_on(prgs);
|
return compile_brec_on(prgs);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
lean_assert(prgs.size() == 1);
|
||||||
|
return compile_pat_match(prgs[0]);
|
||||||
}
|
}
|
||||||
buffer<expr> rs;
|
|
||||||
for (program const & p : prgs) {
|
|
||||||
expr r = compile(p);
|
|
||||||
// out() << r << "\nType: " << m_tc.infer(r).first << "\n";
|
|
||||||
rs.push_back(r);
|
|
||||||
}
|
|
||||||
if (closed(rs[0]))
|
|
||||||
return rs[0];
|
|
||||||
else
|
|
||||||
return m_meta;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue