diff --git a/src/library/definitional/equations.cpp b/src/library/definitional/equations.cpp index ee35c93b7..ad91d9694 100644 --- a/src/library/definitional/equations.cpp +++ b/src/library/definitional/equations.cpp @@ -19,6 +19,7 @@ Author: Leonardo de Moura #include "library/io_state_stream.h" #include "library/annotation.h" #include "library/util.h" +#include "library/locals.h" #include "library/tactic/inversion_tactic.h" namespace lean { @@ -338,23 +339,16 @@ class equation_compiler_fn { } }; - // For debugging purposes - template - static bool contains_local(T const & locals, expr const & l) { - return std::any_of(locals.begin(), locals.end(), - [&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); }); - } - #ifdef LEAN_DEBUG // For debugging purposes: checks whether all local constants occurring in \c e // are in local_ctx or m_global_context bool check_ctx(expr const & e, list const & context, list const & local_context) const { for_each(e, [&](expr const & e, unsigned) { if (is_local(e) && - !(contains_local(local_context, e) || - contains_local(context, e) || - contains_local(m_global_context, e) || - contains_local(m_fns, e))) { + !(contains_local(e, local_context) || + contains_local(e, context) || + contains_local(e, m_global_context) || + contains_local(e, m_fns))) { lean_unreachable(); return false; } @@ -617,10 +611,6 @@ class equation_compiler_fn { return r; } - static bool contains(list const & local_ctx, expr const & p) { - return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); }); - } - expr compile_skip(program const & prg) { lean_assert(!head(prg.m_var_stack)); auto new_stack = tail(prg.m_var_stack); @@ -646,7 +636,7 @@ class equation_compiler_fn { new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs); } else { lean_assert(is_local(p)); - if (contains(e.m_local_context, p)) { + if (contains_local(p, e.m_local_context)) { list new_local_ctx = e.m_local_context; new_local_ctx = remove(new_local_ctx, p); new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); }); @@ -796,7 +786,7 @@ class equation_compiler_fn { } } - expr compile(program const & p) { + expr compile_pat_match(program const & p) { buffer vars; to_buffer(p.m_context, vars); if (!is_proof_irrelevant()) { @@ -812,11 +802,7 @@ class equation_compiler_fn { /** \brief Return true iff \c e is one of the functions being defined */ bool is_fn(expr const & e) const { - return - is_local(e) && - std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) { - return mlocal_name(fn) == mlocal_name(e); - }); + return is_local(e) && contains_local(e, m_fns); } /** \brief Return true iff the equations are recursive. */ @@ -833,18 +819,39 @@ class equation_compiler_fn { return false; } - /** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/ - bool is_inductive(expr const & t) const { - expr const & fn = get_app_fn(t); - return - is_constant(fn) && - env().find(name{const_name(fn), "brec_on"}); + /** \brief Return true if all locals are distinct local constants. */ + static bool all_distinct_locals(unsigned num, expr const * locals) { + for (unsigned i = 0; i < num; i++) { + if (!is_local(locals[i])) + return false; + if (contains_local(locals[i], locals, locals + i)) + return false; + } + return true; + } + + /** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition, + and all indices of t are in ctx. */ + bool is_rec_inductive(list const & ctx, expr const & t) const { + expr const & I = get_app_fn(t); + if (is_constant(I) && env().find(name{const_name(I), "brec_on"})) { + unsigned nindices = *inductive::get_num_indices(env(), const_name(I)); + if (nindices > 0) { + buffer args; + get_app_args(I, args); + return + all_distinct_locals(nindices, args.end() - nindices) && + std::all_of(args.end() - nindices, args.end(), [&](expr const & idx) { return contains_local(idx, ctx); }); + } else { + return true; + } + } else { + return false; + } } /** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */ bool is_compatible_inductive(expr const & t1, expr const & t2) { - lean_assert(is_inductive(t1)); - lean_assert(is_inductive(t2)); buffer args1, args2; name const & I1 = const_name(get_app_args(t1, args1)); name const & I2 = const_name(get_app_args(t2, args2)); @@ -986,7 +993,7 @@ class equation_compiler_fn { expr const & v = p.get_var(*n); expr const & t = mlocal_type(v); if (// argument must be an inductive datatype - is_inductive(t) && + is_rec_inductive(p.m_context, t) && // argument must be an inductive datatype different from the ones in arg_types std::all_of(arg_types.begin(), arg_types.end(), [&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) && @@ -1011,23 +1018,149 @@ class equation_compiler_fn { return find_rec_args(prgs, 0, arg_pos, arg_types); } - void apply_brec_on(buffer & prgs) { + expr compile_brec_on(buffer & prgs) { lean_assert(!prgs.empty()); buffer arg_pos; if (!find_rec_args(prgs, arg_pos)) { throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller " << "(possible solution: use well-founded recursion)"); } - // out() << "Found recursive arguments: "; - // for (unsigned p : arg_pos) out() << " " << p; out() << "\n"; - // TODO(Leo): + // Return the recursive argument of the i-th program + auto get_rec_arg = [&](unsigned i) -> expr { + program const & pi = prgs[i]; + return get_ith(pi.m_context, arg_pos[i]); + }; + + // Return the type of the recursive argument of the i-th program + auto get_rec_type = [&](unsigned i) -> expr { + return mlocal_type(get_rec_arg(i)); + }; + + // Return the program associated with the inductive datatype named I_name. + // Return none if there isn't one. + auto get_prg_for = [&](name const & I_name) -> optional { + for (unsigned i = 0; i < prgs.size(); i++) { + expr const & t = get_rec_type(i); + if (const_name(get_app_fn(t)) == I_name) + return optional(i); + } + return optional(); + }; + + expr const & a0_type = get_rec_type(0); + lean_assert(is_rec_inductive(prgs[0].m_context, a0_type)); + buffer a0_type_args; + expr const & I0 = get_app_args(a0_type, a0_type_args); + inductive::inductive_decls t = *inductive::is_inductive_decl(env(), const_name(I0)); + unsigned nparams = std::get<1>(t); + list decls = std::get<2>(t); + + // Distribute parameters of the ith program intro three groups: + // indices, major premise (arg), and remaining arguments (rest) + auto distribute_context = [&](unsigned i, buffer & indices, expr & arg, buffer & rest) { + program const & p = prgs[i]; + arg = get_rec_arg(i); + list const & ctx = p.m_context; + buffer arg_args; + get_app_args(mlocal_type(arg), arg_args); + lean_assert(nparams <= arg_args.size()); + indices.append(arg_args.size() - nparams, arg_args.data() + nparams); + for (expr const & l : ctx) { + if (mlocal_name(l) != mlocal_name(arg) && !contains_local(l, indices)) + rest.push_back(l); + } + }; + + // Compute the resulting universe level for brec_on + auto get_brec_on_result_level = [&]() -> level { + buffer indices, rest; + expr arg; + distribute_context(0, indices, arg, rest); + expr r_type = Pi(indices, prgs[0].m_type); + return sort_level(m_tc.ensure_type(r_type).first); + }; + + level rlvl = get_brec_on_result_level(); + levels brec_on_lvls = cons(rlvl, const_levels(I0)); + expr brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls); + buffer params; + // add parameters + for (unsigned i = 0; i < nparams; i++) { + params.push_back(a0_type_args[i]); + brec_on = mk_app(brec_on, a0_type_args[i]); + } + + buffer Cs; // brec_on "motives" + buffer> C_args_buffer; + for (inductive::inductive_decl const & decl : decls) { + name const & I_name = inductive::inductive_decl_name(decl); + expr C; + C_args_buffer.push_back(buffer()); + buffer & C_args = C_args_buffer.back(); + if (optional p_idx = get_prg_for(I_name)) { + buffer indices, rest; expr arg; + distribute_context(*p_idx, indices, arg, rest); + expr type = Pi(rest, prgs[*p_idx].m_type); + C_args.append(indices); + C_args.push_back(arg); + C = Fun(C_args, type); + } else { + expr type = whnf(infer_type(brec_on)); + expr d = binding_domain(type); + expr unit = mk_constant("unit", rlvl); + to_telescope_ext(d, C_args); + C = Fun(C_args, unit); + } + brec_on = mk_app(brec_on, C); + Cs.push_back(C); + } + + // add functionals + unsigned i = 0; + for (inductive::inductive_decl const & decl : decls) { + name const & I_name = inductive::inductive_decl_name(decl); + expr below = mk_constant(name{I_name, "below"}, brec_on_lvls); + below = mk_app(mk_app(below, params), Cs); + expr F; + buffer & C_args = C_args_buffer[i]; + if (optional p_idx = get_prg_for(I_name)) { + program & prg_i = prgs[*p_idx]; + buffer indices, rest; expr arg; + distribute_context(*p_idx, indices, arg, rest); + below = mk_app(mk_app(below, indices), arg); + expr b = mk_local(mk_fresh_name(), "b", below, binder_info()); + buffer new_ctx; + new_ctx.append(indices); + new_ctx.push_back(arg); + new_ctx.push_back(b); + new_ctx.append(rest); + prg_i.m_context = to_list(new_ctx); + // TODO(Leo): replace recursive calls with "b" applications + F = compile_pat_match(prg_i); + } else { + expr star = mk_constant(name{"unit", "star"}, rlvl); + buffer F_args; + F_args.append(C_args); + below = mk_app(below, F_args); + F_args.push_back(mk_local(mk_fresh_name(), "b", below, binder_info())); + F = Fun(F_args, star); + } + brec_on = mk_app(brec_on, F); + i++; + } + + // out() << "brec_on: " << brec_on << "\n"; + + return brec_on; } - void apply_wf(buffer & /* prgs */) { + expr compile_wf(buffer & /* prgs */) { // TODO(Leo) + return expr(); } + public: equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax): m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) { @@ -1041,21 +1174,14 @@ public: initialize(eqns, prgs); if (is_recursive(prgs)) { if (is_wf_equations(eqns)) { - apply_wf(prgs); + return compile_wf(prgs); } else { - apply_brec_on(prgs); + return compile_brec_on(prgs); } + } else { + lean_assert(prgs.size() == 1); + return compile_pat_match(prgs[0]); } - buffer rs; - for (program const & p : prgs) { - expr r = compile(p); - // out() << r << "\nType: " << m_tc.infer(r).first << "\n"; - rs.push_back(r); - } - if (closed(rs[0])) - return rs[0]; - else - return m_meta; } };