feat(library/normalize,frontends/lean): allow multiple arguments in [unfold] hint

closes #693
This commit is contained in:
Leonardo de Moura 2015-07-07 18:01:57 -07:00
parent a27b20cd9c
commit 26574e29a9
7 changed files with 134 additions and 55 deletions

View file

@ -112,11 +112,17 @@ void decl_attributes::parse(buffer<name> const & ns, parser & p) {
m_constructor_hint = true; m_constructor_hint = true;
} else if (p.curr_is_token(get_unfold_tk())) { } else if (p.curr_is_token(get_unfold_tk())) {
p.next(); p.next();
buffer<unsigned> idxs;
while (true) {
unsigned r = p.parse_small_nat(); unsigned r = p.parse_small_nat();
if (r == 0) if (r == 0)
throw parser_error("invalid '[unfold]' attribute, value must be greater than 0", pos); throw parser_error("invalid '[unfold]' attribute, value must be greater than 0", pos);
m_unfold_hint = r - 1; idxs.push_back(r-1);
p.check_token_next(get_rbracket_tk(), "invalid 'unfold', ']' expected"); if (p.curr_is_token(get_rbracket_tk()))
break;
}
p.next();
m_unfold_hint = to_list(idxs);
} else if (p.curr_is_token(get_symm_tk())) { } else if (p.curr_is_token(get_symm_tk())) {
p.next(); p.next();
m_symm = true; m_symm = true;
@ -193,7 +199,7 @@ environment decl_attributes::apply(environment env, io_state const & ios, name c
if (m_is_quasireducible) if (m_is_quasireducible)
env = set_reducible(env, d, reducible_status::Quasireducible, m_persistent); env = set_reducible(env, d, reducible_status::Quasireducible, m_persistent);
if (m_unfold_hint) if (m_unfold_hint)
env = add_unfold_hint(env, d, *m_unfold_hint, m_persistent); env = add_unfold_hint(env, d, m_unfold_hint, m_persistent);
if (m_unfold_full_hint) if (m_unfold_full_hint)
env = add_unfold_full_hint(env, d, m_persistent); env = add_unfold_full_hint(env, d, m_persistent);
} }
@ -223,7 +229,8 @@ void decl_attributes::write(serializer & s) const {
<< m_is_reducible << m_is_irreducible << m_is_semireducible << m_is_quasireducible << m_is_reducible << m_is_irreducible << m_is_semireducible << m_is_quasireducible
<< m_is_class << m_is_parsing_only << m_has_multiple_instances << m_unfold_full_hint << m_is_class << m_is_parsing_only << m_has_multiple_instances << m_unfold_full_hint
<< m_constructor_hint << m_symm << m_trans << m_refl << m_subst << m_recursor << m_constructor_hint << m_symm << m_trans << m_refl << m_subst << m_recursor
<< m_rewrite << m_recursor_major_pos << m_priority << m_unfold_hint; << m_rewrite << m_recursor_major_pos << m_priority;
write_list(s, m_unfold_hint);
} }
void decl_attributes::read(deserializer & d) { void decl_attributes::read(deserializer & d) {
@ -231,6 +238,7 @@ void decl_attributes::read(deserializer & d) {
>> m_is_reducible >> m_is_irreducible >> m_is_semireducible >> m_is_quasireducible >> m_is_reducible >> m_is_irreducible >> m_is_semireducible >> m_is_quasireducible
>> m_is_class >> m_is_parsing_only >> m_has_multiple_instances >> m_unfold_full_hint >> m_is_class >> m_is_parsing_only >> m_has_multiple_instances >> m_unfold_full_hint
>> m_constructor_hint >> m_symm >> m_trans >> m_refl >> m_subst >> m_recursor >> m_constructor_hint >> m_symm >> m_trans >> m_refl >> m_subst >> m_recursor
>> m_rewrite >> m_recursor_major_pos >> m_priority >> m_unfold_hint; >> m_rewrite >> m_recursor_major_pos >> m_priority;
m_unfold_hint = read_list<unsigned>(d);
} }
} }

View file

@ -32,7 +32,7 @@ class decl_attributes {
bool m_rewrite; bool m_rewrite;
optional<unsigned> m_recursor_major_pos; optional<unsigned> m_recursor_major_pos;
optional<unsigned> m_priority; optional<unsigned> m_priority;
optional<unsigned> m_unfold_hint; list<unsigned> m_unfold_hint;
void parse(name const & n, parser & p); void parse(name const & n, parser & p);
public: public:

View file

@ -724,8 +724,8 @@ struct structure_cmd_fn {
rec_on_decl.get_type(), rec_on_decl.get_value()); rec_on_decl.get_type(), rec_on_decl.get_value());
m_env = module::add(m_env, check(m_env, new_decl)); m_env = module::add(m_env, check(m_env, new_decl));
m_env = set_reducible(m_env, n, reducible_status::Reducible); m_env = set_reducible(m_env, n, reducible_status::Reducible);
if (optional<unsigned> idx = has_unfold_hint(m_env, rec_on_name)) if (list<unsigned> idx = has_unfold_hint(m_env, rec_on_name))
m_env = add_unfold_hint(m_env, n, *idx); m_env = add_unfold_hint(m_env, n, idx);
save_def_info(n); save_def_info(n);
add_alias(n); add_alias(n);
} }

View file

@ -33,24 +33,26 @@ struct unfold_hint_entry {
kind m_kind; //!< true if it is an unfold_c hint kind m_kind; //!< true if it is an unfold_c hint
bool m_add; //!< add/remove hint bool m_add; //!< add/remove hint
name m_decl_name; name m_decl_name;
unsigned m_arg_idx; list<unsigned> m_arg_idxs; //!< only relevant if m_kind == Unfold
unfold_hint_entry():m_kind(Unfold), m_add(false), m_arg_idx(0) {} unfold_hint_entry():m_kind(Unfold), m_add(false) {}
unfold_hint_entry(kind k, bool add, name const & n, unsigned idx): unfold_hint_entry(kind k, bool add, name const & n):
m_kind(k), m_add(add), m_decl_name(n), m_arg_idx(idx) {} m_kind(k), m_add(add), m_decl_name(n) {}
unfold_hint_entry(bool add, name const & n, list<unsigned> const & idxs):
m_kind(Unfold), m_add(add), m_decl_name(n), m_arg_idxs(idxs) {}
}; };
unfold_hint_entry mk_add_unfold_entry(name const & n, unsigned idx) { return unfold_hint_entry(unfold_hint_entry::Unfold, true, n, idx); } unfold_hint_entry mk_add_unfold_entry(name const & n, list<unsigned> const & idxs) { return unfold_hint_entry(true, n, idxs); }
unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n, 0); } unfold_hint_entry mk_erase_unfold_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Unfold, false, n); }
unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n, 0); } unfold_hint_entry mk_add_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, true, n); }
unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n, 0); } unfold_hint_entry mk_erase_unfold_full_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::UnfoldFull, false, n); }
unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n, 0); } unfold_hint_entry mk_add_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, true, n); }
unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n, 0); } unfold_hint_entry mk_erase_constructor_entry(name const & n) { return unfold_hint_entry(unfold_hint_entry::Constructor, false, n); }
static name * g_unfold_hint_name = nullptr; static name * g_unfold_hint_name = nullptr;
static std::string * g_key = nullptr; static std::string * g_key = nullptr;
struct unfold_hint_state { struct unfold_hint_state {
name_map<unsigned> m_unfold; name_map<list<unsigned>> m_unfold;
name_set m_unfold_full; name_set m_unfold_full;
name_set m_constructor; name_set m_constructor;
}; };
@ -63,7 +65,7 @@ struct unfold_hint_config {
switch (e.m_kind) { switch (e.m_kind) {
case unfold_hint_entry::Unfold: case unfold_hint_entry::Unfold:
if (e.m_add) if (e.m_add)
s.m_unfold.insert(e.m_decl_name, e.m_arg_idx); s.m_unfold.insert(e.m_decl_name, e.m_arg_idxs);
else else
s.m_unfold.erase(e.m_decl_name); s.m_unfold.erase(e.m_decl_name);
break; break;
@ -88,13 +90,17 @@ struct unfold_hint_config {
return *g_key; return *g_key;
} }
static void write_entry(serializer & s, entry const & e) { static void write_entry(serializer & s, entry const & e) {
s << static_cast<char>(e.m_kind) << e.m_add << e.m_decl_name << e.m_arg_idx; s << static_cast<char>(e.m_kind) << e.m_add << e.m_decl_name;
if (e.m_kind == unfold_hint_entry::Unfold)
write_list(s, e.m_arg_idxs);
} }
static entry read_entry(deserializer & d) { static entry read_entry(deserializer & d) {
char k; char k;
entry e; entry e;
d >> k >> e.m_add >> e.m_decl_name >> e.m_arg_idx; d >> k >> e.m_add >> e.m_decl_name;
e.m_kind = static_cast<unfold_hint_entry::kind>(k); e.m_kind = static_cast<unfold_hint_entry::kind>(k);
if (e.m_kind == unfold_hint_entry::Unfold)
e.m_arg_idxs = read_list<unsigned>(d);
return e; return e;
} }
static optional<unsigned> get_fingerprint(entry const & e) { static optional<unsigned> get_fingerprint(entry const & e) {
@ -105,19 +111,20 @@ struct unfold_hint_config {
template class scoped_ext<unfold_hint_config>; template class scoped_ext<unfold_hint_config>;
typedef scoped_ext<unfold_hint_config> unfold_hint_ext; typedef scoped_ext<unfold_hint_config> unfold_hint_ext;
environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent) { environment add_unfold_hint(environment const & env, name const & n, list<unsigned> const & idxs, bool persistent) {
lean_assert(idxs);
declaration const & d = env.get(n); declaration const & d = env.get(n);
if (!d.is_definition()) if (!d.is_definition())
throw exception("invalid [unfold] hint, declaration must be a non-opaque definition"); throw exception("invalid [unfold] hint, declaration must be a non-opaque definition");
return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idx), persistent); return unfold_hint_ext::add_entry(env, get_dummy_ios(), mk_add_unfold_entry(n, idxs), persistent);
} }
optional<unsigned> has_unfold_hint(environment const & env, name const & d) { list<unsigned> has_unfold_hint(environment const & env, name const & d) {
unfold_hint_state const & s = unfold_hint_ext::get_state(env); unfold_hint_state const & s = unfold_hint_ext::get_state(env);
if (auto it = s.m_unfold.find(d)) if (auto it = s.m_unfold.find(d))
return optional<unsigned>(*it); return list<unsigned>(*it);
else else
return optional<unsigned>(); return list<unsigned>();
} }
environment erase_unfold_hint(environment const & env, name const & n, bool persistent) { environment erase_unfold_hint(environment const & env, name const & n, bool persistent) {
@ -246,9 +253,9 @@ class normalize_fn {
return update_binding(e, d, b); return update_binding(e, d, b);
} }
optional<unsigned> has_unfold_hint(expr const & f) { list<unsigned> has_unfold_hint(expr const & f) {
if (!is_constant(f)) if (!is_constant(f))
return optional<unsigned>(); return list<unsigned>();
return ::lean::has_unfold_hint(env(), const_name(f)); return ::lean::has_unfold_hint(env(), const_name(f));
} }
@ -270,27 +277,39 @@ class normalize_fn {
} }
} }
optional<expr> unfold_recursor_core(expr const & f, unsigned idx, buffer<expr> & args, bool is_rec) { optional<expr> unfold_recursor_core(expr const & f, unsigned i,
if (idx < args.size()) { buffer<unsigned> const & idxs, buffer<expr> & args, bool is_rec) {
expr & arg = args[args.size() - idx - 1]; if (i == idxs.size()) {
if (optional<expr> new_arg = is_constructor_like(arg)) {
flet<expr> set_arg(arg, *new_arg);
expr new_app = mk_rev_app(f, args); expr new_app = mk_rev_app(f, args);
if (is_rec) if (is_rec)
return some_expr(normalize(new_app)); return some_expr(normalize(new_app));
else if (optional<expr> r = unfold_app(env(), new_app)) else if (optional<expr> r = unfold_app(env(), new_app))
return some_expr(normalize(*r)); return some_expr(normalize(*r));
} else
}
return none_expr(); return none_expr();
} else {
unsigned idx = idxs[i];
if (idx >= args.size())
return none_expr();
expr & arg = args[args.size() - idx - 1];
optional<expr> new_arg = is_constructor_like(arg);
if (!new_arg)
return none_expr();
flet<expr> set_arg(arg, *new_arg);
return unfold_recursor_core(f, i+1, idxs, args, is_rec);
}
} }
optional<expr> unfold_recursor_like(expr const & f, unsigned idx, buffer<expr> & args) { optional<expr> unfold_recursor_like(expr const & f, list<unsigned> const & idx_lst, buffer<expr> & args) {
return unfold_recursor_core(f, idx, args, false); buffer<unsigned> idxs;
to_buffer(idx_lst, idxs);
return unfold_recursor_core(f, 0, idxs, args, false);
} }
optional<expr> unfold_recursor_major(expr const & f, unsigned idx, buffer<expr> & args) { optional<expr> unfold_recursor_major(expr const & f, unsigned idx, buffer<expr> & args) {
return unfold_recursor_core(f, idx, args, true); buffer<unsigned> idxs;
idxs.push_back(idx);
return unfold_recursor_core(f, 0, idxs, args, true);
} }
expr normalize_app(expr const & e) { expr normalize_app(expr const & e) {
@ -311,8 +330,8 @@ class normalize_fn {
return normalize(*r); return normalize(*r);
} }
} }
if (auto idx = has_unfold_hint(f)) { if (auto idxs = has_unfold_hint(f)) {
if (auto r = unfold_recursor_like(f, *idx, args)) if (auto r = unfold_recursor_like(f, idxs, args))
return *r; return *r;
} }
if (is_constant(f)) { if (is_constant(f)) {

View file

@ -40,10 +40,10 @@ expr normalize(type_checker & tc, expr const & e, std::function<bool(expr const&
Of course, kernel opaque constants are not unfolded. Of course, kernel opaque constants are not unfolded.
*/ */
environment add_unfold_hint(environment const & env, name const & n, unsigned idx, bool persistent = true); environment add_unfold_hint(environment const & env, name const & n, list<unsigned> const & idxs, bool persistent = true);
environment erase_unfold_hint(environment const & env, name const & n, bool persistent = true); environment erase_unfold_hint(environment const & env, name const & n, bool persistent = true);
/** \brief Retrieve the hint added with the procedure add_unfold_hint. */ /** \brief Retrieve the hint added with the procedure add_unfold_hint. */
optional<unsigned> has_unfold_hint(environment const & env, name const & d); list<unsigned> has_unfold_hint(environment const & env, name const & d);
/** \brief [unfold-full] hint instructs normalizer (and simplifier) that function application /** \brief [unfold-full] hint instructs normalizer (and simplifier) that function application
(f a_1 ... a_n) should be unfolded when it is fully applied */ (f a_1 ... a_n) should be unfolded when it is fully applied */

36
tests/lean/693.lean Normal file
View file

@ -0,0 +1,36 @@
open nat
definition foo [unfold 1 3] (a : nat) (b : nat) (c :nat) : nat :=
(a + c) * b
example (c : nat) : c = 1 → foo 1 c 0 = foo 1 1 0 :=
begin
intro h,
esimp,
state,
subst c
end
example (b c : nat) : c = 1 → foo 1 c b = foo 1 1 b :=
begin
intro h,
esimp, -- should not unfold foo
state,
subst c
end
example (b c : nat) : c = 1 → foo b c 0 = foo b 1 0 :=
begin
intro h,
esimp, -- should not unfold foo
state,
subst c
end
example (b c : nat) : c = 1 → foo 1 c 1 = foo c 1 1 :=
begin
intro h,
esimp, -- should fold only first foo
state,
subst c
end

View file

@ -0,0 +1,16 @@
693.lean:10:2: proof state
c : ,
h : c = 1
⊢ (1 + 0) * c = (1 + 0) * 1
693.lean:18:2: proof state
b c : ,
h : c = 1
⊢ foo 1 c b = foo 1 1 b
693.lean:26:2: proof state
b c : ,
h : c = 1
⊢ foo b c 0 = foo b 1 0
693.lean:34:2: proof state
b c : ,
h : c = 1
⊢ (1 + 1) * c = foo c 1 1