fix(kernel/type_checker): restore type checker cache when a failure occurs, do not send constraints to add_cnstr_fn when a type checker failure occurrs
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
9c745057b4
commit
eca22edda3
3 changed files with 120 additions and 31 deletions
|
@ -10,6 +10,7 @@ Author: Leonardo de Moura
|
|||
#include "util/lbool.h"
|
||||
#include "util/flet.h"
|
||||
#include "util/sstream.h"
|
||||
#include "util/scoped_map.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/expr_maps.h"
|
||||
#include "kernel/instantiate.h"
|
||||
|
@ -33,6 +34,8 @@ add_cnstr_fn mk_no_contranint_fn() {
|
|||
|
||||
/** \brief Auxiliary functional object used to implement type checker. */
|
||||
struct type_checker::imp {
|
||||
typedef scoped_map<expr, expr, expr_hash, is_bi_equal_proc> cache;
|
||||
|
||||
/** \brief Interface type_checker <-> converter */
|
||||
class converter_context : public converter::context {
|
||||
imp & m_imp;
|
||||
|
@ -64,16 +67,63 @@ struct type_checker::imp {
|
|||
// Examples:
|
||||
// The type of (lambda x : A, t) is (Pi x : A, typeof(t))
|
||||
// The type of (lambda {x : A}, t) is (Pi {x : A}, typeof(t))
|
||||
expr_bi_struct_map<expr> m_infer_type_cache[2];
|
||||
cache m_infer_type_cache[2];
|
||||
converter_context m_conv_ctx;
|
||||
type_checker_context m_tc_ctx;
|
||||
bool m_memoize;
|
||||
// temp flag
|
||||
level_param_names m_params;
|
||||
buffer<constraint> m_cs; // temporary cache of constraints
|
||||
bool m_cache_cs; // true if we should cache the constraints; false if we should send to m_add_cnstr_fn
|
||||
|
||||
// Auxiliary object used to restore cache and filter constraints
|
||||
// when a failure occurs in the type checker.
|
||||
// That is, we should not keep cached results, and we should not sent constraints
|
||||
// when a failure occurs.
|
||||
struct scope {
|
||||
imp & m_imp;
|
||||
unsigned m_old_cs_size;
|
||||
bool m_old_cache_cs;
|
||||
bool m_keep;
|
||||
scope(imp & i):m_imp(i), m_old_cs_size(m_imp.m_cs.size()), m_old_cache_cs(m_imp.m_cache_cs), m_keep(false) {
|
||||
m_imp.m_infer_type_cache[0].push();
|
||||
m_imp.m_infer_type_cache[1].push();
|
||||
m_imp.m_cache_cs = true;
|
||||
}
|
||||
~scope() {
|
||||
if (m_keep) {
|
||||
// keep results
|
||||
m_imp.m_infer_type_cache[0].keep();
|
||||
m_imp.m_infer_type_cache[1].keep();
|
||||
} else {
|
||||
// restore caches
|
||||
m_imp.m_infer_type_cache[0].pop();
|
||||
m_imp.m_infer_type_cache[1].pop();
|
||||
m_imp.m_cs.shrink(m_old_cs_size);
|
||||
}
|
||||
m_imp.m_cache_cs = m_old_cache_cs;
|
||||
}
|
||||
void keep() {
|
||||
m_keep = true;
|
||||
if (!m_old_cache_cs) {
|
||||
lean_assert(m_old_cs_size == 0);
|
||||
// send results to m_add_cnstr_fn
|
||||
try {
|
||||
for (auto const & c : m_imp.m_cs)
|
||||
m_imp.m_add_cnstr_fn(c);
|
||||
} catch (...) {
|
||||
m_imp.m_cs.clear();
|
||||
throw;
|
||||
}
|
||||
m_imp.m_cs.clear();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
imp(environment const & env, name_generator const & g, add_cnstr_fn const & h, std::unique_ptr<converter> && conv, bool memoize):
|
||||
m_env(env), m_gen(g), m_add_cnstr_fn(h), m_conv(std::move(conv)), m_conv_ctx(*this), m_tc_ctx(*this),
|
||||
m_memoize(memoize) {}
|
||||
m_memoize(memoize), m_cache_cs(false) {
|
||||
}
|
||||
|
||||
optional<expr> expand_macro(expr const & m) {
|
||||
lean_assert(is_macro(m));
|
||||
|
@ -91,17 +141,10 @@ struct type_checker::imp {
|
|||
|
||||
/** \brief Add given constraint using m_add_cnstr_fn. */
|
||||
void add_cnstr(constraint const & c) {
|
||||
m_add_cnstr_fn(c);
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c t and \c s are definitionally equal */
|
||||
bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst) {
|
||||
return m_conv->is_def_eq(t, s, m_conv_ctx, jst);
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c e is a proposition */
|
||||
bool is_prop(expr const & e) {
|
||||
return whnf(infer_type(e)) == Bool;
|
||||
if (m_cache_cs)
|
||||
m_cs.push_back(c);
|
||||
else
|
||||
m_add_cnstr_fn(c);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -157,7 +200,7 @@ struct type_checker::imp {
|
|||
\remark \c s is used to extract position (line number information) when an
|
||||
error message is produced
|
||||
*/
|
||||
expr ensure_sort(expr e, expr const & s) {
|
||||
expr ensure_sort_core(expr e, expr const & s) {
|
||||
if (is_sort(e))
|
||||
return e;
|
||||
e = whnf(e);
|
||||
|
@ -181,7 +224,7 @@ struct type_checker::imp {
|
|||
}
|
||||
|
||||
/** \brief Similar to \c ensure_sort, but makes sure \c e "is" a Pi. */
|
||||
expr ensure_pi(expr e, expr const & s) {
|
||||
expr ensure_pi_core(expr e, expr const & s) {
|
||||
if (is_pi(e))
|
||||
return e;
|
||||
e = whnf(e);
|
||||
|
@ -339,16 +382,16 @@ struct type_checker::imp {
|
|||
case expr_kind::Lambda: {
|
||||
if (!infer_only) {
|
||||
expr t = infer_type_core(binding_domain(e), infer_only);
|
||||
ensure_sort(t, binding_domain(e));
|
||||
ensure_sort_core(t, binding_domain(e));
|
||||
}
|
||||
auto b = open_binding_body(e);
|
||||
r = mk_pi(binding_name(e), binding_domain(e), abstract_local(infer_type_core(b.first, infer_only), b.second), binding_info(e));
|
||||
break;
|
||||
}
|
||||
case expr_kind::Pi: {
|
||||
expr t1 = ensure_sort(infer_type_core(binding_domain(e), infer_only), binding_domain(e));
|
||||
expr t1 = ensure_sort_core(infer_type_core(binding_domain(e), infer_only), binding_domain(e));
|
||||
auto b = open_binding_body(e);
|
||||
expr t2 = ensure_sort(infer_type_core(b.first, infer_only), binding_body(e));
|
||||
expr t2 = ensure_sort_core(infer_type_core(b.first, infer_only), binding_body(e));
|
||||
if (m_env.impredicative())
|
||||
r = mk_sort(mk_imax(sort_level(t1), sort_level(t2)));
|
||||
else
|
||||
|
@ -356,7 +399,7 @@ struct type_checker::imp {
|
|||
break;
|
||||
}
|
||||
case expr_kind::App: {
|
||||
expr f_type = ensure_pi(infer_type_core(app_fn(e), infer_only), app_fn(e));
|
||||
expr f_type = ensure_pi_core(infer_type_core(app_fn(e), infer_only), app_fn(e));
|
||||
if (!infer_only) {
|
||||
expr a_type = infer_type_core(app_arg(e), infer_only);
|
||||
app_delayed_jst jst(m_env, e, f_type, a_type);
|
||||
|
@ -373,7 +416,7 @@ struct type_checker::imp {
|
|||
}
|
||||
case expr_kind::Let:
|
||||
if (!infer_only) {
|
||||
ensure_sort(infer_type_core(let_type(e), infer_only), let_type(e));
|
||||
ensure_sort_core(infer_type_core(let_type(e), infer_only), let_type(e));
|
||||
expr val_type = infer_type_core(let_value(e), infer_only);
|
||||
simple_delayed_justification jst([=]() { return mk_let_mismatch_jst(e, val_type); });
|
||||
if (!is_def_eq(val_type, let_type(e), jst)) {
|
||||
|
@ -394,16 +437,55 @@ struct type_checker::imp {
|
|||
return r;
|
||||
}
|
||||
|
||||
expr infer_type(expr const & e) { return infer_type_core(e, true); }
|
||||
expr check(expr const & e, level_param_names const & ps) {
|
||||
flet<level_param_names> updt(m_params, ps);
|
||||
return infer_type_core(e, false);
|
||||
expr infer_type(expr const & e) {
|
||||
scope mk_scope(*this);
|
||||
expr r = infer_type_core(e, true);
|
||||
mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
expr check(expr const & e, level_param_names const & ps) {
|
||||
scope mk_scope(*this);
|
||||
flet<level_param_names> updt(m_params, ps);
|
||||
expr r = infer_type_core(e, false);
|
||||
mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
expr ensure_sort(expr const & e, expr const & s) {
|
||||
scope mk_scope(*this);
|
||||
expr r = ensure_sort_core(e, s);
|
||||
mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
expr ensure_pi(expr const & e, expr const & s) {
|
||||
scope mk_scope(*this);
|
||||
expr r = ensure_pi_core(e, s);
|
||||
mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
/** \brief Return true iff \c t and \c s are definitionally equal */
|
||||
bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst) {
|
||||
scope mk_scope(*this);
|
||||
bool r = m_conv->is_def_eq(t, s, m_conv_ctx, jst);
|
||||
if (r) mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
bool is_def_eq(expr const & t, expr const & s) {
|
||||
scope mk_scope(*this);
|
||||
bool r = m_conv->is_def_eq(t, s, m_conv_ctx);
|
||||
if (r) mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
bool is_def_eq(expr const & t, expr const & s) { return m_conv->is_def_eq(t, s, m_conv_ctx); }
|
||||
bool is_def_eq(expr const & t, expr const & s, justification const & j) {
|
||||
as_delayed_justification djst(j);
|
||||
return is_def_eq(t, s, djst);
|
||||
}
|
||||
/** \brief Return true iff \c e is a proposition */
|
||||
bool is_prop(expr const & e) {
|
||||
scope mk_scope(*this);
|
||||
bool r = whnf(infer_type(e)) == Bool;
|
||||
if (r) mk_scope.keep();
|
||||
return r;
|
||||
}
|
||||
expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); }
|
||||
};
|
||||
|
||||
|
|
|
@ -5,11 +5,16 @@ print(env:normalize(Fun(a, m)))
|
|||
print(env:normalize(Fun(a, m(a))))
|
||||
local m2 = mk_metavar("m2", mk_arrow(Bool, Bool, Bool))
|
||||
print(env:normalize(Fun(a, (m2(a))(a))))
|
||||
print("step1")
|
||||
env:type_check(m)
|
||||
print("step2")
|
||||
env:type_check(Fun(a, m(a)))
|
||||
print("step3")
|
||||
env:type_check(Fun(a, (m2(a))(a)))
|
||||
local m3 = mk_metavar("m3", mk_metavar("m4", mk_sort(mk_meta_univ("l"))))
|
||||
print("step4")
|
||||
env:type_check(m3)
|
||||
print("step5")
|
||||
-- The following call fails, because the type checker will try to
|
||||
-- create a constraint, but constraint generation is not supported by
|
||||
-- the type checker used to implement the method type_check
|
||||
|
@ -17,7 +22,4 @@ assert(not pcall(function()
|
|||
env:type_check(m3(a))
|
||||
end
|
||||
))
|
||||
|
||||
|
||||
|
||||
|
||||
print("before end")
|
||||
|
|
|
@ -7,14 +7,19 @@ local t = Fun(a, Bool, a)
|
|||
local b = Const("b")
|
||||
print(t(b))
|
||||
assert(tc:whnf(t(b)) == b)
|
||||
local cs = {}
|
||||
local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end)
|
||||
assert(tc:check(Bool) == mk_sort(mk_level_one()))
|
||||
print(tc:infer(t))
|
||||
local m = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("u"))))
|
||||
print(tc:infer(m))
|
||||
|
||||
local cs = {}
|
||||
local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end)
|
||||
local t2 = Fun(a, Bool, m(a))
|
||||
print("---------")
|
||||
print("t2: ")
|
||||
print(t2)
|
||||
print("check(t): ")
|
||||
print(tc2:check(t))
|
||||
print("check(t2): ")
|
||||
print(tc2:check(t2))
|
||||
assert(#cs == 2)
|
||||
|
|
Loading…
Reference in a new issue