From f7138b6ecfb81dd46867d648b68b68f8b0f1b9c9 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 25 Jul 2013 19:13:45 -0700 Subject: [PATCH] Fix normalize Signed-off-by: Leonardo de Moura --- src/kernel/expr.cpp | 2 + src/kernel/normalize.cpp | 139 ++++++++++++++++++--------------- src/tests/kernel/normalize.cpp | 91 ++++++++++++++++++++- 3 files changed, 169 insertions(+), 63 deletions(-) diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 5ae98fb74..d73ddd718 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -210,3 +210,5 @@ expr copy(expr const & a) { return expr(); } } + +void pp(lean::expr const & e) { std::cout << e << std::endl; } diff --git a/src/kernel/normalize.cpp b/src/kernel/normalize.cpp index e94a0f9ea..21de4b5d0 100644 --- a/src/kernel/normalize.cpp +++ b/src/kernel/normalize.cpp @@ -9,123 +9,138 @@ Author: Leonardo de Moura #include "list.h" #include "buffer.h" #include "trace.h" +#include "exception.h" namespace lean { class value; typedef list context; +enum class value_kind { Expr, Closure, BoundedVar }; class value { - expr m_expr; - context m_ctx; + unsigned m_kind:2; + unsigned m_bvar:30; + expr m_expr; + context m_ctx; public: value() {} - explicit value(expr const & e):m_expr(e) {} - value(expr const & e, context const & c):m_expr(e), m_ctx(c) {} + explicit value(expr const & e):m_kind(static_cast(value_kind::Expr)), m_expr(e) {} + explicit value(unsigned k):m_kind(static_cast(value_kind::BoundedVar)), m_bvar(k) {} + value(expr const & e, context const & c):m_kind(static_cast(value_kind::Closure)), m_expr(e), m_ctx(c) { lean_assert(is_lambda(e)); } - expr const & get_expr() const { return m_expr; } - context const & get_ctx() const { return m_ctx; } + value_kind kind() const { return static_cast(m_kind); } + + bool is_expr() const { return kind() == value_kind::Expr; } + bool is_closure() const { return kind() == value_kind::Closure; } + bool is_bounded_var() const { return kind() == value_kind::BoundedVar; } + + expr const & get_expr() const { lean_assert(is_expr() || is_closure()); return m_expr; } + context const & get_ctx() const { lean_assert(is_closure()); return m_ctx; } + unsigned get_var_idx() const { lean_assert(is_bounded_var()); return m_bvar; } }; +value_kind kind(value const & v) { return v.kind(); } expr const & to_expr(value const & v) { return v.get_expr(); } -context const & ctx_of(value const & v) { return v.get_ctx(); } +context const & ctx_of(value const & v) { return v.get_ctx(); } +unsigned to_bvar(value const & v) { return v.get_var_idx(); } -bool lookup(context const & c, unsigned i, value & r) { +value lookup(context const & c, unsigned i) { context const * curr = &c; while (!is_nil(*curr)) { - if (i == 0) { - r = head(*curr); - return !is_null(to_expr(r)); - } + if (i == 0) + return head(*curr); --i; curr = &tail(*curr); } - return false; + throw exception("unknown free variable"); } -context extend(context const & c, value const & v = value()) { return cons(v, c); } +context extend(context const & c, value const & v) { return cons(v, c); } -value normalize(expr const & a, context const & c); -expr expand(value const & v); +value normalize(expr const & a, context const & c, unsigned k); +expr reify(value const & v, unsigned k); -expr expand(expr const & a, context const & c) { - if (is_lambda(a)) { - expr new_t = to_expr(normalize(abst_type(a), c)); - expr new_b = expand(normalize(abst_body(a), extend(c))); - if (is_app(new_b)) { - // (lambda (x:T) (app f ... (var 0))) - // check eta-rule applicability - unsigned n = num_args(new_b); - lean_assert(n >= 2); - expr const & last_arg = arg(new_b, n - 1); - if (is_var(last_arg) && var_idx(last_arg) == 0) { - // FIXME: I have to shift the variables in new_b - if (n == 2) - return arg(new_b, 0); - else - return app(n - 1, begin_args(new_b)); - } +expr reify_closure(expr const & a, context const & c, unsigned k) { + lean_assert(is_lambda(a)); + expr new_t = reify(normalize(abst_type(a), c, k), k); + expr new_b = reify(normalize(abst_body(a), extend(c, value(k)), k+1), k+1); + return lambda(abst_name(a), new_t, new_b); +#if 0 + // TODO: ETA-reduction + if (is_app(new_b)) { + // (lambda (x:T) (app f ... (var 0))) + // check eta-rule applicability + unsigned n = num_args(new_b); + lean_assert(n >= 2); + expr const & last_arg = arg(new_b, n - 1); + if (is_var(last_arg) && var_idx(last_arg) == 0) { + if (n == 2) + return arg(new_b, 0); + else + return app(n - 1, begin_args(new_b)); } return lambda(abst_name(a), new_t, new_b); } else { - return a; + return lambda(abst_name(a), new_t, new_b); } +#endif +} +expr reify(value const & v, unsigned k) { + lean_trace("normalize", tout << "Reify kind: " << static_cast(v.kind()) << "\n"; + if (v.is_bounded_var()) tout << "#" << to_bvar(v); else tout << to_expr(v); tout << "\n";); + switch (v.kind()) { + case value_kind::Expr: return to_expr(v); + case value_kind::BoundedVar: return var(k - to_bvar(v) - 1); + case value_kind::Closure: return reify_closure(to_expr(v), ctx_of(v), k); + } + lean_unreachable(); + return expr(); } -expr expand(value const & v) { - return expand(to_expr(v), ctx_of(v)); -} - -value normalize(expr const & a, context const & c) { - lean_trace("normalize", tout << a << "\n";); +value normalize(expr const & a, context const & c, unsigned k) { + lean_trace("normalize", tout << "Normalize, k: " << k << "\n" << a << "\n";); switch (a.kind()) { - case expr_kind::Var: { - value r; - if (lookup(c, var_idx(a), r)) - return r; - else - return value(a); - } + case expr_kind::Var: + return lookup(c, var_idx(a)); case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: return value(a); case expr_kind::App: { - value f = normalize(arg(a, 0), c); + value f = normalize(arg(a, 0), c, k); unsigned i = 1; unsigned n = num_args(a); while (true) { - expr const & fv = to_expr(f); - lean_trace("normalize", tout << "fv: " << fv << "\ni: " << i << "\n";); - switch (fv.kind()) { - case expr_kind::Lambda: { + if (f.is_closure()) { // beta reduction - value a_v = normalize(arg(a, i), c); - f = normalize(abst_body(fv), extend(ctx_of(f), a_v)); + expr const & fv = to_expr(f); + lean_trace("normalize", tout << "beta reduction...\n" << fv << "\n";); + context new_c = extend(ctx_of(f), normalize(arg(a, i), c, k)); + f = normalize(abst_body(fv), new_c, k); if (i == n - 1) return f; i++; - break; } - default: { + else { // TODO: support for interpreted symbols buffer new_args; - new_args.push_back(fv); + new_args.push_back(reify(f, k)); for (; i < n; i++) - new_args.push_back(expand(normalize(arg(a, i), c))); + new_args.push_back(reify(normalize(arg(a, i), c, k), k)); return value(app(new_args.size(), new_args.data())); - }} + } } } case expr_kind::Lambda: return value(a, c); case expr_kind::Pi: { - expr new_t = to_expr(normalize(abst_type(a), c)); - expr new_b = to_expr(normalize(abst_body(a), extend(c))); + expr new_t = reify(normalize(abst_type(a), c, k), k); + expr new_b = reify(normalize(abst_body(a), extend(c, value(k)), k+1), k+1); return value(pi(abst_name(a), new_t, new_b)); }} + lean_unreachable(); return value(a); } expr normalize(expr const & e) { - return expand(normalize(e, context())); + return reify(normalize(e, context(), 0), 0); } } diff --git a/src/tests/kernel/normalize.cpp b/src/tests/kernel/normalize.cpp index dbe23c643..c720647f6 100644 --- a/src/tests/kernel/normalize.cpp +++ b/src/tests/kernel/normalize.cpp @@ -4,15 +4,100 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "normalize.h" #include "trace.h" #include "test.h" +#include "sets.h" using namespace lean; static void eval(expr const & e) { std::cout << e << " --> " << normalize(e) << "\n"; } +static expr t() { return constant("t"); } +static expr lam(expr const & e) { return lambda("_", t(), e); } +static expr lam(expr const & t, expr const & e) { return lambda("_", t, e); } +static expr v(unsigned i) { return var(i); } +static expr arrow(expr const & d, expr const & r) { return pi("_", d, r); } +static expr zero() { + // fun (t : T) (s : t -> t) (z : t) z + return lam(t(), lam(arrow(v(0), v(0)), lam(v(1), v(0)))); +} +static expr one() { + // fun (t : T) (s : t -> t) s + return lam(t(), lam(arrow(v(0), v(0)), v(0))); +} +static expr num() { return constant("num"); } +static expr plus() { + // fun (m n : numeral) (A : Type 0) (f : A -> A) (x : A) => m A f (n A f x). + expr x = v(0), f = v(1), A = v(2), n = v(3), m = v(4); + expr body = m(A, f, n(A, f, x)); + return lam(num(), lam(num(), lam(t(), lam(arrow(v(0), v(0)), lam(v(1), body))))); +} +static expr two() { return app(plus(), one(), one()); } +static expr four() { return app(plus(), two(), two()); } +static expr times() { + // fun (m n : numeral) (A : Type 0) (f : A -> A) (x : A) => m A (n A f) x. + expr x = v(0), f = v(1), A = v(2), n = v(3), m = v(4); + expr body = m(A, n(A, f), x); + return lam(num(), lam(num(), lam(t(), lam(arrow(v(0), v(0)), lam(v(1), body))))); +} +static expr power() { + // fun (m n : numeral) (A : Type 0) => m (A -> A) (n A). + expr A = v(0), n = v(1), m = v(2); + expr body = n(arrow(A, A), m(A)); + return lam(num(), lam(num(), lam(arrow(v(0), v(0)), body))); +} + +unsigned count_core(expr const & a, expr_set & s) { + if (s.find(a) != s.end()) + return 0; + s.insert(a); + switch (a.kind()) { + case expr_kind::Var: case expr_kind::Constant: case expr_kind::Prop: case expr_kind::Type: case expr_kind::Numeral: + return 1; + case expr_kind::App: + return std::accumulate(begin_args(a), end_args(a), 1, + [&](unsigned sum, expr const & arg){ return sum + count_core(arg, s); }); + case expr_kind::Lambda: case expr_kind::Pi: + return count_core(abst_type(a), s) + count_core(abst_body(a), s) + 1; + } + return 0; +} + +unsigned count(expr const & a) { + expr_set s; + return count_core(a, s); +} + +static void tst_church_numbers() { + expr N = constant("N"); + expr z = constant("z"); + expr s = constant("s"); + std::cout << normalize(app(zero(), N, s, z)) << "\n"; + std::cout << normalize(app(one(), N, s, z)) << "\n"; + std::cout << normalize(app(two(), N, s, z)) << "\n"; + std::cout << normalize(app(four(), N, s, z)) << "\n"; + std::cout << count(normalize(app(four(), N, s, z))) << "\n"; + lean_assert(count(normalize(app(four(), N, s, z))) == 4 + 2); + std::cout << normalize(app(app(times(), four(), four()), N, s, z)) << "\n"; + std::cout << normalize(app(app(power(), two(), four()), N, s, z)) << "\n"; + lean_assert(count(normalize(app(app(power(), two(), four()), N, s, z))) == 16 + 2); + std::cout << normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z)) << "\n"; + std::cout << count(normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z))) << "\n"; + std::cout << count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z))) << "\n"; + lean_assert(count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z))) == 64 + 2); + expr sixty_four_k = normalize(app(app(power(), two(), app(power(), two(), four())), N, s, z)); + std::cout << count(sixty_four_k) << "\n"; + lean_assert(count(sixty_four_k) == 65536 + 2); + expr three = app(plus(), two(), one()); + lean_assert(count(normalize(app(app(power(), three, three), N, s, z))) == 27 + 2); + // expr big = normalize(app(app(power(), two(), app(times(), app(plus(), four(), one()), four())), N, s, z)); + // std::cout << count(big) << "\n"; + std::cout << normalize(lam(lam(app(app(times(), four(), four()), N, var(0), z)))) << "\n"; +} + static void tst1() { expr f = constant("f"); expr a = constant("a"); @@ -29,11 +114,15 @@ static void tst1() { app(var(0), b)), lambda("g", t, f(var(1))))), a)); + expr l01 = lam(v(0)(v(1))); + expr l12 = lam(lam(v(1)(v(2)))); + eval(lam(l12(l01))); + lean_assert(normalize(lam(l12(l01))) == lam(lam(v(1)(v(1))))); } int main() { - enable_trace("normalize"); continue_on_violation(true); tst1(); + tst_church_numbers(); return has_violations() ? 1 : 0; }