From 81c9de229b8d633f6a7bc964da42b82258697bf2 Mon Sep 17 00:00:00 2001 From: Soonho Kong Date: Mon, 23 Sep 2013 18:53:39 -0700 Subject: [PATCH] Add then and orelse rewrite combinators and tests --- src/library/rewrite/fo_match.cpp | 106 +++-------- src/library/rewrite/fo_match.h | 2 +- src/library/rewrite/rewrite.cpp | 105 +++++++---- src/library/rewrite/rewrite.h | 49 ++++- src/tests/library/rewrite/rewrite.cpp | 247 +++++++++++++++++++++++--- 5 files changed, 366 insertions(+), 143 deletions(-) diff --git a/src/library/rewrite/fo_match.cpp b/src/library/rewrite/fo_match.cpp index 6c07d77ce..4f4583f55 100644 --- a/src/library/rewrite/fo_match.cpp +++ b/src/library/rewrite/fo_match.cpp @@ -5,6 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Soonho Kong */ #include +#include "util/trace.h" #include "kernel/expr.h" #include "kernel/context.h" #include "library/all/all.h" @@ -30,64 +31,45 @@ std::ostream & operator<<(std::ostream & out, subst_map & s) { } bool fo_match::match_var(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_var : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; - + lean_trace("fo_match", tout << "match_var : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); unsigned idx = var_idx(p); if (idx < o) { // Current variable is the one created by lambda inside of pattern - // and it is not a target of pattern matching. + // and it is *not* a target of pattern matching. return p == t; } else { - auto it = s.find(p); + auto it = s.find(idx); if (it != s.end()) { // This variable already has an entry in the substitution // map. We need to make sure that 't' and s[idx] are the // same - cout << "match_var exist:" << p << " |-> " << it->second << endl; + lean_trace("fo_match", tout << "match_var exist:" << idx << " |-> " << it->second << endl;); return it->second == t; } // This variable has no entry in the substituition map. Let's // add one. - s.insert(std::make_pair(p, t)); - cout << "match_var MATCHED : " << s << endl; + s.insert(std::make_pair(idx, t)); + lean_trace("fo_match", tout << "match_var MATCHED : " << s << endl;); return true; } } bool fo_match::match_constant(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_constant : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; + lean_trace("fo_match", tout << "match_constant : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); return p == t; } bool fo_match::match_value(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_value : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; + lean_trace("fo_match", tout << "match_value : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); return p == t; } bool fo_match::match_app(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_app : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; + lean_trace("fo_match", tout << "match_app : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + if (!is_app(t)) + return false; unsigned num_p = num_args(p); - unsigned num_t = num_args(p); + unsigned num_t = num_args(t); if (num_p != num_t) { return false; } @@ -100,15 +82,8 @@ bool fo_match::match_app(expr const & p, expr const & t, unsigned o, subst_map & } bool fo_match::match_lambda(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_lambda : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; - cout << "fun (" << abst_name(p) - << " : " << abst_domain(p) - << "), " << abst_body(p) << endl; + lean_trace("fo_match", tout << "match_lambda : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + lean_trace("fo_match", tout << "fun (" << abst_name(p) << " : " << abst_domain(p) << "), " << abst_body(p) << endl;); if (!is_lambda(t)) { return false; } else { @@ -126,16 +101,8 @@ bool fo_match::match_lambda(expr const & p, expr const & t, unsigned o, subst_ma } bool fo_match::match_pi(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_pi : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; - cout << "Pi (" << abst_name(p) - << " : " << abst_domain(p) - << "), " << abst_body(p) << endl; - + lean_trace("fo_match", tout << "match_pi : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + lean_trace("fo_match", tout << "Pi (" << abst_name(p) << " : " << abst_domain(p) << "), " << abst_body(p) << endl;); if (!is_pi(t)) { return false; } else { @@ -153,33 +120,19 @@ bool fo_match::match_pi(expr const & p, expr const & t, unsigned o, subst_map & } bool fo_match::match_type(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_type : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; + lean_trace("fo_match", tout << "match_type : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); return p == t; } bool fo_match::match_eq(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_eq : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; + lean_trace("fo_match", tout << "match_eq : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + if (!is_eq(t)) + return false; return match(eq_lhs(p), eq_lhs(t), o, s) && match(eq_rhs(p), eq_rhs(t), o, s); } bool fo_match::match_let(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_let : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; - + lean_trace("fo_match", tout << "match_let : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); if (!is_let(t)) { return false; } else { @@ -202,23 +155,12 @@ bool fo_match::match_let(expr const & p, expr const & t, unsigned o, subst_map & } } bool fo_match::match_metavar(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match_meta : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; + lean_trace("fo_match", tout << "match_meta : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); return p == t; } bool fo_match::match(expr const & p, expr const & t, unsigned o, subst_map & s) { - cout << "match : (" - << p << ", " - << t << ", " - << o << ", " - << s << ")" - << endl; - + lean_trace("fo_match", tout << "match : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); switch (p.kind()) { case expr_kind::Var: return match_var(p, t, o, s); diff --git a/src/library/rewrite/fo_match.h b/src/library/rewrite/fo_match.h index 3fd09934d..404b61285 100644 --- a/src/library/rewrite/fo_match.h +++ b/src/library/rewrite/fo_match.h @@ -17,7 +17,7 @@ Author: Soonho Kong namespace lean { -typedef expr_map subst_map; +using subst_map = std::unordered_map; class fo_match { private: diff --git a/src/library/rewrite/rewrite.cpp b/src/library/rewrite/rewrite.cpp index fbe858121..16db39870 100644 --- a/src/library/rewrite/rewrite.cpp +++ b/src/library/rewrite/rewrite.cpp @@ -1,13 +1,16 @@ /* -Copyright (c) 2013 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. + Copyright (c) 2013 Microsoft Corporation. All rights reserved. + Released under Apache 2.0 license as described in the file LICENSE. -Author: Soonho Kong + Author: Soonho Kong */ +#include "util/trace.h" #include "kernel/abstract.h" #include "kernel/builtin.h" #include "kernel/context.h" #include "kernel/environment.h" +#include "kernel/replace.h" +#include "library/basic_thms.h" #include "library/printer.h" #include "library/rewrite/fo_match.h" #include "library/rewrite/rewrite.h" @@ -27,51 +30,85 @@ Author: Soonho Kong using std::cout; using std::endl; +using std::pair; +using std::make_pair; namespace lean { -theorem_rw::theorem_rw(expr const & t, expr const & v) - : thm_t(t), thm_v(v), num_args(0) { - cout << "================= Theorem Rewrite Constructor Start ===================" << endl; - cout << "Type = " << thm_t << endl; - cout << "Body = " << thm_v << endl; +theorem_rewrite::theorem_rewrite(expr const & type, expr const & body) + : thm_type(type), thm_body(body), num_args(0) { + lean_trace("rewrite", tout << "Type = " << thm_type << endl;); + lean_trace("rewrite", tout << "Body = " << thm_body << endl;); // We expect the theorem is in the form of - // Pi (x_1 : t_1 ... x_n : t_n), t = s - expr tmp = t; - while (is_pi(tmp)) { - tmp = abst_body(tmp); + // Pi (x_1 : t_1 ... x_n : t_n), pattern = rhs + pattern = type; + while (is_pi(pattern)) { + pattern = abst_body(pattern); num_args++; } - if (!is_eq(tmp)) { - cout << "Theorem " << t << " is not in the form of " - << "Pi (x_1 : t_1 ... x_n : t_n), t = s" << endl; + if (!is_eq(pattern)) { + lean_trace("rewrite", tout << "Theorem " << thm_type << " is not in the form of " + << "Pi (x_1 : t_1 ... x_n : t_n), pattern = rhs" << endl;); } - cout << "OK. Number of Arg = " << num_args << endl; - cout << "================= Theorem Rewrite Constructor END ===================" << endl; + rhs = eq_rhs(pattern); + pattern = eq_lhs(pattern); + + lean_trace("rewrite", tout << "Number of Arg = " << num_args << endl;); } -void theorem_rw::operator()(context & ctx, expr const & t) { - cout << "================= Theorem Rewrite () START ===================" << endl; - cout << "Context = " << ctx << endl; - cout << "Term = " << t << endl; - expr tmp = thm_t; - while (is_pi(tmp)) { - tmp = abst_body(tmp); - num_args++; - } - if (!is_eq(tmp)) { - cout << "Theorem " << t << " is not in the form of " - << "Pi (x_1 : t_1 ... x_n : t_n), t = s" << endl; - } - expr const & lhs = eq_lhs(tmp); - expr const & rhs = eq_rhs(tmp); +pair theorem_rewrite::operator()(context & ctx, expr const & v, expr const & ) const throw(rewrite_exception) { + lean_trace("rewrite", tout << "Context = " << ctx << endl;); + lean_trace("rewrite", tout << "Term = " << v << endl;); + lean_trace("rewrite", tout << "Pattern = " << pattern << endl;); + lean_trace("rewrite", tout << "Num Args = " << num_args << endl;); + fo_match fm; subst_map s; - fm.match(lhs, t, 0, s); + if (!fm.match(pattern, v, 0, s)) { + throw rewrite_exception(); + } + // apply s to rhs + auto f = [&s](expr const & e, unsigned offset) -> expr { + if (is_var(e)) { + lean_trace("rewrite", tout << "Inside of apply : offset = " << offset + << ", e = " << e + << ", idx = " << var_idx(e) << endl;); + unsigned idx = var_idx(e); + auto it = s.find(idx); + if (it != s.end()) { + lean_trace("rewrite", tout << "Inside of apply : s[" << idx << "] = " << s[idx] << endl;); + return s[idx]; + } + } + return e; + }; - cout << "================= Theorem Rewrite () END ===================" << endl; + expr new_rhs = replace_fn(f)(rhs); + lean_trace("rewrite", tout << "New RHS = " << new_rhs << endl;); + + expr proof = thm_body; + for (int i = num_args -1 ; i >= 0; i--) { + proof = mk_app(proof, s[i]); + lean_trace("rewrite", tout << "proof: " << i << "\t" << s[i] << "\t" << proof << endl;); + } + lean_trace("rewrite", tout << "Proof = " << proof << endl;); + return make_pair(new_rhs, proof); } +pair orelse_rewrite::operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception) { + try { + return rewrite1(ctx, v, t); + } catch (rewrite_exception & ) { + return rewrite2(ctx, v, t); + } +} + +pair then_rewrite::operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception) { + pair result1 = rewrite1(ctx, v, t); + pair result2 = rewrite2(ctx, result1.first, t); + return make_pair(result2.first, + Trans(t, v, result1.first, result2.first, result1.second, result2.second)); +} } diff --git a/src/library/rewrite/rewrite.h b/src/library/rewrite/rewrite.h index cb81d63f6..8f35ca8ac 100644 --- a/src/library/rewrite/rewrite.h +++ b/src/library/rewrite/rewrite.h @@ -5,17 +5,56 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Soonho Kong */ #pragma once +#include +#include "util/exception.h" namespace lean { -class theorem_rw { +class rewrite_exception : public exception { +}; + +class rewrite { +public: + virtual std::pair operator()(context & ctx, expr const & v, expr const & t) const = 0; +}; + +class theorem_rewrite : public rewrite { private: - expr const & thm_t; - expr const & thm_v; + expr const & thm_type; + expr const & thm_body; + expr pattern; + expr rhs; unsigned num_args; public: - theorem_rw(expr const & t, expr const & v); - void operator()(context & ctx, expr const & t); + theorem_rewrite(expr const & type, expr const & body); + std::pair operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception); +}; + +class orelse_rewrite : public rewrite { +private: + rewrite const & rewrite1; + rewrite const & rewrite2; +public: + orelse_rewrite(rewrite const & rw1, rewrite const & rw2) : + rewrite1(rw1), rewrite2(rw2) { } + std::pair operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception); +}; + +class then_rewrite : public rewrite { +private: + rewrite const & rewrite1; + rewrite const & rewrite2; +public: + then_rewrite(rewrite const & rw1, rewrite const & rw2) : + rewrite1(rw1), rewrite2(rw2) { } + std::pair operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception); +}; + +class fail_rewrite : public rewrite { +public: + std::pair operator()(context &, expr const &) const throw(rewrite_exception) { + throw rewrite_exception(); + } }; } diff --git a/src/tests/library/rewrite/rewrite.cpp b/src/tests/library/rewrite/rewrite.cpp index 78dbd269d..640e7f60a 100644 --- a/src/tests/library/rewrite/rewrite.cpp +++ b/src/tests/library/rewrite/rewrite.cpp @@ -4,42 +4,247 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Soonho Kong */ +#include "util/trace.h" #include "kernel/abstract.h" #include "kernel/context.h" #include "kernel/expr.h" +#include "kernel/type_checker.h" #include "library/all/all.h" #include "library/arith/arith.h" #include "library/arith/nat.h" #include "library/rewrite/fo_match.h" #include "library/rewrite/rewrite.h" +#include "library/basic_thms.h" #include "library/printer.h" using namespace lean; using std::cout; +using std::pair; using std::endl; -int main() { +static void theorem_rewrite1_tst() { + cout << "=== theorem_rewrite1_tst() ===" << endl; + // Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y + // Term : a + b + // Result : (b + a, ADD_COMM a b) + expr a = Const("a"); // a : Nat + expr b = Const("b"); // b : Nat + expr a_plus_b = nAdd(a, b); + expr b_plus_a = nAdd(b, a); + expr add_comm_thm_type = Pi("x", Nat, + Pi("y", Nat, + Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); + expr add_comm_thm_body = Const("ADD_COMM"); + environment env = mk_toplevel(); - env.add_var("x", Nat); - expr x = Const("x"); // x : Nat - expr y = Const("y"); // y : Nat - expr zero = nVal(0); // 0 : Nat - expr x_plus_zero = nAdd(x, zero); // x_plus_zero := x + 0 - expr y_plus_zero = nAdd(y, zero); // y_plus_zero := y + 0 - cout << "x := " << x << endl; - cout << "y := " << y << endl; - cout << "x + 0 := " << x_plus_zero << endl; - cout << "x + 0 := " << y_plus_zero << endl; - //env.display(cout); + env.add_var("a", Nat); + env.add_var("b", Nat); + env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z - expr thm_t = Pi("x", Nat, Eq(nAdd(Const("x"), nVal(0)), Const("x"))); // Pi (x : Z), x + 0 = x - cout << "theorem := " << thm_t << endl; - env.add_axiom("H1", thm_t); // H1 : Pi (x : N), x = x + 0 - expr thm_v = Const("H1"); - cout << "axiom := " << thm_v << endl; - - theorem_rw trw(thm_t, thm_v); + // Rewriting + theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body); context ctx; - trw(ctx, y_plus_zero); - return 0; + pair result = add_comm_thm_rewriter(ctx, a_plus_b, Nat); + expr concl = mk_eq(a_plus_b, result.first); + expr proof = result.second; + + cout << "Theorem: " << add_comm_thm_type << " := " << add_comm_thm_body << endl; + cout << " " << concl << " := " << proof << endl; + + lean_assert(concl == mk_eq(a_plus_b, b_plus_a)); + lean_assert(proof == mk_app(mk_app(Const("ADD_COMM"), a), b)); + env.add_theorem("New_theorem1", concl, proof); +} + +static void theorem_rewrite2_tst() { + cout << "=== theorem_rewrite2_tst() ===" << endl; + // Theorem: Pi(x : N), x + 0 = x := ADD_ID x + // Term : a + 0 + // Result : (a, ADD_ID a) + expr a = Const("a"); // a : at + expr zero = nVal(0); // zero : Nat + expr a_plus_zero = nAdd(a, zero); + expr add_id_thm_type = Pi("x", Nat, + Eq(nAdd(Const("x"), zero), Const("x"))); + expr add_id_thm_body = Const("ADD_ID"); + + environment env = mk_toplevel(); + env.add_var("a", Nat); + env.add_axiom("ADD_ID", add_id_thm_type); // ADD_ID : Pi (x : N), x = x + 0 + + // Rewriting + theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body); + context ctx; + pair result = add_id_thm_rewriter(ctx, a_plus_zero, Nat); + expr concl = mk_eq(a_plus_zero, result.first); + expr proof = result.second; + + cout << "Theorem: " << add_id_thm_type << " := " << add_id_thm_body << endl; + cout << " " << concl << " := " << proof << endl; + + lean_assert(concl == mk_eq(a_plus_zero, a)); + lean_assert(proof == mk_app(Const("ADD_ID"), a)); + env.add_theorem("New_theorem2", concl, proof); +} + +static void then_rewrite1_tst() { + cout << "=== then_rewrite1_tst() ===" << endl; + // Theorem1: Pi(x y : N), x + y = y + x := ADD_COMM x y + // Theorem2: Pi(x : N) , x + 0 = x := ADD_ID x + // Term : 0 + a + // Result : (a, TRANS (ADD_COMM 0 a) (ADD_ID a)) + + expr a = Const("a"); // a : Nat + expr zero = nVal(0); // zero : Nat + expr a_plus_zero = nAdd(a, zero); + expr zero_plus_a = nAdd(zero, a); + expr add_comm_thm_type = Pi("x", Nat, + Pi("y", Nat, + Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); + expr add_comm_thm_body = Const("ADD_COMM"); + expr add_id_thm_type = Pi("x", Nat, + Eq(nAdd(Const("x"), zero), Const("x"))); + expr add_id_thm_body = Const("ADD_ID"); + + environment env = mk_toplevel(); + env.add_var("a", Nat); + env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z + env.add_axiom("ADD_ID", add_id_thm_type); // ADD_ID : Pi (x : N), x = x + 0 + + // Rewriting + theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body); + theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body); + then_rewrite then_rewriter1(add_comm_thm_rewriter, add_id_thm_rewriter); + context ctx; + pair result = then_rewriter1(ctx, zero_plus_a, Nat); + expr concl = mk_eq(zero_plus_a, result.first); + expr proof = result.second; + + cout << "Theorem: " << add_comm_thm_type << " := " << add_comm_thm_body << endl; + cout << "Theorem: " << add_id_thm_type << " := " << add_id_thm_body << endl; + cout << " " << concl << " := " << proof << endl; + + lean_assert(concl == mk_eq(zero_plus_a, a)); + lean_assert(proof == Trans(Nat, zero_plus_a, a_plus_zero, a, + mk_app(mk_app(Const("ADD_COMM"), zero), a), mk_app(Const("ADD_ID"), a))); + + env.add_theorem("New_theorem3", concl, proof); +} + +static void then_rewrite2_tst() { + cout << "=== then_rewrite2_tst() ===" << endl; + // Theorem1: Pi(x y z: N), x + (y + z) = (x + y) + z := ADD_ASSOC x y z + // Theorem2: Pi(x y : N), x + y = y + x := ADD_COMM x y + // Theorem3: Pi(x : N), x + 0 = x := ADD_ID x + // Term : 0 + (a + 0) + // Result : (a, TRANS (ADD_ASSOC 0 a 0) // (0 + a) + 0 + // (ADD_ID (0 + a)) // 0 + a + // (ADD_COMM 0 a) // a + 0 + // (ADD_ID a)) // a + + expr a = Const("a"); // a : Nat + expr zero = nVal(0); // zero : Nat + expr zero_plus_a = nAdd(zero, a); + expr a_plus_zero = nAdd(a, zero); + expr zero_plus_a_plus_zero = nAdd(zero, nAdd(a, zero)); + expr zero_plus_a_plus_zero_ = nAdd(nAdd(zero, a), zero); + expr add_assoc_thm_type = Pi("x", Nat, + Pi("y", Nat, + Pi("z", Nat, + Eq(nAdd(Const("x"), nAdd(Const("y"), Const("z"))), + nAdd(nAdd(Const("x"), Const("y")), Const("z")))))); + expr add_assoc_thm_body = Const("ADD_ASSOC"); + expr add_comm_thm_type = Pi("x", Nat, + Pi("y", Nat, + Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); + expr add_comm_thm_body = Const("ADD_COMM"); + expr add_id_thm_type = Pi("x", Nat, + Eq(nAdd(Const("x"), zero), Const("x"))); + expr add_id_thm_body = Const("ADD_ID"); + + environment env = mk_toplevel(); + env.add_var("a", Nat); + env.add_axiom("ADD_ASSOC", add_assoc_thm_type); // ADD_ASSOC : Pi (x, y, z : N), x + (y + z) = (x + y) + z + env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z + env.add_axiom("ADD_ID", add_id_thm_type); // ADD_ID : Pi (x : N), x = x + 0 + + // Rewriting + theorem_rewrite add_assoc_thm_rewriter(add_assoc_thm_type, add_assoc_thm_body); + theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body); + theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body); + then_rewrite then_rewriter2(then_rewrite(add_assoc_thm_rewriter, add_id_thm_rewriter), + then_rewrite(add_comm_thm_rewriter, add_id_thm_rewriter)); + context ctx; + pair result = then_rewriter2(ctx, zero_plus_a_plus_zero, Nat); + expr concl = mk_eq(zero_plus_a_plus_zero, result.first); + expr proof = result.second; + cout << "Theorem: " << add_assoc_thm_type << " := " << add_assoc_thm_body << endl; + cout << "Theorem: " << add_comm_thm_type << " := " << add_comm_thm_body << endl; + cout << "Theorem: " << add_id_thm_type << " := " << add_id_thm_body << endl; + cout << " " << concl << " := " << proof << endl; + + lean_assert(concl == mk_eq(zero_plus_a_plus_zero, a)); + lean_assert(proof == Trans(Nat, zero_plus_a_plus_zero, zero_plus_a, a, + Trans(Nat, zero_plus_a_plus_zero, zero_plus_a_plus_zero_, zero_plus_a, + mk_app(mk_app(mk_app(Const("ADD_ASSOC"), zero), a), zero), + mk_app(Const("ADD_ID"), zero_plus_a)), + Trans(Nat, zero_plus_a, a_plus_zero, a, + mk_app(mk_app(Const("ADD_COMM"), zero), a), + mk_app(Const("ADD_ID"), a)))); + + env.add_theorem("New_theorem4", concl, proof); +} + + +static void orelse_rewrite1_tst() { + cout << "=== orelse_rewrite1_tst() ===" << endl; + // Theorem1: Pi(x y z: N), x + (y + z) = (x + y) + z := ADD_ASSOC x y z + // Theorem2: Pi(x y : N), x + y = y + x := ADD_COMM x y + // Term : a + b + // Result : (b + a, ADD_COMM a b) + expr a = Const("a"); // a : Nat + expr b = Const("b"); // b : Nat + expr a_plus_b = nAdd(a, b); + expr b_plus_a = nAdd(b, a); + expr add_assoc_thm_type = Pi("x", Nat, + Pi("y", Nat, + Pi("z", Nat, + Eq(nAdd(Const("x"), nAdd(Const("y"), Const("z"))), + nAdd(nAdd(Const("x"), Const("y")), Const("z")))))); + expr add_assoc_thm_body = Const("ADD_ASSOC"); + expr add_comm_thm_type = Pi("x", Nat, + Pi("y", Nat, + Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); + expr add_comm_thm_body = Const("ADD_COMM"); + + environment env = mk_toplevel(); + env.add_var("a", Nat); + env.add_var("b", Nat); + env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z + + // Rewriting + theorem_rewrite add_assoc_thm_rewriter(add_assoc_thm_type, add_assoc_thm_body); + theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body); + orelse_rewrite add_assoc_or_comm_thm_rewriter(add_assoc_thm_rewriter, add_comm_thm_rewriter); + context ctx; + pair result = add_assoc_or_comm_thm_rewriter(ctx, a_plus_b, Nat); + expr concl = mk_eq(a_plus_b, result.first); + expr proof = result.second; + + cout << "Theorem: " << add_assoc_thm_type << " := " << add_assoc_thm_body << endl; + cout << "Theorem: " << add_comm_thm_type << " := " << add_comm_thm_body << endl; + cout << " " << concl << " := " << proof << endl; + + lean_assert(concl == mk_eq(a_plus_b, b_plus_a)); + lean_assert(proof == mk_app(mk_app(Const("ADD_COMM"), a), b)); + env.add_theorem("New_theorem5", concl, proof); +} + +int main() { + theorem_rewrite1_tst(); + theorem_rewrite2_tst(); + then_rewrite1_tst(); + then_rewrite2_tst(); + orelse_rewrite1_tst(); + return has_violations() ? 1 : 0; }