feat(library/normalize,frontends/lean): allow multiple arguments in [unfold] hint
closes #693
This commit is contained in:
parent
a27b20cd9c
commit
26574e29a9
7 changed files with 134 additions and 55 deletions
|
@ -112,11 +112,17 @@ void decl_attributes::parse(buffer<name> const & ns, parser & p) {
|
|||
m_constructor_hint = true;
|
||||
} else if (p.curr_is_token(get_unfold_tk())) {
|
||||
p.next();
|
||||
buffer<unsigned> idxs;
|
||||
while (true) {
|
||||
unsigned r = p.parse_small_nat();
|
||||
if (r == 0)
|
||||
throw parser_error("invalid '[unfold]' attribute, value must be greater than 0", pos);
|
||||
m_unfold_hint = r - 1;
|
||||
p.check_token_next(get_rbracket_tk(), "invalid 'unfold', ']' expected");
|
||||
idxs.push_back(r-1);
|
||||
if (p.curr_is_token(get_rbracket_tk()))
|
||||
break;
|
||||
}
|
||||
p.next();
|
||||
m_unfold_hint = to_list(idxs);
|
||||
} else if (p.curr_is_token(get_symm_tk())) {
|
||||
p.next();
|
||||
m_symm = true;
|
||||
|
@ -193,7 +199,7 @@ environment decl_attributes::apply(environment env, io_state const & ios, name c
|
|||
if (m_is_quasireducible)
|
||||
env = set_reducible(env, d, reducible_status::Quasireducible, m_persistent);
|
||||
if (m_unfold_hint)
|
||||
env = add_unfold_hint(env, d, *m_unfold_hint, m_persistent);
|
||||
env = add_unfold_hint(env, d, m_unfold_hint, m_persistent);
|
||||
if (m_unfold_full_hint)
|
||||
env = add_unfold_full_hint(env, d, m_persistent);
|
||||
}
|
||||
|
@ -223,7 +229,8 @@ void decl_attributes::write(serializer & s) const {
|
|||
<< m_is_reducible << m_is_irreducible << m_is_semireducible << m_is_quasireducible
|
||||
<< m_is_class << m_is_parsing_only << m_has_multiple_instances << m_unfold_full_hint
|
||||
<< m_constructor_hint << m_symm << m_trans << m_refl << m_subst << m_recursor
|
||||
<< m_rewrite << m_recursor_major_pos << m_priority << m_unfold_hint;
|
||||
<< m_rewrite << m_recursor_major_pos << m_priority;
|
||||
write_list(s, m_unfold_hint);
|
||||
}
|
||||
|
||||
void decl_attributes::read(deserializer & d) {
|
||||
|
@ -231,6 +238,7 @@ void decl_attributes::read(deserializer & d) {
|
|||
>> m_is_reducible >> m_is_irreducible >> m_is_semireducible >> m_is_quasireducible
|
||||
>> m_is_class >> m_is_parsing_only >> m_has_multiple_instances >> m_unfold_full_hint
|
||||
>> m_constructor_hint >> m_symm >> m_trans >> m_refl >> m_subst >> m_recursor
|
||||
>> m_rewrite >> m_recursor_major_pos >> m_priority >> m_unfold_hint;
|
||||
>> m_rewrite >> m_recursor_major_pos >> m_priority;
|
||||
m_unfold_hint = read_list<unsigned>(d);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ class decl_attributes {
|
|||
bool m_rewrite;
|
||||
optional<unsigned> m_recursor_major_pos;
|
||||
optional<unsigned> m_priority;
|
||||
optional<unsigned> m_unfold_hint;
|
||||
list<unsigned> m_unfold_hint;
|
||||
|
||||
void parse(name const & n, parser & p);
|
||||
public:
|
||||
|
|
|
@ -724,8 +724,8 @@ struct structure_cmd_fn {
|
|||
rec_on_decl.get_type(), rec_on_decl.get_value());
|
||||
m_env = module::add(m_env, check(m_env, new_decl));
|
||||
m_env = set_reducible(m_env, n, reducible_status::Reducible);
|
||||
if (optional<unsigned> idx = has_unfold_hint(m_env, rec_on_name))
|
||||
m_env = add_unfold_hint(m_env, n, *idx);
|
||||
if (list<unsigned> idx = has_unfold_hint(m_env, rec_on_name))
|
||||
m_env = add_unfold_hint(m_env, n, idx);
|
||||
save_def_info(n);
|
||||
add_alias(n);
|
||||
}
|
||||
|
|
|
@ -33,24 +33,26 @@ struct unfold_hint_entry {
|
|||
kind m_kind; //!< true if it is an unfold_c hint
|
||||
bool m_add; //!< add/remove hint
|
||||
name m_decl_name;
|
||||
unsigned m_arg_idx;
|
||||
unfold_hint_entry():m_kind(Unfold), m_add(false), m_arg_idx(0) {}
|
||||
unfold_hint_entry(kind k, bool add, name const & n, unsigned idx):
|
||||
m_kind(k), m_add(add), m_decl_name(n), m_arg_idx(idx) {}
|
||||
list<unsigned> m_arg_idxs; //!< only relevant if m_kind == Unfold
|
||||
unfold_hint_entry():m_kind(Unfold), m_add(false) {}
|
||||
unfold_hint_entry(kind k, bool add, name const & n):
|
||||
m_kind(k), m_add(add), m_decl_name(n) {}
|
||||
unfold_hint_entry(bool add, name const & n, list<unsigned> const & idxs):
|
||||
m_kind(Unfold), m_add(add), m_decl_name(n), m_arg_idxs(idxs) {}
|
||||
};
|
||||
|
||||
unfold_hint_entry mk_add_unfold_entry(name const & n, unsigned idx) { return unfold_hint_entry(unfold_hint_entry::Unfold, true, n, idx); }
|
||||
unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n, 0); }
|
||||
unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n, 0); }
|
||||
unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n, 0); }
|
||||
unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n, 0); }
|
||||
unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n, 0); }
|
||||
unfold_hint_entry mk_add_unfold_entry(name const & n, list<unsigned> const & idxs) { return unfold_hint_entry(true, n, idxs); }
|
||||
unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n); }
|
||||
unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n); }
|
||||
unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n); }
|
||||
unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n); }
|
||||
unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n); }
|
||||
|
||||
static name * g_unfold_hint_name = nullptr;
|
||||
static std::string * g_key = nullptr;
|
||||
|
||||
struct unfold_hint_state {
|
||||
name_map<unsigned> m_unfold;
|
||||
name_map<list<unsigned>> m_unfold;
|
||||
name_set m_unfold_full;
|
||||
name_set m_constructor;
|
||||
};
|
||||
|
@ -63,7 +65,7 @@ struct unfold_hint_config {
|
|||
switch (e.m_kind) {
|
||||
case unfold_hint_entry::Unfold:
|
||||
if (e.m_add)
|
||||
s.m_unfold.insert(e.m_decl_name, e.m_arg_idx);
|
||||
s.m_unfold.insert(e.m_decl_name, e.m_arg_idxs);
|
||||
else
|
||||
s.m_unfold.erase(e.m_decl_name);
|
||||
break;
|
||||
|
@ -88,13 +90,17 @@ struct unfold_hint_config {
|
|||
return *g_key;
|
||||
}
|
||||
static void write_entry(serializer & s, entry const & e) {
|
||||
s << static_cast<char>(e.m_kind) << e.m_add << e.m_decl_name << e.m_arg_idx;
|
||||
s << static_cast<char>(e.m_kind) << e.m_add << e.m_decl_name;
|
||||
if (e.m_kind == unfold_hint_entry::Unfold)
|
||||
write_list(s, e.m_arg_idxs);
|
||||
}
|
||||
static entry read_entry(deserializer & d) {
|
||||
char k;
|
||||
entry e;
|
||||
d >> k >> e.m_add >> e.m_decl_name >> e.m_arg_idx;
|
||||
d >> k >> e.m_add >> e.m_decl_name;
|
||||
e.m_kind = static_cast<unfold_hint_entry::kind>(k);
|
||||
if (e.m_kind == unfold_hint_entry::Unfold)
|
||||
e.m_arg_idxs = read_list<unsigned>(d);
|
||||
return e;
|
||||
}
|
||||
static optional<unsigned> get_fingerprint(entry const & e) {
|
||||
|
@ -105,19 +111,20 @@ struct unfold_hint_config {
|
|||
template class scoped_ext<unfold_hint_config>;
|
||||
typedef scoped_ext<unfold_hint_config> unfold_hint_ext;
|
||||
|
||||
environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent) {
|
||||
environment add_unfold_hint(environment const & env, name const & n, list<unsigned> const & idxs, bool persistent) {
|
||||
lean_assert(idxs);
|
||||
declaration const & d = env.get(n);
|
||||
if (!d.is_definition())
|
||||
throw exception("invalid [unfold] hint, declaration must be a non-opaque definition");
|
||||
return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idx), persistent);
|
||||
return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idxs), persistent);
|
||||
}
|
||||
|
||||
optional<unsigned> has_unfold_hint(environment const & env, name const & d) {
|
||||
list<unsigned> has_unfold_hint(environment const & env, name const & d) {
|
||||
unfold_hint_state const & s = unfold_hint_ext::get_state(env);
|
||||
if (auto it = s.m_unfold.find(d))
|
||||
return optional<unsigned>(*it);
|
||||
return list<unsigned>(*it);
|
||||
else
|
||||
return optional<unsigned>();
|
||||
return list<unsigned>();
|
||||
}
|
||||
|
||||
environment erase_unfold_hint(environment const & env, name const & n, bool persistent) {
|
||||
|
@ -246,9 +253,9 @@ class normalize_fn {
|
|||
return update_binding(e, d, b);
|
||||
}
|
||||
|
||||
optional<unsigned> has_unfold_hint(expr const & f) {
|
||||
list<unsigned> has_unfold_hint(expr const & f) {
|
||||
if (!is_constant(f))
|
||||
return optional<unsigned>();
|
||||
return list<unsigned>();
|
||||
return ::lean::has_unfold_hint(env(), const_name(f));
|
||||
}
|
||||
|
||||
|
@ -270,27 +277,39 @@ class normalize_fn {
|
|||
}
|
||||
}
|
||||
|
||||
optional<expr> unfold_recursor_core(expr const & f, unsigned idx, buffer<expr> & args, bool is_rec) {
|
||||
if (idx < args.size()) {
|
||||
expr & arg = args[args.size() - idx - 1];
|
||||
if (optional<expr> new_arg = is_constructor_like(arg)) {
|
||||
flet<expr> set_arg(arg, *new_arg);
|
||||
optional<expr> unfold_recursor_core(expr const & f, unsigned i,
|
||||
buffer<unsigned> const & idxs, buffer<expr> & args, bool is_rec) {
|
||||
if (i == idxs.size()) {
|
||||
expr new_app = mk_rev_app(f, args);
|
||||
if (is_rec)
|
||||
return some_expr(normalize(new_app));
|
||||
else if (optional<expr> r = unfold_app(env(), new_app))
|
||||
return some_expr(normalize(*r));
|
||||
}
|
||||
}
|
||||
else
|
||||
return none_expr();
|
||||
} else {
|
||||
unsigned idx = idxs[i];
|
||||
if (idx >= args.size())
|
||||
return none_expr();
|
||||
expr & arg = args[args.size() - idx - 1];
|
||||
optional<expr> new_arg = is_constructor_like(arg);
|
||||
if (!new_arg)
|
||||
return none_expr();
|
||||
flet<expr> set_arg(arg, *new_arg);
|
||||
return unfold_recursor_core(f, i+1, idxs, args, is_rec);
|
||||
}
|
||||
}
|
||||
|
||||
optional<expr> unfold_recursor_like(expr const & f, unsigned idx, buffer<expr> & args) {
|
||||
return unfold_recursor_core(f, idx, args, false);
|
||||
optional<expr> unfold_recursor_like(expr const & f, list<unsigned> const & idx_lst, buffer<expr> & args) {
|
||||
buffer<unsigned> idxs;
|
||||
to_buffer(idx_lst, idxs);
|
||||
return unfold_recursor_core(f, 0, idxs, args, false);
|
||||
}
|
||||
|
||||
optional<expr> unfold_recursor_major(expr const & f, unsigned idx, buffer<expr> & args) {
|
||||
return unfold_recursor_core(f, idx, args, true);
|
||||
buffer<unsigned> idxs;
|
||||
idxs.push_back(idx);
|
||||
return unfold_recursor_core(f, 0, idxs, args, true);
|
||||
}
|
||||
|
||||
expr normalize_app(expr const & e) {
|
||||
|
@ -311,8 +330,8 @@ class normalize_fn {
|
|||
return normalize(*r);
|
||||
}
|
||||
}
|
||||
if (auto idx = has_unfold_hint(f)) {
|
||||
if (auto r = unfold_recursor_like(f, *idx, args))
|
||||
if (auto idxs = has_unfold_hint(f)) {
|
||||
if (auto r = unfold_recursor_like(f, idxs, args))
|
||||
return *r;
|
||||
}
|
||||
if (is_constant(f)) {
|
||||
|
|
|
@ -40,10 +40,10 @@ expr normalize(type_checker & tc, expr const & e, std::function<bool(expr const&
|
|||
|
||||
Of course, kernel opaque constants are not unfolded.
|
||||
*/
|
||||
environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent = true);
|
||||
environment add_unfold_hint(environment const & env, name const & n, list<unsigned> const & idxs, bool persistent = true);
|
||||
environment erase_unfold_hint(environment const & env, name const & n, bool persistent = true);
|
||||
/** \brief Retrieve the hint added with the procedure add_unfold_hint. */
|
||||
optional<unsigned> has_unfold_hint(environment const & env, name const & d);
|
||||
list<unsigned> has_unfold_hint(environment const & env, name const & d);
|
||||
|
||||
/** \brief [unfold-full] hint instructs normalizer (and simplifier) that function application
|
||||
(f a_1 ... a_n) should be unfolded when it is fully applied */
|
||||
|
|
36
tests/lean/693.lean
Normal file
36
tests/lean/693.lean
Normal file
|
@ -0,0 +1,36 @@
|
|||
open nat
|
||||
|
||||
definition foo [unfold 1 3] (a : nat) (b : nat) (c :nat) : nat :=
|
||||
(a + c) * b
|
||||
|
||||
example (c : nat) : c = 1 → foo 1 c 0 = foo 1 1 0 :=
|
||||
begin
|
||||
intro h,
|
||||
esimp,
|
||||
state,
|
||||
subst c
|
||||
end
|
||||
|
||||
example (b c : nat) : c = 1 → foo 1 c b = foo 1 1 b :=
|
||||
begin
|
||||
intro h,
|
||||
esimp, -- should not unfold foo
|
||||
state,
|
||||
subst c
|
||||
end
|
||||
|
||||
example (b c : nat) : c = 1 → foo b c 0 = foo b 1 0 :=
|
||||
begin
|
||||
intro h,
|
||||
esimp, -- should not unfold foo
|
||||
state,
|
||||
subst c
|
||||
end
|
||||
|
||||
example (b c : nat) : c = 1 → foo 1 c 1 = foo c 1 1 :=
|
||||
begin
|
||||
intro h,
|
||||
esimp, -- should fold only first foo
|
||||
state,
|
||||
subst c
|
||||
end
|
16
tests/lean/693.lean.expected.out
Normal file
16
tests/lean/693.lean.expected.out
Normal file
|
@ -0,0 +1,16 @@
|
|||
693.lean:10:2: proof state
|
||||
c : ℕ,
|
||||
h : c = 1
|
||||
⊢ (1 + 0) * c = (1 + 0) * 1
|
||||
693.lean:18:2: proof state
|
||||
b c : ℕ,
|
||||
h : c = 1
|
||||
⊢ foo 1 c b = foo 1 1 b
|
||||
693.lean:26:2: proof state
|
||||
b c : ℕ,
|
||||
h : c = 1
|
||||
⊢ foo b c 0 = foo b 1 0
|
||||
693.lean:34:2: proof state
|
||||
b c : ℕ,
|
||||
h : c = 1
|
||||
⊢ (1 + 1) * c = foo c 1 1
|
Loading…
Reference in a new issue