feat(library/definitional/equations): add dependent pattern matching compilation
This commit is contained in:
parent
762a515a5b
commit
7f7d318b22
4 changed files with 811 additions and 51 deletions
|
@ -5,10 +5,15 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <string>
|
||||
#include "util/sstream.h"
|
||||
#include "util/list_fn.h"
|
||||
#include "kernel/expr.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/error_msgs.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/find_fn.h"
|
||||
#include "library/generic_exception.h"
|
||||
#include "library/kernel_serializer.h"
|
||||
#include "library/io_state_stream.h"
|
||||
|
@ -220,44 +225,233 @@ class equation_compiler_fn {
|
|||
expr m_meta;
|
||||
expr m_meta_type;
|
||||
bool m_relax;
|
||||
buffer<expr> m_global_context;
|
||||
buffer<expr> m_fns; // functions being defined
|
||||
|
||||
environment const & env() const { return m_tc.env(); }
|
||||
io_state const & ios() const { return m_ios; }
|
||||
io_state_stream out() const { return regular(env(), ios()); }
|
||||
name mk_fresh_name() { return m_tc.mk_fresh_name(); }
|
||||
expr whnf(expr const & e) { return m_tc.whnf(e).first; }
|
||||
expr infer_type(expr const & e) { return m_tc.infer(e).first; }
|
||||
bool is_def_eq(expr const & e1, expr const & e2) { return m_tc.is_def_eq(e1, e2).first; }
|
||||
|
||||
optional<name> is_constructor(expr const & e) const {
|
||||
if (!is_constant(e))
|
||||
return optional<name>();
|
||||
return inductive::is_intro_rule(env(), const_name(e));
|
||||
}
|
||||
|
||||
expr to_telescope(expr const & e, buffer<expr> & tele) {
|
||||
name_generator ngen = m_tc.mk_ngen();
|
||||
return ::lean::to_telescope(ngen, e, tele, optional<binder_info>());
|
||||
}
|
||||
|
||||
expr fun_to_telescope(expr const & e, buffer<expr> & tele) {
|
||||
name_generator ngen = m_tc.mk_ngen();
|
||||
return ::lean::fun_to_telescope(ngen, e, tele, optional<binder_info>());
|
||||
}
|
||||
|
||||
// Similar to to_telescope, but uses normalization
|
||||
expr to_telescope_ext(expr const & e, buffer<expr> & tele) {
|
||||
return ::lean::to_telescope(m_tc, e, tele, optional<binder_info>());
|
||||
}
|
||||
|
||||
[[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
|
||||
[[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
|
||||
[[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); }
|
||||
|
||||
void check_limitations(expr const & eqns) const {
|
||||
if (is_wf_equations(eqns) && equations_num_fns(eqns) != 1)
|
||||
throw_error("mutually recursive equations do not support well-founded recursion yet", eqns);
|
||||
}
|
||||
|
||||
#ifdef LEAN_DEBUG
|
||||
static bool disjoint(list<expr> const & l1, list<expr> const & l2) {
|
||||
for (expr const & e1 : l1) {
|
||||
for (expr const & e2 : l2) {
|
||||
lean_assert(mlocal_name(e1) != mlocal_name(e2));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return true iff all names in s1 are names of local constants in s2.
|
||||
static bool contained(list<optional<name>> const & s1, list<expr> const & s2) {
|
||||
return std::all_of(s1.begin(), s1.end(), [&](optional<name> const & n) {
|
||||
return
|
||||
!n ||
|
||||
std::any_of(s2.begin(), s2.end(), [&](expr const & e) {
|
||||
return mlocal_name(e) == *n;
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
struct eqn {
|
||||
// The local context for an equation is of additional
|
||||
// local constants occurring in m_patterns and m_rhs
|
||||
// which are not in m_global_context or
|
||||
// in the function containing the equation.
|
||||
// Remark: each function/program contains its own m_context.
|
||||
// So, the variables occurring in m_patterns and m_rhs
|
||||
// are in m_global_context, m_context, or m_local_context,
|
||||
// or is one of the functions being defined.
|
||||
// We say an equation is in "compiled" form
|
||||
// if m_local_context and m_patterns are empty.
|
||||
list<expr> m_local_context;
|
||||
list<expr> m_patterns; // patterns to be processed
|
||||
expr m_rhs; // right-hand-side
|
||||
eqn(list<expr> const & c, list<expr> const & p, expr const & r):
|
||||
m_local_context(c), m_patterns(p), m_rhs(r) {}
|
||||
};
|
||||
|
||||
// Data-structure used to store for compiling pattern matching.
|
||||
// We create a program object for each function being defined
|
||||
struct program {
|
||||
expr m_fn; // function being defined
|
||||
list<expr> m_context; // local constants
|
||||
list<optional<name>> m_var_stack; // variables that must be matched with the patterns it is a "subset" of m_context.
|
||||
list<eqn> m_eqns; // equations
|
||||
expr m_type; // result type
|
||||
|
||||
// Due to dependent pattern matching some elements in m_var_stack are "none", and are skipped
|
||||
// during dependent pattern matching.
|
||||
|
||||
// The goal of the compiler is to process all variables in m_var_stack
|
||||
program(expr const & fn, list<expr> const & ctx, list<optional<name>> const & s, list<eqn> const & e, expr const & t):
|
||||
m_fn(fn), m_context(ctx), m_var_stack(s), m_eqns(e), m_type(t) {
|
||||
lean_assert(contained(m_var_stack, m_context));
|
||||
}
|
||||
program(program const & p, list<optional<name>> const & new_s, list<eqn> const & new_e):
|
||||
program(p.m_fn, p.m_context, new_s, new_e, p.m_type) {}
|
||||
program() {}
|
||||
expr const & get_var(name const & n) const {
|
||||
for (expr const & v : m_context) {
|
||||
if (mlocal_name(v) == n)
|
||||
return v;
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
};
|
||||
|
||||
// For debugging purposes
|
||||
template<typename T>
|
||||
static bool contains_local(T const & locals, expr const & l) {
|
||||
return std::any_of(locals.begin(), locals.end(),
|
||||
[&](expr const & l1) { return mlocal_name(l1) == mlocal_name(l); });
|
||||
}
|
||||
|
||||
#ifdef LEAN_DEBUG
|
||||
// For debugging purposes: checks whether all local constants occurring in \c e
|
||||
// are in local_ctx or m_global_context
|
||||
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
|
||||
for_each(e, [&](expr const & e, unsigned) {
|
||||
if (is_local(e) &&
|
||||
!(contains_local(local_context, e) ||
|
||||
contains_local(context, e) ||
|
||||
contains_local(m_global_context, e) ||
|
||||
contains_local(m_fns, e))) {
|
||||
lean_unreachable();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
// For debugging purposes: check if the program is well-formed
|
||||
bool check_program(program const & s) const {
|
||||
unsigned sz = length(s.m_var_stack);
|
||||
lean_assert(contained(s.m_var_stack, s.m_context));
|
||||
for (eqn const & e : s.m_eqns) {
|
||||
// the number of patterns in each equation is equal to the variable stack size
|
||||
if (length(e.m_patterns) != sz) {
|
||||
lean_unreachable();
|
||||
return false;
|
||||
}
|
||||
check_ctx(e.m_rhs, s.m_context, e.m_local_context);
|
||||
for (expr const & p : e.m_patterns)
|
||||
check_ctx(p, s.m_context, e.m_local_context);
|
||||
lean_assert(disjoint(e.m_local_context, s.m_context));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Initialize m_fns (the vector of functions to be compiled)
|
||||
void initialize_fns(expr const & eqns) {
|
||||
lean_assert(is_equations(eqns));
|
||||
unsigned num_fns = equations_num_fns(eqns);
|
||||
buffer<expr> eqs;
|
||||
to_equations(eqns, eqs);
|
||||
expr eq = eqs[0];
|
||||
for (unsigned i = 0; i < num_fns; i++) {
|
||||
expr fn = mk_local(mk_fresh_name(), binding_name(eq), binding_domain(eq), binder_info());
|
||||
m_fns.push_back(fn);
|
||||
eq = instantiate(binding_body(eq), fn);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the variable stack for each function that needs
|
||||
// to be compiled.
|
||||
// This method assumes m_fns has been already initialized.
|
||||
// This method also initialized the buffer prg, but the eqns
|
||||
// field of each program is not initialized by it.
|
||||
void initialize_var_stack(buffer<program> & prgs) {
|
||||
lean_assert(!m_fns.empty());
|
||||
lean_assert(prgs.empty());
|
||||
for (expr const & fn : m_fns) {
|
||||
buffer<expr> args;
|
||||
expr r_type = to_telescope(mlocal_type(fn), args);
|
||||
list<expr> ctx = to_list(args);
|
||||
list<optional<name>> vstack = map2<optional<name>>(ctx, [](expr const & e) {
|
||||
return optional<name>(mlocal_name(e));
|
||||
});
|
||||
prgs.push_back(program(fn, ctx, vstack, list<eqn>(), r_type));
|
||||
}
|
||||
}
|
||||
|
||||
struct validate_exception {
|
||||
expr m_expr;
|
||||
validate_exception(expr const & e):m_expr(e) {}
|
||||
};
|
||||
|
||||
[[ noreturn ]] void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
|
||||
[[ noreturn ]] void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
|
||||
|
||||
// --------------------------------
|
||||
// Pattern validation/normalization
|
||||
// --------------------------------
|
||||
|
||||
expr validate_lhs_arg(expr arg) {
|
||||
if (is_inaccessible(arg))
|
||||
return arg;
|
||||
if (is_local(arg))
|
||||
return arg;
|
||||
expr new_arg = m_tc.whnf(arg).first;
|
||||
if (is_local(new_arg))
|
||||
return new_arg;
|
||||
buffer<expr> arg_args;
|
||||
expr const & fn = get_app_args(new_arg, arg_args);
|
||||
if (!is_constant(fn) || !inductive::is_intro_rule(env(), const_name(fn)))
|
||||
throw validate_exception(arg);
|
||||
for (expr & arg_arg : arg_args)
|
||||
arg_arg = validate_lhs_arg(arg_arg);
|
||||
return mk_app(fn, arg_args);
|
||||
// Validate/normalize the given pattern.
|
||||
// It stores in reachable_vars any variable that does not occur
|
||||
// in inaccessible terms.
|
||||
expr validate_pattern(expr pat, name_set & reachable_vars) {
|
||||
if (is_inaccessible(pat))
|
||||
return pat;
|
||||
if (is_local(pat)) {
|
||||
reachable_vars.insert(mlocal_name(pat));
|
||||
return pat;
|
||||
}
|
||||
expr new_pat = whnf(pat);
|
||||
if (is_local(new_pat)) {
|
||||
reachable_vars.insert(mlocal_name(new_pat));
|
||||
return new_pat;
|
||||
}
|
||||
buffer<expr> pat_args;
|
||||
expr const & fn = get_app_args(new_pat, pat_args);
|
||||
if (auto in = is_constructor(fn)) {
|
||||
unsigned num_params = *inductive::get_num_params(env(), *in);
|
||||
for (unsigned i = num_params; i < pat_args.size(); i++)
|
||||
pat_args[i] = validate_pattern(pat_args[i], reachable_vars);
|
||||
return mk_app(fn, pat_args);
|
||||
} else {
|
||||
throw validate_exception(pat);
|
||||
}
|
||||
}
|
||||
|
||||
expr validate_lhs(expr const & lhs) {
|
||||
buffer<expr> args;
|
||||
expr fn = get_app_args(lhs, args);
|
||||
for (expr & arg : args) {
|
||||
// Validate/normalize the patterns associated with the given lhs.
|
||||
// The lhs is only used to report errors.
|
||||
// It stores in reachable_vars any variable that does not occur
|
||||
// in inaccessible terms.
|
||||
void validate_patterns(expr const & lhs, buffer<expr> & patterns, name_set & reachable_vars) {
|
||||
for (expr & pat : patterns) {
|
||||
try {
|
||||
arg = validate_lhs_arg(arg);
|
||||
pat = validate_pattern(pat, reachable_vars);
|
||||
} catch (validate_exception & ex) {
|
||||
expr problem_expr = ex.m_expr;
|
||||
throw_error(lhs, [=](formatter const & fmt) {
|
||||
|
@ -270,40 +464,586 @@ class equation_compiler_fn {
|
|||
});
|
||||
}
|
||||
}
|
||||
return mk_app(fn, args);
|
||||
}
|
||||
|
||||
expr validate_patterns_core(expr eq) {
|
||||
buffer<expr> args;
|
||||
name_generator ngen = m_tc.mk_ngen();
|
||||
eq = fun_to_telescope(ngen, eq, args, optional<binder_info>());
|
||||
lean_assert(is_equation(eq));
|
||||
expr new_lhs = validate_lhs(equation_lhs(eq));
|
||||
return Fun(args, mk_equation(new_lhs, equation_rhs(eq)));
|
||||
}
|
||||
|
||||
expr validate_patterns(expr const & eqns) {
|
||||
// Create initial program state for each function being defined.
|
||||
void initialize(expr const & eqns, buffer<program> & prg) {
|
||||
lean_assert(is_equations(eqns));
|
||||
initialize_fns(eqns);
|
||||
initialize_var_stack(prg);
|
||||
buffer<expr> eqs;
|
||||
buffer<expr> new_eqs;
|
||||
to_equations(eqns, eqs);
|
||||
for (expr const & eq : eqs) {
|
||||
new_eqs.push_back(validate_patterns_core(eq));
|
||||
buffer<buffer<eqn>> res_eqns;
|
||||
res_eqns.resize(m_fns.size());
|
||||
for (expr eq : eqs) {
|
||||
for (expr const & fn : m_fns)
|
||||
eq = instantiate(binding_body(eq), fn);
|
||||
buffer<expr> local_ctx;
|
||||
eq = fun_to_telescope(eq, local_ctx);
|
||||
expr const & lhs = equation_lhs(eq);
|
||||
expr const & rhs = equation_rhs(eq);
|
||||
buffer<expr> patterns;
|
||||
expr const & fn = get_app_args(lhs, patterns);
|
||||
name_set reachable_vars;
|
||||
validate_patterns(lhs, patterns, reachable_vars);
|
||||
for (expr const & v : local_ctx) {
|
||||
// every variable in the local_ctx must be "reachable".
|
||||
if (!reachable_vars.contains(mlocal_name(v))) {
|
||||
throw_error(lhs, [=](formatter const & fmt) {
|
||||
format r("invalid equation left-hand-side, variable '");
|
||||
r += format(local_pp_name(v));
|
||||
r += format("' only occurs in inaccessible terms in the following equation left-hand-side");
|
||||
r += pp_indent_expr(fmt, lhs);
|
||||
return r;
|
||||
});
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < m_fns.size(); i++) {
|
||||
if (mlocal_name(fn) == mlocal_name(m_fns[i])) {
|
||||
if (patterns.size() != length(prg[i].m_var_stack))
|
||||
throw_error("ill-formed equation, number of provided arguments does not match function type", eq);
|
||||
res_eqns[i].push_back(eqn(to_list(local_ctx), to_list(patterns), rhs));
|
||||
}
|
||||
}
|
||||
}
|
||||
return update_equations(eqns, new_eqs);
|
||||
for (unsigned i = 0; i < m_fns.size(); i++) {
|
||||
prg[i].m_eqns = to_list(res_eqns[i]);
|
||||
lean_assert(check_program(prg[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// For debugging purposes: display the context at m_ios
|
||||
template<typename Ctx>
|
||||
void display_ctx(Ctx const & ctx) const {
|
||||
bool first = true;
|
||||
for (expr const & e : ctx) {
|
||||
out() << (first ? "" : ", ") << local_pp_name(e) << " : " << mlocal_type(e);
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
|
||||
// For debugging purposes: dump prg in m_ios
|
||||
void display(program const & prg) const {
|
||||
display_ctx(prg.m_context);
|
||||
out() << " ;;";
|
||||
for (optional<name> const & v : prg.m_var_stack) {
|
||||
if (v)
|
||||
out() << " " << local_pp_name(prg.get_var(*v));
|
||||
else
|
||||
out() << " <none>";
|
||||
}
|
||||
out() << " |- " << prg.m_type << "\n";
|
||||
out() << "\n";
|
||||
for (eqn const & e : prg.m_eqns) {
|
||||
out() << "> ";
|
||||
display_ctx(e.m_local_context);
|
||||
out() << " |-";
|
||||
for (expr const & p : e.m_patterns) {
|
||||
if (is_atomic(p))
|
||||
out() << " " << p;
|
||||
else
|
||||
out() << " (" << p << ")";
|
||||
}
|
||||
out() << " := " << e.m_rhs << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Return true iff the next pattern in all equations is a variable or an inaccessible term
|
||||
bool is_variable_transition(program const & p) const {
|
||||
for (eqn const & e : p.m_eqns) {
|
||||
lean_assert(e.m_patterns);
|
||||
if (!is_local(head(e.m_patterns)) && !is_inaccessible(head(e.m_patterns)))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return true iff the next pattern in all equations is a constructor
|
||||
bool is_constructor_transition(program const & p) const {
|
||||
for (eqn const & e : p.m_eqns) {
|
||||
lean_assert(e.m_patterns);
|
||||
if (!is_constructor(get_app_fn(head(e.m_patterns))))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return true iff the next pattern of every equation is a constructor or variable,
|
||||
// and there are at least one equation where it is a variable and another where it is a
|
||||
// constructor.
|
||||
bool is_complete_transition(program const & p) const {
|
||||
bool has_variable = false;
|
||||
bool has_constructor = false;
|
||||
for (eqn const & e : p.m_eqns) {
|
||||
lean_assert(e.m_patterns);
|
||||
expr const & p = head(e.m_patterns);
|
||||
if (is_local(p))
|
||||
has_variable = true;
|
||||
else if (is_constructor(get_app_fn(p)))
|
||||
has_constructor = true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
return has_variable && has_constructor;
|
||||
}
|
||||
|
||||
// Remove variable from local context
|
||||
static list<expr> remove(list<expr> const & local_ctx, expr const & l) {
|
||||
if (!local_ctx)
|
||||
return local_ctx;
|
||||
else if (mlocal_name(head(local_ctx)) == mlocal_name(l))
|
||||
return tail(local_ctx);
|
||||
else
|
||||
return cons(head(local_ctx), remove(tail(local_ctx), l));
|
||||
}
|
||||
|
||||
// Replace local constant \c from with \c to in the expression \c e.
|
||||
static expr replace(expr const & e, expr const & from, expr const & to) {
|
||||
return instantiate(abstract_local(e, from), to);
|
||||
}
|
||||
|
||||
static expr replace(expr const & e, name const & from, expr const & to) {
|
||||
return instantiate(abstract_local(e, from), to);
|
||||
}
|
||||
|
||||
static expr replace(expr const & e, name_map<expr> const & subst) {
|
||||
expr r = e;
|
||||
subst.for_each([&](name const & l, expr const & v) {
|
||||
r = replace(r, l, v);
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
static bool contains(list<expr> const & local_ctx, expr const & p) {
|
||||
return std::any_of(local_ctx.begin(), local_ctx.end(), [&](expr const & l) { return mlocal_name(l) == mlocal_name(p); });
|
||||
}
|
||||
|
||||
expr compile_skip(program const & prg) {
|
||||
lean_assert(!head(prg.m_var_stack));
|
||||
auto new_stack = tail(prg.m_var_stack);
|
||||
buffer<eqn> new_eqs;
|
||||
for (eqn const & e : prg.m_eqns) {
|
||||
auto new_patterns = tail(e.m_patterns);
|
||||
new_eqs.emplace_back(e.m_local_context, new_patterns, e.m_rhs);
|
||||
}
|
||||
return compile_core(program(prg, new_stack, to_list(new_eqs)));
|
||||
}
|
||||
|
||||
expr compile_variable(program const & prg) {
|
||||
// The next pattern of every equation is a variable (or inaccessible term).
|
||||
// Thus, we just rename them with the variable on
|
||||
// the top of the variable stack.
|
||||
// Remark: if the pattern is an inaccessible term, we just ignore it.
|
||||
expr x = prg.get_var(*head(prg.m_var_stack));
|
||||
auto new_stack = tail(prg.m_var_stack);
|
||||
buffer<eqn> new_eqs;
|
||||
for (eqn const & e : prg.m_eqns) {
|
||||
expr p = head(e.m_patterns);
|
||||
if (is_inaccessible(p)) {
|
||||
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
|
||||
} else {
|
||||
lean_assert(is_local(p));
|
||||
if (contains(e.m_local_context, p)) {
|
||||
list<expr> new_local_ctx = e.m_local_context;
|
||||
new_local_ctx = remove(new_local_ctx, p);
|
||||
new_local_ctx = map(new_local_ctx, [&](expr const & l) { return replace(l, p, x); });
|
||||
auto new_patterns = map(tail(e.m_patterns), [&](expr const & p2) { return replace(p2, p, x); });
|
||||
auto new_rhs = replace(e.m_rhs, p, x);
|
||||
new_eqs.emplace_back(new_local_ctx, new_patterns, new_rhs);
|
||||
} else {
|
||||
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return compile_core(program(prg, new_stack, to_list(new_eqs)));
|
||||
}
|
||||
|
||||
class implementation : public inversion::implementation {
|
||||
eqn m_eqn;
|
||||
public:
|
||||
implementation(eqn const & e):m_eqn(e) {}
|
||||
|
||||
eqn const & get_eqn() const { return m_eqn; }
|
||||
|
||||
virtual name const & get_constructor_name() const {
|
||||
return const_name(get_app_fn(head(m_eqn.m_patterns)));
|
||||
}
|
||||
|
||||
virtual void update_exprs(std::function<expr(expr const &)> const & fn) {
|
||||
m_eqn.m_local_context = map(m_eqn.m_local_context, fn);
|
||||
m_eqn.m_patterns = map(m_eqn.m_patterns, fn);
|
||||
m_eqn.m_rhs = fn(m_eqn.m_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
// Wrap the equations from \c p as an "implementation_list" for the inversion package.
|
||||
inversion::implementation_list to_implementation_list(program const & p) {
|
||||
return map2<inversion::implementation_ptr>(p.m_eqns, [&](eqn const & e) {
|
||||
return std::shared_ptr<inversion::implementation>(new implementation(e));
|
||||
});
|
||||
}
|
||||
|
||||
// Convert program into a goal. We need that to be able to invoke the inversion package.
|
||||
goal to_goal(program const & p) {
|
||||
buffer<expr> hyps;
|
||||
to_buffer(p.m_context, hyps);
|
||||
expr new_type = p.m_type;
|
||||
expr new_meta = mk_app(mk_metavar(mk_fresh_name(), Pi(hyps, new_type)), hyps);
|
||||
return goal(new_meta, new_type);
|
||||
}
|
||||
|
||||
// Convert goal and implementation_list back into a program.
|
||||
// - nvars is the number of new variables in the variable stack.
|
||||
program to_program(expr const & fn, goal const & g, unsigned nvars, list<optional<name>> const & new_var_stack, inversion::implementation_list const & imps) {
|
||||
buffer<expr> new_context;
|
||||
g.get_hyps(new_context);
|
||||
expr new_type = g.get_type();
|
||||
buffer<eqn> new_equations;
|
||||
for (inversion::implementation_ptr const & imp : imps) {
|
||||
eqn e = static_cast<implementation*>(imp.get())->get_eqn();
|
||||
buffer<expr> pat_args;
|
||||
get_app_args(head(e.m_patterns), pat_args);
|
||||
lean_assert(pat_args.size() >= nvars);
|
||||
list<expr> new_pats = to_list(pat_args.end() - nvars, pat_args.end(), tail(e.m_patterns));
|
||||
new_equations.push_back(eqn(e.m_local_context, new_pats, e.m_rhs));
|
||||
}
|
||||
return program(fn, to_list(new_context), new_var_stack, to_list(new_equations), new_type);
|
||||
}
|
||||
|
||||
expr compile_constructor(program const & p) {
|
||||
expr h = p.get_var(*head(p.m_var_stack));
|
||||
goal g = to_goal(p);
|
||||
auto imps = to_implementation_list(p);
|
||||
if (auto r = apply(env(), ios(), m_tc, g, h, imps)) {
|
||||
substitution subst = r->m_subst;
|
||||
list<list<expr>> args = r->m_args;
|
||||
list<rename_map> rn_maps = r->m_renames;
|
||||
list<inversion::implementation_list> imps_list = r->m_implementation_lists;
|
||||
for (goal const & new_g : r->m_goals) {
|
||||
list<optional<name>> new_vars = map2<optional<name>>(head(args),
|
||||
[](expr const & a) {
|
||||
if (is_local(a))
|
||||
return optional<name>(mlocal_name(a));
|
||||
else
|
||||
return optional<name>();
|
||||
});
|
||||
rename_map const & rn = head(rn_maps);
|
||||
list<optional<name>> new_var_stack = map(tail(p.m_var_stack),
|
||||
[&](optional<name> const & n) {
|
||||
if (n)
|
||||
return optional<name>(rn.find(*n));
|
||||
else
|
||||
return n;
|
||||
});
|
||||
list<optional<name>> new_case_stack = append(new_vars, new_var_stack);
|
||||
program new_p = to_program(p.m_fn, new_g, length(new_vars), new_case_stack, head(imps_list));
|
||||
args = tail(args);
|
||||
imps_list = tail(imps_list);
|
||||
rn_maps = tail(rn_maps);
|
||||
expr t = compile_core(new_p);
|
||||
subst.assign(new_g.get_name(), new_g.abstract(t));
|
||||
}
|
||||
expr t = subst.instantiate_all(g.get_meta());
|
||||
// out() << "RESULT: " << t << "\n";
|
||||
return t;
|
||||
} else {
|
||||
throw_error(sstream() << "patter matching failed");
|
||||
}
|
||||
}
|
||||
|
||||
expr compile_complete(program const & /* p */) {
|
||||
// The next pattern of every equation is a constructor or variable.
|
||||
// We split the equations where the next pattern is a variable into cases.
|
||||
// That is, we are reducing this case to the compile_constructor case.
|
||||
|
||||
// TODO(Leo)
|
||||
return expr();
|
||||
}
|
||||
|
||||
expr compile_core(program const & p) {
|
||||
lean_assert(check_program(p));
|
||||
// out() << "compile_core step\n";
|
||||
// display(p);
|
||||
// out() << "------------------\n";
|
||||
if (p.m_var_stack) {
|
||||
if (!head(p.m_var_stack)) {
|
||||
return compile_skip(p);
|
||||
} else if (is_variable_transition(p)) {
|
||||
return compile_variable(p);
|
||||
} else if (is_constructor_transition(p)) {
|
||||
return compile_constructor(p);
|
||||
} else if (is_complete_transition(p)) {
|
||||
return compile_complete(p);
|
||||
} else {
|
||||
// In some equations the next pattern is an inaccessible term,
|
||||
// and in others it is a constructor.
|
||||
throw_error(sstream() << "invalid recursive equations for '" << local_pp_name(p.m_fn)
|
||||
<< "', inconsistent use of inaccessible term annotation, "
|
||||
<< "in some equations a pattern is a constructor, and in another it is an inaccessible term");
|
||||
}
|
||||
} else {
|
||||
if (p.m_eqns) {
|
||||
// variable stack is empty
|
||||
expr r = head(p.m_eqns).m_rhs;
|
||||
lean_assert(is_def_eq(infer_type(r), p.m_type));
|
||||
return r;
|
||||
} else {
|
||||
throw_error(sstream() << "invalid non-exhaustive set of recursive equations");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expr compile(program const & p) {
|
||||
buffer<expr> vars;
|
||||
to_buffer(p.m_context, vars);
|
||||
expr r = compile_core(p);
|
||||
return Fun(vars, r);
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c e is one of the functions being defined */
|
||||
bool is_fn(expr const & e) const {
|
||||
return
|
||||
is_local(e) &&
|
||||
std::any_of(m_fns.begin(), m_fns.end(), [&](expr const & fn) {
|
||||
return mlocal_name(fn) == mlocal_name(e);
|
||||
});
|
||||
}
|
||||
|
||||
/** \brief Return true iff the equations are recursive. */
|
||||
bool is_recursive(buffer<program> const & prgs) const {
|
||||
lean_assert(!prgs.empty());
|
||||
for (program const & p : prgs) {
|
||||
for (eqn const & e : p.m_eqns) {
|
||||
if (find(e.m_rhs, [&](expr const & e, unsigned) { return is_fn(e); }))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (prgs.size() > 1)
|
||||
throw_error(sstream() << "mutual recursion is not needed when defining non-recursive functions");
|
||||
return false;
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition*/
|
||||
bool is_inductive(expr const & t) const {
|
||||
expr const & fn = get_app_fn(t);
|
||||
return
|
||||
is_constant(fn) &&
|
||||
env().find(name{const_name(fn), "brec_on"});
|
||||
}
|
||||
|
||||
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
|
||||
bool is_compatible_inductive(expr const & t1, expr const & t2) {
|
||||
lean_assert(is_inductive(t1));
|
||||
lean_assert(is_inductive(t2));
|
||||
buffer<expr> args1, args2;
|
||||
name const & I1 = const_name(get_app_args(t1, args1));
|
||||
name const & I2 = const_name(get_app_args(t2, args2));
|
||||
inductive::inductive_decls decls = *inductive::is_inductive_decl(env(), I1);
|
||||
unsigned nparams = std::get<1>(decls);
|
||||
for (auto decl : std::get<2>(decls)) {
|
||||
if (inductive::inductive_decl_name(decl) == I2) {
|
||||
// parameters must be definitionally equal
|
||||
unsigned i = 0;
|
||||
for (; i < nparams; i++) {
|
||||
if (!is_def_eq(args1[i], args2[i]))
|
||||
break;
|
||||
}
|
||||
if (i == nparams)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c t1 and \c t2 are instances of the same inductive datatype */
|
||||
static bool is_same_inductive(expr const & t1, expr const & t2) {
|
||||
return const_name(get_app_fn(t1)) != const_name(get_app_fn(t2));
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */
|
||||
bool is_le(expr const & s, expr const & t) {
|
||||
return is_def_eq(s, t) || is_lt(s, t);
|
||||
}
|
||||
|
||||
/** Return true iff \c s is structurally smaller than \c t */
|
||||
bool is_lt(expr s, expr const & t) {
|
||||
s = whnf(s);
|
||||
if (is_app(s)) {
|
||||
expr const & s_fn = get_app_fn(s);
|
||||
if (!is_constructor(s_fn))
|
||||
return is_lt(s_fn, t); // f < t ==> s := f a_1 ... a_n < t
|
||||
}
|
||||
buffer<expr> t_args;
|
||||
expr const & t_fn = get_app_args(t, t_args);
|
||||
if (!is_constructor(t_fn))
|
||||
return false;
|
||||
return std::any_of(t_args.begin(), t_args.end(), [&](expr const & t_arg) { return is_le(s, t_arg); });
|
||||
}
|
||||
|
||||
/** \brief Auxiliary functional object for checking whether recursive application are structurally smaller or not */
|
||||
struct check_rhs_fn {
|
||||
equation_compiler_fn & m_main;
|
||||
buffer<program> const & m_prgs;
|
||||
buffer<unsigned> const & m_arg_pos;
|
||||
|
||||
check_rhs_fn(equation_compiler_fn & m, buffer<program> const & prgs, buffer<unsigned> const & arg_pos):
|
||||
m_main(m), m_prgs(prgs), m_arg_pos(arg_pos) {}
|
||||
|
||||
/** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */
|
||||
bool check_rhs(expr const & e, expr const & arg) const {
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Meta:
|
||||
case expr_kind::Local: case expr_kind::Constant:
|
||||
case expr_kind::Sort:
|
||||
return true;
|
||||
case expr_kind::Macro:
|
||||
for (unsigned i = 0; i < macro_num_args(e); i++)
|
||||
if (!check_rhs(macro_arg(e, i), arg))
|
||||
return false;
|
||||
return true;
|
||||
case expr_kind::App: {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (!check_rhs(fn, arg))
|
||||
return false;
|
||||
for (unsigned i = 0; i < args.size(); i++)
|
||||
if (!check_rhs(args[i], arg))
|
||||
return false;
|
||||
if (is_local(fn)) {
|
||||
for (unsigned j = 0; j < m_prgs.size(); j++) {
|
||||
if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) {
|
||||
// it is a recusive application
|
||||
unsigned pos_j = m_arg_pos[j];
|
||||
if (pos_j < args.size()) {
|
||||
expr const & arg_j = args[pos_j];
|
||||
// arg_j must be structurally smaller than arg
|
||||
if (!m_main.is_lt(arg_j, arg))
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case expr_kind::Lambda:
|
||||
case expr_kind::Pi:
|
||||
if (!check_rhs(binding_domain(e), arg)) {
|
||||
return false;
|
||||
} else {
|
||||
expr l = mk_local(m_main.mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
|
||||
return check_rhs(instantiate(binding_body(e), l), arg);
|
||||
}
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
bool operator()(expr const & e, expr const & arg) const {
|
||||
return check_rhs(e, arg);
|
||||
}
|
||||
};
|
||||
|
||||
/** \brief Return true iff the recursive equations in prgs are "admissible" with respect to
|
||||
the following configuration of recursive arguments.
|
||||
We say the equations are admissible when every recursive application of prgs[i]
|
||||
is structurally smaller at arguments arg_pos[i].
|
||||
*/
|
||||
bool check_rec_args(buffer<program> const & prgs, buffer<unsigned> const & arg_pos) {
|
||||
lean_assert(prgs.size() == arg_pos.size());
|
||||
check_rhs_fn check_rhs(*this, prgs, arg_pos);
|
||||
for (unsigned i = 0; i < prgs.size(); i++) {
|
||||
program const & prg = prgs[i];
|
||||
unsigned pos_i = arg_pos[i];
|
||||
for (eqn const & e : prg.m_eqns) {
|
||||
expr const & p_i = get_ith(e.m_patterns, pos_i);
|
||||
if (!check_rhs(e.m_rhs, p_i))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool find_rec_args(buffer<program> const & prgs, unsigned i, buffer<unsigned> & arg_pos, buffer<expr> & arg_types) {
|
||||
if (i == prgs.size()) {
|
||||
return check_rec_args(prgs, arg_pos);
|
||||
} else {
|
||||
program const & p = prgs[i];
|
||||
unsigned j = 0;
|
||||
for (optional<name> const & n : p.m_var_stack) {
|
||||
lean_assert(n);
|
||||
expr const & v = p.get_var(*n);
|
||||
expr const & t = mlocal_type(v);
|
||||
if (// argument must be an inductive datatype
|
||||
is_inductive(t) &&
|
||||
// argument must be an inductive datatype different from the ones in arg_types
|
||||
std::all_of(arg_types.begin(), arg_types.end(),
|
||||
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
|
||||
// argument type must be in the same mutually recursive declaration of previous argument types
|
||||
(arg_types.empty() || is_compatible_inductive(t, arg_types[0]))) {
|
||||
// Found candidate
|
||||
arg_pos.push_back(j);
|
||||
arg_types.push_back(t);
|
||||
if (find_rec_args(prgs, i+1, arg_pos, arg_types))
|
||||
return true;
|
||||
arg_pos.pop_back();
|
||||
arg_types.pop_back();
|
||||
}
|
||||
j++;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool find_rec_args(buffer<program> const & prgs, buffer<unsigned> & arg_pos) {
|
||||
buffer<expr> arg_types;
|
||||
return find_rec_args(prgs, 0, arg_pos, arg_types);
|
||||
}
|
||||
|
||||
void apply_brec_on(buffer<program> & prgs) {
|
||||
lean_assert(!prgs.empty());
|
||||
buffer<unsigned> arg_pos;
|
||||
if (!find_rec_args(prgs, arg_pos)) {
|
||||
throw_error(sstream() << "invalid recursive equations, failed to find recursive arguments that are structurally smaller "
|
||||
<< "(possible solution: use well-founded recursion)");
|
||||
}
|
||||
// out() << "Found recursive arguments: ";
|
||||
// for (unsigned p : arg_pos) out() << " " << p; out() << "\n";
|
||||
|
||||
// TODO(Leo):
|
||||
}
|
||||
|
||||
void apply_wf(buffer<program> & /* prgs */) {
|
||||
// TODO(Leo)
|
||||
}
|
||||
|
||||
public:
|
||||
equation_compiler_fn(type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type, bool relax):
|
||||
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type), m_relax(relax) {
|
||||
get_app_args(m_meta, m_global_context);
|
||||
}
|
||||
|
||||
expr operator()(expr eqns) {
|
||||
proof_state ps = to_proof_state(m_meta, m_meta_type, m_tc.mk_ngen(), m_relax);
|
||||
eqns = validate_patterns(eqns);
|
||||
regular(env(), m_ios) << "Equations:\n" << eqns << "\n";
|
||||
regular(env(), m_ios) << ps.pp(env(), m_ios) << "\n\n";
|
||||
return eqns;
|
||||
check_limitations(eqns);
|
||||
// out() << "Equations:\n" << eqns << "\n";
|
||||
buffer<program> prgs;
|
||||
initialize(eqns, prgs);
|
||||
if (is_recursive(prgs)) {
|
||||
if (is_wf_equations(eqns)) {
|
||||
apply_wf(prgs);
|
||||
} else {
|
||||
apply_brec_on(prgs);
|
||||
}
|
||||
}
|
||||
buffer<expr> rs;
|
||||
for (program const & p : prgs) {
|
||||
expr r = compile(p);
|
||||
// out() << r << "\nType: " << m_tc.infer(r).first << "\n";
|
||||
rs.push_back(r);
|
||||
}
|
||||
if (closed(rs[0]))
|
||||
return rs[0];
|
||||
else
|
||||
return m_meta;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -210,7 +210,7 @@ void initialize_library_util() {
|
|||
g_heq_name = new name("heq");
|
||||
|
||||
g_sigma_name = new name("sigma");
|
||||
g_sigma_mk_name = new name{"sigma", "dpair"};
|
||||
g_sigma_mk_name = new name{"sigma", "mk"};
|
||||
}
|
||||
|
||||
void finalize_library_util() {
|
||||
|
|
|
@ -23,7 +23,6 @@ definition map {A B C : Type} (f : A → B → C) : Π {n}, vector A n → vecto
|
|||
map nil nil := nil,
|
||||
map (a :: va) (b :: vb) := f a b :: map va vb
|
||||
|
||||
|
||||
definition half : nat → nat,
|
||||
half 0 := 0,
|
||||
half 1 := 0,
|
||||
|
@ -35,3 +34,24 @@ mk : Π a, image_of f (f a)
|
|||
|
||||
definition inv {f : A → B} : Π b, image_of f b → A,
|
||||
inv ⌞f a⌟ (image_of.mk f a) := a
|
||||
|
||||
namespace tst
|
||||
|
||||
definition fib : nat → nat,
|
||||
fib 0 := 1,
|
||||
fib 1 := 1,
|
||||
fib (a+2) := fib a + fib (a+1)
|
||||
|
||||
end tst
|
||||
|
||||
definition simple : nat → nat → nat,
|
||||
simple x y := x + y
|
||||
|
||||
definition simple2 : nat → nat → nat,
|
||||
simple2 (x+1) y := x + y,
|
||||
simple2 ⌞y+1⌟ y := y
|
||||
|
||||
|
||||
|
||||
check @vector.brec_on
|
||||
check @vector.cases_on
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
set_option pp.implicit true
|
||||
set_option pp.notation false
|
||||
|
||||
definition symm : Π {A : Type} {a b : A}, a = b → b = a,
|
||||
definition symm {A : Type} : Π {a b : A}, a = b → b = a,
|
||||
symm rfl := rfl
|
||||
|
||||
definition trans : Π {A : Type} {a b c : A}, a = b → b = c → a = c,
|
||||
definition trans {A : Type} : Π {a b c : A}, a = b → b = c → a = c,
|
||||
trans rfl rfl := rfl
|
||||
|
|
Loading…
Reference in a new issue