feat(frontends/lean): elaborate recursive equations

Remark: we are not compiling them yet.
This commit is contained in:
Leonardo de Moura 2014-12-10 22:25:40 -08:00
parent 2867789bec
commit b8f665e561
8 changed files with 427 additions and 58 deletions

View file

@ -8,6 +8,7 @@ Author: Leonardo de Moura
#include "util/sstream.h"
#include "kernel/type_checker.h"
#include "kernel/abstract.h"
#include "kernel/replace_fn.h"
#include "kernel/for_each_fn.h"
#include "library/scoped_ext.h"
#include "library/aliases.h"
@ -315,13 +316,87 @@ static bool is_curr_with_or_comma(parser & p) {
return p.curr_is_token(get_with_tk()) || p.curr_is_token(get_comma_tk());
}
expr parse_equations(parser & p, name const & n, expr const & type, buffer<expr> & auxs) {
/**
For convenience, the left-hand-side of a recursive equation may contain
undeclared variables.
We use parser::undef_id_to_local_scope to force the parser to create a local constant for
each undefined identifier.
This method validates occurrences of these variables. They can only occur as an application
or macro argument.
*/
static void validate_equation_lhs(parser const & p, expr const & lhs, buffer<expr> const & locals) {
if (is_app(lhs)) {
validate_equation_lhs(p, app_fn(lhs), locals);
validate_equation_lhs(p, app_arg(lhs), locals);
} else if (is_macro(lhs)) {
for (unsigned i = 0; i < macro_num_args(lhs); i++)
validate_equation_lhs(p, macro_arg(lhs, i), locals);
} else if (!is_local(lhs)) {
for_each(lhs, [&](expr const & e, unsigned) {
if (is_local(e) &&
std::any_of(locals.begin(), locals.end(), [&](expr const & local) {
return mlocal_name(e) == mlocal_name(local);
})) {
throw parser_error(sstream() << "invalid occurrence of variable '" << mlocal_name(lhs) <<
"' in the left-hand-side of recursive equation", p.pos_of(lhs));
}
return has_local(e);
});
}
}
/**
\brief Merge multiple occurrences of a variable in the left-hand-side of a recursive equation.
\see validate_equation_lhs
*/
static expr merge_equation_lhs_vars(expr const & lhs, buffer<expr> & locals) {
expr_map<expr> m;
unsigned j = 0;
for (unsigned i = 0; i < locals.size(); i++) {
unsigned k;
for (k = 0; k < i; k++) {
if (mlocal_name(locals[k]) == mlocal_name(locals[i])) {
m.insert(mk_pair(locals[i], locals[k]));
break;
}
}
if (k == i) {
locals[j] = locals[i];
j++;
}
}
if (j == locals.size())
return lhs;
locals.shrink(j);
return replace(lhs, [&](expr const & e) {
if (!has_local(e))
return some_expr(e);
if (is_local(e)) {
auto it = m.find(e);
if (it != m.end())
return some_expr(it->second);
}
return none_expr();
});
}
static void throw_invalid_equation_lhs(name const & n, pos_info const & p) {
throw parser_error(sstream() << "invalid recursive equation, head symbol '"
<< n << "' in the left-hand-side does not correspond to function(s) being defined", p);
}
expr parse_equations(parser & p, name const & n, expr const & type, buffer<expr> & auxs,
optional<local_environment> const & lenv, buffer<expr> const & ps) {
buffer<expr> eqns;
buffer<expr> fns;
{
parser::local_scope scope1(p);
parser::undef_id_to_local_scope scope2(p);
parser::local_scope scope1(p, lenv);
for (expr const & param : ps)
p.add_local(param);
lean_assert(is_curr_with_or_comma(p));
expr f = mk_local(n, type);
fns.push_back(mk_local(n, type));
if (p.curr_is_token(get_with_tk())) {
while (p.curr_is_token(get_with_tk())) {
p.next();
@ -330,28 +405,57 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer<expr>
expr g_type = p.parse_expr();
expr g = mk_local(g_name, g_type);
auxs.push_back(g);
fns.push_back(g);
}
}
p.check_token_next(get_comma_tk(), "invalid declaration, ',' expected");
p.add_local(f);
for (expr const & g : auxs)
p.add_local(g);
for (expr const & fn : fns)
p.add_local(fn);
while (true) {
expr lhs = p.parse_expr();
expr lhs;
unsigned prev_num_undef_ids = p.get_num_undef_ids();
buffer<expr> locals;
{
parser::undef_id_to_local_scope scope2(p);
auto lhs_pos = p.pos();
lhs = p.parse_expr();
expr lhs_fn = get_app_fn(lhs);
if (is_explicit(lhs_fn))
lhs_fn = get_explicit_arg(lhs_fn);
if (is_constant(lhs_fn))
throw_invalid_equation_lhs(const_name(lhs_fn), lhs_pos);
if (is_local(lhs_fn) && std::all_of(fns.begin(), fns.end(), [&](expr const & fn) { return fn != lhs_fn; }))
throw_invalid_equation_lhs(local_pp_name(lhs_fn), lhs_pos);
if (!is_local(lhs_fn))
throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant", lhs_pos);
unsigned num_undef_ids = p.get_num_undef_ids();
for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) {
locals.push_back(p.get_undef_id(i));
}
}
validate_equation_lhs(p, lhs, locals);
lhs = merge_equation_lhs_vars(lhs, locals);
p.check_token_next(get_assign_tk(), "invalid declaration, ':=' expected");
expr rhs = p.parse_expr();
eqns.push_back(mk_equation(lhs, rhs));
{
parser::local_scope scope2(p);
for (expr const & local : locals)
p.add_local(local);
expr rhs = p.parse_expr();
eqns.push_back(Fun(fns, Fun(locals, mk_equation(lhs, rhs), p)));
}
if (!p.curr_is_token(get_comma_tk()))
break;
p.next();
}
}
if (p.curr_is_token(get_wf_tk())) {
auto pos = p.pos();
p.next();
expr R = p.save_pos(mk_expr_placeholder(), pos);
expr Hwf = p.parse_expr();
return mk_equations(eqns.size(), eqns.data(), Hwf);
return mk_equations(fns.size(), eqns.size(), eqns.data(), R, Hwf);
} else {
return mk_equations(eqns.size(), eqns.data());
return mk_equations(fns.size(), eqns.size(), eqns.data());
}
}
@ -409,7 +513,7 @@ class definition_cmd_fn {
auto pos = m_p.pos();
m_type = m_p.parse_expr();
if (is_curr_with_or_comma(m_p)) {
m_value = parse_equations(m_p, m_name, m_type, m_aux_decls);
m_value = parse_equations(m_p, m_name, m_type, m_aux_decls, optional<local_environment>(), buffer<expr>());
} else if (!is_definition() && !m_p.curr_is_token(get_assign_tk())) {
check_end_of_theorem(m_p);
m_value = m_p.save_pos(mk_expr_placeholder(), pos);
@ -427,7 +531,7 @@ class definition_cmd_fn {
m_p.next();
m_type = m_p.parse_scoped_expr(ps, *lenv);
if (is_curr_with_or_comma(m_p)) {
m_value = parse_equations(m_p, m_name, m_type, m_aux_decls);
m_value = parse_equations(m_p, m_name, m_type, m_aux_decls, lenv, ps);
} else if (!is_definition() && !m_p.curr_is_token(get_assign_tk())) {
check_end_of_theorem(m_p);
m_value = m_p.save_pos(mk_expr_placeholder(), pos);

View file

@ -33,6 +33,7 @@ Author: Leonardo de Moura
#include "library/local_context.h"
#include "library/tactic/expr_to_tactic.h"
#include "library/error_handling/error_handling.h"
#include "library/definitional/equations.h"
#include "frontends/lean/local_decls.h"
#include "frontends/lean/class.h"
#include "frontends/lean/tactic_hint.h"
@ -99,10 +100,11 @@ elaborator::elaborator(elaborator_context & ctx, name_generator const & ngen, bo
m_has_sorry = has_sorry(m_ctx.m_env);
m_relax_main_opaque = false;
m_use_tactic_hints = true;
m_no_info = false;
m_tc[0] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), false);
m_tc[1] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), true);
m_nice_mvar_names = nice_mvar_names;
m_no_info = false;
m_in_equation_lhs = false;
m_tc[0] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), false);
m_tc[1] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), true);
m_nice_mvar_names = nice_mvar_names;
}
expr elaborator::mk_local(name const & n, expr const & t, binder_info const & bi) {
@ -812,6 +814,155 @@ expr elaborator::visit_sorry(expr const & e) {
return mk_app(update_constant(e, to_list(u)), m, e.get_tag());
}
expr const & elaborator::get_equation_fn(expr const & eq) const {
expr it = eq;
while (is_lambda(it))
it = binding_body(it);
if (!is_equation(it))
throw_elaborator_exception(env(), "ill-formed equation", eq);
expr const & fn = get_app_fn(equation_lhs(it));
if (!is_local(fn))
throw_elaborator_exception(env(), "ill-formed equation", eq);
return fn;
}
static expr copy_domain(unsigned num, expr const & source, expr const & target) {
if (num == 0) {
return target;
} else {
lean_assert(is_binding(source) && is_binding(target));
return update_binding(source, mk_as_is(binding_domain(source)), copy_domain(num-1, binding_body(source), binding_body(target)));
}
}
static constraint mk_equations_cnstr(environment const & env, io_state const & ios, expr const & m, expr const & eqns) {
justification j = mk_failed_to_synthesize_jst(env, m);
auto choice_fn = [=](expr const & , expr const &, substitution const & s,
name_generator const &) {
expr new_eqns = substitution(s).instantiate(eqns);
regular(env, ios) << "Equations:\n" << new_eqns << "\n\n";
// TODO(Leo);
return lazy_list<constraints>(constraints());
};
bool owner = true;
bool relax = false;
return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::MaxDelayed), owner, j, relax);
}
expr elaborator::visit_equations(expr const & eqns, constraint_seq & cs) {
buffer<expr> eqs;
buffer<expr> new_eqs;
optional<expr> new_R;
optional<expr> new_Hwf;
to_equations(eqns, eqs);
if (eqs.empty())
throw_elaborator_exception(env(), "invalid empty set of recursive equations", eqns);
if (is_wf_equations(eqns)) {
new_R = visit(equations_wf_rel(eqns), cs);
new_Hwf = visit(equations_wf_proof(eqns), cs);
expr Hwf_type = infer_type(*new_Hwf, cs);
expr wf = visit(mk_constant("well_founded"), cs);
wf = ::lean::mk_app(wf, *new_R);
justification j = mk_type_mismatch_jst(*new_Hwf, Hwf_type, wf, equations_wf_proof(eqns));
auto new_Hwf_cs = ensure_has_type(*new_Hwf, Hwf_type, wf, j, m_relax_main_opaque);
new_Hwf = new_Hwf_cs.first;
cs += new_Hwf_cs.second;
}
flet<optional<expr>> set1(m_equation_R, new_R);
unsigned num_fns = equations_num_fns(eqns);
optional<expr> first_eq;
for (expr const & eq : eqs) {
expr new_eq;
if (first_eq) {
// Replace first num_fns domains of eq with the ones in first_eq.
// This is a trick/hack to ensure the fns in each equation have
// the same elaborated type.
new_eq = visit(copy_domain(num_fns, *first_eq, eq), cs);
} else {
new_eq = visit(eq, cs);
first_eq = new_eq;
}
new_eqs.push_back(new_eq);
}
expr new_eqns;
if (new_R) {
new_eqns = mk_equations(num_fns, new_eqs.size(), new_eqs.data(), *new_R, *new_Hwf);
} else {
new_eqns = mk_equations(num_fns, new_eqs.size(), new_eqs.data());
}
lean_assert(first_eq && is_lambda(*first_eq));
expr type = binding_domain(*first_eq);
expr m = m_full_context.mk_meta(m_ngen, some_expr(type), eqns.get_tag());
register_meta(m);
constraint c = mk_equations_cnstr(env(), ios(), m, new_eqns);
cs += c;
return m;
}
expr elaborator::visit_equation(expr const & eq, constraint_seq & cs) {
expr const & lhs = equation_lhs(eq);
expr const & rhs = equation_rhs(eq);
expr lhs_fn = get_app_fn(lhs);
if (is_explicit(lhs_fn))
lhs_fn = get_explicit_arg(lhs_fn);
if (!is_local(lhs_fn))
throw exception("ill-formed equation");
expr new_lhs, new_rhs;
{
flet<bool> set(m_in_equation_lhs, true);
new_lhs = visit(lhs, cs);
}
{
optional<expr> some_new_lhs(new_lhs);
flet<optional<expr>> set1(m_equation_lhs, some_new_lhs);
new_rhs = visit(rhs, cs);
}
expr lhs_type = infer_type(new_lhs, cs);
expr rhs_type = infer_type(new_rhs, cs);
justification j = mk_justification(eq, [=](formatter const & fmt, substitution const & subst) {
substitution s(subst);
return pp_def_type_mismatch(fmt, local_pp_name(lhs_fn), s.instantiate(lhs_type), s.instantiate(rhs_type));
});
pair<expr, constraint_seq> new_rhs_cs = ensure_has_type(new_rhs, rhs_type, lhs_type, j, m_relax_main_opaque);
new_rhs = new_rhs_cs.first;
cs += new_rhs_cs.second;
return mk_equation(new_lhs, new_rhs);
}
expr elaborator::visit_inaccessible(expr const & e, constraint_seq & cs) {
if (!m_in_equation_lhs)
throw_elaborator_exception(env(), "invalid occurrence of 'inaccessible' annotation, it must only occur in the "
"left-hand-side of recursive equations", e);
return mk_inaccessible(visit(get_annotation_arg(e), cs));
}
expr elaborator::visit_decreasing(expr const & e, constraint_seq & cs) {
if (!m_equation_lhs)
throw_elaborator_exception(env(), "invalid occurrence of 'decreasing' annotation, it must only occur in "
"the right-hand-side of recursive equations", e);
if (!m_equation_R)
throw_elaborator_exception(env(), "invalid occurrence of 'decreasing' annotation, it can only be used when "
"recursive equations are being defined by well-founded recursion", e);
expr const & lhs_fn = get_app_fn(*m_equation_lhs);
if (get_app_fn(decreasing_app(e)) != lhs_fn)
throw_elaborator_exception(env(), "invalid occurrence of 'decreasing' annotation, expression must be an "
"application of the recursive function being defined", e);
expr dec_app = visit(decreasing_app(e), cs);
expr dec_proof = visit(decreasing_proof(e), cs);
// Remark: perhaps we should enforce the type of dec_proof here.
// We may have enough information to wrap the arguments in a sigma type (reason: the type of the function being elaborated has holes).
// Possible solution: create a constraint that enforces the type as soon the type of function has been elaborated.
return mk_decreasing(dec_app, dec_proof);
}
expr elaborator::visit_core(expr const & e, constraint_seq & cs) {
if (is_placeholder(e)) {
return visit_placeholder(e, cs);
@ -841,6 +992,14 @@ expr elaborator::visit_core(expr const & e, constraint_seq & cs) {
return visit_core(get_explicit_arg(e), cs);
} else if (is_sorry(e)) {
return visit_sorry(e);
} else if (is_equations(e)) {
lean_unreachable();
} else if (is_equation(e)) {
return visit_equation(e, cs);
} else if (is_inaccessible(e)) {
return visit_inaccessible(e, cs);
} else if (is_decreasing(e)) {
return visit_decreasing(e, cs);
} else {
switch (e.kind()) {
case expr_kind::Local: return e;
@ -882,6 +1041,8 @@ pair<expr, constraint_seq> elaborator::visit(expr const & e) {
} else {
r = visit_core(b, cs);
}
} else if (is_equations(e)) {
r = visit_equations(e, cs);
} else if (is_explicit(get_app_fn(e))) {
r = visit_core(e, cs);
} else {

View file

@ -53,6 +53,14 @@ class elaborator : public coercion_info_manager {
// if m_no_info is true, we do not collect information when true,
// we set is to true whenever we find no_info annotation.
bool m_no_info;
// if m_in_equation_lhs is true, we are processing the left-hand-side of an equation
// and inaccessible expressions are allowed
bool m_in_equation_lhs;
// if m_equation_lhs is not none, we are processing the right-hand-side of an equation
// and decreasing expressions are allowed
optional<expr> m_equation_lhs;
// if m_equation_R is not none when elaborator is processing recursive equation using the well-founded relation R.
optional<expr> m_equation_R;
bool m_use_tactic_hints;
info_manager m_pre_info_data;
bool m_has_sorry;
@ -151,6 +159,13 @@ class elaborator : public coercion_info_manager {
std::tuple<expr, level_param_names> apply(substitution & s, expr const & e);
pair<expr, constraints> elaborate_nested(list<expr> const & g, expr const & e,
bool relax, bool use_tactic_hints, bool report_unassigned);
expr const & get_equation_fn(expr const & eq) const;
expr visit_equations(expr const & eqns, constraint_seq & cs);
expr visit_equation(expr const & e, constraint_seq & cs);
expr visit_inaccessible(expr const & e, constraint_seq & cs);
expr visit_decreasing(expr const & e, constraint_seq & cs);
public:
elaborator(elaborator_context & ctx, name_generator const & ngen, bool nice_mvar_names = false);
std::tuple<expr, level_param_names> operator()(list<expr> const & ctx, expr const & e, bool _ensure_type,

View file

@ -76,6 +76,12 @@ parser::local_scope::local_scope(parser & p, environment const & env):
m_p.m_env = env;
m_p.push_local_scope();
}
parser::local_scope::local_scope(parser & p, optional<environment> const & env):
m_p(p), m_env(p.env()) {
if (env)
m_p.m_env = *env;
m_p.push_local_scope();
}
parser::local_scope::~local_scope() {
m_p.pop_local_scope();
m_p.m_env = m_env;
@ -362,8 +368,11 @@ expr parser::propagate_levels(expr const & e, levels const & ls) {
}
}
pos_info parser::pos_of(expr const & e, pos_info default_pos) {
if (auto it = m_pos_table.find(get_tag(e)))
pos_info parser::pos_of(expr const & e, pos_info default_pos) const {
tag t = e.get_tag();
if (t == nulltag)
return default_pos;
if (auto it = m_pos_table.find(t))
return *it;
else
return default_pos;
@ -432,7 +441,7 @@ void parser::push_local_scope(bool save_options) {
optional<options> opts;
if (save_options)
opts = m_ios.get_options();
m_parser_scope_stack = cons(parser_scope_stack_elem(opts, m_level_variables, m_variables, m_include_vars),
m_parser_scope_stack = cons(parser_scope_stack_elem(opts, m_level_variables, m_variables, m_include_vars, m_undef_ids.size()),
m_parser_scope_stack);
}
@ -451,6 +460,7 @@ void parser::pop_local_scope() {
m_level_variables = s.m_level_variables;
m_variables = s.m_variables;
m_include_vars = s.m_include_vars;
m_undef_ids.shrink(s.m_num_undef_ids);
m_parser_scope_stack = tail(m_parser_scope_stack);
}
@ -1111,7 +1121,9 @@ expr parser::id_to_expr(name const & id, pos_info const & p) {
if (m_undef_id_behavior == undef_id_behavior::AssumeConstant) {
r = save_pos(mk_constant(get_namespace(m_env) + id, ls), p);
} else if (m_undef_id_behavior == undef_id_behavior::AssumeLocal) {
r = save_pos(mk_local(id, mk_expr_placeholder()), p);
expr local = mk_local(id, mk_expr_placeholder());
m_undef_ids.push_back(local);
r = save_pos(local, p);
}
}
if (!r)

View file

@ -52,8 +52,10 @@ struct parser_scope_stack_elem {
name_set m_level_variables;
name_set m_variables;
name_set m_include_vars;
parser_scope_stack_elem(optional<options> const & o, name_set const & lvs, name_set const & vs, name_set const & ivs):
m_options(o), m_level_variables(lvs), m_variables(vs), m_include_vars(ivs) {}
unsigned m_num_undef_ids;
parser_scope_stack_elem(optional<options> const & o, name_set const & lvs, name_set const & vs, name_set const & ivs,
unsigned num_undef_ids):
m_options(o), m_level_variables(lvs), m_variables(vs), m_include_vars(ivs), m_num_undef_ids(num_undef_ids) {}
};
typedef list<parser_scope_stack_elem> parser_scope_stack;
@ -130,6 +132,8 @@ class parser {
// curr command token
name m_cmd_token;
buffer<expr> m_undef_ids;
void display_warning_pos(unsigned line, unsigned pos);
void display_warning_pos(pos_info p);
void display_error_pos(unsigned line, unsigned pos);
@ -255,8 +259,8 @@ public:
pos_info pos() const { return pos_info(m_scanner.get_line(), m_scanner.get_pos()); }
expr save_pos(expr e, pos_info p);
expr rec_save_pos(expr const & e, pos_info p);
pos_info pos_of(expr const & e, pos_info default_pos);
pos_info pos_of(expr const & e) { return pos_of(e, pos()); }
pos_info pos_of(expr const & e, pos_info default_pos) const;
pos_info pos_of(expr const & e) const { return pos_of(e, pos()); }
pos_info cmd_pos() const { return m_last_cmd_pos; }
name const & get_cmd_token() const { return m_cmd_token; }
void set_line(unsigned p) { return m_scanner.set_line(p); }
@ -359,7 +363,10 @@ public:
expr parse_scoped_expr(buffer<expr> const & ps, unsigned rbp = 0) { return parse_scoped_expr(ps.size(), ps.data(), rbp); }
struct local_scope { parser & m_p; environment m_env;
local_scope(parser & p, bool save_options = false); local_scope(parser & p, environment const & env); ~local_scope();
local_scope(parser & p, bool save_options = false);
local_scope(parser & p, environment const & env);
local_scope(parser & p, optional<environment> const & env);
~local_scope();
};
bool has_locals() const { return !m_local_decls.empty() || !m_local_level_decls.empty(); }
void add_local_level(name const & n, level const & l, bool is_variable = false);
@ -395,6 +402,11 @@ public:
struct undef_id_to_const_scope : public flet<undef_id_behavior> { undef_id_to_const_scope(parser & p); };
struct undef_id_to_local_scope : public flet<undef_id_behavior> { undef_id_to_local_scope(parser &); };
/** \brief Return the size of the stack of undefined local constants */
unsigned get_num_undef_ids() const { return m_undef_ids.size(); }
/** \brief Return the i-th undefined local constants */
expr const & get_undef_id(unsigned i) const { return m_undef_ids[i]; }
/** \brief Elaborate \c e, and tolerate metavariables in the result. */
std::tuple<expr, level_param_names> elaborate_relaxed(expr const & e, list<expr> const & ctx = list<expr>());
/** \brief Elaborate \c e, and ensure it is a type. */

View file

@ -22,18 +22,27 @@ static std::string * g_decreasing_opcode = nullptr;
[[ noreturn ]] static void throw_eq_ex() { throw exception("unexpected occurrence of 'equation' expression"); }
class equations_macro_cell : public macro_definition_cell {
unsigned m_num_fns;
public:
equations_macro_cell(unsigned num_fns):m_num_fns(num_fns) {}
virtual name get_name() const { return *g_equations_name; }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_eqs_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_eqs_ex(); }
virtual void write(serializer & s) const { s.write_string(*g_equations_opcode); }
virtual void write(serializer & s) const { s << *g_equations_opcode << m_num_fns; }
unsigned get_num_fns() const { return m_num_fns; }
};
class equation_macro_cell : public macro_definition_cell {
public:
virtual name get_name() const { return *g_equation_name; }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const { throw_eq_ex(); }
virtual optional<expr> expand(expr const &, extension_context &) const { throw_eq_ex(); }
virtual pair<expr, constraint_seq> get_type(expr const &, extension_context &) const {
expr dummy = mk_Prop();
return mk_pair(dummy, constraint_seq());
}
virtual optional<expr> expand(expr const &, extension_context &) const {
expr dummy = mk_Type();
return some_expr(dummy);
}
virtual void write(serializer & s) const { s.write_string(*g_equation_opcode); }
};
@ -56,11 +65,18 @@ public:
virtual void write(serializer & s) const { s.write_string(*g_decreasing_opcode); }
};
static macro_definition * g_equations = nullptr;
static macro_definition * g_equation = nullptr;
static macro_definition * g_decreasing = nullptr;
bool is_equation(expr const & e) { return is_macro(e) && macro_def(e) == *g_equation; }
bool is_lambda_equation(expr const & e) {
if (is_lambda(e))
return is_lambda_equation(binding_body(e));
else
return is_equation(e);
}
expr const & equation_lhs(expr const & e) { lean_assert(is_equation(e)); return macro_arg(e, 0); }
expr const & equation_rhs(expr const & e) { lean_assert(is_equation(e)); return macro_arg(e, 1); }
expr mk_equation(expr const & lhs, expr const & rhs) {
@ -76,40 +92,54 @@ expr mk_decreasing(expr const & t, expr const & H) {
return mk_macro(*g_decreasing, 2, args);
}
bool is_equations(expr const & e) { return is_macro(e) && macro_def(e) == *g_equations; }
bool is_equations(expr const & e) { return is_macro(e) && macro_def(e).get_name() == *g_equations_name; }
bool is_wf_equations_core(expr const & e) {
lean_assert(is_equations(e));
return !is_equation(macro_arg(e, macro_num_args(e) - 1));
return macro_num_args(e) >= 3 && !is_lambda_equation(macro_arg(e, macro_num_args(e) - 1));
}
bool is_wf_equations(expr const & e) { return is_equations(e) && is_wf_equations_core(e); }
unsigned equations_size(expr const & e) {
lean_assert(is_equations(e));
if (is_wf_equations_core(e))
return macro_num_args(e) - 1;
return macro_num_args(e) - 2;
else
return macro_num_args(e);
}
void to_equations(expr const & e, buffer<expr> & eqns) {
lean_assert(is_equation(e));
unsigned sz = equations_size(e);
for (unsigned i = 0; i < sz; i++)
eqns.push_back(macro_arg(e, i));
unsigned equations_num_fns(expr const & e) {
lean_assert(is_equations(e));
return static_cast<equations_macro_cell const*>(macro_def(e).raw())->get_num_fns();
}
expr const & equations_wf_proof(expr const & e) {
lean_assert(is_wf_equations(e));
return macro_arg(e, macro_num_args(e) - 1);
}
expr mk_equations(unsigned num, expr const * eqns) {
lean_assert(std::all_of(eqns, eqns+num, is_equation));
lean_assert(num > 0);
return mk_macro(*g_equations, num, eqns);
expr const & equations_wf_rel(expr const & e) {
lean_assert(is_wf_equations(e));
return macro_arg(e, macro_num_args(e) - 2);
}
expr mk_equations(unsigned num, expr const * eqns, expr const & Hwf) {
lean_assert(std::all_of(eqns, eqns+num, is_equation));
lean_assert(num > 0);
void to_equations(expr const & e, buffer<expr> & eqns) {
lean_assert(is_equations(e));
unsigned sz = equations_size(e);
for (unsigned i = 0; i < sz; i++)
eqns.push_back(macro_arg(e, i));
}
expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs) {
lean_assert(num_fns > 0);
lean_assert(num_eqs > 0);
lean_assert(std::all_of(eqs, eqs+num_eqs, is_lambda_equation));
macro_definition def(new equations_macro_cell(num_fns));
return mk_macro(def, num_eqs, eqs);
}
expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf) {
lean_assert(num_fns > 0);
lean_assert(num_eqs > 0);
lean_assert(std::all_of(eqs, eqs+num_eqs, is_lambda_equation));
buffer<expr> args;
args.append(num, eqns);
args.append(num_eqs, eqs);
args.push_back(R);
args.push_back(Hwf);
return mk_macro(*g_equations, args.size(), args.data());
macro_definition def(new equations_macro_cell(num_fns));
return mk_macro(def, args.size(), args.data());
}
expr mk_inaccessible(expr const & e) { return mk_annotation(*g_inaccessible_name, e); }
@ -120,7 +150,6 @@ void initialize_equations() {
g_equation_name = new name("equation");
g_decreasing_name = new name("decreasing");
g_inaccessible_name = new name("innaccessible");
g_equations = new macro_definition(new equations_macro_cell());
g_equation = new macro_definition(new equation_macro_cell());
g_decreasing = new macro_definition(new decreasing_macro_cell());
g_equations_opcode = new std::string("Eqns");
@ -128,15 +157,17 @@ void initialize_equations() {
g_decreasing_opcode = new std::string("Decr");
register_annotation(*g_inaccessible_name);
register_macro_deserializer(*g_equations_opcode,
[](deserializer &, unsigned num, expr const * args) {
if (num == 0)
[](deserializer & d, unsigned num, expr const * args) {
unsigned num_fns;
d >> num_fns;
if (num == 0 || num_fns == 0)
throw corrupted_stream_exception();
if (!is_equation(args[num-1])) {
if (num == 1)
if (!is_lambda_equation(args[num-1])) {
if (num <= 2)
throw corrupted_stream_exception();
return mk_equations(num-1, args, args[num-1]);
return mk_equations(num_fns, num-2, args, args[num-2], args[num-1]);
} else {
return mk_equations(num, args);
return mk_equations(num_fns, num, args);
}
});
register_macro_deserializer(*g_equation_opcode,
@ -157,7 +188,6 @@ void finalize_equations() {
delete g_equation_opcode;
delete g_equations_opcode;
delete g_decreasing_opcode;
delete g_equations;
delete g_equation;
delete g_decreasing;
delete g_equations_name;

View file

@ -12,6 +12,8 @@ bool is_equation(expr const & e);
expr const & equation_lhs(expr const & e);
expr const & equation_rhs(expr const & e);
expr mk_equation(expr const & lhs, expr const & rhs);
/** \brief Return true if e is of the form <tt>fun a_1 ... a_n, equation</tt> */
bool is_lambda_equation(expr const & e);
bool is_decreasing(expr const & e);
expr const & decreasing_app(expr const & e);
@ -21,10 +23,12 @@ expr mk_decreasing(expr const & t, expr const & H);
bool is_equations(expr const & e);
bool is_wf_equations(expr const & e);
unsigned equations_size(expr const & e);
unsigned equations_num_fns(expr const & e);
void to_equations(expr const & e, buffer<expr> & eqns);
expr const & equations_wf_proof(expr const & e);
expr mk_equations(unsigned num, expr const * eqns);
expr mk_equations(unsigned num, expr const * eqns, expr const & Hwf);
expr const & equations_wf_rel(expr const & e);
expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs);
expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf);
expr mk_inaccessible(expr const & e);
bool is_inaccessible(expr const & e);

31
tests/lean/extra/rec.lean Normal file
View file

@ -0,0 +1,31 @@
import data.vector
open nat vector
check lt.base
set_option pp.implicit true
definition add : nat → nat → nat,
add zero b := b,
add (succ a) b := succ (add a b)
definition map {A B C : Type} (f : A → B → C) : Π {n}, vector A n → vector B n → vector C n,
map nil nil := nil,
map (a :: va) (b :: vb) := f a b :: map va vb
definition fib : nat → nat,
fib 0 := 1,
fib 1 := 1,
fib (a+2) := (fib a ↓ lt.step (lt.base a)) + (fib (a+1) ↓ lt.base (a+1))
[wf] lt.wf
definition half : nat → nat,
half 0 := 0,
half 1 := 0,
half (x+2) := half x + 1
variables {A B : Type}
inductive image_of (f : A → B) : B → Type :=
mk : Π a, image_of f (f a)
definition inv {f : A → B} : Π b, image_of f b → A,
inv ⌞f a⌟ (image_of.mk f a) := a