From ae0508128f6cf03e853b28536dee05c91cd0ea8e Mon Sep 17 00:00:00 2001 From: Soonho Kong Date: Sun, 1 Dec 2013 00:57:09 -0500 Subject: [PATCH] refactor(library/rewriter): move apply_rewriter_fn into rewriter.h --- src/library/rewriter/apply_rewriter_fn.h | 231 ----------------------- src/library/rewriter/rewriter.h | 222 +++++++++++++++++++++- src/tests/library/rewriter/rewriter.cpp | 1 - 3 files changed, 219 insertions(+), 235 deletions(-) delete mode 100644 src/library/rewriter/apply_rewriter_fn.h diff --git a/src/library/rewriter/apply_rewriter_fn.h b/src/library/rewriter/apply_rewriter_fn.h deleted file mode 100644 index 109f81143..000000000 --- a/src/library/rewriter/apply_rewriter_fn.h +++ /dev/null @@ -1,231 +0,0 @@ -/* -Copyright (c) 2013 Microsoft Corporation. -Copyright (c) 2013 Carnegie Mellon University. -All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. - -Author: Leonardo de Moura - Soonho Kong -*/ -#pragma once -#include -#include "kernel/abstract.h" -#include "kernel/context.h" -#include "kernel/environment.h" -#include "kernel/expr.h" -#include "kernel/replace.h" -#include "library/basic_thms.h" -#include "library/rewriter/rewriter.h" -#include "library/type_inferer.h" -#include "util/scoped_map.h" - -namespace lean { -/** - \brief Functional for applying F to the subexpressions of a given expression. - - The signature of \c F is - expr const &, context const & ctx, unsigned n -> expr - - F is invoked for each subexpression \c s of the input expression e. - In a call F(s, c, n), \c c is the context where \c s occurs, - and \c n is the size of \c c. - - P is a "post-processing" functional object that is applied to each - pair (old, new) -*/ -template -class apply_rewriter_fn { - // the return type of RW()(env, ctx, e) should be std::pair - static_assert(std::is_same())(environment const &, context &, expr const &)>::type, - std::pair>::value, - "apply_rewriter_fn: the return type of RW()(env, ctx, e) should be std::pair"); - // the return type of P()(e1, e2) should be void - static_assert(std::is_same())(expr const &, expr const &)>::type, - void>::value, - "apply_rewriter_fn: the return type of P()(e1, e2) is not void"); - - typedef scoped_map, expr_hash, expr_eqp> cache; - cache m_cache; - RW m_rw; - P m_post; - - std::pair apply(environment const & env, context & ctx, expr const & v) { - bool shared = false; - if (is_shared(v)) { - shared = true; - auto it = m_cache.find(v); - if (it != m_cache.end()) - return it->second; - } - - std::pair result; // m_rw(env, ctx, v); - // if (is_eqp(v, result.first)) - type_inferer lc(env); - expr ty_v = lc(v, ctx); - - switch (v.kind()) { - case expr_kind::Type: - result = m_rw(env, ctx, v); - break; - case expr_kind::Value: - result = m_rw(env, ctx, v); - break; - case expr_kind::Constant: - result = m_rw(env, ctx, v); - break; - case expr_kind::Var: - result = m_rw(env, ctx, v); - break; - case expr_kind::MetaVar: - result = m_rw(env, ctx, v); - break; - case expr_kind::App: { - buffer> results; - for (unsigned i = 0; i < num_args(v); i++) { - results.push_back(apply(env, ctx, arg(v, i))); - } - result = rewrite_app(env, ctx, v, results); - } - break; - case expr_kind::Eq: { - expr const & lhs = eq_lhs(v); - expr const & rhs = eq_rhs(v); - std::pair result_lhs = apply(env, ctx, lhs); - std::pair result_rhs = apply(env, ctx, rhs); - expr const & new_lhs = result_lhs.first; - expr const & new_rhs = result_rhs.first; - if (lhs != new_lhs) { - if (rhs != new_rhs) { - // lhs & rhs changed - result = rewrite_eq(env, ctx, v, result_lhs, result_rhs); - } else { - // only lhs changed - result = rewrite_eq_lhs(env, ctx, v, result_lhs); - } - } else { - if (rhs != new_rhs) { - // only rhs changed - result = rewrite_eq_rhs(env, ctx, v, result_rhs); - } else { - // nothing changed - result = std::make_pair(v, Refl(lc(v, ctx), v)); - } - } - } - break; - case expr_kind::Lambda: { - name const & n = abst_name(v); - expr const & ty = abst_domain(v); - expr const & body = abst_body(v); - context new_ctx = extend(ctx, n, ty); - std::pair result_ty = apply(env, ctx, ty); - std::pair result_body = apply(env, new_ctx, body); - if (ty != result_ty.first) { - if (body != result_body.first) { - // ty and body changed - result = rewrite_lambda(env, ctx, v, result_ty, result_body); - } else { - // ty changed - result = rewrite_lambda_type(env, ctx, v, result_ty); - } - } else { - if (body != result_body.first) { - // body changed - result = rewrite_lambda_body(env, ctx, v, result_body); - } else { - // nothing changed - result = std::make_pair(v, Refl(lc(v, ctx), v)); - } - } - } - break; - - case expr_kind::Pi: { - name const & n = abst_name(v); - expr const & ty = abst_domain(v); - expr const & body = abst_body(v); - context new_ctx = extend(ctx, n, ty); - std::pair result_ty = apply(env, ctx, ty); - std::pair result_body = apply(env, new_ctx, body); - if (ty != result_ty.first) { - if (body != result_body.first) { - // ty and body changed - result = rewrite_pi(env, ctx, v, result_ty, result_body); - } else { - // ty changed - result = rewrite_pi_type(env, ctx, v, result_ty); - } - } else { - if (body != result_body.first) { - // body changed - result = rewrite_pi_body(env, ctx, v, result_body); - } else { - // nothing changed - result = std::make_pair(v, Refl(lc(v, ctx), v)); - } - } - } - break; - case expr_kind::Let: { - name const & n = let_name(v); - expr const & ty = let_type(v); - expr const & val = let_value(v); - expr const & body = let_body(v); - - expr new_v = v; - expr ty_v = lc(v, ctx); - expr pf = Refl(ty_v, v); - bool changed = false; - - std::pair result_ty = apply(env, ctx, ty); - if (ty != result_ty.first) { - // ty changed - result = rewrite_let_type(env, ctx, new_v, result_ty); - new_v = result.first; - pf = result.second; - changed = true; - } - - std::pair result_val = apply(env, ctx, val); - if (val != result_val.first) { - result = rewrite_let_value(env, ctx, new_v, result_val); - if (changed) { - pf = Trans(ty_v, v, new_v, result.first, pf, result.second); - } else { - pf = result.second; - } - new_v = result.first; - changed = true; - } - - context new_ctx = extend(ctx, n, ty); - std::pair result_body = apply(env, new_ctx, body); - if (body != result_body.first) { - result = rewrite_let_body(env, ctx, new_v, result_body); - if (changed) { - pf = Trans(ty_v, v, new_v, result.first, pf, result.second); - } else { - pf = result.second; - } - new_v = result.first; - changed = true; - } - result = std::make_pair(new_v, pf); - } - break; - } - if (shared) - m_cache.insert(std::make_pair(v, result)); - return result; - } - -public: - apply_rewriter_fn(RW const & rw, P const & p = P()): - m_rw(rw), - m_post(p) { - } - std::pair operator()(environment const & env, context & ctx, expr const & v) { - return apply(env, ctx, v); - } -}; -} diff --git a/src/library/rewriter/rewriter.h b/src/library/rewriter/rewriter.h index f6d56caab..a79a9cd94 100644 --- a/src/library/rewriter/rewriter.h +++ b/src/library/rewriter/rewriter.h @@ -7,11 +7,18 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Soonho Kong */ #pragma once -#include #include -#include "util/exception.h" +#include +#include "kernel/abstract.h" +#include "kernel/context.h" #include "kernel/environment.h" - +#include "kernel/expr.h" +#include "kernel/replace.h" +#include "library/basic_thms.h" +#include "library/rewriter/rewriter.h" +#include "library/type_inferer.h" +#include "util/exception.h" +#include "util/scoped_map.h" // TODO(soonhok) // FORALL // FAIL_IF @@ -308,4 +315,213 @@ rewriter mk_success_rewriter(); rewriter mk_repeat_rewriter(rewriter const & rw); rewriter mk_depth_rewriter(rewriter const & rw); +/** + \brief Functional for applying F to the subexpressions of a given expression. + + The signature of \c F is + expr const &, context const & ctx, unsigned n -> expr + + F is invoked for each subexpression \c s of the input expression e. + In a call F(s, c, n), \c c is the context where \c s occurs, + and \c n is the size of \c c. + + P is a "post-processing" functional object that is applied to each + pair (old, new) +*/ +template +class apply_rewriter_fn { + // the return type of RW()(env, ctx, e) should be std::pair + static_assert(std::is_same())(environment const &, context &, expr const &)>::type, + std::pair>::value, + "apply_rewriter_fn: the return type of RW()(env, ctx, e) should be std::pair"); + // the return type of P()(e1, e2) should be void + static_assert(std::is_same())(expr const &, expr const &)>::type, + void>::value, + "apply_rewriter_fn: the return type of P()(e1, e2) is not void"); + + typedef scoped_map, expr_hash, expr_eqp> cache; + cache m_cache; + RW m_rw; + P m_post; + + std::pair apply(environment const & env, context & ctx, expr const & v) { + bool shared = false; + if (is_shared(v)) { + shared = true; + auto it = m_cache.find(v); + if (it != m_cache.end()) + return it->second; + } + + std::pair result; // m_rw(env, ctx, v); + // if (is_eqp(v, result.first)) + type_inferer lc(env); + expr ty_v = lc(v, ctx); + + switch (v.kind()) { + case expr_kind::Type: + result = m_rw(env, ctx, v); + break; + case expr_kind::Value: + result = m_rw(env, ctx, v); + break; + case expr_kind::Constant: + result = m_rw(env, ctx, v); + break; + case expr_kind::Var: + result = m_rw(env, ctx, v); + break; + case expr_kind::MetaVar: + result = m_rw(env, ctx, v); + break; + case expr_kind::App: { + buffer> results; + for (unsigned i = 0; i < num_args(v); i++) { + results.push_back(apply(env, ctx, arg(v, i))); + } + result = rewrite_app(env, ctx, v, results); + } + break; + case expr_kind::Eq: { + expr const & lhs = eq_lhs(v); + expr const & rhs = eq_rhs(v); + std::pair result_lhs = apply(env, ctx, lhs); + std::pair result_rhs = apply(env, ctx, rhs); + expr const & new_lhs = result_lhs.first; + expr const & new_rhs = result_rhs.first; + if (lhs != new_lhs) { + if (rhs != new_rhs) { + // lhs & rhs changed + result = rewrite_eq(env, ctx, v, result_lhs, result_rhs); + } else { + // only lhs changed + result = rewrite_eq_lhs(env, ctx, v, result_lhs); + } + } else { + if (rhs != new_rhs) { + // only rhs changed + result = rewrite_eq_rhs(env, ctx, v, result_rhs); + } else { + // nothing changed + result = std::make_pair(v, Refl(lc(v, ctx), v)); + } + } + } + break; + case expr_kind::Lambda: { + name const & n = abst_name(v); + expr const & ty = abst_domain(v); + expr const & body = abst_body(v); + context new_ctx = extend(ctx, n, ty); + std::pair result_ty = apply(env, ctx, ty); + std::pair result_body = apply(env, new_ctx, body); + if (ty != result_ty.first) { + if (body != result_body.first) { + // ty and body changed + result = rewrite_lambda(env, ctx, v, result_ty, result_body); + } else { + // ty changed + result = rewrite_lambda_type(env, ctx, v, result_ty); + } + } else { + if (body != result_body.first) { + // body changed + result = rewrite_lambda_body(env, ctx, v, result_body); + } else { + // nothing changed + result = std::make_pair(v, Refl(lc(v, ctx), v)); + } + } + } + break; + + case expr_kind::Pi: { + name const & n = abst_name(v); + expr const & ty = abst_domain(v); + expr const & body = abst_body(v); + context new_ctx = extend(ctx, n, ty); + std::pair result_ty = apply(env, ctx, ty); + std::pair result_body = apply(env, new_ctx, body); + if (ty != result_ty.first) { + if (body != result_body.first) { + // ty and body changed + result = rewrite_pi(env, ctx, v, result_ty, result_body); + } else { + // ty changed + result = rewrite_pi_type(env, ctx, v, result_ty); + } + } else { + if (body != result_body.first) { + // body changed + result = rewrite_pi_body(env, ctx, v, result_body); + } else { + // nothing changed + result = std::make_pair(v, Refl(lc(v, ctx), v)); + } + } + } + break; + case expr_kind::Let: { + name const & n = let_name(v); + expr const & ty = let_type(v); + expr const & val = let_value(v); + expr const & body = let_body(v); + + expr new_v = v; + expr ty_v = lc(v, ctx); + expr pf = Refl(ty_v, v); + bool changed = false; + + std::pair result_ty = apply(env, ctx, ty); + if (ty != result_ty.first) { + // ty changed + result = rewrite_let_type(env, ctx, new_v, result_ty); + new_v = result.first; + pf = result.second; + changed = true; + } + + std::pair result_val = apply(env, ctx, val); + if (val != result_val.first) { + result = rewrite_let_value(env, ctx, new_v, result_val); + if (changed) { + pf = Trans(ty_v, v, new_v, result.first, pf, result.second); + } else { + pf = result.second; + } + new_v = result.first; + changed = true; + } + + context new_ctx = extend(ctx, n, ty); + std::pair result_body = apply(env, new_ctx, body); + if (body != result_body.first) { + result = rewrite_let_body(env, ctx, new_v, result_body); + if (changed) { + pf = Trans(ty_v, v, new_v, result.first, pf, result.second); + } else { + pf = result.second; + } + new_v = result.first; + changed = true; + } + result = std::make_pair(new_v, pf); + } + break; + } + if (shared) + m_cache.insert(std::make_pair(v, result)); + return result; + } + +public: + apply_rewriter_fn(RW const & rw, P const & p = P()): + m_rw(rw), + m_post(p) { + } + std::pair operator()(environment const & env, context & ctx, expr const & v) { + return apply(env, ctx, v); + } +}; + } diff --git a/src/tests/library/rewriter/rewriter.cpp b/src/tests/library/rewriter/rewriter.cpp index 394edd646..93f03e4fe 100644 --- a/src/tests/library/rewriter/rewriter.cpp +++ b/src/tests/library/rewriter/rewriter.cpp @@ -15,7 +15,6 @@ Author: Soonho Kong #include "library/arith/nat.h" #include "library/rewriter/fo_match.h" #include "library/rewriter/rewriter.h" -#include "library/rewriter/apply_rewriter_fn.h" #include "library/basic_thms.h" using namespace lean;