Add simplification rule to add_inst

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-09-17 02:57:28 -07:00
parent 21c7a45f67
commit 99e8d2feae
2 changed files with 38 additions and 5 deletions

View file

@ -119,7 +119,10 @@ expr instantiate_metavars(expr const & e, metavar_env const & env) {
}
meta_ctx add_lift(meta_ctx const & ctx, unsigned s, unsigned n) {
return cons(mk_lift(s, n), ctx);
if (n == 0)
return ctx;
else
return cons(mk_lift(s, n), ctx);
}
expr add_lift(expr const & m, unsigned s, unsigned n) {
@ -130,10 +133,15 @@ meta_ctx add_inst(meta_ctx const & ctx, unsigned s, expr const & v) {
if (ctx) {
meta_entry e = head(ctx);
if (e.is_lift() && e.s() <= s && s < e.s() + e.n()) {
if (e.n() == 1)
return tail(ctx);
else
return add_lift(tail(ctx), e.s(), e.n() - 1);
return add_lift(tail(ctx), e.s(), e.n() - 1);
}
// Simplifications such as
// inst:4 #6 lift:5:3 --> lift:4:2
// inst:3 #7 lift:4:5 --> lift:3:4
// General rule is:
// inst:(s-1) #(s+n-2) lift:s:n --> lift:s-1:n-1
if (e.is_lift() && is_var(v) && e.s() > 0 && s == e.s() - 1 && e.s() + e.n() > 2 && var_idx(v) == e.s() + e.n() - 2) {
return add_lift(tail(ctx), e.s() - 1, e.n() - 1);
}
}
return cons(mk_inst(s, v), ctx);

View file

@ -406,6 +406,30 @@ static void tst19() {
expr F = Fun({{N, Type()}, {x, N}, {y, N}}, m1);
std::cout << norm(F) << "\n";
std::cout << norm(F, ctx) << "\n";
lean_assert(norm(F) == F);
lean_assert(norm(F, ctx) == F);
}
static void tst20() {
environment env;
metavar_env menv;
normalizer norm(env);
context ctx;
ctx = extend(ctx, "w1", Type());
ctx = extend(ctx, "w2", Type());
expr m1 = menv.mk_metavar();
expr x = Const("x");
expr y = Const("y");
expr z = Const("z");
expr N = Const("N");
expr a = Const("a");
expr b = Const("b");
env.add_var("N", Type());
env.add_var("a", N);
env.add_var("b", N);
expr F = Fun({{x, N}, {y, N}, {z, N}}, Fun({{x, N}, {y, N}}, m1)(a, b));
std::cout << norm(F) << "\n";
std::cout << norm(F, ctx) << "\n";
}
int main() {
@ -428,5 +452,6 @@ int main() {
tst17();
tst18();
tst19();
tst20();
return has_violations() ? 1 : 0;
}