diff --git a/src/kernel/instantiate.cpp b/src/kernel/instantiate.cpp index dd54af0a3..f064d2c8d 100644 --- a/src/kernel/instantiate.cpp +++ b/src/kernel/instantiate.cpp @@ -9,7 +9,7 @@ Author: Leonardo de Moura #include "replace.h" namespace lean { -expr instantiate(unsigned n, expr const * s, expr const & e) { +expr instantiate_with_closed(unsigned n, expr const * s, expr const & e) { lean_assert(std::all_of(s, s+n, closed)); auto f = [=](expr const & e, unsigned offset) -> expr { @@ -24,7 +24,21 @@ expr instantiate(unsigned n, expr const * s, expr const & e) { } return e; }; - + return replace_fn(f)(e); +} +expr instantiate(unsigned n, expr const * s, expr const & e) { + auto f = [=](expr const & e, unsigned offset) -> expr { + if (is_var(e)) { + unsigned vidx = var_idx(e); + if (vidx >= offset) { + if (vidx < offset + n) + return lift_free_vars(s[n - (vidx - offset) - 1], offset); + else + return var(vidx - n); + } + } + return e; + }; return replace_fn(f)(e); } } diff --git a/src/kernel/instantiate.h b/src/kernel/instantiate.h index 98d606bc4..e21a4108f 100644 --- a/src/kernel/instantiate.h +++ b/src/kernel/instantiate.h @@ -13,6 +13,13 @@ namespace lean { \pre s[0], ..., s[n-1] must be closed expressions (i.e., no free variables). */ +expr instantiate_with_closed(unsigned n, expr const * s, expr const & e); +inline expr instantiate_with_closed(expr const & s, expr const & e) { return instantiate_with_closed(1, &s, e); } + +/** + \brief Replace the free variables with indices 0,...,n-1 with s[n-1],...,s[0] in e. +*/ expr instantiate(unsigned n, expr const * s, expr const & e); inline expr instantiate(expr const & s, expr const & e) { return instantiate(1, &s, e); } + } diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index 5e2fbb498..68e082c9f 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -7,6 +7,7 @@ Author: Leonardo de Moura #include #include "type_check.h" #include "normalize.h" +#include "instantiate.h" #include "free_vars.h" #include "exception.h" #include "trace.h" @@ -83,23 +84,19 @@ class infer_type_fn { case expr_kind::Var: return lookup(ctx, var_idx(e)); case expr_kind::Type: return type(ty_level(e) + 1); case expr_kind::App: { - expr f_t = infer_pi(arg(e, 0), ctx); + expr f = arg(e, 0); unsigned i = 1; unsigned num = num_args(e); lean_assert(num >= 2); while (true) { expr const & c = arg(e, i); + expr f_t = infer_pi(f, ctx); expr c_t = infer_type(c, ctx); check_type(e, i, abst_domain(f_t), c_t, ctx); - // Remark: if f_t is an arrow, we don't have to call normalize and - // lower_free_vars - f_t = normalize(abst_body(f_t), ctx, c); - f_t = lower_free_vars(f_t, 1); + f = instantiate(abst_body(f_t), c); i++; if (i == num) - return f_t; - if (!is_pi(f_t)) - throw exception("function expected"); + return f; } } case expr_kind::Lambda: { diff --git a/src/tests/kernel/replace.cpp b/src/tests/kernel/replace.cpp index 7e9d70074..b3f9b845d 100644 --- a/src/tests/kernel/replace.cpp +++ b/src/tests/kernel/replace.cpp @@ -30,13 +30,15 @@ static void tst1() { static void tst2() { expr r = lambda("x", type(level()), app(var(0), var(1), var(2))); - std::cout << instantiate(constant("a"), r) << std::endl; - lean_assert(instantiate(constant("a"), r) == lambda("x", type(level()), app(var(0), constant("a"), var(1)))); - lean_assert(instantiate(constant("b"), instantiate(constant("a"), r)) == + std::cout << instantiate_with_closed(constant("a"), r) << std::endl; + lean_assert(instantiate_with_closed(constant("a"), r) == lambda("x", type(level()), app(var(0), constant("a"), var(1)))); + lean_assert(instantiate_with_closed(constant("b"), instantiate_with_closed(constant("a"), r)) == lambda("x", type(level()), app(var(0), constant("a"), constant("b")))); - std::cout << instantiate(constant("a"), abst_body(r)) << std::endl; - lean_assert(instantiate(constant("a"), abst_body(r)) == app(constant("a"), var(0), var(1))); -} + std::cout << instantiate_with_closed(constant("a"), abst_body(r)) << std::endl; + lean_assert(instantiate_with_closed(constant("a"), abst_body(r)) == app(constant("a"), var(0), var(1))); + std::cout << instantiate(var(10), r) << std::endl; + lean_assert(instantiate(var(10), r) == lambda("x", type(level()), app(var(0), var(11), var(1)))); + } int main() { continue_on_violation(true);