feat(library/definitional/equations): add 'equations_result' macro used to wrap multiple functions being defined by recursive equations

This is only useful when compiling "mutually recursive" functions.
This commit is contained in:
Leonardo de Moura 2015-01-05 19:08:06 -08:00
parent 3325d791de
commit 322cdb8a98
2 changed files with 97 additions and 31 deletions

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <algorithm>
#include <string>
#include "util/sstream.h"
#include "util/list_fn.h"
@ -24,13 +25,15 @@ Author: Leonardo de Moura
#include "library/tactic/inversion_tactic.h"
namespace lean {
static name * g_equations_name = nullptr;
static name * g_equation_name = nullptr;
static name * g_decreasing_name = nullptr;
static name * g_inaccessible_name = nullptr;
static std::string * g_equations_opcode = nullptr;
static std::string * g_equation_opcode = nullptr;
static std::string * g_decreasing_opcode = nullptr;
static name * g_equations_name = nullptr;
static name * g_equation_name = nullptr;
static name * g_decreasing_name = nullptr;
static name * g_inaccessible_name = nullptr;
static name * g_equations_result_name = nullptr;
static std::string * g_equations_opcode = nullptr;
static std::string * g_equation_opcode = nullptr;
static std::string * g_decreasing_opcode = nullptr;
static std::string * g_equations_result_opcode = nullptr;
[[ noreturn ]] static void throw_eqs_ex() { throw exception("unexpected occurrence of 'equations' expression"); }
@ -169,16 +172,43 @@ expr update_equations(expr const & eqns, buffer<expr> const & new_eqs) {
expr mk_inaccessible(expr const & e) { return mk_annotation(*g_inaccessible_name, e); }
bool is_inaccessible(expr const & e) { return is_annotation(e, *g_inaccessible_name); }
// Auxiliary macro used to store the result of a set of equations defining a mutually recursive
// definition.
class equations_result_macro_cell : public macro_definition_cell {
public:
virtual name get_name() const { return *g_equations_result_name; }
virtual pair<expr, constraint_seq> get_type(expr const & m, extension_context & ctx) const {
return ctx.infer_type(macro_arg(m, 0));
}
virtual optional<expr> expand(expr const & m, extension_context &) const {
return some_expr(macro_arg(m, 0));
}
virtual void write(serializer & s) const { s << *g_equations_result_opcode; }
};
static macro_definition * g_equations_result = nullptr;
static expr mk_equations_result(unsigned n, expr const * rs) {
return mk_macro(*g_equations_result, n, rs);
}
bool is_equations_result(expr const & e) { return is_macro(e) && macro_def(e) == *g_equations_result; }
unsigned get_equations_result_size(expr const & e) { return macro_num_args(e); }
expr const & get_equations_result(expr const & e, unsigned i) { return macro_arg(e, i); }
void initialize_equations() {
g_equations_name = new name("equations");
g_equation_name = new name("equation");
g_decreasing_name = new name("decreasing");
g_inaccessible_name = new name("innaccessible");
g_equation = new macro_definition(new equation_macro_cell());
g_decreasing = new macro_definition(new decreasing_macro_cell());
g_equations_opcode = new std::string("Eqns");
g_equation_opcode = new std::string("Eqn");
g_decreasing_opcode = new std::string("Decr");
g_equations_name = new name("equations");
g_equation_name = new name("equation");
g_decreasing_name = new name("decreasing");
g_inaccessible_name = new name("innaccessible");
g_equations_result_name = new name("equations_result");
g_equation = new macro_definition(new equation_macro_cell());
g_decreasing = new macro_definition(new decreasing_macro_cell());
g_equations_result = new macro_definition(new equations_result_macro_cell());
g_equations_opcode = new std::string("Eqns");
g_equation_opcode = new std::string("Eqn");
g_decreasing_opcode = new std::string("Decr");
g_equations_result_opcode = new std::string("EqnR");
register_annotation(*g_inaccessible_name);
register_macro_deserializer(*g_equations_opcode,
[](deserializer & d, unsigned num, expr const * args) {
@ -206,14 +236,21 @@ void initialize_equations() {
throw corrupted_stream_exception();
return mk_decreasing(args[0], args[1]);
});
register_macro_deserializer(*g_equations_result_opcode,
[](deserializer &, unsigned num, expr const * args) {
return mk_equations_result(num, args);
});
}
void finalize_equations() {
delete g_equations_result_opcode;
delete g_equation_opcode;
delete g_equations_opcode;
delete g_decreasing_opcode;
delete g_equations_result;
delete g_equation;
delete g_decreasing;
delete g_equations_result_name;
delete g_equations_name;
delete g_equation_name;
delete g_decreasing_name;
@ -1303,21 +1340,8 @@ class equation_compiler_fn {
lean_assert(check_program(prgs[0]));
}
expr compile_brec_on(buffer<program> & prgs) {
lean_assert(!prgs.empty());
buffer<unsigned> rec_arg_pos;
if (!find_rec_args(prgs, rec_arg_pos)) {
throw_error(sstream() << "invalid recursive equations, "
<< "failed to find recursive arguments that are structurally smaller "
<< "(possible solution: use well-founded recursion)");
}
// Remark: move_params updates argument positions.
// Thus, we copy rec_arg_pos to arg_pos.
// We use rec_arg_pos when invoking elim_rec_apps_fn
buffer<unsigned> arg_pos;
arg_pos.append(rec_arg_pos);
move_params(prgs, arg_pos);
expr compile_brec_on_core(buffer<program> const & prgs,
buffer<unsigned> const & arg_pos, buffer<unsigned> const & rec_arg_pos) {
// Return the recursive argument of the i-th program
auto get_rec_arg = [&](unsigned i) -> expr {
program const & pi = prgs[i];
@ -1513,6 +1537,40 @@ class equation_compiler_fn {
return r;
}
expr compile_brec_on(buffer<program> & prgs) {
lean_assert(!prgs.empty());
buffer<unsigned> rec_arg_pos;
if (!find_rec_args(prgs, rec_arg_pos)) {
throw_error(sstream() << "invalid recursive equations, "
<< "failed to find recursive arguments that are structurally smaller "
<< "(possible solution: use well-founded recursion)");
}
// Remark: move_params updates argument positions.
// Thus, we copy rec_arg_pos to arg_pos.
// We use rec_arg_pos when invoking elim_rec_apps_fn
buffer<unsigned> arg_pos;
arg_pos.append(rec_arg_pos);
move_params(prgs, arg_pos);
buffer<expr> rs;
for (unsigned i = 0; i < prgs.size(); i++) {
// Remark: this loop is very hackish.
// We are "compiling" the code prgs.size() times!
// This is wasteful. We should rewrite this.
std::swap(prgs[0], prgs[i]);
std::swap(arg_pos[0], arg_pos[i]);
std::swap(rec_arg_pos[0], rec_arg_pos[i]);
rs.push_back(compile_brec_on_core(prgs, arg_pos, rec_arg_pos));
std::swap(prgs[0], prgs[i]);
std::swap(arg_pos[0], arg_pos[i]);
std::swap(rec_arg_pos[0], rec_arg_pos[i]);
}
if (rs.size() > 1)
return mk_equations_result(rs.size(), rs.data());
else
return rs[0];
}
expr compile_wf(buffer<program> & /* prgs */) {
// TODO(Leo)
return expr();

View file

@ -40,6 +40,14 @@ bool is_inaccessible(expr const & e);
expr compile_equations(type_checker & tc, io_state const & ios, expr const & eqns,
expr const & meta, expr const & meta_type, bool relax);
/** \brief Return true if \c e is an auxiliary macro used to store the result of mutually recursive declarations.
For example, if a set of recursive equations is defining \c n mutually recursive functions, we wrap
the \c n resulting functions with an \c equations_result macro.
*/
bool is_equations_result(expr const & e);
unsigned get_equations_result_size(expr const & e);
expr const & get_equations_result(expr const & e, unsigned i);
void initialize_equations();
void finalize_equations();
}