Add max_shared: function for computing maximally shared expressions.
This commit is contained in:
parent
aed8a07c1b
commit
2a9d0de57b
6 changed files with 153 additions and 17 deletions
|
@ -115,9 +115,9 @@ void expr_cell::dealloc() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace expr_eq {
|
namespace expr_eq_ns {
|
||||||
static thread_local expr_cell_pair_set g_eq_visited;
|
static thread_local expr_cell_pair_set g_eq_visited;
|
||||||
bool eq(expr const & a, expr const & b) {
|
bool apply(expr const & a, expr const & b) {
|
||||||
if (eqp(a, b)) return true;
|
if (eqp(a, b)) return true;
|
||||||
if (a.hash() != b.hash()) return false;
|
if (a.hash() != b.hash()) return false;
|
||||||
if (a.kind() != b.kind()) return false;
|
if (a.kind() != b.kind()) return false;
|
||||||
|
@ -136,14 +136,14 @@ bool eq(expr const & a, expr const & b) {
|
||||||
if (get_num_args(a) != get_num_args(b))
|
if (get_num_args(a) != get_num_args(b))
|
||||||
return false;
|
return false;
|
||||||
for (unsigned i = 0; i < get_num_args(a); i++)
|
for (unsigned i = 0; i < get_num_args(a); i++)
|
||||||
if (!eq(get_arg(a, i), get_arg(b, i)))
|
if (!apply(get_arg(a, i), get_arg(b, i)))
|
||||||
return false;
|
return false;
|
||||||
return true;
|
return true;
|
||||||
case expr_kind::Lambda:
|
case expr_kind::Lambda:
|
||||||
case expr_kind::Pi:
|
case expr_kind::Pi:
|
||||||
// Lambda and Pi
|
// Lambda and Pi
|
||||||
// Remark: we ignore get_abs_name because we want alpha-equivalence
|
// Remark: we ignore get_abs_name because we want alpha-equivalence
|
||||||
return eq(get_abs_type(a), get_abs_type(b)) && eq(get_abs_expr(a), get_abs_expr(b));
|
return apply(get_abs_type(a), get_abs_type(b)) && apply(get_abs_expr(a), get_abs_expr(b));
|
||||||
case expr_kind::Prop: lean_unreachable(); return true;
|
case expr_kind::Prop: lean_unreachable(); return true;
|
||||||
case expr_kind::Type:
|
case expr_kind::Type:
|
||||||
if (get_ty_num_vars(a) != get_ty_num_vars(b))
|
if (get_ty_num_vars(a) != get_ty_num_vars(b))
|
||||||
|
@ -162,8 +162,8 @@ bool eq(expr const & a, expr const & b) {
|
||||||
}
|
}
|
||||||
} // namespace expr_eq
|
} // namespace expr_eq
|
||||||
bool operator==(expr const & a, expr const & b) {
|
bool operator==(expr const & a, expr const & b) {
|
||||||
expr_eq::g_eq_visited.clear();
|
expr_eq_ns::g_eq_visited.clear();
|
||||||
return expr_eq::eq(a, b);
|
return expr_eq_ns::apply(a, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Low-level pretty printer
|
// Low-level pretty printer
|
||||||
|
|
26
src/kernel/expr_functors.h
Normal file
26
src/kernel/expr_functors.h
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||||
|
Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
|
Author: Leonardo de Moura
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include "expr.h"
|
||||||
|
|
||||||
|
namespace lean {
|
||||||
|
|
||||||
|
struct expr_hash {
|
||||||
|
unsigned operator()(expr const & e) const { return e.hash(); }
|
||||||
|
};
|
||||||
|
struct expr_eqp {
|
||||||
|
bool operator()(expr const & e1, expr const & e2) const { return eqp(e1, e2); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct expr_cell_hash {
|
||||||
|
unsigned operator()(expr_cell * e) const { return e->hash(); }
|
||||||
|
};
|
||||||
|
struct expr_cell_eqp {
|
||||||
|
bool operator()(expr_cell * e1, expr_cell * e2) const { return e1 == e2; }
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
20
src/kernel/expr_map.h
Normal file
20
src/kernel/expr_map.h
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||||
|
Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
|
Author: Leonardo de Moura
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "expr.h"
|
||||||
|
#include "expr_functors.h"
|
||||||
|
|
||||||
|
namespace lean {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
using expr_map = typename std::unordered_map<expr, T, expr_hash, expr_eqp>;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
using expr_cell_map = typename std::unordered_map<expr_cell *, T, expr_cell_hash, expr_cell_eqp>;
|
||||||
|
|
||||||
|
};
|
70
src/kernel/expr_max_shared.cpp
Normal file
70
src/kernel/expr_max_shared.cpp
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||||
|
Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
|
Author: Leonardo de Moura
|
||||||
|
*/
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
#include "expr.h"
|
||||||
|
#include "expr_functors.h"
|
||||||
|
|
||||||
|
namespace lean {
|
||||||
|
namespace max_shared_ns {
|
||||||
|
struct expr_struct_eq { unsigned operator()(expr const & e1, expr const & e2) const { return e1 == e2; }};
|
||||||
|
typedef typename std::unordered_set<expr, expr_hash, expr_struct_eq> expr_cache;
|
||||||
|
static thread_local expr_cache g_cache;
|
||||||
|
expr apply(expr const & a) {
|
||||||
|
auto r = g_cache.find(a);
|
||||||
|
if (r != g_cache.end())
|
||||||
|
return *r;
|
||||||
|
switch (a.kind()) {
|
||||||
|
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||||
|
g_cache.insert(a);
|
||||||
|
return a;
|
||||||
|
case expr_kind::App: {
|
||||||
|
std::vector<expr> new_args;
|
||||||
|
bool modified = false;
|
||||||
|
for (expr const & old_arg : app_args(a)) {
|
||||||
|
new_args.push_back(apply(old_arg));
|
||||||
|
if (!eqp(old_arg, new_args.back()))
|
||||||
|
modified = true;
|
||||||
|
}
|
||||||
|
if (!modified) {
|
||||||
|
g_cache.insert(a);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
expr r = app(static_cast<unsigned>(new_args.size()), new_args.data());
|
||||||
|
g_cache.insert(r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case expr_kind::Lambda:
|
||||||
|
case expr_kind::Pi: {
|
||||||
|
expr const & old_t = get_abs_type(a);
|
||||||
|
expr const & old_e = get_abs_expr(a);
|
||||||
|
expr t = apply(old_t);
|
||||||
|
expr e = apply(old_e);
|
||||||
|
if (!eqp(t, old_t) || !eqp(e, old_e)) {
|
||||||
|
name const & n = get_abs_name(a);
|
||||||
|
expr r = is_pi(a) ? pi(n, t, e) : lambda(n, t, e);
|
||||||
|
g_cache.insert(r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
g_cache.insert(a);
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
lean_unreachable();
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
} // namespace max_shared_ns
|
||||||
|
expr max_shared(expr const & a) {
|
||||||
|
max_shared_ns::g_cache.clear();
|
||||||
|
expr r = max_shared_ns::apply(a);
|
||||||
|
max_shared_ns::g_cache.clear();
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
} // namespace lean
|
15
src/kernel/expr_max_shared.h
Normal file
15
src/kernel/expr_max_shared.h
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
/*
|
||||||
|
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
|
||||||
|
Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
|
Author: Leonardo de Moura
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include "expr.h"
|
||||||
|
|
||||||
|
namespace lean {
|
||||||
|
/**
|
||||||
|
\brief Return a structurally identical expression that is maximally shared.
|
||||||
|
*/
|
||||||
|
expr max_shared(expr const & a);
|
||||||
|
}
|
|
@ -159,23 +159,28 @@ unsigned count(expr const & a) {
|
||||||
|
|
||||||
void tst5() {
|
void tst5() {
|
||||||
expr f = constant(name("f"));
|
expr f = constant(name("f"));
|
||||||
expr r1 = mk_redundant_dag(f, 1);
|
{
|
||||||
expr r2 = max_shared(r1);
|
expr r1 = mk_redundant_dag(f, 5);
|
||||||
std::cout << "r1: " << r1 << "\n";
|
expr r2 = max_shared(r1);
|
||||||
std::cout << "r2: " << r1 << "\n";
|
std::cout << "count(r1): " << count(r1) << "\n";
|
||||||
std::cout << "count(r1): " << count(r1) << "\n";
|
std::cout << "count(r2): " << count(r2) << "\n";
|
||||||
std::cout << "count(r2): " << count(r2) << "\n";
|
lean_assert(r1 == r2);
|
||||||
lean_assert(r1 == r2);
|
}
|
||||||
|
{
|
||||||
|
expr r1 = mk_redundant_dag(f, 16);
|
||||||
|
expr r2 = max_shared(r1);
|
||||||
|
lean_assert(r1 == r2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// continue_on_violation(true);
|
// continue_on_violation(true);
|
||||||
std::cout << "sizeof(expr): " << sizeof(expr) << "\n";
|
std::cout << "sizeof(expr): " << sizeof(expr) << "\n";
|
||||||
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
|
std::cout << "sizeof(expr_app): " << sizeof(expr_app) << "\n";
|
||||||
// tst1();
|
tst1();
|
||||||
// tst2();
|
tst2();
|
||||||
// tst3();
|
tst3();
|
||||||
// tst4();
|
tst4();
|
||||||
tst5();
|
tst5();
|
||||||
std::cout << "done" << "\n";
|
std::cout << "done" << "\n";
|
||||||
return has_violations() ? 1 : 0;
|
return has_violations() ? 1 : 0;
|
||||||
|
|
Loading…
Add table
Reference in a new issue