fix(light_checker): fix inconsistent cache bug in light_checker, add tests that expose the problem

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-10-01 19:25:58 -07:00
parent aa5be3262f
commit 5bd6ba37d0
2 changed files with 41 additions and 1 deletions

View file

@ -21,6 +21,7 @@ class light_checker::imp {
typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache;
environment m_env;
context m_ctx;
metavar_env * m_menv;
unsigned m_menv_timestamp;
unification_problems * m_up;
@ -164,6 +165,13 @@ class light_checker::imp {
return r;
}
void set_ctx(context const & ctx) {
if (!is_eqp(m_ctx, ctx)) {
clear();
m_ctx = ctx;
}
}
public:
imp(environment const & env):
m_env(env),
@ -175,6 +183,7 @@ public:
}
expr operator()(expr const & e, context const & ctx, metavar_env * menv, unification_problems * up) {
set_ctx(ctx);
set_menv(menv);
flet<unification_problems*> set(m_up, up);
return infer_type(e, ctx);
@ -188,7 +197,8 @@ public:
void clear() {
m_cache.clear();
m_normalizer.clear();
m_menv = nullptr;
m_ctx = context();
m_menv = nullptr;
m_menv_timestamp = 0;
}
};

View file

@ -67,8 +67,38 @@ static void tst2() {
}
}
static void tst3() {
environment env;
import_all(env);
context ctx1, ctx2;
expr A = Const("A");
expr vec1 = Const("vec1");
expr vec2 = Const("vec2");
env.add_var("vec1", Int >> (Type() >> Type()));
env.add_var("vec2", Real >> (Type() >> Type()));
ctx1 = extend(ctx1, "x", Int, iVal(1));
ctx1 = extend(ctx1, "f", Pi({A, Int}, vec1(A, Int)));
ctx2 = extend(ctx2, "x", Real, rVal(2));
ctx2 = extend(ctx2, "f", Pi({A, Real}, vec2(A, Real)));
expr F = Var(0)(Var(1));
expr F_copy = F;
light_checker infer(env);
std::cout << infer(F, ctx1) << "\n";
lean_assert_eq(infer(F, ctx1), vec1(Var(1), Int));
lean_assert_eq(infer(F, ctx2), vec2(Var(1), Real));
lean_assert(is_eqp(infer(F, ctx2), infer(F, ctx2)));
lean_assert(is_eqp(infer(F, ctx1), infer(F, ctx1)));
expr r = infer(F, ctx1);
infer.clear();
lean_assert(!is_eqp(r, infer(F, ctx1)));
r = infer(F, ctx1);
lean_assert(is_eqp(r, infer(F, ctx1)));
}
int main() {
tst1();
tst2();
tst3();
return has_violations() ? 1 : 0;
}