feat(library/definitional/equations): add dependent pattern matching compilation

This commit is contained in:
Leonardo de Moura 2015-01-02 22:05:07 -08:00
parent 762a515a5b
commit 7f7d318b22
4 changed files with 811 additions and 51 deletions

View file

@ -5,10 +5,15 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <string>
#include "util/sstream.h"
#include "util/list_fn.h"
#include "kernel/expr.h"
#include "kernel/type_checker.h"
#include "kernel/abstract.h"
#include "kernel/instantiate.h"
#include "kernel/error_msgs.h"
#include "kernel/for_each_fn.h"
#include "kernel/find_fn.h"
#include "library/generic_exception.h"
#include "library/kernel_serializer.h"
#include "library/io_state_stream.h"
@ -220,44 +225,233 @@ class equation_compiler_fn {
expr m_meta;
expr m_meta_type;
bool m_relax;
buffer<expr> m_global_context;
buffer<expr> m_fns; // functions being defined
environment const & env() const { return m_tc.env(); }
io_state const & ios() const { return m_ios; }
io_state_stream out() const { return regular(env(), ios()); }
name mk_fresh_name() { return m_tc.mk_fresh_name(); }
expr whnf(expr const & e) { return m_tc.whnf(e).first; }
expr infer_type(expr const & e) { return m_tc.infer(e).first; }
bool is_def_eq(expr const & e1, expr const & e2) { return m_tc.is_def_eq(e1, e2).first; }
optional<name> is_constructor(expr const & e) const {
if (!is_constant(e))
return optional<name>();
return inductive::is_intro_rule(env(), const_name(e));
}
expr to_telescope(expr const & e, buffer<expr> & tele) {
name_generator ngen = m_tc.mk_ngen();
return ::lean::to_telescope(ngen, e, tele, optional<binder_info>());
}
expr fun_to_telescope(expr const & e, buffer<expr> & tele) {
name_generator ngen = m_tc.mk_ngen();
return ::lean::fun_to_telescope(ngen, e, tele, optional<binder_info>());
}
// Similar to to_telescope, but uses normalization
expr to_telescope_ext(expr const & e, buffer<expr> & tele) {
return ::lean::to_telescope(m_tc, e, tele, optional<binder_info>());
}
[[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
[[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
[[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); }
void check_limitations(expr const & eqns) const {
if (is_wf_equations(eqns) && equations_num_fns(eqns) != 1)
throw_error("mutually recursive equations do not support well-founded recursion yet", eqns);
}
#ifdef LEAN_DEBUG
static bool disjoint(list<expr> const & l1, list<expr> const & l2) {
for (expr const & e1 : l1) {
for (expr const & e2 : l2) {
lean_assert(mlocal_name(e1) != mlocal_name(e2));
}
}
return true;
}
// Return true iff all names in s1 are names of local constants in s2.
static bool contained(list<optional<name>> const & s1, list<expr> const & s2) {
return std::all_of(s1.begin(), s1.end(), [&](optional<name> const & n) {
return
!n ||
std::any_of(s2.begin(), s2.end(), [&](expr const & e) {
return mlocal_name(e) == *n;
});
});
}
#endif
struct eqn {
// The local context for an equation is of additional
// local constants occurring in m_patterns and m_rhs
// which are not in m_global_context or
// in the function containing the equation.
// Remark: each function/program contains its own m_context.
// So, the variables occurring in m_patterns and m_rhs
// are in m_global_context, m_context, or m_local_context,
// or is one of the functions being defined.
// We say an equation is in "compiled" form
// if m_local_context and m_patterns are empty.
list<expr> m_local_context;
list<expr> m_patterns; // patterns to be processed
expr m_rhs; // right-hand-side
eqn(list<expr> const & c, list<expr> const & p, expr const & r):
m_local_context(c), m_patterns(p), m_rhs(r) {}
};
// Data-structure used to store for compiling pattern matching.
// We create a program object for each function being defined
struct program {
expr m_fn; // function being defined
list<expr> m_context; // local constants
list<optional<name>> m_var_stack; // variables that must be matched with the patterns it is a "subset" of m_context.
list<eqn> m_eqns; // equations
expr m_type; // result type
// Due to dependent pattern matching some elements in m_var_stack are "none", and are skipped
// during dependent pattern matching.
// The goal of the compiler is to process all variables in m_var_stack
program(expr const & fn, list<expr> const & ctx, list<optional<name>> const & s, list<eqn> const & e, expr const & t):
m_fn(fn), m_context(ctx), m_var_stack(s), m_eqns(e), m_type(t) {
lean_assert(contained(m_var_stack, m_context));
}
program(program const & p, list<optional<name>> const & new_s, list<eqn> const & new_e):
program(p.m_fn, p.m_context, new_s, new_e, p.m_type) {}
program() {}
expr const & get_var(name const & n) const {
for (expr const & v : m_context) {
if (mlocal_name(v) == n)
return v;
}
lean_unreachable();
}
};
// For debugging purposes
template<typename T>
static bool contains_local(T const & locals, expr const & l) {
return std::any_of(locals.begin(), locals.end(),
[&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); });
}
#ifdef LEAN_DEBUG
// For debugging purposes: checks whether all local constants occurring in \c e
// are in local_ctx or m_global_context
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
for_each(e, [&](expr const & e, unsigned) {
if (is_local(e) &&
!(contains_local(local_context, e) ||
contains_local(context, e) ||
contains_local(m_global_context, e) ||
contains_local(m_fns, e))) {
lean_unreachable();
return false;
}
return true;
});
return true;
}
// For debugging purposes: check if the program is well-formed
bool check_program(program const & s) const {
unsigned sz = length(s.m_var_stack);
lean_assert(contained(s.m_var_stack, s.m_context));
for (eqn const & e : s.m_eqns) {
// the number of patterns in each equation is equal to the variable stack size
if (length(e.m_patterns) != sz) {
lean_unreachable();
return false;
}
check_ctx(e.m_rhs, s.m_context, e.m_local_context);
for (expr const & p : e.m_patterns)
check_ctx(p, s.m_context, e.m_local_context);
lean_assert(disjoint(e.m_local_context, s.m_context));
}
return true;
}
#endif
// Initialize m_fns (the vector of functions to be compiled)
void initialize_fns(expr const & eqns) {
lean_assert(is_equations(eqns));
unsigned num_fns = equations_num_fns(eqns);
buffer<expr> eqs;
to_equations(eqns, eqs);
expr eq = eqs[0];
for (unsigned i = 0; i < num_fns; i++) {
expr fn = mk_local(mk_fresh_name(), binding_name(eq), binding_domain(eq), binder_info());
m_fns.push_back(fn);
eq = instantiate(binding_body(eq), fn);
}
}
// Initialize the variable stack for each function that needs
// to be compiled.
// This method assumes m_fns has been already initialized.
// This method also initialized the buffer prg, but the eqns
// field of each program is not initialized by it.
void initialize_var_stack(buffer<program> & prgs) {
lean_assert(!m_fns.empty());
lean_assert(prgs.empty());
for (expr const & fn : m_fns) {
buffer<expr> args;
expr r_type = to_telescope(mlocal_type(fn), args);
list<expr> ctx = to_list(args);
list<optional<name>> vstack = map2<optional<name>>(ctx, [](expr const & e) {
return optional<name>(mlocal_name(e));
});
prgs.push_back(program(fn, ctx, vstack, list<eqn>(), r_type));
}
}
struct validate_exception {
expr m_expr;
validate_exception(expr const & e):m_expr(e) {}
};
[[ noreturn ]] void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
[[ noreturn ]] void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
// --------------------------------
// Pattern validation/normalization
// --------------------------------
expr validate_lhs_arg(expr arg) {
if (is_inaccessible(arg))
return arg;
if (is_local(arg))
return arg;
expr new_arg = m_tc.whnf(arg).first;
if (is_local(new_arg))
return new_arg;
buffer<expr> arg_args;
expr const & fn = get_app_args(new_arg, arg_args);
if (!is_constant(fn) || !inductive::is_intro_rule(env(), const_name(fn)))
throw validate_exception(arg);
for (expr & arg_arg : arg_args)
arg_arg = validate_lhs_arg(arg_arg);
return mk_app(fn, arg_args);
// Validate/normalize the given pattern.
// It stores in reachable_vars any variable that does not occur
// in inaccessible terms.
expr validate_pattern(expr pat, name_set & reachable_vars) {
if (is_inaccessible(pat))
return pat;
if (is_local(pat)) {
reachable_vars.insert(mlocal_name(pat));
return pat;
}
expr new_pat = whnf(pat);
if (is_local(new_pat)) {
reachable_vars.insert(mlocal_name(new_pat));
return new_pat;
}
buffer<expr> pat_args;
expr const & fn = get_app_args(new_pat, pat_args);
if (auto in = is_constructor(fn)) {
unsigned num_params = *inductive::get_num_params(env(), *in);
for (unsigned i = num_params; i < pat_args.size(); i++)
pat_args[i] = validate_pattern(pat_args[i], reachable_vars);
return mk_app(fn, pat_args);
} else {
throw validate_exception(pat);
}
}
expr validate_lhs(expr const & lhs) {
buffer<expr> args;
expr fn = get_app_args(lhs, args);
for (expr & arg : args) {
// Validate/normalize the patterns associated with the given lhs.
// The lhs is only used to report errors.
// It stores in reachable_vars any variable that does not occur
// in inaccessible terms.
void validate_patterns(expr const & lhs, buffer<expr> & patterns, name_set & reachable_vars) {
for (expr & pat : patterns) {
try {
arg = validate_lhs_arg(arg);
pat = validate_pattern(pat, reachable_vars);
} catch (validate_exception & ex) {
expr problem_expr = ex.m_expr;
throw_error(lhs, [=](formatter const & fmt) {
@ -270,40 +464,586 @@ class equation_compiler_fn {
});
}
}
return mk_app(fn, args);
}
expr validate_patterns_core(expr eq) {
buffer<expr> args;
name_generator ngen = m_tc.mk_ngen();
eq = fun_to_telescope(ngen, eq, args, optional<binder_info>());
lean_assert(is_equation(eq));
expr new_lhs = validate_lhs(equation_lhs(eq));
return Fun(args, mk_equation(new_lhs, equation_rhs(eq)));
}
expr validate_patterns(expr const & eqns) {
// Create initial program state for each function being defined.
void initialize(expr const & eqns, buffer<program> & prg) {
lean_assert(is_equations(eqns));
initialize_fns(eqns);
initialize_var_stack(prg);
buffer<expr> eqs;
buffer<expr> new_eqs;
to_equations(eqns, eqs);
for (expr const & eq : eqs) {
new_eqs.push_back(validate_patterns_core(eq));
buffer<buffer<eqn>> res_eqns;
res_eqns.resize(m_fns.size());
for (expr eq : eqs) {
for (expr const & fn : m_fns)
eq = instantiate(binding_body(eq), fn);
buffer<expr> local_ctx;
eq = fun_to_telescope(eq, local_ctx);
expr const & lhs = equation_lhs(eq);
expr const & rhs = equation_rhs(eq);
buffer<expr> patterns;
expr const & fn = get_app_args(lhs, patterns);
name_set reachable_vars;
validate_patterns(lhs, patterns, reachable_vars);
for (expr const & v : local_ctx) {
// every variable in the local_ctx must be "reachable".
if (!reachable_vars.contains(mlocal_name(v))) {
throw_error(lhs, [=](formatter const & fmt) {
format r("invalid equation left-hand-side, variable '");
r += format(local_pp_name(v));
r += format("' only occurs in inaccessible terms in the following equation left-hand-side");
r += pp_indent_expr(fmt, lhs);
return r;
});
}
}
for (unsigned i = 0; i < m_fns.size(); i++) {
if (mlocal_name(fn) == mlocal_name(m_fns[i])) {
if (patterns.size() != length(prg[i].m_var_stack))
throw_error("ill-formed equation, number of provided arguments does not match function type", eq);
res_eqns[i].push_back(eqn(to_list(local_ctx), to_list(patterns), rhs));
}
}
}
return update_equations(eqns, new_eqs);
for (unsigned i = 0; i < m_fns.size(); i++) {
prg[i].m_eqns = to_list(res_eqns[i]);
lean_assert(check_program(prg[i]));
}
}
// For debugging purposes: display the context at m_ios
template<typename Ctx>
void display_ctx(Ctx const & ctx) const {
bool first = true;
for (expr const & e : ctx) {
out() << (first ? "" : ", ") << local_pp_name(e) << " : " << mlocal_type(e);
first = false;
}
}
// For debugging purposes: dump prg in m_ios
void display(program const & prg) const {
display_ctx(prg.m_context);
out() << " ;;";
for (optional<name> const & v : prg.m_var_stack) {
if (v)
out() << " " << local_pp_name(prg.get_var(*v));
else
out() << " <none>";
}
out() << " |- " << prg.m_type << "\n";
out() << "\n";
for (eqn const & e : prg.m_eqns) {
out() << "> ";
display_ctx(e.m_local_context);
out() << " |-";
for (expr const & p : e.m_patterns) {
if (is_atomic(p))
out() << " " << p;
else
out() << " (" << p << ")";
}
out() << " := " << e.m_rhs << "\n";
}
}
// Return true iff the next pattern in all equations is a variable or an inaccessible term
bool is_variable_transition(program const & p) const {
for (eqn const & e : p.m_eqns) {
lean_assert(e.m_patterns);
if (!is_local(head(e.m_patterns)) && !is_inaccessible(head(e.m_patterns)))
return false;
}
return true;
}
// Return true iff the next pattern in all equations is a constructor
bool is_constructor_transition(program const & p) const {
for (eqn const & e : p.m_eqns) {
lean_assert(e.m_patterns);
if (!is_constructor(get_app_fn(head(e.m_patterns))))
return false;
}
return true;
}
// Return true iff the next pattern of every equation is a constructor or variable,
// and there are at least one equation where it is a variable and another where it is a
// constructor.
bool is_complete_transition(program const & p) const {
bool has_variable = false;
bool has_constructor = false;
for (eqn const & e : p.m_eqns) {
lean_assert(e.m_patterns);
expr const & p = head(e.m_patterns);
if (is_local(p))
has_variable = true;
else if (is_constructor(get_app_fn(p)))
has_constructor = true;
else
return false;
}
return has_variable && has_constructor;
}
// Remove variable from local context
static list<expr> remove(list<expr> const & local_ctx, expr const & l) {
if (!local_ctx)
return local_ctx;
else if (mlocal_name(head(local_ctx)) == mlocal_name(l))
return tail(local_ctx);
else
return cons(head(local_ctx), remove(tail(local_ctx), l));
}
// Replace local constant \c from with \c to in the expression \c e.
static expr replace(expr const & e, expr const & from, expr const & to) {
return instantiate(abstract_local(e, from), to);
}
static expr replace(expr const & e, name const & from, expr const & to) {
return instantiate(abstract_local(e, from), to);
}
static expr replace(expr const & e, name_map<expr> const & subst) {
expr r = e;
subst.for_each([&](name const & l, expr const & v) {
r = replace(r, l, v);
});
return r;
}
static bool contains(list<expr> const & local_ctx, expr const & p) {
return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); });
}
expr compile_skip(program const & prg) {
lean_assert(!head(prg.m_var_stack));
auto new_stack = tail(prg.m_var_stack);
buffer<eqn> new_eqs;
for (eqn const & e : prg.m_eqns) {
auto new_patterns = tail(e.m_patterns);
new_eqs.emplace_back(e.m_local_context, new_patterns, e.m_rhs);
}
return compile_core(program(prg, new_stack, to_list(new_eqs)));
}
expr compile_variable(program const & prg) {
// The next pattern of every equation is a variable (or inaccessible term).
// Thus, we just rename them with the variable on
// the top of the variable stack.
// Remark: if the pattern is an inaccessible term, we just ignore it.
expr x = prg.get_var(*head(prg.m_var_stack));
auto new_stack = tail(prg.m_var_stack);
buffer<eqn> new_eqs;
for (eqn const & e : prg.m_eqns) {
expr p = head(e.m_patterns);
if (is_inaccessible(p)) {
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
} else {
lean_assert(is_local(p));
if (contains(e.m_local_context, p)) {
list<expr> new_local_ctx = e.m_local_context;
new_local_ctx = remove(new_local_ctx, p);
new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); });
auto new_patterns = map(tail(e.m_patterns), [&](expr const & p2) { return replace(p2, p, x); });
auto new_rhs = replace(e.m_rhs, p, x);
new_eqs.emplace_back(new_local_ctx, new_patterns, new_rhs);
} else {
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
}
}
}
return compile_core(program(prg, new_stack, to_list(new_eqs)));
}
class implementation : public inversion::implementation {
eqn m_eqn;
public:
implementation(eqn const & e):m_eqn(e) {}
eqn const & get_eqn() const { return m_eqn; }
virtual name const & get_constructor_name() const {
return const_name(get_app_fn(head(m_eqn.m_patterns)));
}
virtual void update_exprs(std::function<expr(expr const &)> const & fn) {
m_eqn.m_local_context = map(m_eqn.m_local_context, fn);
m_eqn.m_patterns = map(m_eqn.m_patterns, fn);
m_eqn.m_rhs = fn(m_eqn.m_rhs);
}
};
// Wrap the equations from \c p as an "implementation_list" for the inversion package.
inversion::implementation_list to_implementation_list(program const & p) {
return map2<inversion::implementation_ptr>(p.m_eqns, [&](eqn const & e) {
return std::shared_ptr<inversion::implementation>(new implementation(e));
});
}
// Convert program into a goal. We need that to be able to invoke the inversion package.
goal to_goal(program const & p) {
buffer<expr> hyps;
to_buffer(p.m_context, hyps);
expr new_type = p.m_type;
expr new_meta = mk_app(mk_metavar(mk_fresh_name(), Pi(hyps, new_type)), hyps);
return goal(new_meta, new_type);
}
// Convert goal and implementation_list back into a program.
// - nvars is the number of new variables in the variable stack.
program to_program(expr const & fn, goal const & g, unsigned nvars, list<optional<name>> const & new_var_stack, inversion::implementation_list const & imps) {
buffer<expr> new_context;
g.get_hyps(new_context);
expr new_type = g.get_type();
buffer<eqn> new_equations;
for (inversion::implementation_ptr const & imp : imps) {
eqn e = static_cast<implementation*>(imp.get())->get_eqn();
buffer<expr> pat_args;
get_app_args(head(e.m_patterns), pat_args);
lean_assert(pat_args.size() >= nvars);
list<expr> new_pats = to_list(pat_args.end() - nvars, pat_args.end(), tail(e.m_patterns));
new_equations.push_back(eqn(e.m_local_context, new_pats, e.m_rhs));
}
return program(fn, to_list(new_context), new_var_stack, to_list(new_equations), new_type);
}
expr compile_constructor(program const & p) {
expr h = p.get_var(*head(p.m_var_stack));
goal g = to_goal(p);
auto imps = to_implementation_list(p);
if (auto r = apply(env(), ios(), m_tc, g, h, imps)) {
substitution subst = r->m_subst;
list<list<expr>> args = r->m_args;
list<rename_map> rn_maps = r->m_renames;
list<inversion::implementation_list> imps_list = r->m_implementation_lists;
for (goal const & new_g : r->m_goals) {
list<optional<name>> new_vars = map2<optional<name>>(head(args),
[](expr const & a) {
if (is_local(a))
return optional<name>(mlocal_name(a));
else
return optional<name>();
});
rename_map const & rn = head(rn_maps);
list<optional<name>> new_var_stack = map(tail(p.m_var_stack),
[&](optional<name> const & n) {
if (n)
return optional<name>(rn.find(*n));
else
return n;
});
list<optional<name>> new_case_stack = append(new_vars, new_var_stack);
program new_p = to_program(p.m_fn, new_g, length(new_vars), new_case_stack, head(imps_list));
args = tail(args);
imps_list = tail(imps_list);
rn_maps = tail(rn_maps);
expr t = compile_core(new_p);
subst.assign(new_g.get_name(), new_g.abstract(t));
}
expr t = subst.instantiate_all(g.get_meta());
// out() << "RESULT: " << t << "\n";
return t;
} else {
throw_error(sstream() << "patter matching failed");
}
}
expr compile_complete(program const & /* p */) {
// The next pattern of every equation is a constructor or variable.
// We split the equations where the next pattern is a variable into cases.
// That is, we are reducing this case to the compile_constructor case.
// TODO(Leo)
return expr();
}
expr compile_core(program const & p) {
lean_assert(check_program(p));
// out() << "compile_core step\n";
// display(p);
// out() << "------------------\n";
if (p.m_var_stack) {
if (!head(p.m_var_stack)) {
return compile_skip(p);
} else if (is_variable_transition(p)) {
return compile_variable(p);
} else if (is_constructor_transition(p)) {
return compile_constructor(p);
} else if (is_complete_transition(p)) {
return compile_complete(p);
} else {
// In some equations the next pattern is an inaccessible term,
// and in others it is a constructor.
throw_error(sstream() << "invalid recursive equations for '" << local_pp_name(p.m_fn)
<< "', inconsistent use of inaccessible term annotation, "
<< "in some equations a pattern is a constructor, and in another it is an inaccessible term");
}
} else {
if (p.m_eqns) {
// variable stack is empty
expr r = head(p.m_eqns).m_rhs;
lean_assert(is_def_eq(infer_type(r), p.m_type));
return r;
} else {
throw_error(sstream() << "invalid non-exhaustive set of recursive equations");
}
}
}
expr compile(program const & p) {
buffer<expr> vars;
to_buffer(p.m_context, vars);
expr r = compile_core(p);
return Fun(vars, r);
}
/** \brief Return true iff \c e is one of the functions being defined */
bool is_fn(expr const & e) const {
return
is_local(e) &&
std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) {
return mlocal_name(fn) == mlocal_name(e);
});
}
/** \brief Return true iff the equations are recursive. */
bool is_recursive(buffer<program> const & prgs) const {
lean_assert(!prgs.empty());
for (program const & p : prgs) {
for (eqn const & e : p.m_eqns) {
if (find(e.m_rhs, [&](expr const & e, unsigned) { return is_fn(e); }))
return true;
}
}
if (prgs.size() > 1)
throw_error(sstream() << "mutual recursion is not needed when defining non-recursive functions");
return false;
}
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/
bool is_inductive(expr const & t) const {
expr const & fn = get_app_fn(t);
return
is_constant(fn) &&
env().find(name{const_name(fn), "brec_on"});
}
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
bool is_compatible_inductive(expr const & t1, expr const & t2) {
lean_assert(is_inductive(t1));
lean_assert(is_inductive(t2));
buffer<expr> args1, args2;
name const & I1 = const_name(get_app_args(t1, args1));
name const & I2 = const_name(get_app_args(t2, args2));
inductive::inductive_decls decls = *inductive::is_inductive_decl(env(), I1);
unsigned nparams = std::get<1>(decls);
for (auto decl : std::get<2>(decls)) {
if (inductive::inductive_decl_name(decl) == I2) {
// parameters must be definitionally equal
unsigned i = 0;
for (; i < nparams; i++) {
if (!is_def_eq(args1[i], args2[i]))
break;
}
if (i == nparams)
return true;
}
}
return false;
}
/** \brief Return true iff \c t1 and \c t2 are instances of the same inductive datatype */
static bool is_same_inductive(expr const & t1, expr const & t2) {
return const_name(get_app_fn(t1)) != const_name(get_app_fn(t2));
}
/** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */
bool is_le(expr const & s, expr const & t) {
return is_def_eq(s, t) || is_lt(s, t);
}
/** Return true iff \c s is structurally smaller than \c t */
bool is_lt(expr s, expr const & t) {
s = whnf(s);
if (is_app(s)) {
expr const & s_fn = get_app_fn(s);
if (!is_constructor(s_fn))
return is_lt(s_fn, t); // f < t ==> s := f a_1 ... a_n < t
}
buffer<expr> t_args;
expr const & t_fn = get_app_args(t, t_args);
if (!is_constructor(t_fn))
return false;
return std::any_of(t_args.begin(), t_args.end(), [&](expr const & t_arg) { return is_le(s, t_arg); });
}
/** \brief Auxiliary functional object for checking whether recursive application are structurally smaller or not */
struct check_rhs_fn {
equation_compiler_fn & m_main;
buffer<program> const & m_prgs;
buffer<unsigned> const & m_arg_pos;
check_rhs_fn(equation_compiler_fn & m, buffer<program> const & prgs, buffer<unsigned> const & arg_pos):
m_main(m), m_prgs(prgs), m_arg_pos(arg_pos) {}
/** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */
bool check_rhs(expr const & e, expr const & arg) const {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Meta:
case expr_kind::Local: case expr_kind::Constant:
case expr_kind::Sort:
return true;
case expr_kind::Macro:
for (unsigned i = 0; i < macro_num_args(e); i++)
if (!check_rhs(macro_arg(e, i), arg))
return false;
return true;
case expr_kind::App: {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (!check_rhs(fn, arg))
return false;
for (unsigned i = 0; i < args.size(); i++)
if (!check_rhs(args[i], arg))
return false;
if (is_local(fn)) {
for (unsigned j = 0; j < m_prgs.size(); j++) {
if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) {
// it is a recusive application
unsigned pos_j = m_arg_pos[j];
if (pos_j < args.size()) {
expr const & arg_j = args[pos_j];
// arg_j must be structurally smaller than arg
if (!m_main.is_lt(arg_j, arg))
return false;
}
break;
}
}
}
return true;
}
case expr_kind::Lambda:
case expr_kind::Pi:
if (!check_rhs(binding_domain(e), arg)) {
return false;
} else {
expr l = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
return check_rhs(instantiate(binding_body(e), l), arg);
}
}
lean_unreachable();
}
bool operator()(expr const & e, expr const & arg) const {
return check_rhs(e, arg);
}
};
/** \brief Return true iff the recursive equations in prgs are "admissible" with respect to
the following configuration of recursive arguments.
We say the equations are admissible when every recursive application of prgs[i]
is structurally smaller at arguments arg_pos[i].
*/
bool check_rec_args(buffer<program> const & prgs, buffer<unsigned> const & arg_pos) {
lean_assert(prgs.size() == arg_pos.size());
check_rhs_fn check_rhs(*this, prgs, arg_pos);
for (unsigned i = 0; i < prgs.size(); i++) {
program const & prg = prgs[i];
unsigned pos_i = arg_pos[i];
for (eqn const & e : prg.m_eqns) {
expr const & p_i = get_ith(e.m_patterns, pos_i);
if (!check_rhs(e.m_rhs, p_i))
return false;
}
}
return true;
}
bool find_rec_args(buffer<program> const & prgs, unsigned i, buffer<unsigned> & arg_pos, buffer<expr> & arg_types) {
if (i == prgs.size()) {
return check_rec_args(prgs, arg_pos);
} else {
program const & p = prgs[i];
unsigned j = 0;
for (optional<name> const & n : p.m_var_stack) {
lean_assert(n);
expr const & v = p.get_var(*n);
expr const & t = mlocal_type(v);
if (// argument must be an inductive datatype
is_inductive(t) &&
// argument must be an inductive datatype different from the ones in arg_types
std::all_of(arg_types.begin(), arg_types.end(),
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
// argument type must be in the same mutually recursive declaration of previous argument types
(arg_types.empty() || is_compatible_inductive(t, arg_types[0]))) {
// Found candidate
arg_pos.push_back(j);
arg_types.push_back(t);
if (find_rec_args(prgs, i+1, arg_pos, arg_types))
return true;
arg_pos.pop_back();
arg_types.pop_back();
}
j++;
}
return false;
}
}
bool find_rec_args(buffer<program> const & prgs, buffer<unsigned> & arg_pos) {
buffer<expr> arg_types;
return find_rec_args(prgs, 0, arg_pos, arg_types);
}
void apply_brec_on(buffer<program> & prgs) {
lean_assert(!prgs.empty());
buffer<unsigned> arg_pos;
if (!find_rec_args(prgs, arg_pos)) {
throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller "
<< "(possible solution: use well-founded recursion)");
}
// out() << "Found recursive arguments: ";
// for (unsigned p : arg_pos) out() << " " << p; out() << "\n";
// TODO(Leo):
}
void apply_wf(buffer<program> & /* prgs */) {
// TODO(Leo)
}
public:
equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax):
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) {
get_app_args(m_meta, m_global_context);
}
expr operator()(expr eqns) {
proof_state ps = to_proof_state(m_meta, m_meta_type, m_tc.mk_ngen(), m_relax);
eqns = validate_patterns(eqns);
regular(env(), m_ios) << "Equations:\n" << eqns << "\n";
regular(env(), m_ios) << ps.pp(env(), m_ios) << "\n\n";
return eqns;
check_limitations(eqns);
// out() << "Equations:\n" << eqns << "\n";
buffer<program> prgs;
initialize(eqns, prgs);
if (is_recursive(prgs)) {
if (is_wf_equations(eqns)) {
apply_wf(prgs);
} else {
apply_brec_on(prgs);
}
}
buffer<expr> rs;
for (program const & p : prgs) {
expr r = compile(p);
// out() << r << "\nType: " << m_tc.infer(r).first << "\n";
rs.push_back(r);
}
if (closed(rs[0]))
return rs[0];
else
return m_meta;
}
};

View file

@ -210,7 +210,7 @@ void initialize_library_util() {
g_heq_name = new name("heq");
g_sigma_name = new name("sigma");
g_sigma_mk_name = new name{"sigma", "dpair"};
g_sigma_mk_name = new name{"sigma", "mk"};
}
void finalize_library_util() {

View file

@ -23,7 +23,6 @@ definition map {A B C : Type} (f : A → B → C) : Π {n}, vector A n → vecto
map nil nil := nil,
map (a :: va) (b :: vb) := f a b :: map va vb
definition half : nat → nat,
half 0 := 0,
half 1 := 0,
@ -35,3 +34,24 @@ mk : Π a, image_of f (f a)
definition inv {f : A → B} : Π b, image_of f b → A,
inv ⌞f a⌟ (image_of.mk f a) := a
namespace tst
definition fib : nat → nat,
fib 0 := 1,
fib 1 := 1,
fib (a+2) := fib a + fib (a+1)
end tst
definition simple : nat → nat → nat,
simple x y := x + y
definition simple2 : nat → nat → nat,
simple2 (x+1) y := x + y,
simple2 ⌞y+1⌟ y := y
check @vector.brec_on
check @vector.cases_on

View file

@ -1,8 +1,8 @@
set_option pp.implicit true
set_option pp.notation false
definition symm : Π {A : Type} {a b : A}, a = b → b = a,
definition symm {A : Type} : Π {a b : A}, a = b → b = a,
symm rfl := rfl
definition trans : Π {A : Type} {a b c : A}, a = b → b = c → a = c,
definition trans {A : Type} : Π {a b c : A}, a = b → b = c → a = c,
trans rfl rfl := rfl