fix(kernel/type_checker): restore type checker cache when a failure occurs, do not send constraints to add_cnstr_fn when a type checker failure occurrs

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-22 09:00:04 -07:00
parent 9c745057b4
commit eca22edda3
3 changed files with 120 additions and 31 deletions

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
#include "util/lbool.h"
#include "util/flet.h"
#include "util/sstream.h"
#include "util/scoped_map.h"
#include "kernel/type_checker.h"
#include "kernel/expr_maps.h"
#include "kernel/instantiate.h"
@ -33,6 +34,8 @@ add_cnstr_fn mk_no_contranint_fn() {
/** \brief Auxiliary functional object used to implement type checker. */
struct type_checker::imp {
typedef scoped_map<expr, expr, expr_hash, is_bi_equal_proc> cache;
/** \brief Interface type_checker <-> converter */
class converter_context : public converter::context {
imp & m_imp;
@ -64,16 +67,63 @@ struct type_checker::imp {
// Examples:
// The type of (lambda x : A, t) is (Pi x : A, typeof(t))
// The type of (lambda {x : A}, t) is (Pi {x : A}, typeof(t))
expr_bi_struct_map<expr> m_infer_type_cache[2];
cache m_infer_type_cache[2];
converter_context m_conv_ctx;
type_checker_context m_tc_ctx;
bool m_memoize;
// temp flag
level_param_names m_params;
buffer<constraint> m_cs; // temporary cache of constraints
bool m_cache_cs; // true if we should cache the constraints; false if we should send to m_add_cnstr_fn
// Auxiliary object used to restore cache and filter constraints
// when a failure occurs in the type checker.
// That is, we should not keep cached results, and we should not sent constraints
// when a failure occurs.
struct scope {
imp & m_imp;
unsigned m_old_cs_size;
bool m_old_cache_cs;
bool m_keep;
scope(imp & i):m_imp(i), m_old_cs_size(m_imp.m_cs.size()), m_old_cache_cs(m_imp.m_cache_cs), m_keep(false) {
m_imp.m_infer_type_cache[0].push();
m_imp.m_infer_type_cache[1].push();
m_imp.m_cache_cs = true;
}
~scope() {
if (m_keep) {
// keep results
m_imp.m_infer_type_cache[0].keep();
m_imp.m_infer_type_cache[1].keep();
} else {
// restore caches
m_imp.m_infer_type_cache[0].pop();
m_imp.m_infer_type_cache[1].pop();
m_imp.m_cs.shrink(m_old_cs_size);
}
m_imp.m_cache_cs = m_old_cache_cs;
}
void keep() {
m_keep = true;
if (!m_old_cache_cs) {
lean_assert(m_old_cs_size == 0);
// send results to m_add_cnstr_fn
try {
for (auto const & c : m_imp.m_cs)
m_imp.m_add_cnstr_fn(c);
} catch (...) {
m_imp.m_cs.clear();
throw;
}
m_imp.m_cs.clear();
}
}
};
imp(environment const & env, name_generator const & g, add_cnstr_fn const & h, std::unique_ptr<converter> && conv, bool memoize):
m_env(env), m_gen(g), m_add_cnstr_fn(h), m_conv(std::move(conv)), m_conv_ctx(*this), m_tc_ctx(*this),
m_memoize(memoize) {}
m_memoize(memoize), m_cache_cs(false) {
}
optional<expr> expand_macro(expr const & m) {
lean_assert(is_macro(m));
@ -91,17 +141,10 @@ struct type_checker::imp {
/** \brief Add given constraint using m_add_cnstr_fn. */
void add_cnstr(constraint const & c) {
m_add_cnstr_fn(c);
}
/** \brief Return true iff \c t and \c s are definitionally equal */
bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst) {
return m_conv->is_def_eq(t, s, m_conv_ctx, jst);
}
/** \brief Return true iff \c e is a proposition */
bool is_prop(expr const & e) {
return whnf(infer_type(e)) == Bool;
if (m_cache_cs)
m_cs.push_back(c);
else
m_add_cnstr_fn(c);
}
/**
@ -157,7 +200,7 @@ struct type_checker::imp {
\remark \c s is used to extract position (line number information) when an
error message is produced
*/
expr ensure_sort(expr e, expr const & s) {
expr ensure_sort_core(expr e, expr const & s) {
if (is_sort(e))
return e;
e = whnf(e);
@ -181,7 +224,7 @@ struct type_checker::imp {
}
/** \brief Similar to \c ensure_sort, but makes sure \c e "is" a Pi. */
expr ensure_pi(expr e, expr const & s) {
expr ensure_pi_core(expr e, expr const & s) {
if (is_pi(e))
return e;
e = whnf(e);
@ -339,16 +382,16 @@ struct type_checker::imp {
case expr_kind::Lambda: {
if (!infer_only) {
expr t = infer_type_core(binding_domain(e), infer_only);
ensure_sort(t, binding_domain(e));
ensure_sort_core(t, binding_domain(e));
}
auto b = open_binding_body(e);
r = mk_pi(binding_name(e), binding_domain(e), abstract_local(infer_type_core(b.first, infer_only), b.second), binding_info(e));
break;
}
case expr_kind::Pi: {
expr t1 = ensure_sort(infer_type_core(binding_domain(e), infer_only), binding_domain(e));
expr t1 = ensure_sort_core(infer_type_core(binding_domain(e), infer_only), binding_domain(e));
auto b = open_binding_body(e);
expr t2 = ensure_sort(infer_type_core(b.first, infer_only), binding_body(e));
expr t2 = ensure_sort_core(infer_type_core(b.first, infer_only), binding_body(e));
if (m_env.impredicative())
r = mk_sort(mk_imax(sort_level(t1), sort_level(t2)));
else
@ -356,7 +399,7 @@ struct type_checker::imp {
break;
}
case expr_kind::App: {
expr f_type = ensure_pi(infer_type_core(app_fn(e), infer_only), app_fn(e));
expr f_type = ensure_pi_core(infer_type_core(app_fn(e), infer_only), app_fn(e));
if (!infer_only) {
expr a_type = infer_type_core(app_arg(e), infer_only);
app_delayed_jst jst(m_env, e, f_type, a_type);
@ -373,7 +416,7 @@ struct type_checker::imp {
}
case expr_kind::Let:
if (!infer_only) {
ensure_sort(infer_type_core(let_type(e), infer_only), let_type(e));
ensure_sort_core(infer_type_core(let_type(e), infer_only), let_type(e));
expr val_type = infer_type_core(let_value(e), infer_only);
simple_delayed_justification jst([=]() { return mk_let_mismatch_jst(e, val_type); });
if (!is_def_eq(val_type, let_type(e), jst)) {
@ -394,16 +437,55 @@ struct type_checker::imp {
return r;
}
expr infer_type(expr const & e) { return infer_type_core(e, true); }
expr check(expr const & e, level_param_names const & ps) {
flet<level_param_names> updt(m_params, ps);
return infer_type_core(e, false);
expr infer_type(expr const & e) {
scope mk_scope(*this);
expr r = infer_type_core(e, true);
mk_scope.keep();
return r;
}
expr check(expr const & e, level_param_names const & ps) {
scope mk_scope(*this);
flet<level_param_names> updt(m_params, ps);
expr r = infer_type_core(e, false);
mk_scope.keep();
return r;
}
expr ensure_sort(expr const & e, expr const & s) {
scope mk_scope(*this);
expr r = ensure_sort_core(e, s);
mk_scope.keep();
return r;
}
expr ensure_pi(expr const & e, expr const & s) {
scope mk_scope(*this);
expr r = ensure_pi_core(e, s);
mk_scope.keep();
return r;
}
/** \brief Return true iff \c t and \c s are definitionally equal */
bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst) {
scope mk_scope(*this);
bool r = m_conv->is_def_eq(t, s, m_conv_ctx, jst);
if (r) mk_scope.keep();
return r;
}
bool is_def_eq(expr const & t, expr const & s) {
scope mk_scope(*this);
bool r = m_conv->is_def_eq(t, s, m_conv_ctx);
if (r) mk_scope.keep();
return r;
}
bool is_def_eq(expr const & t, expr const & s) { return m_conv->is_def_eq(t, s, m_conv_ctx); }
bool is_def_eq(expr const & t, expr const & s, justification const & j) {
as_delayed_justification djst(j);
return is_def_eq(t, s, djst);
}
/** \brief Return true iff \c e is a proposition */
bool is_prop(expr const & e) {
scope mk_scope(*this);
bool r = whnf(infer_type(e)) == Bool;
if (r) mk_scope.keep();
return r;
}
expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); }
};

View file

@ -5,11 +5,16 @@ print(env:normalize(Fun(a, m)))
print(env:normalize(Fun(a, m(a))))
local m2 = mk_metavar("m2", mk_arrow(Bool, Bool, Bool))
print(env:normalize(Fun(a, (m2(a))(a))))
print("step1")
env:type_check(m)
print("step2")
env:type_check(Fun(a, m(a)))
print("step3")
env:type_check(Fun(a, (m2(a))(a)))
local m3 = mk_metavar("m3", mk_metavar("m4", mk_sort(mk_meta_univ("l"))))
print("step4")
env:type_check(m3)
print("step5")
-- The following call fails, because the type checker will try to
-- create a constraint, but constraint generation is not supported by
-- the type checker used to implement the method type_check
@ -17,7 +22,4 @@ assert(not pcall(function()
env:type_check(m3(a))
end
))
print("before end")

View file

@ -7,14 +7,19 @@ local t = Fun(a, Bool, a)
local b = Const("b")
print(t(b))
assert(tc:whnf(t(b)) == b)
local cs = {}
local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end)
assert(tc:check(Bool) == mk_sort(mk_level_one()))
print(tc:infer(t))
local m = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("u"))))
print(tc:infer(m))
local cs = {}
local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end)
local t2 = Fun(a, Bool, m(a))
print("---------")
print("t2: ")
print(t2)
print("check(t): ")
print(tc2:check(t))
print("check(t2): ")
print(tc2:check(t2))
assert(#cs == 2)