From 7f7d318b22a443f5ac8ca1ed1eda0a5fe74f123e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 2 Jan 2015 22:05:07 -0800 Subject: [PATCH] feat(library/definitional/equations): add dependent pattern matching compilation --- src/library/definitional/equations.cpp | 834 +++++++++++++++++++++++-- src/library/util.cpp | 2 +- tests/lean/extra/rec.lean | 22 +- tests/lean/extra/rec3.lean | 4 +- 4 files changed, 811 insertions(+), 51 deletions(-) diff --git a/src/library/definitional/equations.cpp b/src/library/definitional/equations.cpp index 4c2933602..2245122ee 100644 --- a/src/library/definitional/equations.cpp +++ b/src/library/definitional/equations.cpp @@ -5,10 +5,15 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include "util/sstream.h" +#include "util/list_fn.h" #include "kernel/expr.h" #include "kernel/type_checker.h" #include "kernel/abstract.h" +#include "kernel/instantiate.h" #include "kernel/error_msgs.h" +#include "kernel/for_each_fn.h" +#include "kernel/find_fn.h" #include "library/generic_exception.h" #include "library/kernel_serializer.h" #include "library/io_state_stream.h" @@ -220,44 +225,233 @@ class equation_compiler_fn { expr m_meta; expr m_meta_type; bool m_relax; + buffer m_global_context; + buffer m_fns; // functions being defined environment const & env() const { return m_tc.env(); } + io_state const & ios() const { return m_ios; } + io_state_stream out() const { return regular(env(), ios()); } + name mk_fresh_name() { return m_tc.mk_fresh_name(); } + expr whnf(expr const & e) { return m_tc.whnf(e).first; } + expr infer_type(expr const & e) { return m_tc.infer(e).first; } + bool is_def_eq(expr const & e1, expr const & e2) { return m_tc.is_def_eq(e1, e2).first; } + + optional is_constructor(expr const & e) const { + if (!is_constant(e)) + return optional(); + return inductive::is_intro_rule(env(), const_name(e)); + } + + expr to_telescope(expr const & e, buffer & tele) { + name_generator ngen = m_tc.mk_ngen(); + return ::lean::to_telescope(ngen, e, tele, optional()); + } + + expr fun_to_telescope(expr const & e, buffer & tele) { + name_generator ngen = m_tc.mk_ngen(); + return ::lean::fun_to_telescope(ngen, e, tele, optional()); + } + + // Similar to to_telescope, but uses normalization + expr to_telescope_ext(expr const & e, buffer & tele) { + return ::lean::to_telescope(m_tc, e, tele, optional()); + } + + [[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); } + [[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); } + [[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); } + + void check_limitations(expr const & eqns) const { + if (is_wf_equations(eqns) && equations_num_fns(eqns) != 1) + throw_error("mutually recursive equations do not support well-founded recursion yet", eqns); + } + +#ifdef LEAN_DEBUG + static bool disjoint(list const & l1, list const & l2) { + for (expr const & e1 : l1) { + for (expr const & e2 : l2) { + lean_assert(mlocal_name(e1) != mlocal_name(e2)); + } + } + return true; + } + + // Return true iff all names in s1 are names of local constants in s2. + static bool contained(list> const & s1, list const & s2) { + return std::all_of(s1.begin(), s1.end(), [&](optional const & n) { + return + !n || + std::any_of(s2.begin(), s2.end(), [&](expr const & e) { + return mlocal_name(e) == *n; + }); + }); + } +#endif + + struct eqn { + // The local context for an equation is of additional + // local constants occurring in m_patterns and m_rhs + // which are not in m_global_context or + // in the function containing the equation. + // Remark: each function/program contains its own m_context. + // So, the variables occurring in m_patterns and m_rhs + // are in m_global_context, m_context, or m_local_context, + // or is one of the functions being defined. + // We say an equation is in "compiled" form + // if m_local_context and m_patterns are empty. + list m_local_context; + list m_patterns; // patterns to be processed + expr m_rhs; // right-hand-side + eqn(list const & c, list const & p, expr const & r): + m_local_context(c), m_patterns(p), m_rhs(r) {} + }; + + // Data-structure used to store for compiling pattern matching. + // We create a program object for each function being defined + struct program { + expr m_fn; // function being defined + list m_context; // local constants + list> m_var_stack; // variables that must be matched with the patterns it is a "subset" of m_context. + list m_eqns; // equations + expr m_type; // result type + + // Due to dependent pattern matching some elements in m_var_stack are "none", and are skipped + // during dependent pattern matching. + + // The goal of the compiler is to process all variables in m_var_stack + program(expr const & fn, list const & ctx, list> const & s, list const & e, expr const & t): + m_fn(fn), m_context(ctx), m_var_stack(s), m_eqns(e), m_type(t) { + lean_assert(contained(m_var_stack, m_context)); + } + program(program const & p, list> const & new_s, list const & new_e): + program(p.m_fn, p.m_context, new_s, new_e, p.m_type) {} + program() {} + expr const & get_var(name const & n) const { + for (expr const & v : m_context) { + if (mlocal_name(v) == n) + return v; + } + lean_unreachable(); + } + }; + + // For debugging purposes + template + static bool contains_local(T const & locals, expr const & l) { + return std::any_of(locals.begin(), locals.end(), + [&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); }); + } + +#ifdef LEAN_DEBUG + // For debugging purposes: checks whether all local constants occurring in \c e + // are in local_ctx or m_global_context + bool check_ctx(expr const & e, list const & context, list const & local_context) const { + for_each(e, [&](expr const & e, unsigned) { + if (is_local(e) && + !(contains_local(local_context, e) || + contains_local(context, e) || + contains_local(m_global_context, e) || + contains_local(m_fns, e))) { + lean_unreachable(); + return false; + } + return true; + }); + return true; + } + + // For debugging purposes: check if the program is well-formed + bool check_program(program const & s) const { + unsigned sz = length(s.m_var_stack); + lean_assert(contained(s.m_var_stack, s.m_context)); + for (eqn const & e : s.m_eqns) { + // the number of patterns in each equation is equal to the variable stack size + if (length(e.m_patterns) != sz) { + lean_unreachable(); + return false; + } + check_ctx(e.m_rhs, s.m_context, e.m_local_context); + for (expr const & p : e.m_patterns) + check_ctx(p, s.m_context, e.m_local_context); + lean_assert(disjoint(e.m_local_context, s.m_context)); + } + return true; + } +#endif + + // Initialize m_fns (the vector of functions to be compiled) + void initialize_fns(expr const & eqns) { + lean_assert(is_equations(eqns)); + unsigned num_fns = equations_num_fns(eqns); + buffer eqs; + to_equations(eqns, eqs); + expr eq = eqs[0]; + for (unsigned i = 0; i < num_fns; i++) { + expr fn = mk_local(mk_fresh_name(), binding_name(eq), binding_domain(eq), binder_info()); + m_fns.push_back(fn); + eq = instantiate(binding_body(eq), fn); + } + } + + // Initialize the variable stack for each function that needs + // to be compiled. + // This method assumes m_fns has been already initialized. + // This method also initialized the buffer prg, but the eqns + // field of each program is not initialized by it. + void initialize_var_stack(buffer & prgs) { + lean_assert(!m_fns.empty()); + lean_assert(prgs.empty()); + for (expr const & fn : m_fns) { + buffer args; + expr r_type = to_telescope(mlocal_type(fn), args); + list ctx = to_list(args); + list> vstack = map2>(ctx, [](expr const & e) { + return optional(mlocal_name(e)); + }); + prgs.push_back(program(fn, ctx, vstack, list(), r_type)); + } + } struct validate_exception { expr m_expr; validate_exception(expr const & e):m_expr(e) {} }; - [[ noreturn ]] void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); } - [[ noreturn ]] void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); } - - // -------------------------------- - // Pattern validation/normalization - // -------------------------------- - - expr validate_lhs_arg(expr arg) { - if (is_inaccessible(arg)) - return arg; - if (is_local(arg)) - return arg; - expr new_arg = m_tc.whnf(arg).first; - if (is_local(new_arg)) - return new_arg; - buffer arg_args; - expr const & fn = get_app_args(new_arg, arg_args); - if (!is_constant(fn) || !inductive::is_intro_rule(env(), const_name(fn))) - throw validate_exception(arg); - for (expr & arg_arg : arg_args) - arg_arg = validate_lhs_arg(arg_arg); - return mk_app(fn, arg_args); + // Validate/normalize the given pattern. + // It stores in reachable_vars any variable that does not occur + // in inaccessible terms. + expr validate_pattern(expr pat, name_set & reachable_vars) { + if (is_inaccessible(pat)) + return pat; + if (is_local(pat)) { + reachable_vars.insert(mlocal_name(pat)); + return pat; + } + expr new_pat = whnf(pat); + if (is_local(new_pat)) { + reachable_vars.insert(mlocal_name(new_pat)); + return new_pat; + } + buffer pat_args; + expr const & fn = get_app_args(new_pat, pat_args); + if (auto in = is_constructor(fn)) { + unsigned num_params = *inductive::get_num_params(env(), *in); + for (unsigned i = num_params; i < pat_args.size(); i++) + pat_args[i] = validate_pattern(pat_args[i], reachable_vars); + return mk_app(fn, pat_args); + } else { + throw validate_exception(pat); + } } - expr validate_lhs(expr const & lhs) { - buffer args; - expr fn = get_app_args(lhs, args); - for (expr & arg : args) { + // Validate/normalize the patterns associated with the given lhs. + // The lhs is only used to report errors. + // It stores in reachable_vars any variable that does not occur + // in inaccessible terms. + void validate_patterns(expr const & lhs, buffer & patterns, name_set & reachable_vars) { + for (expr & pat : patterns) { try { - arg = validate_lhs_arg(arg); + pat = validate_pattern(pat, reachable_vars); } catch (validate_exception & ex) { expr problem_expr = ex.m_expr; throw_error(lhs, [=](formatter const & fmt) { @@ -270,40 +464,586 @@ class equation_compiler_fn { }); } } - return mk_app(fn, args); } - expr validate_patterns_core(expr eq) { - buffer args; - name_generator ngen = m_tc.mk_ngen(); - eq = fun_to_telescope(ngen, eq, args, optional()); - lean_assert(is_equation(eq)); - expr new_lhs = validate_lhs(equation_lhs(eq)); - return Fun(args, mk_equation(new_lhs, equation_rhs(eq))); - } - - expr validate_patterns(expr const & eqns) { + // Create initial program state for each function being defined. + void initialize(expr const & eqns, buffer & prg) { lean_assert(is_equations(eqns)); + initialize_fns(eqns); + initialize_var_stack(prg); buffer eqs; - buffer new_eqs; to_equations(eqns, eqs); - for (expr const & eq : eqs) { - new_eqs.push_back(validate_patterns_core(eq)); + buffer> res_eqns; + res_eqns.resize(m_fns.size()); + for (expr eq : eqs) { + for (expr const & fn : m_fns) + eq = instantiate(binding_body(eq), fn); + buffer local_ctx; + eq = fun_to_telescope(eq, local_ctx); + expr const & lhs = equation_lhs(eq); + expr const & rhs = equation_rhs(eq); + buffer patterns; + expr const & fn = get_app_args(lhs, patterns); + name_set reachable_vars; + validate_patterns(lhs, patterns, reachable_vars); + for (expr const & v : local_ctx) { + // every variable in the local_ctx must be "reachable". + if (!reachable_vars.contains(mlocal_name(v))) { + throw_error(lhs, [=](formatter const & fmt) { + format r("invalid equation left-hand-side, variable '"); + r += format(local_pp_name(v)); + r += format("' only occurs in inaccessible terms in the following equation left-hand-side"); + r += pp_indent_expr(fmt, lhs); + return r; + }); + } + } + for (unsigned i = 0; i < m_fns.size(); i++) { + if (mlocal_name(fn) == mlocal_name(m_fns[i])) { + if (patterns.size() != length(prg[i].m_var_stack)) + throw_error("ill-formed equation, number of provided arguments does not match function type", eq); + res_eqns[i].push_back(eqn(to_list(local_ctx), to_list(patterns), rhs)); + } + } } - return update_equations(eqns, new_eqs); + for (unsigned i = 0; i < m_fns.size(); i++) { + prg[i].m_eqns = to_list(res_eqns[i]); + lean_assert(check_program(prg[i])); + } + } + + // For debugging purposes: display the context at m_ios + template + void display_ctx(Ctx const & ctx) const { + bool first = true; + for (expr const & e : ctx) { + out() << (first ? "" : ", ") << local_pp_name(e) << " : " << mlocal_type(e); + first = false; + } + } + + // For debugging purposes: dump prg in m_ios + void display(program const & prg) const { + display_ctx(prg.m_context); + out() << " ;;"; + for (optional const & v : prg.m_var_stack) { + if (v) + out() << " " << local_pp_name(prg.get_var(*v)); + else + out() << " "; + } + out() << " |- " << prg.m_type << "\n"; + out() << "\n"; + for (eqn const & e : prg.m_eqns) { + out() << "> "; + display_ctx(e.m_local_context); + out() << " |-"; + for (expr const & p : e.m_patterns) { + if (is_atomic(p)) + out() << " " << p; + else + out() << " (" << p << ")"; + } + out() << " := " << e.m_rhs << "\n"; + } + } + + // Return true iff the next pattern in all equations is a variable or an inaccessible term + bool is_variable_transition(program const & p) const { + for (eqn const & e : p.m_eqns) { + lean_assert(e.m_patterns); + if (!is_local(head(e.m_patterns)) && !is_inaccessible(head(e.m_patterns))) + return false; + } + return true; + } + + // Return true iff the next pattern in all equations is a constructor + bool is_constructor_transition(program const & p) const { + for (eqn const & e : p.m_eqns) { + lean_assert(e.m_patterns); + if (!is_constructor(get_app_fn(head(e.m_patterns)))) + return false; + } + return true; + } + + // Return true iff the next pattern of every equation is a constructor or variable, + // and there are at least one equation where it is a variable and another where it is a + // constructor. + bool is_complete_transition(program const & p) const { + bool has_variable = false; + bool has_constructor = false; + for (eqn const & e : p.m_eqns) { + lean_assert(e.m_patterns); + expr const & p = head(e.m_patterns); + if (is_local(p)) + has_variable = true; + else if (is_constructor(get_app_fn(p))) + has_constructor = true; + else + return false; + } + return has_variable && has_constructor; + } + + // Remove variable from local context + static list remove(list const & local_ctx, expr const & l) { + if (!local_ctx) + return local_ctx; + else if (mlocal_name(head(local_ctx)) == mlocal_name(l)) + return tail(local_ctx); + else + return cons(head(local_ctx), remove(tail(local_ctx), l)); + } + + // Replace local constant \c from with \c to in the expression \c e. + static expr replace(expr const & e, expr const & from, expr const & to) { + return instantiate(abstract_local(e, from), to); + } + + static expr replace(expr const & e, name const & from, expr const & to) { + return instantiate(abstract_local(e, from), to); + } + + static expr replace(expr const & e, name_map const & subst) { + expr r = e; + subst.for_each([&](name const & l, expr const & v) { + r = replace(r, l, v); + }); + return r; + } + + static bool contains(list const & local_ctx, expr const & p) { + return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); }); + } + + expr compile_skip(program const & prg) { + lean_assert(!head(prg.m_var_stack)); + auto new_stack = tail(prg.m_var_stack); + buffer new_eqs; + for (eqn const & e : prg.m_eqns) { + auto new_patterns = tail(e.m_patterns); + new_eqs.emplace_back(e.m_local_context, new_patterns, e.m_rhs); + } + return compile_core(program(prg, new_stack, to_list(new_eqs))); + } + + expr compile_variable(program const & prg) { + // The next pattern of every equation is a variable (or inaccessible term). + // Thus, we just rename them with the variable on + // the top of the variable stack. + // Remark: if the pattern is an inaccessible term, we just ignore it. + expr x = prg.get_var(*head(prg.m_var_stack)); + auto new_stack = tail(prg.m_var_stack); + buffer new_eqs; + for (eqn const & e : prg.m_eqns) { + expr p = head(e.m_patterns); + if (is_inaccessible(p)) { + new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs); + } else { + lean_assert(is_local(p)); + if (contains(e.m_local_context, p)) { + list new_local_ctx = e.m_local_context; + new_local_ctx = remove(new_local_ctx, p); + new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); }); + auto new_patterns = map(tail(e.m_patterns), [&](expr const & p2) { return replace(p2, p, x); }); + auto new_rhs = replace(e.m_rhs, p, x); + new_eqs.emplace_back(new_local_ctx, new_patterns, new_rhs); + } else { + new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs); + } + } + } + return compile_core(program(prg, new_stack, to_list(new_eqs))); + } + + class implementation : public inversion::implementation { + eqn m_eqn; + public: + implementation(eqn const & e):m_eqn(e) {} + + eqn const & get_eqn() const { return m_eqn; } + + virtual name const & get_constructor_name() const { + return const_name(get_app_fn(head(m_eqn.m_patterns))); + } + + virtual void update_exprs(std::function const & fn) { + m_eqn.m_local_context = map(m_eqn.m_local_context, fn); + m_eqn.m_patterns = map(m_eqn.m_patterns, fn); + m_eqn.m_rhs = fn(m_eqn.m_rhs); + } + }; + + // Wrap the equations from \c p as an "implementation_list" for the inversion package. + inversion::implementation_list to_implementation_list(program const & p) { + return map2(p.m_eqns, [&](eqn const & e) { + return std::shared_ptr(new implementation(e)); + }); + } + + // Convert program into a goal. We need that to be able to invoke the inversion package. + goal to_goal(program const & p) { + buffer hyps; + to_buffer(p.m_context, hyps); + expr new_type = p.m_type; + expr new_meta = mk_app(mk_metavar(mk_fresh_name(), Pi(hyps, new_type)), hyps); + return goal(new_meta, new_type); + } + + // Convert goal and implementation_list back into a program. + // - nvars is the number of new variables in the variable stack. + program to_program(expr const & fn, goal const & g, unsigned nvars, list> const & new_var_stack, inversion::implementation_list const & imps) { + buffer new_context; + g.get_hyps(new_context); + expr new_type = g.get_type(); + buffer new_equations; + for (inversion::implementation_ptr const & imp : imps) { + eqn e = static_cast(imp.get())->get_eqn(); + buffer pat_args; + get_app_args(head(e.m_patterns), pat_args); + lean_assert(pat_args.size() >= nvars); + list new_pats = to_list(pat_args.end() - nvars, pat_args.end(), tail(e.m_patterns)); + new_equations.push_back(eqn(e.m_local_context, new_pats, e.m_rhs)); + } + return program(fn, to_list(new_context), new_var_stack, to_list(new_equations), new_type); + } + + expr compile_constructor(program const & p) { + expr h = p.get_var(*head(p.m_var_stack)); + goal g = to_goal(p); + auto imps = to_implementation_list(p); + if (auto r = apply(env(), ios(), m_tc, g, h, imps)) { + substitution subst = r->m_subst; + list> args = r->m_args; + list rn_maps = r->m_renames; + list imps_list = r->m_implementation_lists; + for (goal const & new_g : r->m_goals) { + list> new_vars = map2>(head(args), + [](expr const & a) { + if (is_local(a)) + return optional(mlocal_name(a)); + else + return optional(); + }); + rename_map const & rn = head(rn_maps); + list> new_var_stack = map(tail(p.m_var_stack), + [&](optional const & n) { + if (n) + return optional(rn.find(*n)); + else + return n; + }); + list> new_case_stack = append(new_vars, new_var_stack); + program new_p = to_program(p.m_fn, new_g, length(new_vars), new_case_stack, head(imps_list)); + args = tail(args); + imps_list = tail(imps_list); + rn_maps = tail(rn_maps); + expr t = compile_core(new_p); + subst.assign(new_g.get_name(), new_g.abstract(t)); + } + expr t = subst.instantiate_all(g.get_meta()); + // out() << "RESULT: " << t << "\n"; + return t; + } else { + throw_error(sstream() << "patter matching failed"); + } + } + + expr compile_complete(program const & /* p */) { + // The next pattern of every equation is a constructor or variable. + // We split the equations where the next pattern is a variable into cases. + // That is, we are reducing this case to the compile_constructor case. + + // TODO(Leo) + return expr(); + } + + expr compile_core(program const & p) { + lean_assert(check_program(p)); + // out() << "compile_core step\n"; + // display(p); + // out() << "------------------\n"; + if (p.m_var_stack) { + if (!head(p.m_var_stack)) { + return compile_skip(p); + } else if (is_variable_transition(p)) { + return compile_variable(p); + } else if (is_constructor_transition(p)) { + return compile_constructor(p); + } else if (is_complete_transition(p)) { + return compile_complete(p); + } else { + // In some equations the next pattern is an inaccessible term, + // and in others it is a constructor. + throw_error(sstream() << "invalid recursive equations for '" << local_pp_name(p.m_fn) + << "', inconsistent use of inaccessible term annotation, " + << "in some equations a pattern is a constructor, and in another it is an inaccessible term"); + } + } else { + if (p.m_eqns) { + // variable stack is empty + expr r = head(p.m_eqns).m_rhs; + lean_assert(is_def_eq(infer_type(r), p.m_type)); + return r; + } else { + throw_error(sstream() << "invalid non-exhaustive set of recursive equations"); + } + } + } + + expr compile(program const & p) { + buffer vars; + to_buffer(p.m_context, vars); + expr r = compile_core(p); + return Fun(vars, r); + } + + /** \brief Return true iff \c e is one of the functions being defined */ + bool is_fn(expr const & e) const { + return + is_local(e) && + std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) { + return mlocal_name(fn) == mlocal_name(e); + }); + } + + /** \brief Return true iff the equations are recursive. */ + bool is_recursive(buffer const & prgs) const { + lean_assert(!prgs.empty()); + for (program const & p : prgs) { + for (eqn const & e : p.m_eqns) { + if (find(e.m_rhs, [&](expr const & e, unsigned) { return is_fn(e); })) + return true; + } + } + if (prgs.size() > 1) + throw_error(sstream() << "mutual recursion is not needed when defining non-recursive functions"); + return false; + } + + /** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/ + bool is_inductive(expr const & t) const { + expr const & fn = get_app_fn(t); + return + is_constant(fn) && + env().find(name{const_name(fn), "brec_on"}); + } + + /** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */ + bool is_compatible_inductive(expr const & t1, expr const & t2) { + lean_assert(is_inductive(t1)); + lean_assert(is_inductive(t2)); + buffer args1, args2; + name const & I1 = const_name(get_app_args(t1, args1)); + name const & I2 = const_name(get_app_args(t2, args2)); + inductive::inductive_decls decls = *inductive::is_inductive_decl(env(), I1); + unsigned nparams = std::get<1>(decls); + for (auto decl : std::get<2>(decls)) { + if (inductive::inductive_decl_name(decl) == I2) { + // parameters must be definitionally equal + unsigned i = 0; + for (; i < nparams; i++) { + if (!is_def_eq(args1[i], args2[i])) + break; + } + if (i == nparams) + return true; + } + } + return false; + } + + /** \brief Return true iff \c t1 and \c t2 are instances of the same inductive datatype */ + static bool is_same_inductive(expr const & t1, expr const & t2) { + return const_name(get_app_fn(t1)) != const_name(get_app_fn(t2)); + } + + /** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */ + bool is_le(expr const & s, expr const & t) { + return is_def_eq(s, t) || is_lt(s, t); + } + + /** Return true iff \c s is structurally smaller than \c t */ + bool is_lt(expr s, expr const & t) { + s = whnf(s); + if (is_app(s)) { + expr const & s_fn = get_app_fn(s); + if (!is_constructor(s_fn)) + return is_lt(s_fn, t); // f < t ==> s := f a_1 ... a_n < t + } + buffer t_args; + expr const & t_fn = get_app_args(t, t_args); + if (!is_constructor(t_fn)) + return false; + return std::any_of(t_args.begin(), t_args.end(), [&](expr const & t_arg) { return is_le(s, t_arg); }); + } + + /** \brief Auxiliary functional object for checking whether recursive application are structurally smaller or not */ + struct check_rhs_fn { + equation_compiler_fn & m_main; + buffer const & m_prgs; + buffer const & m_arg_pos; + + check_rhs_fn(equation_compiler_fn & m, buffer const & prgs, buffer const & arg_pos): + m_main(m), m_prgs(prgs), m_arg_pos(arg_pos) {} + + /** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */ + bool check_rhs(expr const & e, expr const & arg) const { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Meta: + case expr_kind::Local: case expr_kind::Constant: + case expr_kind::Sort: + return true; + case expr_kind::Macro: + for (unsigned i = 0; i < macro_num_args(e); i++) + if (!check_rhs(macro_arg(e, i), arg)) + return false; + return true; + case expr_kind::App: { + buffer args; + expr const & fn = get_app_args(e, args); + if (!check_rhs(fn, arg)) + return false; + for (unsigned i = 0; i < args.size(); i++) + if (!check_rhs(args[i], arg)) + return false; + if (is_local(fn)) { + for (unsigned j = 0; j < m_prgs.size(); j++) { + if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) { + // it is a recusive application + unsigned pos_j = m_arg_pos[j]; + if (pos_j < args.size()) { + expr const & arg_j = args[pos_j]; + // arg_j must be structurally smaller than arg + if (!m_main.is_lt(arg_j, arg)) + return false; + } + break; + } + } + } + return true; + } + case expr_kind::Lambda: + case expr_kind::Pi: + if (!check_rhs(binding_domain(e), arg)) { + return false; + } else { + expr l = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e)); + return check_rhs(instantiate(binding_body(e), l), arg); + } + } + lean_unreachable(); + } + + bool operator()(expr const & e, expr const & arg) const { + return check_rhs(e, arg); + } + }; + + /** \brief Return true iff the recursive equations in prgs are "admissible" with respect to + the following configuration of recursive arguments. + We say the equations are admissible when every recursive application of prgs[i] + is structurally smaller at arguments arg_pos[i]. + */ + bool check_rec_args(buffer const & prgs, buffer const & arg_pos) { + lean_assert(prgs.size() == arg_pos.size()); + check_rhs_fn check_rhs(*this, prgs, arg_pos); + for (unsigned i = 0; i < prgs.size(); i++) { + program const & prg = prgs[i]; + unsigned pos_i = arg_pos[i]; + for (eqn const & e : prg.m_eqns) { + expr const & p_i = get_ith(e.m_patterns, pos_i); + if (!check_rhs(e.m_rhs, p_i)) + return false; + } + } + return true; + } + + bool find_rec_args(buffer const & prgs, unsigned i, buffer & arg_pos, buffer & arg_types) { + if (i == prgs.size()) { + return check_rec_args(prgs, arg_pos); + } else { + program const & p = prgs[i]; + unsigned j = 0; + for (optional const & n : p.m_var_stack) { + lean_assert(n); + expr const & v = p.get_var(*n); + expr const & t = mlocal_type(v); + if (// argument must be an inductive datatype + is_inductive(t) && + // argument must be an inductive datatype different from the ones in arg_types + std::all_of(arg_types.begin(), arg_types.end(), + [&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) && + // argument type must be in the same mutually recursive declaration of previous argument types + (arg_types.empty() || is_compatible_inductive(t, arg_types[0]))) { + // Found candidate + arg_pos.push_back(j); + arg_types.push_back(t); + if (find_rec_args(prgs, i+1, arg_pos, arg_types)) + return true; + arg_pos.pop_back(); + arg_types.pop_back(); + } + j++; + } + return false; + } + } + + bool find_rec_args(buffer const & prgs, buffer & arg_pos) { + buffer arg_types; + return find_rec_args(prgs, 0, arg_pos, arg_types); + } + + void apply_brec_on(buffer & prgs) { + lean_assert(!prgs.empty()); + buffer arg_pos; + if (!find_rec_args(prgs, arg_pos)) { + throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller " + << "(possible solution: use well-founded recursion)"); + } + // out() << "Found recursive arguments: "; + // for (unsigned p : arg_pos) out() << " " << p; out() << "\n"; + + // TODO(Leo): + } + + void apply_wf(buffer & /* prgs */) { + // TODO(Leo) } public: equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax): m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) { + get_app_args(m_meta, m_global_context); } expr operator()(expr eqns) { - proof_state ps = to_proof_state(m_meta, m_meta_type, m_tc.mk_ngen(), m_relax); - eqns = validate_patterns(eqns); - regular(env(), m_ios) << "Equations:\n" << eqns << "\n"; - regular(env(), m_ios) << ps.pp(env(), m_ios) << "\n\n"; - return eqns; + check_limitations(eqns); + // out() << "Equations:\n" << eqns << "\n"; + buffer prgs; + initialize(eqns, prgs); + if (is_recursive(prgs)) { + if (is_wf_equations(eqns)) { + apply_wf(prgs); + } else { + apply_brec_on(prgs); + } + } + buffer rs; + for (program const & p : prgs) { + expr r = compile(p); + // out() << r << "\nType: " << m_tc.infer(r).first << "\n"; + rs.push_back(r); + } + if (closed(rs[0])) + return rs[0]; + else + return m_meta; } }; diff --git a/src/library/util.cpp b/src/library/util.cpp index 589fd7526..67dad2c19 100644 --- a/src/library/util.cpp +++ b/src/library/util.cpp @@ -210,7 +210,7 @@ void initialize_library_util() { g_heq_name = new name("heq"); g_sigma_name = new name("sigma"); - g_sigma_mk_name = new name{"sigma", "dpair"}; + g_sigma_mk_name = new name{"sigma", "mk"}; } void finalize_library_util() { diff --git a/tests/lean/extra/rec.lean b/tests/lean/extra/rec.lean index 459a839ec..5bd465439 100644 --- a/tests/lean/extra/rec.lean +++ b/tests/lean/extra/rec.lean @@ -23,7 +23,6 @@ definition map {A B C : Type} (f : A → B → C) : Π {n}, vector A n → vecto map nil nil := nil, map (a :: va) (b :: vb) := f a b :: map va vb - definition half : nat → nat, half 0 := 0, half 1 := 0, @@ -35,3 +34,24 @@ mk : Π a, image_of f (f a) definition inv {f : A → B} : Π b, image_of f b → A, inv ⌞f a⌟ (image_of.mk f a) := a + +namespace tst + +definition fib : nat → nat, +fib 0 := 1, +fib 1 := 1, +fib (a+2) := fib a + fib (a+1) + +end tst + +definition simple : nat → nat → nat, +simple x y := x + y + +definition simple2 : nat → nat → nat, +simple2 (x+1) y := x + y, +simple2 ⌞y+1⌟ y := y + + + +check @vector.brec_on +check @vector.cases_on diff --git a/tests/lean/extra/rec3.lean b/tests/lean/extra/rec3.lean index 63da983a7..e52ba103e 100644 --- a/tests/lean/extra/rec3.lean +++ b/tests/lean/extra/rec3.lean @@ -1,8 +1,8 @@ set_option pp.implicit true set_option pp.notation false -definition symm : Π {A : Type} {a b : A}, a = b → b = a, +definition symm {A : Type} : Π {a b : A}, a = b → b = a, symm rfl := rfl -definition trans : Π {A : Type} {a b c : A}, a = b → b = c → a = c, +definition trans {A : Type} : Π {a b c : A}, a = b → b = c → a = c, trans rfl rfl := rfl