Add scoped_map. Cache type checker results.

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-08-10 19:27:56 -07:00
parent 19440bc103
commit 7ebaac62a8
5 changed files with 289 additions and 16 deletions

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#include <sstream>
#include "type_check.h"
#include "scoped_map.h"
#include "normalize.h"
#include "instantiate.h"
#include "builtin.h"
@ -43,7 +44,10 @@ bool is_convertible(expr const & expected, expr const & given, environment const
}
struct infer_type_fn {
typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache;
environment const & m_env;
cache m_cache;
expr lookup(context const & c, unsigned i) {
context const & def_c = ::lean::lookup(c, i);
@ -87,11 +91,26 @@ struct infer_type_fn {
expr infer_type(expr const & e, context const & ctx) {
lean_trace("type_check", tout << "infer type\n" << e << "\n" << ctx << "\n";);
bool shared = false;
if (true && is_shared(e)) {
shared = true;
auto it = m_cache.find(e);
if (it != m_cache.end())
return it->second;
}
expr r;
switch (e.kind()) {
case expr_kind::Constant:
return m_env.get_object(const_name(e)).get_type();
case expr_kind::Var: return lookup(ctx, var_idx(e));
case expr_kind::Type: return mk_type(ty_level(e) + 1);
r = m_env.get_object(const_name(e)).get_type();
break;
case expr_kind::Var:
r = lookup(ctx, var_idx(e));
break;
case expr_kind::Type:
r = mk_type(ty_level(e) + 1);
break;
case expr_kind::App: {
expr f_t = infer_pi(arg(e, 0), ctx);
unsigned i = 1;
@ -111,32 +130,56 @@ struct infer_type_fn {
}
f_t = instantiate(abst_body(f_t), c);
i++;
if (i == num)
return f_t;
if (i == num) {
r = f_t;
break;
}
check_pi(f_t, ctx);
}
break;
}
case expr_kind::Eq:
infer_type(eq_lhs(e), ctx);
infer_type(eq_rhs(e), ctx);
return mk_bool_type();
r = mk_bool_type();
break;
case expr_kind::Lambda: {
infer_universe(abst_domain(e), ctx);
expr t = infer_type(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
return mk_pi(abst_name(e), abst_domain(e), t);
expr t;
{
cache::mk_scope sc(m_cache);
t = infer_type(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
}
r = mk_pi(abst_name(e), abst_domain(e), t);
break;
}
case expr_kind::Pi: {
level l1 = infer_universe(abst_domain(e), ctx);
level l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
return mk_type(max(l1, l2));
level l2;
{
cache::mk_scope sc(m_cache);
l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
}
r = mk_type(max(l1, l2));
break;
}
case expr_kind::Let: {
expr lt = infer_type(let_value(e), ctx);
{
cache::mk_scope sc(m_cache);
r = infer_type(let_body(e), extend(ctx, let_name(e), lt, let_value(e)));
}
break;
}
case expr_kind::Let:
return infer_type(let_body(e), extend(ctx, let_name(e), infer_type(let_value(e), ctx), let_value(e)));
case expr_kind::Value:
return to_value(e).get_type();
r = to_value(e).get_type();
break;
}
lean_unreachable();
return e;
if (shared) {
m_cache.insert(e, r);
}
return r;
}
infer_type_fn(environment const & env):

View file

@ -9,6 +9,8 @@ Author: Leonardo de Moura
#include "environment.h"
#include "abstract.h"
#include "exception.h"
#include "toplevel.h"
#include "basic_thms.h"
#include "builtin.h"
#include "arith.h"
#include "trace.h"
@ -82,12 +84,29 @@ static void tst4() {
std::cout << infer_type(pr, env) << "\n";
}
static void tst5() {
environment env = mk_toplevel();
env.add_var("P", Bool);
expr P = Const("P");
expr H = Const("H");
unsigned n = 500;
expr prop = P;
expr pr = H;
for (unsigned i = 1; i < n; i++) {
pr = Conj(P, prop, H, pr);
prop = And(P, prop);
}
expr impPr = Discharge(P, prop, Fun({H, P}, pr));
expr prop2 = infer_type(impPr, env);
lean_assert(Implies(P, prop) == prop2);
}
int main() {
continue_on_violation(true);
enable_trace("type_check");
tst1();
tst2();
tst3();
tst4();
tst5();
return has_violations() ? 1 : 0;
}

View file

@ -22,3 +22,6 @@ add_test(scoped_set ${CMAKE_CURRENT_BINARY_DIR}/scoped_set)
add_executable(options options.cpp)
target_link_libraries(options ${EXTRA_LIBS})
add_test(options ${CMAKE_CURRENT_BINARY_DIR}/options)
add_executable(scoped_map scoped_map.cpp)
target_link_libraries(scoped_map ${EXTRA_LIBS})
add_test(scoped_map ${CMAKE_CURRENT_BINARY_DIR}/scoped_map)

View file

@ -0,0 +1,63 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "scoped_map.h"
#include "test.h"
using namespace lean;
static void tst1() {
scoped_map<int, int> s;
lean_assert(s.empty());
lean_assert(s.size() == 0);
lean_assert(s.find(10) == s.end());
s.insert(10, 20);
lean_assert(!s.empty());
lean_assert(s.size() == 1);
lean_assert(s.find(10) != s.end());
lean_assert(s.find(10)->second == 20);
lean_assert(s.num_scopes() == 0);
lean_assert(s.at_base_lvl());
s.push();
lean_assert(s.num_scopes() == 1);
lean_assert(!s.at_base_lvl());
s.insert(20, 40);
lean_assert(s.find(20) != s.end());
lean_assert(s.find(30) == s.end());
s.insert(10, 30);
lean_assert(s.find(10)->second == 30);
lean_assert(s.size() == 2);
s.pop();
lean_assert(s.size() == 1);
lean_assert(s.find(10) != s.end());
lean_assert(s.find(10)->second == 20);
lean_assert(s.find(20) == s.end());
s.push();
s.insert(30, 50);
lean_assert(s.size() == 2);
s.erase(10);
lean_assert(s.size() == 1);
s.push();
lean_assert(s.num_scopes() == 2);
lean_assert(!s.at_base_lvl());
s.erase(10);
lean_assert(s.size() == 1);
s.pop();
lean_assert(s.num_scopes() == 1);
lean_assert(s.find(10) == s.end());
lean_assert(s.find(30) != s.end());
lean_assert(s.size() == 1);
s.pop();
lean_assert(s.size() == 1);
lean_assert(s.at_base_lvl());
lean_assert(s.find(10) != s.end());
lean_assert(s.find(30) == s.end());
}
int main() {
continue_on_violation(true);
tst1();
return has_violations() ? 1 : 0;
}

145
src/util/scoped_map.h Normal file
View file

@ -0,0 +1,145 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <iostream>
#include <unordered_map>
#include <vector>
#include "debug.h"
#ifndef LEAN_SCOPED_MAP_INITIAL_BUCKET_SIZE
#define LEAN_SCOPED_MAP_INITIAL_BUCKET_SIZE 8
#endif
namespace lean {
/**
\brief Scoped maps (aka backtrackable maps).
*/
template<typename Key, typename T, typename Hash = std::hash<Key>, typename KeyEqual = std::equal_to<Key>>
class scoped_map {
typedef std::unordered_map<Key, T, Hash, KeyEqual> map;
typedef typename map::size_type size_type;
typedef typename map::value_type value_type;
enum class action_kind { Insert, Replace, Erase };
map m_map;
std::vector<std::pair<action_kind, value_type>> m_actions;
std::vector<unsigned> m_scopes;
public:
explicit scoped_map(size_type bucket_count = LEAN_SCOPED_MAP_INITIAL_BUCKET_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual()):
m_map(bucket_count, hash, equal) {}
/** \brief Return the number of scopes. */
unsigned num_scopes() const {
return m_scopes.size();
}
/** \brief Return true iff there are no scopes. */
bool at_base_lvl() const {
return m_scopes.empty();
}
/** \brief Create a new scope (it allows us to restore the current state of the map). */
void push() {
m_scopes.push_back(m_actions.size());
}
/** \brief Remove \c num scopes, and restores the state of the map. */
void pop(unsigned num = 1) {
lean_assert(num <= num_scopes());
unsigned old_sz = m_scopes[num_scopes() - num];
lean_assert(old_sz <= m_actions.size());
unsigned i = m_actions.size();
while (i > old_sz) {
--i;
auto const & p = m_actions.back();
switch (p.first) {
case action_kind::Insert:
m_map.erase(p.second.first);
break;
case action_kind::Replace: {
auto it = m_map.find(p.second.first);
it->second = p.second.second;
break;
}
case action_kind::Erase:
m_map.insert(p.second);
break;
}
m_actions.pop_back();
}
lean_assert(m_actions.size() == old_sz);
m_scopes.resize(num_scopes() - num);
}
/** \brief Return true iff the map is empty */
bool empty() const {
return m_map.empty();
}
/** \brief Return the number of elements stored in the map. */
unsigned size() const {
return m_map.size();
}
/** \brief Insert an element in the map */
void insert(Key const & k, T const & v) {
auto it = m_map.find(k);
if (it == m_map.end()) {
if (!at_base_lvl())
m_actions.push_back(std::make_pair(action_kind::Insert, value_type(k, T())));
m_map.insert(value_type(k, v));
} else {
if (!at_base_lvl())
m_actions.push_back(std::make_pair(action_kind::Replace, *it));
it->second = v;
}
lean_assert(m_map.find(k)->second == v);
}
void insert(value_type const & p) {
insert(p.first, p.second);
}
/** \brief Remove an element from the map */
void erase(Key const & k) {
if (!at_base_lvl()) {
auto it = m_map.find(k);
if (m_map.find(k) != m_map.end())
m_actions.push_back(std::make_pair(action_kind::Erase, *it));
}
m_map.erase(k);
}
/** \brief Remove all elements and scopes */
void clear() {
m_map.clear();
m_actions.clear();
m_scopes.clear();
}
typedef typename map::const_iterator const_iterator;
const_iterator find(Key const & k) const {
return m_map.find(k);
}
const_iterator begin() const {
return m_map.begin();
}
const_iterator end() const {
return m_map.end();
}
class mk_scope {
scoped_map & m_map;
public:
mk_scope(scoped_map & m):m_map(m) { m_map.push(); }
~mk_scope() { m_map.pop(); }
};
};
}