feat(library/definitional/equations): add 'equations_result' macro used to wrap multiple functions being defined by recursive equations
This is only useful when compiling "mutually recursive" functions.
This commit is contained in:
parent
3325d791de
commit
322cdb8a98
2 changed files with 97 additions and 31 deletions
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include "util/sstream.h"
|
||||
#include "util/list_fn.h"
|
||||
|
@ -24,13 +25,15 @@ Author: Leonardo de Moura
|
|||
#include "library/tactic/inversion_tactic.h"
|
||||
|
||||
namespace lean {
|
||||
static name * g_equations_name = nullptr;
|
||||
static name * g_equation_name = nullptr;
|
||||
static name * g_decreasing_name = nullptr;
|
||||
static name * g_inaccessible_name = nullptr;
|
||||
static std::string * g_equations_opcode = nullptr;
|
||||
static std::string * g_equation_opcode = nullptr;
|
||||
static std::string * g_decreasing_opcode = nullptr;
|
||||
static name * g_equations_name = nullptr;
|
||||
static name * g_equation_name = nullptr;
|
||||
static name * g_decreasing_name = nullptr;
|
||||
static name * g_inaccessible_name = nullptr;
|
||||
static name * g_equations_result_name = nullptr;
|
||||
static std::string * g_equations_opcode = nullptr;
|
||||
static std::string * g_equation_opcode = nullptr;
|
||||
static std::string * g_decreasing_opcode = nullptr;
|
||||
static std::string * g_equations_result_opcode = nullptr;
|
||||
|
||||
[[ noreturn ]] static void throw_eqs_ex() { throw exception("unexpected occurrence of 'equations' expression"); }
|
||||
|
||||
|
@ -169,16 +172,43 @@ expr update_equations(expr const & eqns, buffer<expr> const & new_eqs) {
|
|||
expr mk_inaccessible(expr const & e) { return mk_annotation(*g_inaccessible_name, e); }
|
||||
bool is_inaccessible(expr const & e) { return is_annotation(e, *g_inaccessible_name); }
|
||||
|
||||
// Auxiliary macro used to store the result of a set of equations defining a mutually recursive
|
||||
// definition.
|
||||
class equations_result_macro_cell : public macro_definition_cell {
|
||||
public:
|
||||
virtual name get_name() const { return *g_equations_result_name; }
|
||||
virtual pair<expr, constraint_seq> get_type(expr const & m, extension_context & ctx) const {
|
||||
return ctx.infer_type(macro_arg(m, 0));
|
||||
}
|
||||
virtual optional<expr> expand(expr const & m, extension_context &) const {
|
||||
return some_expr(macro_arg(m, 0));
|
||||
}
|
||||
virtual void write(serializer & s) const { s << *g_equations_result_opcode; }
|
||||
};
|
||||
|
||||
static macro_definition * g_equations_result = nullptr;
|
||||
|
||||
static expr mk_equations_result(unsigned n, expr const * rs) {
|
||||
return mk_macro(*g_equations_result, n, rs);
|
||||
}
|
||||
|
||||
bool is_equations_result(expr const & e) { return is_macro(e) && macro_def(e) == *g_equations_result; }
|
||||
unsigned get_equations_result_size(expr const & e) { return macro_num_args(e); }
|
||||
expr const & get_equations_result(expr const & e, unsigned i) { return macro_arg(e, i); }
|
||||
|
||||
void initialize_equations() {
|
||||
g_equations_name = new name("equations");
|
||||
g_equation_name = new name("equation");
|
||||
g_decreasing_name = new name("decreasing");
|
||||
g_inaccessible_name = new name("innaccessible");
|
||||
g_equation = new macro_definition(new equation_macro_cell());
|
||||
g_decreasing = new macro_definition(new decreasing_macro_cell());
|
||||
g_equations_opcode = new std::string("Eqns");
|
||||
g_equation_opcode = new std::string("Eqn");
|
||||
g_decreasing_opcode = new std::string("Decr");
|
||||
g_equations_name = new name("equations");
|
||||
g_equation_name = new name("equation");
|
||||
g_decreasing_name = new name("decreasing");
|
||||
g_inaccessible_name = new name("innaccessible");
|
||||
g_equations_result_name = new name("equations_result");
|
||||
g_equation = new macro_definition(new equation_macro_cell());
|
||||
g_decreasing = new macro_definition(new decreasing_macro_cell());
|
||||
g_equations_result = new macro_definition(new equations_result_macro_cell());
|
||||
g_equations_opcode = new std::string("Eqns");
|
||||
g_equation_opcode = new std::string("Eqn");
|
||||
g_decreasing_opcode = new std::string("Decr");
|
||||
g_equations_result_opcode = new std::string("EqnR");
|
||||
register_annotation(*g_inaccessible_name);
|
||||
register_macro_deserializer(*g_equations_opcode,
|
||||
[](deserializer & d, unsigned num, expr const * args) {
|
||||
|
@ -206,14 +236,21 @@ void initialize_equations() {
|
|||
throw corrupted_stream_exception();
|
||||
return mk_decreasing(args[0], args[1]);
|
||||
});
|
||||
register_macro_deserializer(*g_equations_result_opcode,
|
||||
[](deserializer &, unsigned num, expr const * args) {
|
||||
return mk_equations_result(num, args);
|
||||
});
|
||||
}
|
||||
|
||||
void finalize_equations() {
|
||||
delete g_equations_result_opcode;
|
||||
delete g_equation_opcode;
|
||||
delete g_equations_opcode;
|
||||
delete g_decreasing_opcode;
|
||||
delete g_equations_result;
|
||||
delete g_equation;
|
||||
delete g_decreasing;
|
||||
delete g_equations_result_name;
|
||||
delete g_equations_name;
|
||||
delete g_equation_name;
|
||||
delete g_decreasing_name;
|
||||
|
@ -1303,21 +1340,8 @@ class equation_compiler_fn {
|
|||
lean_assert(check_program(prgs[0]));
|
||||
}
|
||||
|
||||
expr compile_brec_on(buffer<program> & prgs) {
|
||||
lean_assert(!prgs.empty());
|
||||
buffer<unsigned> rec_arg_pos;
|
||||
if (!find_rec_args(prgs, rec_arg_pos)) {
|
||||
throw_error(sstream() << "invalid recursive equations, "
|
||||
<< "failed to find recursive arguments that are structurally smaller "
|
||||
<< "(possible solution: use well-founded recursion)");
|
||||
}
|
||||
// Remark: move_params updates argument positions.
|
||||
// Thus, we copy rec_arg_pos to arg_pos.
|
||||
// We use rec_arg_pos when invoking elim_rec_apps_fn
|
||||
buffer<unsigned> arg_pos;
|
||||
arg_pos.append(rec_arg_pos);
|
||||
move_params(prgs, arg_pos);
|
||||
|
||||
expr compile_brec_on_core(buffer<program> const & prgs,
|
||||
buffer<unsigned> const & arg_pos, buffer<unsigned> const & rec_arg_pos) {
|
||||
// Return the recursive argument of the i-th program
|
||||
auto get_rec_arg = [&](unsigned i) -> expr {
|
||||
program const & pi = prgs[i];
|
||||
|
@ -1513,6 +1537,40 @@ class equation_compiler_fn {
|
|||
return r;
|
||||
}
|
||||
|
||||
expr compile_brec_on(buffer<program> & prgs) {
|
||||
lean_assert(!prgs.empty());
|
||||
buffer<unsigned> rec_arg_pos;
|
||||
if (!find_rec_args(prgs, rec_arg_pos)) {
|
||||
throw_error(sstream() << "invalid recursive equations, "
|
||||
<< "failed to find recursive arguments that are structurally smaller "
|
||||
<< "(possible solution: use well-founded recursion)");
|
||||
}
|
||||
// Remark: move_params updates argument positions.
|
||||
// Thus, we copy rec_arg_pos to arg_pos.
|
||||
// We use rec_arg_pos when invoking elim_rec_apps_fn
|
||||
buffer<unsigned> arg_pos;
|
||||
arg_pos.append(rec_arg_pos);
|
||||
move_params(prgs, arg_pos);
|
||||
buffer<expr> rs;
|
||||
for (unsigned i = 0; i < prgs.size(); i++) {
|
||||
// Remark: this loop is very hackish.
|
||||
// We are "compiling" the code prgs.size() times!
|
||||
// This is wasteful. We should rewrite this.
|
||||
std::swap(prgs[0], prgs[i]);
|
||||
std::swap(arg_pos[0], arg_pos[i]);
|
||||
std::swap(rec_arg_pos[0], rec_arg_pos[i]);
|
||||
rs.push_back(compile_brec_on_core(prgs, arg_pos, rec_arg_pos));
|
||||
std::swap(prgs[0], prgs[i]);
|
||||
std::swap(arg_pos[0], arg_pos[i]);
|
||||
std::swap(rec_arg_pos[0], rec_arg_pos[i]);
|
||||
}
|
||||
|
||||
if (rs.size() > 1)
|
||||
return mk_equations_result(rs.size(), rs.data());
|
||||
else
|
||||
return rs[0];
|
||||
}
|
||||
|
||||
expr compile_wf(buffer<program> & /* prgs */) {
|
||||
// TODO(Leo)
|
||||
return expr();
|
||||
|
|
|
@ -40,6 +40,14 @@ bool is_inaccessible(expr const & e);
|
|||
expr compile_equations(type_checker & tc, io_state const & ios, expr const & eqns,
|
||||
expr const & meta, expr const & meta_type, bool relax);
|
||||
|
||||
/** \brief Return true if \c e is an auxiliary macro used to store the result of mutually recursive declarations.
|
||||
For example, if a set of recursive equations is defining \c n mutually recursive functions, we wrap
|
||||
the \c n resulting functions with an \c equations_result macro.
|
||||
*/
|
||||
bool is_equations_result(expr const & e);
|
||||
unsigned get_equations_result_size(expr const & e);
|
||||
expr const & get_equations_result(expr const & e, unsigned i);
|
||||
|
||||
void initialize_equations();
|
||||
void finalize_equations();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue