diff --git a/src/kernel/context.cpp b/src/kernel/context.cpp index 7913712f8..db7cde446 100644 --- a/src/kernel/context.cpp +++ b/src/kernel/context.cpp @@ -58,11 +58,22 @@ std::ostream & operator<<(std::ostream & out, context const & c) { return out; } -context const & lookup(context const & c, unsigned i) { +std::pair lookup_ext(context const & c, unsigned i) { context const * it1 = &c; while (*it1) { if (i == 0) - return *it1; + return std::pair(head(*it1), tail(*it1)); + --i; + it1 = &tail(*it1); + } + throw exception("unknown free variable"); +} + +context_entry const & lookup(context const & c, unsigned i) { + context const * it1 = &c; + while (*it1) { + if (i == 0) + return head(*it1); --i; it1 = &tail(*it1); } diff --git a/src/kernel/context.h b/src/kernel/context.h index 535cb9c93..7cad176c8 100644 --- a/src/kernel/context.h +++ b/src/kernel/context.h @@ -25,7 +25,17 @@ public: expr const & get_domain() const { return m_domain; } expr const & get_body() const { return m_body; } }; -context const & lookup(context const & c, unsigned i); +/** + \brief Return the context entry for the free variable with de + Bruijn index \c i, and the context for this entry. +*/ +std::pair lookup_ext(context const & c, unsigned i); +/** + \brief Return the context entry for the free variable with de + Bruijn index \c i. +*/ +context_entry const & lookup(context const & c, unsigned i); + inline context extend(context const & c, name const & n, expr const & d, expr const & b) { return context(context_entry(n, d, b), c); } diff --git a/src/kernel/expr_formatter.cpp b/src/kernel/expr_formatter.cpp index 70c72b973..2b9a3dcb6 100644 --- a/src/kernel/expr_formatter.cpp +++ b/src/kernel/expr_formatter.cpp @@ -82,8 +82,7 @@ class simple_expr_formatter : public expr_formatter { switch (a.kind()) { case expr_kind::Var: try { - context const & c1 = lookup(c, var_idx(a)); - out() << head(c1).get_name(); + out() << lookup(c, var_idx(a)).get_name(); } catch (exception & ex) { out() << "#" << var_idx(a); } diff --git a/src/kernel/normalize.cpp b/src/kernel/normalize.cpp index fdf13d190..4f23dd8ec 100644 --- a/src/kernel/normalize.cpp +++ b/src/kernel/normalize.cpp @@ -64,15 +64,13 @@ class normalize_fn { --j; it1 = &tail(*it1); } - context const & c = ::lean::lookup(m_ctx, j); - if (c) { - context_entry const & entry = head(c); - if (entry.get_body()) - return svalue(::lean::normalize(entry.get_body(), m_env, tail(c))); - else - return svalue(length(c) - 1); - } - throw exception("unknown free variable"); + auto p = lookup_ext(m_ctx, j); + context_entry const & entry = p.first; + context const & entry_c = p.second; + if (entry.get_body()) + return svalue(::lean::normalize(entry.get_body(), m_env, entry_c)); + else + return svalue(length(entry_c)); } /** \brief Convert the closure \c a into an expression using the given stack in a context that contains \c k binders. */ diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index 10c9b35ee..0c8249f07 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -50,10 +50,11 @@ struct infer_type_fn { cache m_cache; expr lookup(context const & c, unsigned i) { - context const & def_c = ::lean::lookup(c, i); - lean_assert(length(c) >= length(def_c)); - lean_assert(length(def_c) > 0); - return lift_free_vars(head(def_c).get_domain(), length(c) - (length(def_c) - 1)); + auto p = lookup_ext(c, i); + context_entry const & def = p.first; + context const & def_c = p.second; + lean_assert(length(c) > length(def_c)); + return lift_free_vars(def.get_domain(), length(c) - length(def_c)); } expr_formatter & fmt() { return m_env.get_formatter(); } diff --git a/src/tests/kernel/occurs.cpp b/src/tests/kernel/occurs.cpp index c5a4b963c..7225e8a48 100644 --- a/src/tests/kernel/occurs.cpp +++ b/src/tests/kernel/occurs.cpp @@ -44,7 +44,13 @@ static void tst2() { expr b = Const("b"); context c; c = extend(c, "a", Type()); + lean_assert(length(c) == 1); + lean_assert(lookup(c, 0).get_name() == "a"); + auto p = lookup_ext(c, 0); + lean_assert(p.first.get_name() == "a"); + lean_assert(length(p.second) == 0); std::cout << sanitize_names(c, f(a)) << "\n"; + lean_assert(lookup(sanitize_names(c, f(a)), 0).get_name() != name("a")); std::cout << sanitize_names(c, f(b)) << "\n"; } diff --git a/src/tests/util/list.cpp b/src/tests/util/list.cpp index 3e99de45d..3dfaa7c8b 100644 --- a/src/tests/util/list.cpp +++ b/src/tests/util/list.cpp @@ -40,6 +40,9 @@ static void tst3() { lean_assert(head(tail(l)) == 20); lean_assert(head(tail(tail(l))) == 30); lean_assert(length(l) == 3); + lean_assert(length(list()) == 0); + lean_assert(length(list(10, list())) == 1); + lean_assert(length(tail(list(10, list()))) == 0); } int main() {