feat(kernel/type_checker): add push/pop methods to type_checker, they control the cache, and allow the type checker to reuse results even when it is used inside of a backtracking search
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
eca22edda3
commit
3953d4d122
4 changed files with 65 additions and 8 deletions
|
@ -487,6 +487,20 @@ struct type_checker::imp {
|
|||
return r;
|
||||
}
|
||||
expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); }
|
||||
void push() {
|
||||
lean_assert(!m_cache_cs);
|
||||
m_infer_type_cache[0].push();
|
||||
m_infer_type_cache[1].push();
|
||||
}
|
||||
void pop() {
|
||||
lean_assert(!m_cache_cs);
|
||||
m_infer_type_cache[0].pop();
|
||||
m_infer_type_cache[1].pop();
|
||||
}
|
||||
unsigned num_scopes() const {
|
||||
lean_assert(m_infer_type_cache[0].num_scopes() == m_infer_type_cache[1].num_scopes());
|
||||
return m_infer_type_cache[0].num_scopes();
|
||||
}
|
||||
};
|
||||
|
||||
static add_cnstr_fn g_no_constraint_fn = mk_no_contranint_fn();
|
||||
|
@ -512,6 +526,9 @@ bool type_checker::is_prop(expr const & t) { return m_ptr->is_prop(t); }
|
|||
expr type_checker::whnf(expr const & t) { return m_ptr->whnf(t); }
|
||||
expr type_checker::ensure_pi(expr const & t, expr const & s) { return m_ptr->ensure_pi(t, s); }
|
||||
expr type_checker::ensure_sort(expr const & t, expr const & s) { return m_ptr->ensure_sort(t, s); }
|
||||
void type_checker::push() { m_ptr->push(); }
|
||||
void type_checker::pop() { m_ptr->pop(); }
|
||||
unsigned type_checker::num_scopes() const { return m_ptr->num_scopes(); }
|
||||
|
||||
static void check_no_metavar(environment const & env, expr const & e) {
|
||||
if (has_metavar(e))
|
||||
|
|
|
@ -100,6 +100,13 @@ public:
|
|||
/** \brief Mare sure type of \c e is a sort, and return it. Throw an exception otherwise. */
|
||||
expr ensure_type(expr const & e) { return ensure_sort(infer(e), e); }
|
||||
|
||||
/** \brief Create a backtracking point for cache and generated constraints. */
|
||||
void push();
|
||||
/** \brief Restore backtracking point. */
|
||||
void pop();
|
||||
/** \brief Return the number of backtracking points. */
|
||||
unsigned num_scopes() const;
|
||||
|
||||
void swap(type_checker & tc) { std::swap(m_ptr, tc.m_ptr); }
|
||||
};
|
||||
|
||||
|
|
|
@ -1830,7 +1830,7 @@ static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> &
|
|||
extra_opaque = get_name_set_named_param(L, idx, "extra_opaque", name_set());
|
||||
}
|
||||
|
||||
int mk_type_checker(lua_State * L) {
|
||||
static int mk_type_checker(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs == 1) {
|
||||
return push_type_checker_ref(L, std::make_shared<type_checker>(to_environment(L, 1)));
|
||||
|
@ -1857,29 +1857,37 @@ int mk_type_checker(lua_State * L) {
|
|||
}
|
||||
}
|
||||
}
|
||||
int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); }
|
||||
int type_checker_ensure_pi(lua_State * L) {
|
||||
static int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); }
|
||||
static int type_checker_ensure_pi(lua_State * L) {
|
||||
if (lua_gettop(L) == 2)
|
||||
return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2)));
|
||||
else
|
||||
return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2), to_expr(L, 3)));
|
||||
}
|
||||
int type_checker_ensure_sort(lua_State * L) {
|
||||
static int type_checker_ensure_sort(lua_State * L) {
|
||||
if (lua_gettop(L) == 2)
|
||||
return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2)));
|
||||
else
|
||||
return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2), to_expr(L, 3)));
|
||||
}
|
||||
int type_checker_check(lua_State * L) {
|
||||
static int type_checker_check(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs <= 2)
|
||||
return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), level_param_names()));
|
||||
else
|
||||
return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), to_level_param_names(L, 3)));
|
||||
}
|
||||
int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); }
|
||||
int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); }
|
||||
int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); }
|
||||
static int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); }
|
||||
static int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); }
|
||||
static int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); }
|
||||
static int type_checker_push(lua_State * L) { to_type_checker_ref(L, 1)->push(); return 0; }
|
||||
static int type_checker_pop(lua_State * L) {
|
||||
if (to_type_checker_ref(L, 1)->num_scopes() == 0)
|
||||
throw exception("invalid pop method, type_checker does not have backtracking points");
|
||||
to_type_checker_ref(L, 1)->pop();
|
||||
return 0;
|
||||
}
|
||||
static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); }
|
||||
|
||||
static const struct luaL_Reg type_checker_ref_m[] = {
|
||||
{"__gc", type_checker_ref_gc},
|
||||
|
@ -1890,6 +1898,9 @@ static const struct luaL_Reg type_checker_ref_m[] = {
|
|||
{"infer", safe_function<type_checker_infer>},
|
||||
{"is_def_eq", safe_function<type_checker_is_def_eq>},
|
||||
{"is_prop", safe_function<type_checker_is_prop>},
|
||||
{"push", safe_function<type_checker_push>},
|
||||
{"pop", safe_function<type_checker_pop>},
|
||||
{"num_scopes", safe_function<type_checker_num_scopes>},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
|
|
22
tests/lua/tc8.lua
Normal file
22
tests/lua/tc8.lua
Normal file
|
@ -0,0 +1,22 @@
|
|||
local env = environment()
|
||||
local N = Const("N")
|
||||
env = add_decl(env, mk_var_decl("N", Type))
|
||||
env = add_decl(env, mk_var_decl("f", mk_arrow(N, N)))
|
||||
env = add_decl(env, mk_var_decl("a", N))
|
||||
local f = Const("f")
|
||||
local a = Const("a")
|
||||
local m1 = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("l"))))
|
||||
local cs = {}
|
||||
local ngen = name_generator("tst")
|
||||
local tc = type_checker(env, ngen, function (c) print(c); cs[#cs+1] = c end)
|
||||
assert(tc:num_scopes() == 0)
|
||||
tc:push()
|
||||
assert(tc:num_scopes() == 1)
|
||||
print(tc:check(f(m1)))
|
||||
assert(#cs == 1)
|
||||
print(tc:check(f(f(m1))))
|
||||
assert(#cs == 1) -- New constraint is not generated
|
||||
tc:pop() -- forget that we checked f(m1)
|
||||
print(tc:check(f(m1)))
|
||||
assert(#cs == 2) -- constraint is generated again
|
||||
check_error(function() tc:pop() end)
|
Loading…
Add table
Reference in a new issue