Fix bugs in type checker

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-08-01 21:28:26 -07:00
parent 7b00561a94
commit 3ef9d21875
7 changed files with 74 additions and 15 deletions

View file

@ -27,4 +27,15 @@ inline expr abstract(expr const & s, expr const & e) { return abstract(1, &s, e)
*/
expr abstract_p(unsigned n, expr const * s, expr const & e);
inline expr abstract_p(expr const & s, expr const & e) { return abstract_p(1, &s, e); }
/**
\brief Create a lambda expression (lambda (x : t) b), the term b is abstracted using abstract(constant(x), b).
*/
inline expr fun(name const & n, expr const & t, expr const & b) { return lambda(n, t, abstract(constant(n), b)); }
inline expr fun(char const * n, expr const & t, expr const & b) { return fun(name(n), t, b); }
/**
\brief Create a Pi expression (pi (x : t) b), the term b is abstracted using abstract(constant(x), b).
*/
inline expr Fun(name const & n, expr const & t, expr const & b) { return pi(n, t, abstract(constant(n), b)); }
inline expr Fun(char const * n, expr const & t, expr const & b) { return Fun(name(n), t, b); }
}

View file

@ -151,6 +151,10 @@ bool operator==(expr const & a, expr const & b) {
return eq_fn()(a, b);
}
bool is_arrow(expr const & t) {
return is_pi(t) && !has_free_var(abst_body(t), 0);
}
// Low-level pretty printer
std::ostream & operator<<(std::ostream & out, expr const & a) {
switch (a.kind()) {
@ -166,10 +170,12 @@ std::ostream & operator<<(std::ostream & out, expr const & a) {
break;
case expr_kind::Lambda: out << "(fun (" << abst_name(a) << " : " << abst_type(a) << ") " << abst_body(a) << ")"; break;
case expr_kind::Pi:
if (has_free_var(abst_body(a), 0))
if (!is_arrow(a))
out << "(pi (" << abst_name(a) << " : " << abst_type(a) << ") " << abst_body(a) << ")";
else
else if (!is_arrow(abst_type(a)))
out << abst_type(a) << " -> " << abst_body(a);
else
out << "(" << abst_type(a) << ") -> " << abst_body(a);
break;
case expr_kind::Type: {
level const & l = ty_level(a);

View file

@ -160,9 +160,18 @@ public:
unsigned k = length(m_ctx);
return reify(normalize(e, stack(), k), k);
}
expr operator()(expr const & e, expr const & v) {
unsigned k = length(m_ctx);
stack s = extend(stack(), normalize(v, stack(), k));
return reify(normalize(e, s, k+1), k+1);
}
};
expr normalize(expr const & e, environment const & env, context const & ctx) {
return normalize_fn(env, ctx)(e);
}
expr normalize(expr const & e, environment const & env, context const & ctx, expr const & v) {
return normalize_fn(env, ctx)(e, v);
}
}

View file

@ -11,5 +11,8 @@ Author: Leonardo de Moura
namespace lean {
class environment;
/** \brief Normalize e using the environment env and context ctx */
expr normalize(expr const & e, environment const & env, context const & ctx = context());
/** \brief Normalize e using the environment env, context ctx, and add v to "normalization stack" */
expr normalize(expr const & e, environment const & env, context const & ctx, expr const & v);
}

View file

@ -19,9 +19,15 @@ class infer_type_fn {
return ::lean::normalize(e, m_env, ctx);
}
expr normalize(expr const & e, context const & ctx, expr const & v) {
return ::lean::normalize(e, m_env, ctx, v);
}
expr lookup(context const & c, unsigned i) {
context const & def_c = ::lean::lookup(c, i);
return lift_free_vars(head(def_c).get_type(), length(c) - length(def_c));
lean_assert(length(c) >= length(def_c));
lean_assert(length(def_c) > 0);
return lift_free_vars(head(def_c).get_type(), length(c) - (length(def_c) - 1));
}
level infer_universe(expr const & t, context const & ctx) {
@ -85,7 +91,10 @@ class infer_type_fn {
expr const & c = arg(e, i);
expr c_t = infer_type(c, ctx);
check_type(e, i, abst_type(f_t), c_t, ctx);
f_t = normalize(abst_body(f_t), extend(ctx, abst_name(f_t), c_t));
// Remark: if f_t is an arrow, we don't have to call normalize and
// lower_free_vars
f_t = normalize(abst_body(f_t), ctx, c);
f_t = lower_free_vars(f_t, 1);
i++;
if (i == num)
return f_t;

View file

@ -90,7 +90,7 @@ static void tst5() {
}
int main() {
// continue_on_violation(true);
continue_on_violation(true);
tst1();
tst2();
tst3();

View file

@ -6,18 +6,40 @@ Author: Leonardo de Moura
*/
#include <iostream>
#include "type_check.h"
#include "abstract.h"
#include "exception.h"
#include "trace.h"
#include "test.h"
using namespace lean;
expr c(char const * n) { return constant(n); }
static void tst1() {
environment env;
expr t0 = type(level());
std::cout << infer_type(t0, env) << "\n";
lean_assert(infer_type(t0, env) == type(level()+1));
expr t1 = pi("_", t0, t0);
std::cout << infer_type(t1, env) << "\n";
expr f = pi("_", t0, t0);
std::cout << infer_type(f, env) << "\n";
lean_assert(infer_type(f, env) == type(level()+1));
level u = env.define_uvar("u", level() + 1);
level v = env.define_uvar("v", level() + 1);
expr g = pi("_", type(u), type(v));
std::cout << infer_type(g, env) << "\n";
lean_assert(infer_type(g, env) == type(max(u+1, v+1)));
std::cout << infer_type(type(u), env) << "\n";
lean_assert(infer_type(type(u), env) == type(u+1));
std::cout << infer_type(lambda("x", type(u), var(0)), env) << "\n";
lean_assert(infer_type(lambda("x", type(u), var(0)), env) == pi("_", type(u), type(u)));
std::cout << infer_type(lambda("Nat", type(level()), lambda("n", var(0), var(0))), env) << "\n";
expr nat = c("nat");
expr T = fun("nat", type(level()),
fun("+", arrow(nat, arrow(nat, nat)),
fun("m", nat, app(c("+"), c("m"), c("m")))));
std::cout << T << "\n";
std::cout << infer_type(T, env) << "\n";
std::cout << Fun("nat", type(level()), arrow(arrow(nat, arrow(nat, nat)), arrow(nat, nat))) << "\n";
lean_assert(infer_type(T, env) == Fun("nat", type(level()), arrow(arrow(nat, arrow(nat, nat)), arrow(nat, nat))));
}
static void tst2() {
@ -27,12 +49,12 @@ static void tst2() {
expr t0 = type(level());
expr t1 = type(l1);
expr F =
lambda("Nat", t0,
lambda("Vec", arrow(var(0), t0),
lambda("n", var(1),
lambda("len", arrow(app(var(1), var(0)), var(2)),
lambda("v", app(var(2), var(1)),
app(var(1), var(0)))))));
fun("Nat", t0,
fun("Vec", arrow(c("Nat"), t0),
fun("n", c("Nat"),
fun("len", arrow(app(c("Vec"), c("n")), c("Nat")),
fun("v", app(c("Vec"), c("n")),
app(c("len"), c("v")))))));
std::cout << F << "\n";
std::cout << infer_type(F, env) << "\n";
}
@ -42,10 +64,9 @@ static void tst2() {
}
int main() {
// continue_on_violation(true);
continue_on_violation(true);
enable_trace("type_check");
tst1();
return 0;
tst2();
return has_violations() ? 1 : 0;
}