diff --git a/src/library/rewrite/rewrite.cpp b/src/library/rewrite/rewrite.cpp index 16db39870..f97464938 100644 --- a/src/library/rewrite/rewrite.cpp +++ b/src/library/rewrite/rewrite.cpp @@ -14,19 +14,7 @@ #include "library/printer.h" #include "library/rewrite/fo_match.h" #include "library/rewrite/rewrite.h" - -// Term Rewriting -// ORELSE -// APP_RW -// LAMBDA_RW -// PI_RW -// LET_RW -// DEPTH_RW -// THEOREM2RW -// TRIVIAL_RW -// FORALL -// FAIL -// FAIL_IF +#include "library/light_checker.h" using std::cout; using std::endl; @@ -57,7 +45,7 @@ theorem_rewrite::theorem_rewrite(expr const & type, expr const & body) lean_trace("rewrite", tout << "Number of Arg = " << num_args << endl;); } -pair theorem_rewrite::operator()(context & ctx, expr const & v, expr const & ) const throw(rewrite_exception) { +pair theorem_rewrite::operator()(context & ctx, expr const & v, environment const & ) const throw(rewrite_exception) { lean_trace("rewrite", tout << "Context = " << ctx << endl;); lean_trace("rewrite", tout << "Term = " << v << endl;); lean_trace("rewrite", tout << "Pattern = " << pattern << endl;); @@ -97,18 +85,75 @@ pair theorem_rewrite::operator()(context & ctx, expr const & v, expr return make_pair(new_rhs, proof); } -pair orelse_rewrite::operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception) { +pair orelse_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) { try { - return rewrite1(ctx, v, t); + return rw1(ctx, v, env); } catch (rewrite_exception & ) { - return rewrite2(ctx, v, t); + return rw2(ctx, v, env); } } -pair then_rewrite::operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception) { - pair result1 = rewrite1(ctx, v, t); - pair result2 = rewrite2(ctx, result1.first, t); +pair then_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) { + pair result1 = rw1(ctx, v, env); + pair result2 = rw2(ctx, result1.first, env); + expr const & t = light_checker(env)(v, ctx); return make_pair(result2.first, Trans(t, v, result1.first, result2.first, result1.second, result2.second)); } + +pair app_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) { + if (!is_app(v)) + throw rewrite_exception(); + + unsigned n = num_args(v); + for (unsigned i = 0; i < n; i++) { + auto result = rw(ctx, arg(v, i), env); + } + + // TODO(soonhok) + throw rewrite_exception(); +} + +pair lambda_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) { + if (!is_lambda(v)) + throw rewrite_exception(); + expr const & domain = abst_domain(v); + expr const & body = abst_body(v); + + auto result_domain = rw(ctx, domain, env); + auto result_body = rw(ctx, body, env); // TODO(soonhok): add to context! + + // TODO(soonhok) + throw rewrite_exception(); +} + +pair pi_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) { + if (!is_pi(v)) + throw rewrite_exception(); + + expr const & domain = abst_domain(v); + expr const & body = abst_body(v); + + auto result_domain = rw(ctx, domain, env); + auto result_body = rw(ctx, body, env); // TODO(soonhok): add to context! + + // TODO(soonhok) + throw rewrite_exception(); +} + +pair let_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) { + if (!is_let(v)) + throw rewrite_exception(); + + expr const & ty = let_type(v); + expr const & value = let_value(v); + expr const & body = let_body(v); + + auto result_ty = rw(ctx, ty, env); + auto result_value = rw(ctx, value, env); + auto result_body = rw(ctx, body, env); // TODO(soonhok): add to context! + + // TODO(soonhok) + throw rewrite_exception(); +} } diff --git a/src/library/rewrite/rewrite.h b/src/library/rewrite/rewrite.h index 8f35ca8ac..cb04c5d5d 100644 --- a/src/library/rewrite/rewrite.h +++ b/src/library/rewrite/rewrite.h @@ -7,6 +7,18 @@ Author: Soonho Kong #pragma once #include #include "util/exception.h" +#include "kernel/environment.h" + +// Term Rewriting +// APP_RW +// LAMBDA_RW +// PI_RW +// LET_RW +// DEPTH_RW +// TRIVIAL_RW +// FORALL +// FAIL +// FAIL_IF namespace lean { @@ -15,7 +27,7 @@ class rewrite_exception : public exception { class rewrite { public: - virtual std::pair operator()(context & ctx, expr const & v, expr const & t) const = 0; + virtual std::pair operator()(context & ctx, expr const & v, environment const & env) const = 0; }; class theorem_rewrite : public rewrite { @@ -28,29 +40,66 @@ private: public: theorem_rewrite(expr const & type, expr const & body); - std::pair operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception); + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); }; class orelse_rewrite : public rewrite { private: - rewrite const & rewrite1; - rewrite const & rewrite2; + rewrite const & rw1; + rewrite const & rw2; public: - orelse_rewrite(rewrite const & rw1, rewrite const & rw2) : - rewrite1(rw1), rewrite2(rw2) { } - std::pair operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception); + orelse_rewrite(rewrite const & rw_1, rewrite const & rw_2) : + rw1(rw_1), rw2(rw_2) { } + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); }; class then_rewrite : public rewrite { private: - rewrite const & rewrite1; - rewrite const & rewrite2; + rewrite const & rw1; + rewrite const & rw2; public: - then_rewrite(rewrite const & rw1, rewrite const & rw2) : - rewrite1(rw1), rewrite2(rw2) { } - std::pair operator()(context & ctx, expr const & v, expr const & t) const throw(rewrite_exception); + then_rewrite(rewrite const & rw_1, rewrite const & rw_2) : + rw1(rw_1), rw2(rw_2) { } + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); }; +class app_rewrite : public rewrite { +private: + rewrite const & rw; +public: + app_rewrite(rewrite const & rw_) : + rw(rw_) { } + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); +}; + +class lambda_rewrite : public rewrite { +private: + rewrite const & rw; +public: + lambda_rewrite(rewrite const & rw_) : + rw(rw_) { } + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); +}; + +class pi_rewrite : public rewrite { +private: + rewrite const & rw; +public: + pi_rewrite(rewrite const & rw_) : + rw(rw_) { } + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); +}; + +class let_rewrite : public rewrite { +private: + rewrite const & rw; +public: + let_rewrite(rewrite const & rw_) : + rw(rw_) { } + std::pair operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception); +}; + + class fail_rewrite : public rewrite { public: std::pair operator()(context &, expr const &) const throw(rewrite_exception) { diff --git a/src/tests/library/rewrite/rewrite.cpp b/src/tests/library/rewrite/rewrite.cpp index 640e7f60a..53b1d0941 100644 --- a/src/tests/library/rewrite/rewrite.cpp +++ b/src/tests/library/rewrite/rewrite.cpp @@ -8,7 +8,6 @@ Author: Soonho Kong #include "kernel/abstract.h" #include "kernel/context.h" #include "kernel/expr.h" -#include "kernel/type_checker.h" #include "library/all/all.h" #include "library/arith/arith.h" #include "library/arith/nat.h" @@ -44,7 +43,7 @@ static void theorem_rewrite1_tst() { // Rewriting theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body); context ctx; - pair result = add_comm_thm_rewriter(ctx, a_plus_b, Nat); + pair result = add_comm_thm_rewriter(ctx, a_plus_b, env); expr concl = mk_eq(a_plus_b, result.first); expr proof = result.second; @@ -75,7 +74,7 @@ static void theorem_rewrite2_tst() { // Rewriting theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body); context ctx; - pair result = add_id_thm_rewriter(ctx, a_plus_zero, Nat); + pair result = add_id_thm_rewriter(ctx, a_plus_zero, env); expr concl = mk_eq(a_plus_zero, result.first); expr proof = result.second; @@ -116,7 +115,7 @@ static void then_rewrite1_tst() { theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body); then_rewrite then_rewriter1(add_comm_thm_rewriter, add_id_thm_rewriter); context ctx; - pair result = then_rewriter1(ctx, zero_plus_a, Nat); + pair result = then_rewriter1(ctx, zero_plus_a, env); expr concl = mk_eq(zero_plus_a, result.first); expr proof = result.second; @@ -175,7 +174,7 @@ static void then_rewrite2_tst() { then_rewrite then_rewriter2(then_rewrite(add_assoc_thm_rewriter, add_id_thm_rewriter), then_rewrite(add_comm_thm_rewriter, add_id_thm_rewriter)); context ctx; - pair result = then_rewriter2(ctx, zero_plus_a_plus_zero, Nat); + pair result = then_rewriter2(ctx, zero_plus_a_plus_zero, env); expr concl = mk_eq(zero_plus_a_plus_zero, result.first); expr proof = result.second; cout << "Theorem: " << add_assoc_thm_type << " := " << add_assoc_thm_body << endl; @@ -227,7 +226,7 @@ static void orelse_rewrite1_tst() { theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body); orelse_rewrite add_assoc_or_comm_thm_rewriter(add_assoc_thm_rewriter, add_comm_thm_rewriter); context ctx; - pair result = add_assoc_or_comm_thm_rewriter(ctx, a_plus_b, Nat); + pair result = add_assoc_or_comm_thm_rewriter(ctx, a_plus_b, env); expr concl = mk_eq(a_plus_b, result.first); expr proof = result.second;