diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index 97baff3f4..211b12705 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #include #include "type_check.h" +#include "scoped_map.h" #include "normalize.h" #include "instantiate.h" #include "builtin.h" @@ -43,7 +44,10 @@ bool is_convertible(expr const & expected, expr const & given, environment const } struct infer_type_fn { + typedef scoped_map cache; + environment const & m_env; + cache m_cache; expr lookup(context const & c, unsigned i) { context const & def_c = ::lean::lookup(c, i); @@ -87,11 +91,26 @@ struct infer_type_fn { expr infer_type(expr const & e, context const & ctx) { lean_trace("type_check", tout << "infer type\n" << e << "\n" << ctx << "\n";); + + bool shared = false; + if (true && is_shared(e)) { + shared = true; + auto it = m_cache.find(e); + if (it != m_cache.end()) + return it->second; + } + + expr r; switch (e.kind()) { case expr_kind::Constant: - return m_env.get_object(const_name(e)).get_type(); - case expr_kind::Var: return lookup(ctx, var_idx(e)); - case expr_kind::Type: return mk_type(ty_level(e) + 1); + r = m_env.get_object(const_name(e)).get_type(); + break; + case expr_kind::Var: + r = lookup(ctx, var_idx(e)); + break; + case expr_kind::Type: + r = mk_type(ty_level(e) + 1); + break; case expr_kind::App: { expr f_t = infer_pi(arg(e, 0), ctx); unsigned i = 1; @@ -111,32 +130,56 @@ struct infer_type_fn { } f_t = instantiate(abst_body(f_t), c); i++; - if (i == num) - return f_t; + if (i == num) { + r = f_t; + break; + } check_pi(f_t, ctx); } + break; } case expr_kind::Eq: infer_type(eq_lhs(e), ctx); infer_type(eq_rhs(e), ctx); - return mk_bool_type(); + r = mk_bool_type(); + break; case expr_kind::Lambda: { infer_universe(abst_domain(e), ctx); - expr t = infer_type(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); - return mk_pi(abst_name(e), abst_domain(e), t); + expr t; + { + cache::mk_scope sc(m_cache); + t = infer_type(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); + } + r = mk_pi(abst_name(e), abst_domain(e), t); + break; } case expr_kind::Pi: { level l1 = infer_universe(abst_domain(e), ctx); - level l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); - return mk_type(max(l1, l2)); + level l2; + { + cache::mk_scope sc(m_cache); + l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); + } + r = mk_type(max(l1, l2)); + break; + } + case expr_kind::Let: { + expr lt = infer_type(let_value(e), ctx); + { + cache::mk_scope sc(m_cache); + r = infer_type(let_body(e), extend(ctx, let_name(e), lt, let_value(e))); + } + break; } - case expr_kind::Let: - return infer_type(let_body(e), extend(ctx, let_name(e), infer_type(let_value(e), ctx), let_value(e))); case expr_kind::Value: - return to_value(e).get_type(); + r = to_value(e).get_type(); + break; } - lean_unreachable(); - return e; + + if (shared) { + m_cache.insert(e, r); + } + return r; } infer_type_fn(environment const & env): diff --git a/src/tests/kernel/type_check.cpp b/src/tests/kernel/type_check.cpp index 5314b0e1c..893d12824 100644 --- a/src/tests/kernel/type_check.cpp +++ b/src/tests/kernel/type_check.cpp @@ -9,6 +9,8 @@ Author: Leonardo de Moura #include "environment.h" #include "abstract.h" #include "exception.h" +#include "toplevel.h" +#include "basic_thms.h" #include "builtin.h" #include "arith.h" #include "trace.h" @@ -82,12 +84,29 @@ static void tst4() { std::cout << infer_type(pr, env) << "\n"; } +static void tst5() { + environment env = mk_toplevel(); + env.add_var("P", Bool); + expr P = Const("P"); + expr H = Const("H"); + unsigned n = 500; + expr prop = P; + expr pr = H; + for (unsigned i = 1; i < n; i++) { + pr = Conj(P, prop, H, pr); + prop = And(P, prop); + } + expr impPr = Discharge(P, prop, Fun({H, P}, pr)); + expr prop2 = infer_type(impPr, env); + lean_assert(Implies(P, prop) == prop2); +} + int main() { continue_on_violation(true); - enable_trace("type_check"); tst1(); tst2(); tst3(); tst4(); + tst5(); return has_violations() ? 1 : 0; } diff --git a/src/tests/util/CMakeLists.txt b/src/tests/util/CMakeLists.txt index 7317560b8..8b9627e29 100644 --- a/src/tests/util/CMakeLists.txt +++ b/src/tests/util/CMakeLists.txt @@ -22,3 +22,6 @@ add_test(scoped_set ${CMAKE_CURRENT_BINARY_DIR}/scoped_set) add_executable(options options.cpp) target_link_libraries(options ${EXTRA_LIBS}) add_test(options ${CMAKE_CURRENT_BINARY_DIR}/options) +add_executable(scoped_map scoped_map.cpp) +target_link_libraries(scoped_map ${EXTRA_LIBS}) +add_test(scoped_map ${CMAKE_CURRENT_BINARY_DIR}/scoped_map) diff --git a/src/tests/util/scoped_map.cpp b/src/tests/util/scoped_map.cpp new file mode 100644 index 000000000..507d8c42b --- /dev/null +++ b/src/tests/util/scoped_map.cpp @@ -0,0 +1,63 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "scoped_map.h" +#include "test.h" +using namespace lean; + +static void tst1() { + scoped_map s; + lean_assert(s.empty()); + lean_assert(s.size() == 0); + lean_assert(s.find(10) == s.end()); + s.insert(10, 20); + lean_assert(!s.empty()); + lean_assert(s.size() == 1); + lean_assert(s.find(10) != s.end()); + lean_assert(s.find(10)->second == 20); + lean_assert(s.num_scopes() == 0); + lean_assert(s.at_base_lvl()); + s.push(); + lean_assert(s.num_scopes() == 1); + lean_assert(!s.at_base_lvl()); + s.insert(20, 40); + lean_assert(s.find(20) != s.end()); + lean_assert(s.find(30) == s.end()); + s.insert(10, 30); + lean_assert(s.find(10)->second == 30); + lean_assert(s.size() == 2); + s.pop(); + lean_assert(s.size() == 1); + lean_assert(s.find(10) != s.end()); + lean_assert(s.find(10)->second == 20); + lean_assert(s.find(20) == s.end()); + s.push(); + s.insert(30, 50); + lean_assert(s.size() == 2); + s.erase(10); + lean_assert(s.size() == 1); + s.push(); + lean_assert(s.num_scopes() == 2); + lean_assert(!s.at_base_lvl()); + s.erase(10); + lean_assert(s.size() == 1); + s.pop(); + lean_assert(s.num_scopes() == 1); + lean_assert(s.find(10) == s.end()); + lean_assert(s.find(30) != s.end()); + lean_assert(s.size() == 1); + s.pop(); + lean_assert(s.size() == 1); + lean_assert(s.at_base_lvl()); + lean_assert(s.find(10) != s.end()); + lean_assert(s.find(30) == s.end()); +} + +int main() { + continue_on_violation(true); + tst1(); + return has_violations() ? 1 : 0; +} diff --git a/src/util/scoped_map.h b/src/util/scoped_map.h new file mode 100644 index 000000000..aebf1ec2d --- /dev/null +++ b/src/util/scoped_map.h @@ -0,0 +1,145 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include +#include +#include "debug.h" + +#ifndef LEAN_SCOPED_MAP_INITIAL_BUCKET_SIZE +#define LEAN_SCOPED_MAP_INITIAL_BUCKET_SIZE 8 +#endif + +namespace lean { +/** + \brief Scoped maps (aka backtrackable maps). +*/ +template, typename KeyEqual = std::equal_to> +class scoped_map { + typedef std::unordered_map map; + typedef typename map::size_type size_type; + typedef typename map::value_type value_type; + enum class action_kind { Insert, Replace, Erase }; + map m_map; + std::vector> m_actions; + std::vector m_scopes; +public: + explicit scoped_map(size_type bucket_count = LEAN_SCOPED_MAP_INITIAL_BUCKET_SIZE, + const Hash& hash = Hash(), + const KeyEqual& equal = KeyEqual()): + m_map(bucket_count, hash, equal) {} + + /** \brief Return the number of scopes. */ + unsigned num_scopes() const { + return m_scopes.size(); + } + + /** \brief Return true iff there are no scopes. */ + bool at_base_lvl() const { + return m_scopes.empty(); + } + + /** \brief Create a new scope (it allows us to restore the current state of the map). */ + void push() { + m_scopes.push_back(m_actions.size()); + } + + /** \brief Remove \c num scopes, and restores the state of the map. */ + void pop(unsigned num = 1) { + lean_assert(num <= num_scopes()); + unsigned old_sz = m_scopes[num_scopes() - num]; + lean_assert(old_sz <= m_actions.size()); + unsigned i = m_actions.size(); + while (i > old_sz) { + --i; + auto const & p = m_actions.back(); + switch (p.first) { + case action_kind::Insert: + m_map.erase(p.second.first); + break; + case action_kind::Replace: { + auto it = m_map.find(p.second.first); + it->second = p.second.second; + break; + } + case action_kind::Erase: + m_map.insert(p.second); + break; + } + m_actions.pop_back(); + } + lean_assert(m_actions.size() == old_sz); + m_scopes.resize(num_scopes() - num); + } + + /** \brief Return true iff the map is empty */ + bool empty() const { + return m_map.empty(); + } + + /** \brief Return the number of elements stored in the map. */ + unsigned size() const { + return m_map.size(); + } + + /** \brief Insert an element in the map */ + void insert(Key const & k, T const & v) { + auto it = m_map.find(k); + if (it == m_map.end()) { + if (!at_base_lvl()) + m_actions.push_back(std::make_pair(action_kind::Insert, value_type(k, T()))); + m_map.insert(value_type(k, v)); + } else { + if (!at_base_lvl()) + m_actions.push_back(std::make_pair(action_kind::Replace, *it)); + it->second = v; + } + lean_assert(m_map.find(k)->second == v); + } + + void insert(value_type const & p) { + insert(p.first, p.second); + } + + /** \brief Remove an element from the map */ + void erase(Key const & k) { + if (!at_base_lvl()) { + auto it = m_map.find(k); + if (m_map.find(k) != m_map.end()) + m_actions.push_back(std::make_pair(action_kind::Erase, *it)); + } + m_map.erase(k); + } + + /** \brief Remove all elements and scopes */ + void clear() { + m_map.clear(); + m_actions.clear(); + m_scopes.clear(); + } + + typedef typename map::const_iterator const_iterator; + const_iterator find(Key const & k) const { + return m_map.find(k); + } + + const_iterator begin() const { + return m_map.begin(); + } + + const_iterator end() const { + return m_map.end(); + } + + class mk_scope { + scoped_map & m_map; + public: + mk_scope(scoped_map & m):m_map(m) { m_map.push(); } + ~mk_scope() { m_map.pop(); } + }; +}; +}