Fix normalize
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
14c899e7ca
commit
f7138b6ecf
3 changed files with 169 additions and 63 deletions
|
@ -210,3 +210,5 @@ expr copy(expr const & a) {
|
|||
return expr();
|
||||
}
|
||||
}
|
||||
|
||||
void pp(lean::expr const & e) { std::cout << e << std::endl; }
|
||||
|
|
|
@ -9,123 +9,138 @@ Author: Leonardo de Moura
|
|||
#include "list.h"
|
||||
#include "buffer.h"
|
||||
#include "trace.h"
|
||||
#include "exception.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
class value;
|
||||
typedef list<value> context;
|
||||
enum class value_kind { Expr, Closure, BoundedVar };
|
||||
class value {
|
||||
expr m_expr;
|
||||
context m_ctx;
|
||||
unsigned m_kind:2;
|
||||
unsigned m_bvar:30;
|
||||
expr m_expr;
|
||||
context m_ctx;
|
||||
public:
|
||||
value() {}
|
||||
explicit value(expr const & e):m_expr(e) {}
|
||||
value(expr const & e, context const & c):m_expr(e), m_ctx(c) {}
|
||||
explicit value(expr const & e):m_kind(static_cast<unsigned>(value_kind::Expr)), m_expr(e) {}
|
||||
explicit value(unsigned k):m_kind(static_cast<unsigned>(value_kind::BoundedVar)), m_bvar(k) {}
|
||||
value(expr const & e, context const & c):m_kind(static_cast<unsigned>(value_kind::Closure)), m_expr(e), m_ctx(c) { lean_assert(is_lambda(e)); }
|
||||
|
||||
expr const & get_expr() const { return m_expr; }
|
||||
context const & get_ctx() const { return m_ctx; }
|
||||
value_kind kind() const { return static_cast<value_kind>(m_kind); }
|
||||
|
||||
bool is_expr() const { return kind() == value_kind::Expr; }
|
||||
bool is_closure() const { return kind() == value_kind::Closure; }
|
||||
bool is_bounded_var() const { return kind() == value_kind::BoundedVar; }
|
||||
|
||||
expr const & get_expr() const { lean_assert(is_expr() || is_closure()); return m_expr; }
|
||||
context const & get_ctx() const { lean_assert(is_closure()); return m_ctx; }
|
||||
unsigned get_var_idx() const { lean_assert(is_bounded_var()); return m_bvar; }
|
||||
};
|
||||
|
||||
value_kind kind(value const & v) { return v.kind(); }
|
||||
expr const & to_expr(value const & v) { return v.get_expr(); }
|
||||
context const & ctx_of(value const & v) { return v.get_ctx(); }
|
||||
context const & ctx_of(value const & v) { return v.get_ctx(); }
|
||||
unsigned to_bvar(value const & v) { return v.get_var_idx(); }
|
||||
|
||||
bool lookup(context const & c, unsigned i, value & r) {
|
||||
value lookup(context const & c, unsigned i) {
|
||||
context const * curr = &c;
|
||||
while (!is_nil(*curr)) {
|
||||
if (i == 0) {
|
||||
r = head(*curr);
|
||||
return !is_null(to_expr(r));
|
||||
}
|
||||
if (i == 0)
|
||||
return head(*curr);
|
||||
--i;
|
||||
curr = &tail(*curr);
|
||||
}
|
||||
return false;
|
||||
throw exception("unknown free variable");
|
||||
}
|
||||
|
||||
context extend(context const & c, value const & v = value()) { return cons(v, c); }
|
||||
context extend(context const & c, value const & v) { return cons(v, c); }
|
||||
|
||||
value normalize(expr const & a, context const & c);
|
||||
expr expand(value const & v);
|
||||
value normalize(expr const & a, context const & c, unsigned k);
|
||||
expr reify(value const & v, unsigned k);
|
||||
|
||||
expr expand(expr const & a, context const & c) {
|
||||
if (is_lambda(a)) {
|
||||
expr new_t = to_expr(normalize(abst_type(a), c));
|
||||
expr new_b = expand(normalize(abst_body(a), extend(c)));
|
||||
if (is_app(new_b)) {
|
||||
// (lambda (x:T) (app f ... (var 0)))
|
||||
// check eta-rule applicability
|
||||
unsigned n = num_args(new_b);
|
||||
lean_assert(n >= 2);
|
||||
expr const & last_arg = arg(new_b, n - 1);
|
||||
if (is_var(last_arg) && var_idx(last_arg) == 0) {
|
||||
// FIXME: I have to shift the variables in new_b
|
||||
if (n == 2)
|
||||
return arg(new_b, 0);
|
||||
else
|
||||
return app(n - 1, begin_args(new_b));
|
||||
}
|
||||
expr reify_closure(expr const & a, context const & c, unsigned k) {
|
||||
lean_assert(is_lambda(a));
|
||||
expr new_t = reify(normalize(abst_type(a), c, k), k);
|
||||
expr new_b = reify(normalize(abst_body(a), extend(c, value(k)), k+1), k+1);
|
||||
return lambda(abst_name(a), new_t, new_b);
|
||||
#if 0
|
||||
// TODO: ETA-reduction
|
||||
if (is_app(new_b)) {
|
||||
// (lambda (x:T) (app f ... (var 0)))
|
||||
// check eta-rule applicability
|
||||
unsigned n = num_args(new_b);
|
||||
lean_assert(n >= 2);
|
||||
expr const & last_arg = arg(new_b, n - 1);
|
||||
if (is_var(last_arg) && var_idx(last_arg) == 0) {
|
||||
if (n == 2)
|
||||
return arg(new_b, 0);
|
||||
else
|
||||
return app(n - 1, begin_args(new_b));
|
||||
}
|
||||
return lambda(abst_name(a), new_t, new_b);
|
||||
}
|
||||
else {
|
||||
return a;
|
||||
return lambda(abst_name(a), new_t, new_b);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
expr reify(value const & v, unsigned k) {
|
||||
lean_trace("normalize", tout << "Reify kind: " << static_cast<unsigned>(v.kind()) << "\n";
|
||||
if (v.is_bounded_var()) tout << "#" << to_bvar(v); else tout << to_expr(v); tout << "\n";);
|
||||
switch (v.kind()) {
|
||||
case value_kind::Expr: return to_expr(v);
|
||||
case value_kind::BoundedVar: return var(k - to_bvar(v) - 1);
|
||||
case value_kind::Closure: return reify_closure(to_expr(v), ctx_of(v), k);
|
||||
}
|
||||
lean_unreachable();
|
||||
return expr();
|
||||
}
|
||||
|
||||
expr expand(value const & v) {
|
||||
return expand(to_expr(v), ctx_of(v));
|
||||
}
|
||||
|
||||
value normalize(expr const & a, context const & c) {
|
||||
lean_trace("normalize", tout << a << "\n";);
|
||||
value normalize(expr const & a, context const & c, unsigned k) {
|
||||
lean_trace("normalize", tout << "Normalize, k: " << k << "\n" << a << "\n";);
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var: {
|
||||
value r;
|
||||
if (lookup(c, var_idx(a), r))
|
||||
return r;
|
||||
else
|
||||
return value(a);
|
||||
}
|
||||
case expr_kind::Var:
|
||||
return lookup(c, var_idx(a));
|
||||
case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||
return value(a);
|
||||
case expr_kind::App: {
|
||||
value f = normalize(arg(a, 0), c);
|
||||
value f = normalize(arg(a, 0), c, k);
|
||||
unsigned i = 1;
|
||||
unsigned n = num_args(a);
|
||||
while (true) {
|
||||
expr const & fv = to_expr(f);
|
||||
lean_trace("normalize", tout << "fv: " << fv << "\ni: " << i << "\n";);
|
||||
switch (fv.kind()) {
|
||||
case expr_kind::Lambda: {
|
||||
if (f.is_closure()) {
|
||||
// beta reduction
|
||||
value a_v = normalize(arg(a, i), c);
|
||||
f = normalize(abst_body(fv), extend(ctx_of(f), a_v));
|
||||
expr const & fv = to_expr(f);
|
||||
lean_trace("normalize", tout << "beta reduction...\n" << fv << "\n";);
|
||||
context new_c = extend(ctx_of(f), normalize(arg(a, i), c, k));
|
||||
f = normalize(abst_body(fv), new_c, k);
|
||||
if (i == n - 1)
|
||||
return f;
|
||||
i++;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
else {
|
||||
// TODO: support for interpreted symbols
|
||||
buffer<expr> new_args;
|
||||
new_args.push_back(fv);
|
||||
new_args.push_back(reify(f, k));
|
||||
for (; i < n; i++)
|
||||
new_args.push_back(expand(normalize(arg(a, i), c)));
|
||||
new_args.push_back(reify(normalize(arg(a, i), c, k), k));
|
||||
return value(app(new_args.size(), new_args.data()));
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
case expr_kind::Lambda:
|
||||
return value(a, c);
|
||||
case expr_kind::Pi: {
|
||||
expr new_t = to_expr(normalize(abst_type(a), c));
|
||||
expr new_b = to_expr(normalize(abst_body(a), extend(c)));
|
||||
expr new_t = reify(normalize(abst_type(a), c, k), k);
|
||||
expr new_b = reify(normalize(abst_body(a), extend(c, value(k)), k+1), k+1);
|
||||
return value(pi(abst_name(a), new_t, new_b));
|
||||
}}
|
||||
lean_unreachable();
|
||||
return value(a);
|
||||
}
|
||||
|
||||
expr normalize(expr const & e) {
|
||||
return expand(normalize(e, context()));
|
||||
return reify(normalize(e, context(), 0), 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,15 +4,100 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include "normalize.h"
|
||||
#include "trace.h"
|
||||
#include "test.h"
|
||||
#include "sets.h"
|
||||
using namespace lean;
|
||||
|
||||
static void eval(expr const & e) {
|
||||
std::cout << e << " --> " << normalize(e) << "\n";
|
||||
}
|
||||
|
||||
static expr t() { return constant("t"); }
|
||||
static expr lam(expr const & e) { return lambda("_", t(), e); }
|
||||
static expr lam(expr const & t, expr const & e) { return lambda("_", t, e); }
|
||||
static expr v(unsigned i) { return var(i); }
|
||||
static expr arrow(expr const & d, expr const & r) { return pi("_", d, r); }
|
||||
static expr zero() {
|
||||
// fun (t : T) (s : t -> t) (z : t) z
|
||||
return lam(t(), lam(arrow(v(0), v(0)), lam(v(1), v(0))));
|
||||
}
|
||||
static expr one() {
|
||||
// fun (t : T) (s : t -> t) s
|
||||
return lam(t(), lam(arrow(v(0), v(0)), v(0)));
|
||||
}
|
||||
static expr num() { return constant("num"); }
|
||||
static expr plus() {
|
||||
// fun (m n : numeral) (A : Type 0) (f : A -> A) (x : A) => m A f (n A f x).
|
||||
expr x = v(0), f = v(1), A = v(2), n = v(3), m = v(4);
|
||||
expr body = m(A, f, n(A, f, x));
|
||||
return lam(num(), lam(num(), lam(t(), lam(arrow(v(0), v(0)), lam(v(1), body)))));
|
||||
}
|
||||
static expr two() { return app(plus(), one(), one()); }
|
||||
static expr four() { return app(plus(), two(), two()); }
|
||||
static expr times() {
|
||||
// fun (m n : numeral) (A : Type 0) (f : A -> A) (x : A) => m A (n A f) x.
|
||||
expr x = v(0), f = v(1), A = v(2), n = v(3), m = v(4);
|
||||
expr body = m(A, n(A, f), x);
|
||||
return lam(num(), lam(num(), lam(t(), lam(arrow(v(0), v(0)), lam(v(1), body)))));
|
||||
}
|
||||
static expr power() {
|
||||
// fun (m n : numeral) (A : Type 0) => m (A -> A) (n A).
|
||||
expr A = v(0), n = v(1), m = v(2);
|
||||
expr body = n(arrow(A, A), m(A));
|
||||
return lam(num(), lam(num(), lam(arrow(v(0), v(0)), body)));
|
||||
}
|
||||
|
||||
unsigned count_core(expr const & a, expr_set & s) {
|
||||
if (s.find(a) != s.end())
|
||||
return 0;
|
||||
s.insert(a);
|
||||
switch (a.kind()) {
|
||||
case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral:
|
||||
return 1;
|
||||
case expr_kind::App:
|
||||
return std::accumulate(begin_args(a), end_args(a), 1,
|
||||
[&](unsigned sum, expr const & arg){ return sum + count_core(arg, s); });
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
return count_core(abst_type(a), s) + count_core(abst_body(a), s) + 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned count(expr const & a) {
|
||||
expr_set s;
|
||||
return count_core(a, s);
|
||||
}
|
||||
|
||||
static void tst_church_numbers() {
|
||||
expr N = constant("N");
|
||||
expr z = constant("z");
|
||||
expr s = constant("s");
|
||||
std::cout << normalize(app(zero(), N, s, z)) << "\n";
|
||||
std::cout << normalize(app(one(), N, s, z)) << "\n";
|
||||
std::cout << normalize(app(two(), N, s, z)) << "\n";
|
||||
std::cout << normalize(app(four(), N, s, z)) << "\n";
|
||||
std::cout << count(normalize(app(four(), N, s, z))) << "\n";
|
||||
lean_assert(count(normalize(app(four(), N, s, z))) == 4 + 2);
|
||||
std::cout << normalize(app(app(times(), four(), four()), N, s, z)) << "\n";
|
||||
std::cout << normalize(app(app(power(), two(), four()), N, s, z)) << "\n";
|
||||
lean_assert(count(normalize(app(app(power(), two(), four()), N, s, z))) == 16 + 2);
|
||||
std::cout << normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z)) << "\n";
|
||||
std::cout << count(normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z))) << "\n";
|
||||
std::cout << count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z))) << "\n";
|
||||
lean_assert(count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z))) == 64 + 2);
|
||||
expr sixty_four_k = normalize(app(app(power(), two(), app(power(), two(), four())), N, s, z));
|
||||
std::cout << count(sixty_four_k) << "\n";
|
||||
lean_assert(count(sixty_four_k) == 65536 + 2);
|
||||
expr three = app(plus(), two(), one());
|
||||
lean_assert(count(normalize(app(app(power(), three, three), N, s, z))) == 27 + 2);
|
||||
// expr big = normalize(app(app(power(), two(), app(times(), app(plus(), four(), one()), four())), N, s, z));
|
||||
// std::cout << count(big) << "\n";
|
||||
std::cout << normalize(lam(lam(app(app(times(), four(), four()), N, var(0), z)))) << "\n";
|
||||
}
|
||||
|
||||
static void tst1() {
|
||||
expr f = constant("f");
|
||||
expr a = constant("a");
|
||||
|
@ -29,11 +114,15 @@ static void tst1() {
|
|||
app(var(0), b)),
|
||||
lambda("g", t, f(var(1))))),
|
||||
a));
|
||||
expr l01 = lam(v(0)(v(1)));
|
||||
expr l12 = lam(lam(v(1)(v(2))));
|
||||
eval(lam(l12(l01)));
|
||||
lean_assert(normalize(lam(l12(l01))) == lam(lam(v(1)(v(1)))));
|
||||
}
|
||||
|
||||
int main() {
|
||||
enable_trace("normalize");
|
||||
continue_on_violation(true);
|
||||
tst1();
|
||||
tst_church_numbers();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue