test(library/rewriter): add lambda_rewrite tests

This commit is contained in:
Soonho Kong 2013-12-01 15:21:55 -05:00
parent 75f8d56eb1
commit 0553d29078

View file

@ -638,6 +638,99 @@ static void depth_rewriter1_tst() {
cout << "====================================================" << std::endl; cout << "====================================================" << std::endl;
} }
static void lambda_rewriter1_tst() {
cout << "=== lambda_rewriter1_tst() ===" << std::endl;
// Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y
// Term : f (a + b)
// Result : (f (b + a), ADD_COMM a b)
expr a = Const("a"); // a : Nat
expr b = Const("b"); // b : Nat
expr f1 = Const("f1"); // f : Nat -> Nat
expr f2 = Const("f2"); // f : Nat -> Nat -> Nat
expr f3 = Const("f3"); // f : Nat -> Nat -> Nat -> Nat
expr f4 = Const("f4"); // f : Nat -> Nat -> Nat -> Nat -> Nat
expr zero = nVal(0); // zero : Nat
expr a_plus_b = nAdd(a, b);
expr b_plus_a = nAdd(b, a);
expr add_comm_thm_type = Pi("x", Nat,
Pi("y", Nat,
Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x")))));
expr add_comm_thm_body = Const("ADD_COMM");
environment env = mk_toplevel();
env.add_var("f1", Nat >> Nat);
env.add_var("f2", Nat >> (Nat >> Nat));
env.add_var("f3", Nat >> (Nat >> (Nat >> Nat)));
env.add_var("f4", Nat >> (Nat >> (Nat >> (Nat >> Nat))));
env.add_var("a", Nat);
env.add_var("b", Nat);
env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z
// Rewriting
rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body);
rewriter lambda_rewriter = mk_lambda_body_rewriter(add_comm_thm_rewriter);
context ctx;
cout << "RW = " << lambda_rewriter << std::endl;
expr v = mk_lambda("x", Nat, nAdd(b, a));
pair<expr, expr> result = lambda_rewriter(env, ctx, v);
expr concl = mk_eq(v, result.first);
expr proof = result.second;
cout << "Concl = " << concl << std::endl
<< "Proof = " << proof << std::endl;
lean_assert_eq(concl, mk_eq(v, mk_lambda("x", Nat, nAdd(a, b))));
env.add_theorem("lambda_rewriter1", concl, proof);
cout << "====================================================" << std::endl;
}
static void lambda_rewriter2_tst() {
cout << "=== lambda_rewriter2_tst() ===" << std::endl;
// Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y
// Term : f (a + b)
// Result : (f (b + a), ADD_COMM a b)
expr a = Const("a"); // a : Nat
expr b = Const("b"); // b : Nat
expr f1 = Const("f1"); // f : Nat -> Nat
expr f2 = Const("f2"); // f : Nat -> Nat -> Nat
expr f3 = Const("f3"); // f : Nat -> Nat -> Nat -> Nat
expr f4 = Const("f4"); // f : Nat -> Nat -> Nat -> Nat -> Nat
expr zero = nVal(0); // zero : Nat
expr a_plus_b = nAdd(a, b);
expr b_plus_a = nAdd(b, a);
expr add_comm_thm_type = Pi("x", Nat,
Pi("y", Nat,
Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x")))));
expr add_comm_thm_body = Const("ADD_COMM");
environment env = mk_toplevel();
env.add_var("f1", Nat >> Nat);
env.add_var("f2", Nat >> (Nat >> Nat));
env.add_var("f3", Nat >> (Nat >> (Nat >> Nat)));
env.add_var("f4", Nat >> (Nat >> (Nat >> (Nat >> Nat))));
env.add_var("a", Nat);
env.add_var("b", Nat);
env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z
// Rewriting
rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body);
rewriter lambda_rewriter = mk_lambda_body_rewriter(add_comm_thm_rewriter);
context ctx;
cout << "RW = " << lambda_rewriter << std::endl;
expr v = mk_lambda("x", Nat, nAdd(Var(0), a));
pair<expr, expr> result = lambda_rewriter(env, ctx, v);
expr concl = mk_eq(v, result.first);
expr proof = result.second;
cout << "Concl = " << concl << std::endl
<< "Proof = " << proof << std::endl;
lean_assert_eq(concl, mk_eq(v, mk_lambda("x", Nat, nAdd(a, Var(0)))));
// TODO(soonhok): this one doesn't work now.
// env.add_theorem("lambda_rewriter2", concl, proof);
cout << "====================================================" << std::endl;
}
int main() { int main() {
save_stack_info(); save_stack_info();
theorem_rewriter1_tst(); theorem_rewriter1_tst();
@ -652,5 +745,7 @@ int main() {
repeat_rewriter1_tst(); repeat_rewriter1_tst();
repeat_rewriter2_tst(); repeat_rewriter2_tst();
depth_rewriter1_tst(); depth_rewriter1_tst();
lambda_rewriter1_tst();
lambda_rewriter2_tst();
return has_violations() ? 1 : 0; return has_violations() ? 1 : 0;
} }