diff --git a/src/library/definitional/equations.cpp b/src/library/definitional/equations.cpp index 36a29211f..22a589b67 100644 --- a/src/library/definitional/equations.cpp +++ b/src/library/definitional/equations.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include #include "util/sstream.h" #include "util/list_fn.h" @@ -24,13 +25,15 @@ Author: Leonardo de Moura #include "library/tactic/inversion_tactic.h" namespace lean { -static name * g_equations_name = nullptr; -static name * g_equation_name = nullptr; -static name * g_decreasing_name = nullptr; -static name * g_inaccessible_name = nullptr; -static std::string * g_equations_opcode = nullptr; -static std::string * g_equation_opcode = nullptr; -static std::string * g_decreasing_opcode = nullptr; +static name * g_equations_name = nullptr; +static name * g_equation_name = nullptr; +static name * g_decreasing_name = nullptr; +static name * g_inaccessible_name = nullptr; +static name * g_equations_result_name = nullptr; +static std::string * g_equations_opcode = nullptr; +static std::string * g_equation_opcode = nullptr; +static std::string * g_decreasing_opcode = nullptr; +static std::string * g_equations_result_opcode = nullptr; [[ noreturn ]] static void throw_eqs_ex() { throw exception("unexpected occurrence of 'equations' expression"); } @@ -169,16 +172,43 @@ expr update_equations(expr const & eqns, buffer const & new_eqs) { expr mk_inaccessible(expr const & e) { return mk_annotation(*g_inaccessible_name, e); } bool is_inaccessible(expr const & e) { return is_annotation(e, *g_inaccessible_name); } +// Auxiliary macro used to store the result of a set of equations defining a mutually recursive +// definition. +class equations_result_macro_cell : public macro_definition_cell { +public: + virtual name get_name() const { return *g_equations_result_name; } + virtual pair get_type(expr const & m, extension_context & ctx) const { + return ctx.infer_type(macro_arg(m, 0)); + } + virtual optional expand(expr const & m, extension_context &) const { + return some_expr(macro_arg(m, 0)); + } + virtual void write(serializer & s) const { s << *g_equations_result_opcode; } +}; + +static macro_definition * g_equations_result = nullptr; + +static expr mk_equations_result(unsigned n, expr const * rs) { + return mk_macro(*g_equations_result, n, rs); +} + +bool is_equations_result(expr const & e) { return is_macro(e) && macro_def(e) == *g_equations_result; } +unsigned get_equations_result_size(expr const & e) { return macro_num_args(e); } +expr const & get_equations_result(expr const & e, unsigned i) { return macro_arg(e, i); } + void initialize_equations() { - g_equations_name = new name("equations"); - g_equation_name = new name("equation"); - g_decreasing_name = new name("decreasing"); - g_inaccessible_name = new name("innaccessible"); - g_equation = new macro_definition(new equation_macro_cell()); - g_decreasing = new macro_definition(new decreasing_macro_cell()); - g_equations_opcode = new std::string("Eqns"); - g_equation_opcode = new std::string("Eqn"); - g_decreasing_opcode = new std::string("Decr"); + g_equations_name = new name("equations"); + g_equation_name = new name("equation"); + g_decreasing_name = new name("decreasing"); + g_inaccessible_name = new name("innaccessible"); + g_equations_result_name = new name("equations_result"); + g_equation = new macro_definition(new equation_macro_cell()); + g_decreasing = new macro_definition(new decreasing_macro_cell()); + g_equations_result = new macro_definition(new equations_result_macro_cell()); + g_equations_opcode = new std::string("Eqns"); + g_equation_opcode = new std::string("Eqn"); + g_decreasing_opcode = new std::string("Decr"); + g_equations_result_opcode = new std::string("EqnR"); register_annotation(*g_inaccessible_name); register_macro_deserializer(*g_equations_opcode, [](deserializer & d, unsigned num, expr const * args) { @@ -206,14 +236,21 @@ void initialize_equations() { throw corrupted_stream_exception(); return mk_decreasing(args[0], args[1]); }); + register_macro_deserializer(*g_equations_result_opcode, + [](deserializer &, unsigned num, expr const * args) { + return mk_equations_result(num, args); + }); } void finalize_equations() { + delete g_equations_result_opcode; delete g_equation_opcode; delete g_equations_opcode; delete g_decreasing_opcode; + delete g_equations_result; delete g_equation; delete g_decreasing; + delete g_equations_result_name; delete g_equations_name; delete g_equation_name; delete g_decreasing_name; @@ -1303,21 +1340,8 @@ class equation_compiler_fn { lean_assert(check_program(prgs[0])); } - expr compile_brec_on(buffer & prgs) { - lean_assert(!prgs.empty()); - buffer rec_arg_pos; - if (!find_rec_args(prgs, rec_arg_pos)) { - throw_error(sstream() << "invalid recursive equations, " - << "failed to find recursive arguments that are structurally smaller " - << "(possible solution: use well-founded recursion)"); - } - // Remark: move_params updates argument positions. - // Thus, we copy rec_arg_pos to arg_pos. - // We use rec_arg_pos when invoking elim_rec_apps_fn - buffer arg_pos; - arg_pos.append(rec_arg_pos); - move_params(prgs, arg_pos); - + expr compile_brec_on_core(buffer const & prgs, + buffer const & arg_pos, buffer const & rec_arg_pos) { // Return the recursive argument of the i-th program auto get_rec_arg = [&](unsigned i) -> expr { program const & pi = prgs[i]; @@ -1513,6 +1537,40 @@ class equation_compiler_fn { return r; } + expr compile_brec_on(buffer & prgs) { + lean_assert(!prgs.empty()); + buffer rec_arg_pos; + if (!find_rec_args(prgs, rec_arg_pos)) { + throw_error(sstream() << "invalid recursive equations, " + << "failed to find recursive arguments that are structurally smaller " + << "(possible solution: use well-founded recursion)"); + } + // Remark: move_params updates argument positions. + // Thus, we copy rec_arg_pos to arg_pos. + // We use rec_arg_pos when invoking elim_rec_apps_fn + buffer arg_pos; + arg_pos.append(rec_arg_pos); + move_params(prgs, arg_pos); + buffer rs; + for (unsigned i = 0; i < prgs.size(); i++) { + // Remark: this loop is very hackish. + // We are "compiling" the code prgs.size() times! + // This is wasteful. We should rewrite this. + std::swap(prgs[0], prgs[i]); + std::swap(arg_pos[0], arg_pos[i]); + std::swap(rec_arg_pos[0], rec_arg_pos[i]); + rs.push_back(compile_brec_on_core(prgs, arg_pos, rec_arg_pos)); + std::swap(prgs[0], prgs[i]); + std::swap(arg_pos[0], arg_pos[i]); + std::swap(rec_arg_pos[0], rec_arg_pos[i]); + } + + if (rs.size() > 1) + return mk_equations_result(rs.size(), rs.data()); + else + return rs[0]; + } + expr compile_wf(buffer & /* prgs */) { // TODO(Leo) return expr(); diff --git a/src/library/definitional/equations.h b/src/library/definitional/equations.h index 222dbea92..7a45fa758 100644 --- a/src/library/definitional/equations.h +++ b/src/library/definitional/equations.h @@ -40,6 +40,14 @@ bool is_inaccessible(expr const & e); expr compile_equations(type_checker & tc, io_state const & ios, expr const & eqns, expr const & meta, expr const & meta_type, bool relax); +/** \brief Return true if \c e is an auxiliary macro used to store the result of mutually recursive declarations. + For example, if a set of recursive equations is defining \c n mutually recursive functions, we wrap + the \c n resulting functions with an \c equations_result macro. +*/ +bool is_equations_result(expr const & e); +unsigned get_equations_result_size(expr const & e); +expr const & get_equations_result(expr const & e, unsigned i); + void initialize_equations(); void finalize_equations(); }