fix(frontends/lean/pp): compute local shared nodes, and avoid unnecessary let's

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-22 21:44:24 -08:00
parent 17cce340f6
commit 1638a7bb02
2 changed files with 41 additions and 13 deletions

View file

@ -180,6 +180,7 @@ class pp_fn {
local_names m_local_names; local_names m_local_names;
unsigned m_num_steps; unsigned m_num_steps;
name m_aux; name m_aux;
expr_map<unsigned> m_num_occs;
// Configuration // Configuration
unsigned m_indent; unsigned m_indent;
unsigned m_max_depth; unsigned m_max_depth;
@ -193,18 +194,47 @@ class pp_fn {
// Create a scope for local definitions // Create a scope for local definitions
struct mk_scope { struct mk_scope {
pp_fn & m_fn; pp_fn & m_fn;
unsigned m_old_size; unsigned m_old_size;
mk_scope(pp_fn & fn):m_fn(fn), m_old_size(fn.m_local_aliases_defs.size()) { expr_map<unsigned> m_num_occs;
void update_num_occs(expr const & e) {
buffer<expr> todo;
todo.push_back(e);
while (!todo.empty()) {
expr e = todo.back();
todo.pop_back();
unsigned & n = m_num_occs[e];
n++;
// we do not visit other composite expressions such as Let, Lambda and Pi, since they create new scopes
if (n == 1 && is_app(e)) {
for (unsigned i = 0; i < num_args(e); i++)
todo.push_back(arg(e, i));
}
}
}
mk_scope(pp_fn & fn, expr const & e):m_fn(fn), m_old_size(fn.m_local_aliases_defs.size()) {
m_fn.m_local_aliases.push(); m_fn.m_local_aliases.push();
update_num_occs(e);
swap(m_fn.m_num_occs, m_num_occs);
} }
~mk_scope() { ~mk_scope() {
lean_assert(m_old_size <= m_fn.m_local_aliases_defs.size()); lean_assert(m_old_size <= m_fn.m_local_aliases_defs.size());
m_fn.m_local_aliases.pop(); m_fn.m_local_aliases.pop();
m_fn.m_local_aliases_defs.resize(m_old_size); m_fn.m_local_aliases_defs.resize(m_old_size);
swap(m_fn.m_num_occs, m_num_occs);
} }
}; };
bool has_several_occs(expr const & e) const {
auto it = m_num_occs.find(e);
if (it != m_num_occs.end())
return it->second > 1;
else
return false;
}
format nest(unsigned i, format const & f) { return ::lean::nest(i, f); } format nest(unsigned i, format const & f) { return ::lean::nest(i, f); }
typedef std::pair<format, unsigned> result; typedef std::pair<format, unsigned> result;
@ -357,7 +387,6 @@ class pp_fn {
/** \brief Auxiliary function for pretty printing exists formulas */ /** \brief Auxiliary function for pretty printing exists formulas */
result pp_exists(expr const & e, unsigned depth) { result pp_exists(expr const & e, unsigned depth) {
buffer<std::pair<name, expr>> nested; buffer<std::pair<name, expr>> nested;
local_names::mk_scope mk(m_local_names);
expr b = collect_nested_quantifiers(e, nested); expr b = collect_nested_quantifiers(e, nested);
format head; format head;
if (m_unicode) if (m_unicode)
@ -756,7 +785,7 @@ class pp_fn {
if (is_atomic(e)) { if (is_atomic(e)) {
return pp(e, depth + 1, true); return pp(e, depth + 1, true);
} else { } else {
mk_scope s(*this); mk_scope s(*this, e);
result r = pp(e, depth + 1, true); result r = pp(e, depth + 1, true);
if (m_local_aliases_defs.size() == s.m_old_size) { if (m_local_aliases_defs.size() == s.m_old_size) {
if (prec <= get_operator_precedence(e)) if (prec <= get_operator_precedence(e))
@ -864,7 +893,6 @@ class pp_fn {
*/ */
result pp_abstraction_core(expr const & e, unsigned depth, optional<expr> T, result pp_abstraction_core(expr const & e, unsigned depth, optional<expr> T,
std::vector<bool> const * implicit_args = nullptr) { std::vector<bool> const * implicit_args = nullptr) {
local_names::mk_scope mk(m_local_names);
if (is_arrow(e) && !implicit_args) { if (is_arrow(e) && !implicit_args) {
lean_assert(!T); lean_assert(!T);
result p_lhs = pp_arrow_child(abst_domain(e), depth); result p_lhs = pp_arrow_child(abst_domain(e), depth);
@ -987,7 +1015,6 @@ class pp_fn {
} }
result pp_let(expr const & e, unsigned depth) { result pp_let(expr const & e, unsigned depth) {
local_names::mk_scope mk(m_local_names);
buffer<std::tuple<name, optional<expr>, expr>> bindings; buffer<std::tuple<name, optional<expr>, expr>> bindings;
expr body = collect_nested_let(e, bindings); expr body = collect_nested_let(e, bindings);
unsigned r_weight = 2; unsigned r_weight = 2;
@ -1011,7 +1038,7 @@ class pp_fn {
r_weight += p_def.second; r_weight += p_def.second;
} }
} }
result p_body = pp(body, depth+1); result p_body = pp_scoped_child(body, depth+1);
r_weight += p_body.second; r_weight += p_body.second;
r_format += format{line(), g_in_fmt, space(), nest(2 + 1, p_body.first)}; r_format += format{line(), g_in_fmt, space(), nest(2 + 1, p_body.first)};
return mk_pair(group(r_format), r_weight); return mk_pair(group(r_format), r_weight);
@ -1078,7 +1105,7 @@ class pp_fn {
} }
} }
} }
if (m_extra_lets && is_shared(e)) { if (m_extra_lets && has_several_occs(e)) {
auto it = m_local_aliases.find(e); auto it = m_local_aliases.find(e);
if (it != m_local_aliases.end()) if (it != m_local_aliases.end())
return mk_result(format(it->second), 1); return mk_result(format(it->second), 1);
@ -1099,7 +1126,7 @@ class pp_fn {
case expr_kind::MetaVar: r = pp_metavar(e, depth); break; case expr_kind::MetaVar: r = pp_metavar(e, depth); break;
} }
} }
if (!main && m_extra_lets && is_shared(e) && r.second > m_alias_min_weight) { if (!main && m_extra_lets && has_several_occs(e) && r.second > m_alias_min_weight) {
name new_aux = name(m_aux, m_local_aliases_defs.size()+1); name new_aux = name(m_aux, m_local_aliases_defs.size()+1);
m_local_aliases.insert(e, new_aux); m_local_aliases.insert(e, new_aux);
m_local_aliases_defs.emplace_back(new_aux, r.first); m_local_aliases_defs.emplace_back(new_aux, r.first);

View file

@ -25,9 +25,10 @@ trans (congr (congr2 eq
(congr2 Nat::add (trans (congr2 (ite (a > 0) b) (Nat::add_zeror b)) (if_a_a (a > 0) b))))) (congr2 Nat::add (trans (congr2 (ite (a > 0) b) (Nat::add_zeror b)) (if_a_a (a > 0) b)))))
(congr1 10 (congr2 Nat::add (if_a_a (a > 0) b)))) (congr1 10 (congr2 Nat::add (if_a_a (a > 0) b))))
(eq_id (b + 10)) (eq_id (b + 10))
let κ::1 := congr2 (λ x : , eq ((λ x : , x + 10) x)) trans (congr (congr2 (λ x : , eq ((λ x : , x + 10) x))
(trans (congr2 (ite (a > 0) b) (Nat::add_zeror b)) (if_a_a (a > 0) b)) (trans (congr2 (ite (a > 0) b) (Nat::add_zeror b)) (if_a_a (a > 0) b)))
in trans (congr κ::1 (congr2 (λ x : , x + 10) (if_a_a (a > 0) b))) (eq_id (b + 10)) (congr2 (λ x : , x + 10) (if_a_a (a > 0) b)))
(eq_id (b + 10))
a * a + (a * b + (b * a + b * b)) a * a + (a * b + (b * a + b * b))
→ ⊥ refl ( → ⊥) → ⊥ refl ( → ⊥)
refl () refl ()