feat(kernel/type_checker): add push/pop methods to type_checker, they control the cache, and allow the type checker to reuse results even when it is used inside of a backtracking search

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-22 09:54:05 -07:00
parent eca22edda3
commit 3953d4d122
4 changed files with 65 additions and 8 deletions

View file

@ -487,6 +487,20 @@ struct type_checker::imp {
return r;
}
expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); }
void push() {
lean_assert(!m_cache_cs);
m_infer_type_cache[0].push();
m_infer_type_cache[1].push();
}
void pop() {
lean_assert(!m_cache_cs);
m_infer_type_cache[0].pop();
m_infer_type_cache[1].pop();
}
unsigned num_scopes() const {
lean_assert(m_infer_type_cache[0].num_scopes() == m_infer_type_cache[1].num_scopes());
return m_infer_type_cache[0].num_scopes();
}
};
static add_cnstr_fn g_no_constraint_fn = mk_no_contranint_fn();
@ -512,6 +526,9 @@ bool type_checker::is_prop(expr const & t) { return m_ptr->is_prop(t); }
expr type_checker::whnf(expr const & t) { return m_ptr->whnf(t); }
expr type_checker::ensure_pi(expr const & t, expr const & s) { return m_ptr->ensure_pi(t, s); }
expr type_checker::ensure_sort(expr const & t, expr const & s) { return m_ptr->ensure_sort(t, s); }
void type_checker::push() { m_ptr->push(); }
void type_checker::pop() { m_ptr->pop(); }
unsigned type_checker::num_scopes() const { return m_ptr->num_scopes(); }
static void check_no_metavar(environment const & env, expr const & e) {
if (has_metavar(e))

View file

@ -100,6 +100,13 @@ public:
/** \brief Mare sure type of \c e is a sort, and return it. Throw an exception otherwise. */
expr ensure_type(expr const & e) { return ensure_sort(infer(e), e); }
/** \brief Create a backtracking point for cache and generated constraints. */
void push();
/** \brief Restore backtracking point. */
void pop();
/** \brief Return the number of backtracking points. */
unsigned num_scopes() const;
void swap(type_checker & tc) { std::swap(m_ptr, tc.m_ptr); }
};

View file

@ -1830,7 +1830,7 @@ static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> &
extra_opaque = get_name_set_named_param(L, idx, "extra_opaque", name_set());
}
int mk_type_checker(lua_State * L) {
static int mk_type_checker(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs == 1) {
return push_type_checker_ref(L, std::make_shared<type_checker>(to_environment(L, 1)));
@ -1857,29 +1857,37 @@ int mk_type_checker(lua_State * L) {
}
}
}
int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); }
int type_checker_ensure_pi(lua_State * L) {
static int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); }
static int type_checker_ensure_pi(lua_State * L) {
if (lua_gettop(L) == 2)
return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2)));
else
return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2), to_expr(L, 3)));
}
int type_checker_ensure_sort(lua_State * L) {
static int type_checker_ensure_sort(lua_State * L) {
if (lua_gettop(L) == 2)
return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2)));
else
return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2), to_expr(L, 3)));
}
int type_checker_check(lua_State * L) {
static int type_checker_check(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs <= 2)
return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), level_param_names()));
else
return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), to_level_param_names(L, 3)));
}
int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); }
int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); }
int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); }
static int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); }
static int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); }
static int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); }
static int type_checker_push(lua_State * L) { to_type_checker_ref(L, 1)->push(); return 0; }
static int type_checker_pop(lua_State * L) {
if (to_type_checker_ref(L, 1)->num_scopes() == 0)
throw exception("invalid pop method, type_checker does not have backtracking points");
to_type_checker_ref(L, 1)->pop();
return 0;
}
static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); }
static const struct luaL_Reg type_checker_ref_m[] = {
{"__gc", type_checker_ref_gc},
@ -1890,6 +1898,9 @@ static const struct luaL_Reg type_checker_ref_m[] = {
{"infer", safe_function<type_checker_infer>},
{"is_def_eq", safe_function<type_checker_is_def_eq>},
{"is_prop", safe_function<type_checker_is_prop>},
{"push", safe_function<type_checker_push>},
{"pop", safe_function<type_checker_pop>},
{"num_scopes", safe_function<type_checker_num_scopes>},
{0, 0}
};

22
tests/lua/tc8.lua Normal file
View file

@ -0,0 +1,22 @@
local env = environment()
local N = Const("N")
env = add_decl(env, mk_var_decl("N", Type))
env = add_decl(env, mk_var_decl("f", mk_arrow(N, N)))
env = add_decl(env, mk_var_decl("a", N))
local f = Const("f")
local a = Const("a")
local m1 = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("l"))))
local cs = {}
local ngen = name_generator("tst")
local tc = type_checker(env, ngen, function (c) print(c); cs[#cs+1] = c end)
assert(tc:num_scopes() == 0)
tc:push()
assert(tc:num_scopes() == 1)
print(tc:check(f(m1)))
assert(#cs == 1)
print(tc:check(f(f(m1))))
assert(#cs == 1) -- New constraint is not generated
tc:pop() -- forget that we checked f(m1)
print(tc:check(f(m1)))
assert(#cs == 2) -- constraint is generated again
check_error(function() tc:pop() end)