diff --git a/src/library/definitional/equations.cpp b/src/library/definitional/equations.cpp index dbd689ea4..2603007b3 100644 --- a/src/library/definitional/equations.cpp +++ b/src/library/definitional/equations.cpp @@ -20,6 +20,7 @@ Author: Leonardo de Moura #include "library/annotation.h" #include "library/util.h" #include "library/locals.h" +#include "library/normalize.h" #include "library/tactic/inversion_tactic.h" namespace lean { @@ -1016,7 +1017,136 @@ class equation_compiler_fn { return find_rec_args(prgs, 0, arg_pos, arg_types); } - expr compile_brec_on(buffer & prgs) { + // Auxiliary function object used to eliminate recursive applications using "below" applications + struct elim_rec_apps_fn { + equation_compiler_fn & m_main; + buffer const & m_prgs; + unsigned m_nparams; + buffer const & m_below_cnsts; // below constants + buffer const & m_Cs_locals; // auxiliary local constants representing the "motives" + buffer const & m_rec_arg_pos; // position of recursive argument for each program + buffer> const & m_rest_pos; // position of remaining arguments for each program + + elim_rec_apps_fn(equation_compiler_fn & m, buffer const & prgs, unsigned nparams, + buffer const & below_cnsts, buffer const & Cs_locals, buffer const & rec_arg_pos, + buffer> const & rest_pos): + m_main(m), m_prgs(prgs), m_nparams(nparams), m_below_cnsts(below_cnsts), m_Cs_locals(Cs_locals), + m_rec_arg_pos(rec_arg_pos), m_rest_pos(rest_pos) {} + + bool is_below_type(expr const & t) const { + expr const & fn = get_app_fn(t); + return is_constant(fn) && std::find(m_below_cnsts.begin(), m_below_cnsts.end(), fn) != m_below_cnsts.end(); + } + + /** \brief Retrieve \c a from the below dictionary \c d. \c d is a term made of products, and C's from (m_Cs_locals). + \c b is the below constant that was used to create the below dictionary \c d. + */ + optional to_below(expr const & d, expr const & a, expr const & b) { + expr const & fn = get_app_fn(d); + if (is_constant(fn) && const_name(fn) == "prod") { + if (auto r = to_below(app_arg(app_fn(d)), a, mk_pr1(m_main.m_tc, b))) + return r; + else if (auto r = to_below(app_arg(d), a, mk_pr2(m_main.m_tc, b))) + return r; + else + return none_expr(); + } else if (is_constant(fn) && const_name(fn) == "and") { + // For ibelow, we use "and" instead of products + if (auto r = to_below(app_arg(app_fn(d)), a, mk_and_elim_left(m_main.m_tc, b))) + return r; + else if (auto r = to_below(app_arg(d), a, mk_and_elim_right(m_main.m_tc, b))) + return r; + else + return none_expr(); + } else if (is_local(fn)) { + for (expr const & C : m_Cs_locals) { + if (mlocal_name(C) == mlocal_name(fn) && app_arg(d) == a) + return some_expr(b); + } + return none_expr(); + } else if (is_pi(d)) { + // TODO(Leo) + return none_expr(); + } else { + return none_expr(); + } + } + + expr elim(unsigned prg_idx, buffer const & args, expr const & below) { + // Replace motives with abstract ones. We use the abstract motives (m_Cs_locals) as "markers" + buffer below_args; + expr const & below_cnst = get_app_args(mlocal_type(below), below_args); + buffer abst_below_args; + abst_below_args.append(m_nparams, below_args.data()); + abst_below_args.append(m_Cs_locals); + for (unsigned i = m_nparams + m_Cs_locals.size(); i < below_args.size(); i++) + abst_below_args.push_back(below_args[i]); + expr abst_below = mk_app(below_cnst, abst_below_args); + expr below_dict = normalize(m_main.m_tc, abst_below); + expr rec_arg = normalize(m_main.m_tc, args[m_rec_arg_pos[prg_idx]]); + if (optional b = to_below(below_dict, rec_arg, below)) { + expr r = *b; + for (unsigned rest_pos : m_rest_pos[prg_idx]) + r = mk_app(r, args[rest_pos]); + return r; + } else { + m_main.throw_error(sstream() << "failed to compile recursive equations using brec_on approach (possible solution: use well-founded recursion)"); + } + } + + /** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */ + expr elim(expr const & e, optional const & b) { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Meta: + case expr_kind::Local: case expr_kind::Constant: + case expr_kind::Sort: + return e; + case expr_kind::Macro: { + buffer new_args; + for (unsigned i = 0; i < macro_num_args(e); i++) + new_args.push_back(elim(macro_arg(e, i), b)); + return update_macro(e, new_args.size(), new_args.data()); + } + case expr_kind::App: { + buffer args; + expr const & fn = get_app_args(e, args); + expr new_fn = elim(fn, b); + buffer new_args; + for (expr const & arg : args) + new_args.push_back(elim(arg, b)); + if (is_local(fn) && b) { + for (unsigned j = 0; j < m_prgs.size(); j++) { + if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) { + return elim(j, new_args, *b); + } + } + } + return mk_app(new_fn, new_args); + } + case expr_kind::Lambda: { + expr local = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e)); + expr body = instantiate(binding_body(e), local); + expr new_body; + if (is_below_type(binding_domain(e))) + new_body = elim(body, some_expr(local)); + else + new_body = elim(body, b); + return Fun(local, new_body); + } + case expr_kind::Pi: { + expr local = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e)); + expr new_body = elim(instantiate(binding_body(e), local), b); + return Pi(local, new_body); + }} + lean_unreachable(); + } + + expr operator()(expr const & e) { + return elim(e, none_expr()); + } + }; + + expr compile_brec_on(buffer const & prgs) { lean_assert(!prgs.empty()); buffer arg_pos; if (!find_rec_args(prgs, arg_pos)) { @@ -1054,9 +1184,15 @@ class equation_compiler_fn { unsigned nparams = std::get<1>(t); list decls = std::get<2>(t); + // TODO(Leo): move parameters to global context. + // we should also check if the user tried to perform pattern matching on parameters + // Distribute parameters of the ith program intro three groups: // indices, major premise (arg), and remaining arguments (rest) - auto distribute_context = [&](unsigned i, buffer & indices, expr & arg, buffer & rest) { + // We store the position of the rest arguments in the buffer rest_pos. + // The buffer rest_pos is used to replace the recursive applications with below applications. + auto distribute_context_core = [&](unsigned i, buffer & indices, expr & arg, buffer & rest, + buffer & indices_pos, buffer & rest_pos) { program const & p = prgs[i]; arg = get_rec_arg(i); list const & ctx = p.m_context; @@ -1064,24 +1200,62 @@ class equation_compiler_fn { get_app_args(mlocal_type(arg), arg_args); lean_assert(nparams <= arg_args.size()); indices.append(arg_args.size() - nparams, arg_args.data() + nparams); + unsigned j = 0; for (expr const & l : ctx) { - if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices)) + if (mlocal_name(l) == mlocal_name(arg)) { + // do nothing + } else if (contains_local(l, indices)) { + indices_pos.push_back(j); + } else { rest.push_back(l); + rest_pos.push_back(j); + } + j++; } }; + auto distribute_context = [&](unsigned i, buffer & indices, expr & arg, buffer & rest) { + buffer indices_pos, rest_pos; + distribute_context_core(i, indices, arg, rest, indices_pos, rest_pos); + }; + // Compute the resulting universe level for brec_on auto get_brec_on_result_level = [&]() -> level { - buffer indices, rest; - expr arg; + buffer indices, rest; expr arg; distribute_context(0, indices, arg, rest); - expr r_type = Pi(indices, prgs[0].m_type); + expr r_type = Pi(rest, prgs[0].m_type); return sort_level(m_tc.ensure_type(r_type).first); }; level rlvl = get_brec_on_result_level(); - levels brec_on_lvls = cons(rlvl, const_levels(I0)); - expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls); + bool reflexive = env().prop_proof_irrel() && is_reflexive_datatype(m_tc, const_name(I0)); + bool use_ibelow = reflexive && is_zero(rlvl); + if (reflexive) { + if (!is_zero(rlvl) && !is_not_zero(rlvl)) + throw_error(sstream() << "invalid recursive equations, when trying to recurse over reflexive inductive datatype, " + << "the universe level of the resultant universe must be zero OR not zero for every level assignment"); + if (!is_zero(rlvl)) { + // For reflexive type, the type of brec_on and ibelow perform a +1 on the motive universe. + // Example: for a reflexive formula type, we have: + // formula.below.{l_1} : Π {C : formula → Type.{l_1+1}}, formula → Type.{max (l_1+1) 1} + if (auto dlvl = dec_level(rlvl)) { + rlvl = *dlvl; + } else { + throw_error(sstream() << "invalid recursive equations, when trying to recurse over reflexive inductive datatype, " + << "the universe level of the resultant universe must be zero OR not zero for every level assignment, " + << "the compiler managed to establish that the resultant universe level L is never zero, but fail to comput L-1"); + } + } + } + levels brec_on_lvls; + expr brec_on; + if (use_ibelow) { + brec_on_lvls = const_levels(I0); + brec_on = mk_constant(name{const_name(I0), "binduction_on"}, brec_on_lvls); + } else { + brec_on_lvls = cons(rlvl, const_levels(I0)); + brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls); + } buffer params; // add parameters for (unsigned i = 0; i < nparams; i++) { @@ -1090,12 +1264,19 @@ class equation_compiler_fn { } buffer Cs; // brec_on "motives" + // The following loop fills Cs_locals with auxiliary local constants that will be used to + // convert recursive applications into "below applications". + // These constants are essentially abstracting Cs. + buffer Cs_locals; buffer> C_args_buffer; for (inductive::inductive_decl const & decl : decls) { name const & I_name = inductive::inductive_decl_name(decl); expr C; C_args_buffer.push_back(buffer()); buffer & C_args = C_args_buffer.back(); + expr C_type = whnf(infer_type(brec_on)); + expr C_local = mk_local(mk_fresh_name(), "C", C_type, binder_info()); + Cs_locals.push_back(C_local); if (optional p_idx = get_prg_for(I_name)) { buffer indices, rest; expr arg; distribute_context(*p_idx, indices, arg, rest); @@ -1104,8 +1285,7 @@ class equation_compiler_fn { C_args.push_back(arg); C = Fun(C_args, type); } else { - expr type = whnf(infer_type(brec_on)); - expr d = binding_domain(type); + expr d = binding_domain(C_type); expr unit = mk_constant("unit", rlvl); to_telescope_ext(d, C_args); C = Fun(C_args, unit); @@ -1114,18 +1294,32 @@ class equation_compiler_fn { Cs.push_back(C); } + // add indices and major + buffer indices0, rest0; expr arg0; + distribute_context(0, indices0, arg0, rest0); + brec_on = mk_app(mk_app(brec_on, indices0), arg0); + // add functionals unsigned i = 0; + buffer below_cnsts; + buffer> rest_arg_pos; for (inductive::inductive_decl const & decl : decls) { name const & I_name = inductive::inductive_decl_name(decl); - expr below = mk_constant(name{I_name, "below"}, brec_on_lvls); - below = mk_app(mk_app(below, params), Cs); + expr below_cnst; + if (use_ibelow) + below_cnst = mk_constant(name{I_name, "ibelow"}, brec_on_lvls); + else + below_cnst = mk_constant(name{I_name, "below"}, brec_on_lvls); + below_cnsts.push_back(below_cnst); + expr below = mk_app(mk_app(below_cnst, params), Cs); expr F; buffer & C_args = C_args_buffer[i]; + rest_arg_pos.push_back(buffer()); if (optional p_idx = get_prg_for(I_name)) { - program & prg_i = prgs[*p_idx]; - buffer indices, rest; expr arg; - distribute_context(*p_idx, indices, arg, rest); + program const & prg_i = prgs[*p_idx]; + buffer indices, rest; expr arg; buffer indices_pos; + buffer & rest_pos = rest_arg_pos.back(); + distribute_context_core(*p_idx, indices, arg, rest, indices_pos, rest_pos); below = mk_app(mk_app(below, indices), arg); expr b = mk_local(mk_fresh_name(), "b", below, binder_info()); buffer new_ctx; @@ -1133,9 +1327,7 @@ class equation_compiler_fn { new_ctx.push_back(arg); new_ctx.push_back(b); new_ctx.append(rest); - prg_i.m_context = to_list(new_ctx); - // TODO(Leo): replace recursive calls with "b" applications - F = compile_pat_match(prg_i); + F = compile_pat_match(program(prg_i, to_list(new_ctx))); } else { expr star = mk_constant(name{"unit", "star"}, rlvl); buffer F_args; @@ -1147,10 +1339,14 @@ class equation_compiler_fn { brec_on = mk_app(brec_on, F); i++; } + expr r = elim_rec_apps_fn(*this, prgs, nparams, below_cnsts, Cs_locals, arg_pos, rest_arg_pos)(brec_on); + // add remaining arguments + r = mk_app(r, rest0); - // out() << "brec_on: " << brec_on << "\n"; - - return brec_on; + buffer ctx0_buffer; + to_buffer(prgs[0].m_context, ctx0_buffer); + r = Fun(ctx0_buffer, r); + return r; } expr compile_wf(buffer & /* prgs */) { diff --git a/tests/lean/run/eq4.lean b/tests/lean/run/eq4.lean new file mode 100644 index 000000000..221adb043 --- /dev/null +++ b/tests/lean/run/eq4.lean @@ -0,0 +1,21 @@ +open nat + +definition half : nat → nat, +half 0 := 0, +half 1 := 0, +half (x+2) := half x + 1 + +theorem half0 : half 0 = 0 := +rfl + +theorem half1 : half 1 = 0 := +rfl + +theorem half_succ_succ (a : nat) : half (a + 2) = half a + 1 := +rfl + +example : half 5 = 2 := +rfl + +example : half 8 = 4 := +rfl