feat(library/normalize,frontends/lean): allow multiple arguments in [unfold] hint

closes #693
This commit is contained in:
Leonardo de Moura 2015-07-07 18:01:57 -07:00
parent a27b20cd9c
commit 26574e29a9
7 changed files with 134 additions and 55 deletions

View file

@ -112,11 +112,17 @@ void decl_attributes::parse(buffer<name> const & ns, parser & p) {
m_constructor_hint = true;
} else if (p.curr_is_token(get_unfold_tk())) {
p.next();
buffer<unsigned> idxs;
while (true) {
unsigned r = p.parse_small_nat();
if (r == 0)
throw parser_error("invalid '[unfold]' attribute, value must be greater than 0", pos);
m_unfold_hint = r - 1;
p.check_token_next(get_rbracket_tk(), "invalid 'unfold', ']' expected");
idxs.push_back(r-1);
if (p.curr_is_token(get_rbracket_tk()))
break;
}
p.next();
m_unfold_hint = to_list(idxs);
} else if (p.curr_is_token(get_symm_tk())) {
p.next();
m_symm = true;
@ -193,7 +199,7 @@ environment decl_attributes::apply(environment env, io_state const & ios, name c
if (m_is_quasireducible)
env = set_reducible(env, d, reducible_status::Quasireducible, m_persistent);
if (m_unfold_hint)
env = add_unfold_hint(env, d, *m_unfold_hint, m_persistent);
env = add_unfold_hint(env, d, m_unfold_hint, m_persistent);
if (m_unfold_full_hint)
env = add_unfold_full_hint(env, d, m_persistent);
}
@ -223,7 +229,8 @@ void decl_attributes::write(serializer & s) const {
<< m_is_reducible << m_is_irreducible << m_is_semireducible << m_is_quasireducible
<< m_is_class << m_is_parsing_only << m_has_multiple_instances << m_unfold_full_hint
<< m_constructor_hint << m_symm << m_trans << m_refl << m_subst << m_recursor
<< m_rewrite << m_recursor_major_pos << m_priority << m_unfold_hint;
<< m_rewrite << m_recursor_major_pos << m_priority;
write_list(s, m_unfold_hint);
}
void decl_attributes::read(deserializer & d) {
@ -231,6 +238,7 @@ void decl_attributes::read(deserializer & d) {
>> m_is_reducible >> m_is_irreducible >> m_is_semireducible >> m_is_quasireducible
>> m_is_class >> m_is_parsing_only >> m_has_multiple_instances >> m_unfold_full_hint
>> m_constructor_hint >> m_symm >> m_trans >> m_refl >> m_subst >> m_recursor
>> m_rewrite >> m_recursor_major_pos >> m_priority >> m_unfold_hint;
>> m_rewrite >> m_recursor_major_pos >> m_priority;
m_unfold_hint = read_list<unsigned>(d);
}
}

View file

@ -32,7 +32,7 @@ class decl_attributes {
bool m_rewrite;
optional<unsigned> m_recursor_major_pos;
optional<unsigned> m_priority;
optional<unsigned> m_unfold_hint;
list<unsigned> m_unfold_hint;
void parse(name const & n, parser & p);
public:

View file

@ -724,8 +724,8 @@ struct structure_cmd_fn {
rec_on_decl.get_type(), rec_on_decl.get_value());
m_env = module::add(m_env, check(m_env, new_decl));
m_env = set_reducible(m_env, n, reducible_status::Reducible);
if (optional<unsigned> idx = has_unfold_hint(m_env, rec_on_name))
m_env = add_unfold_hint(m_env, n, *idx);
if (list<unsigned> idx = has_unfold_hint(m_env, rec_on_name))
m_env = add_unfold_hint(m_env, n, idx);
save_def_info(n);
add_alias(n);
}

View file

@ -33,24 +33,26 @@ struct unfold_hint_entry {
kind m_kind; //!< true if it is an unfold_c hint
bool m_add; //!< add/remove hint
name m_decl_name;
unsigned m_arg_idx;
unfold_hint_entry():m_kind(Unfold), m_add(false), m_arg_idx(0) {}
unfold_hint_entry(kind k, bool add, name const & n, unsigned idx):
m_kind(k), m_add(add), m_decl_name(n), m_arg_idx(idx) {}
list<unsigned> m_arg_idxs; //!< only relevant if m_kind == Unfold
unfold_hint_entry():m_kind(Unfold), m_add(false) {}
unfold_hint_entry(kind k, bool add, name const & n):
m_kind(k), m_add(add), m_decl_name(n) {}
unfold_hint_entry(bool add, name const & n, list<unsigned> const & idxs):
m_kind(Unfold), m_add(add), m_decl_name(n), m_arg_idxs(idxs) {}
};
unfold_hint_entry mk_add_unfold_entry(name const & n, unsigned idx) { return unfold_hint_entry(unfold_hint_entry::Unfold, true, n, idx); }
unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n, 0); }
unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n, 0); }
unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n, 0); }
unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n, 0); }
unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n, 0); }
unfold_hint_entry mk_add_unfold_entry(name const & n, list<unsigned> const & idxs) { return unfold_hint_entry(true, n, idxs); }
unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n); }
unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n); }
unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n); }
unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n); }
unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n); }
static name * g_unfold_hint_name = nullptr;
static std::string * g_key = nullptr;
struct unfold_hint_state {
name_map<unsigned> m_unfold;
name_map<list<unsigned>> m_unfold;
name_set m_unfold_full;
name_set m_constructor;
};
@ -63,7 +65,7 @@ struct unfold_hint_config {
switch (e.m_kind) {
case unfold_hint_entry::Unfold:
if (e.m_add)
s.m_unfold.insert(e.m_decl_name, e.m_arg_idx);
s.m_unfold.insert(e.m_decl_name, e.m_arg_idxs);
else
s.m_unfold.erase(e.m_decl_name);
break;
@ -88,13 +90,17 @@ struct unfold_hint_config {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << static_cast<char>(e.m_kind) << e.m_add << e.m_decl_name << e.m_arg_idx;
s << static_cast<char>(e.m_kind) << e.m_add << e.m_decl_name;
if (e.m_kind == unfold_hint_entry::Unfold)
write_list(s, e.m_arg_idxs);
}
static entry read_entry(deserializer & d) {
char k;
entry e;
d >> k >> e.m_add >> e.m_decl_name >> e.m_arg_idx;
d >> k >> e.m_add >> e.m_decl_name;
e.m_kind = static_cast<unfold_hint_entry::kind>(k);
if (e.m_kind == unfold_hint_entry::Unfold)
e.m_arg_idxs = read_list<unsigned>(d);
return e;
}
static optional<unsigned> get_fingerprint(entry const & e) {
@ -105,19 +111,20 @@ struct unfold_hint_config {
template class scoped_ext<unfold_hint_config>;
typedef scoped_ext<unfold_hint_config> unfold_hint_ext;
environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent) {
environment add_unfold_hint(environment const & env, name const & n, list<unsigned> const & idxs, bool persistent) {
lean_assert(idxs);
declaration const & d = env.get(n);
if (!d.is_definition())
throw exception("invalid [unfold] hint, declaration must be a non-opaque definition");
return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idx), persistent);
return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idxs), persistent);
}
optional<unsigned> has_unfold_hint(environment const & env, name const & d) {
list<unsigned> has_unfold_hint(environment const & env, name const & d) {
unfold_hint_state const & s = unfold_hint_ext::get_state(env);
if (auto it = s.m_unfold.find(d))
return optional<unsigned>(*it);
return list<unsigned>(*it);
else
return optional<unsigned>();
return list<unsigned>();
}
environment erase_unfold_hint(environment const & env, name const & n, bool persistent) {
@ -246,9 +253,9 @@ class normalize_fn {
return update_binding(e, d, b);
}
optional<unsigned> has_unfold_hint(expr const & f) {
list<unsigned> has_unfold_hint(expr const & f) {
if (!is_constant(f))
return optional<unsigned>();
return list<unsigned>();
return ::lean::has_unfold_hint(env(), const_name(f));
}
@ -270,27 +277,39 @@ class normalize_fn {
}
}
optional<expr> unfold_recursor_core(expr const & f, unsigned idx, buffer<expr> & args, bool is_rec) {
if (idx < args.size()) {
expr & arg = args[args.size() - idx - 1];
if (optional<expr> new_arg = is_constructor_like(arg)) {
flet<expr> set_arg(arg, *new_arg);
optional<expr> unfold_recursor_core(expr const & f, unsigned i,
buffer<unsigned> const & idxs, buffer<expr> & args, bool is_rec) {
if (i == idxs.size()) {
expr new_app = mk_rev_app(f, args);
if (is_rec)
return some_expr(normalize(new_app));
else if (optional<expr> r = unfold_app(env(), new_app))
return some_expr(normalize(*r));
}
}
else
return none_expr();
} else {
unsigned idx = idxs[i];
if (idx >= args.size())
return none_expr();
expr & arg = args[args.size() - idx - 1];
optional<expr> new_arg = is_constructor_like(arg);
if (!new_arg)
return none_expr();
flet<expr> set_arg(arg, *new_arg);
return unfold_recursor_core(f, i+1, idxs, args, is_rec);
}
}
optional<expr> unfold_recursor_like(expr const & f, unsigned idx, buffer<expr> & args) {
return unfold_recursor_core(f, idx, args, false);
optional<expr> unfold_recursor_like(expr const & f, list<unsigned> const & idx_lst, buffer<expr> & args) {
buffer<unsigned> idxs;
to_buffer(idx_lst, idxs);
return unfold_recursor_core(f, 0, idxs, args, false);
}
optional<expr> unfold_recursor_major(expr const & f, unsigned idx, buffer<expr> & args) {
return unfold_recursor_core(f, idx, args, true);
buffer<unsigned> idxs;
idxs.push_back(idx);
return unfold_recursor_core(f, 0, idxs, args, true);
}
expr normalize_app(expr const & e) {
@ -311,8 +330,8 @@ class normalize_fn {
return normalize(*r);
}
}
if (auto idx = has_unfold_hint(f)) {
if (auto r = unfold_recursor_like(f, *idx, args))
if (auto idxs = has_unfold_hint(f)) {
if (auto r = unfold_recursor_like(f, idxs, args))
return *r;
}
if (is_constant(f)) {

View file

@ -40,10 +40,10 @@ expr normalize(type_checker & tc, expr const & e, std::function<bool(expr const&
Of course, kernel opaque constants are not unfolded.
*/
environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent = true);
environment add_unfold_hint(environment const & env, name const & n, list<unsigned> const & idxs, bool persistent = true);
environment erase_unfold_hint(environment const & env, name const & n, bool persistent = true);
/** \brief Retrieve the hint added with the procedure add_unfold_hint. */
optional<unsigned> has_unfold_hint(environment const & env, name const & d);
list<unsigned> has_unfold_hint(environment const & env, name const & d);
/** \brief [unfold-full] hint instructs normalizer (and simplifier) that function application
(f a_1 ... a_n) should be unfolded when it is fully applied */

36
tests/lean/693.lean Normal file
View file

@ -0,0 +1,36 @@
open nat
definition foo [unfold 1 3] (a : nat) (b : nat) (c :nat) : nat :=
(a + c) * b
example (c : nat) : c = 1 → foo 1 c 0 = foo 1 1 0 :=
begin
intro h,
esimp,
state,
subst c
end
example (b c : nat) : c = 1 → foo 1 c b = foo 1 1 b :=
begin
intro h,
esimp, -- should not unfold foo
state,
subst c
end
example (b c : nat) : c = 1 → foo b c 0 = foo b 1 0 :=
begin
intro h,
esimp, -- should not unfold foo
state,
subst c
end
example (b c : nat) : c = 1 → foo 1 c 1 = foo c 1 1 :=
begin
intro h,
esimp, -- should fold only first foo
state,
subst c
end

View file

@ -0,0 +1,16 @@
693.lean:10:2: proof state
c : ,
h : c = 1
⊢ (1 + 0) * c = (1 + 0) * 1
693.lean:18:2: proof state
b c : ,
h : c = 1
⊢ foo 1 c b = foo 1 1 b
693.lean:26:2: proof state
b c : ,
h : c = 1
⊢ foo b c 0 = foo b 1 0
693.lean:34:2: proof state
b c : ,
h : c = 1
⊢ (1 + 1) * c = foo c 1 1