diff --git a/src/library/rewriter/fo_match.cpp b/src/library/rewriter/fo_match.cpp index 490f067f6..713c467d8 100644 --- a/src/library/rewriter/fo_match.cpp +++ b/src/library/rewriter/fo_match.cpp @@ -7,7 +7,6 @@ Author: Soonho Kong #include #include "util/trace.h" #include "kernel/expr.h" -#include "kernel/context.h" #include "library/all/all.h" #include "library/arith/nat.h" #include "library/arith/arith.h" @@ -20,52 +19,43 @@ using std::endl; namespace lean { -std::ostream & operator<<(std::ostream & out, subst_map & s) { - out << "{"; - for (auto it = s.begin(); it != s.end(); it++) { - out << it->first << " => "; - out << it->second << "; "; - } - out << "}"; - return out; -} - bool fo_match::match_var(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match_var : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + lean_trace("fo_match", tout << "match_var : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE + unsigned idx = var_idx(p); if (idx < o) { // Current variable is the one created by lambda inside of pattern // and it is *not* a target of pattern matching. return p == t; } else { - auto it = s.find(idx); + auto it = s.find(idx - o); if (it != s.end()) { // This variable already has an entry in the substitution // map. We need to make sure that 't' and s[idx] are the // same - lean_trace("fo_match", tout << "match_var exist:" << idx << " |-> " << it->second << endl;); + lean_trace("fo_match", tout << "match_var exist:" << idx - o << " |-> " << it->second << endl;); // LCOV_EXCL_LINE return it->second == t; } // This variable has no entry in the substituition map. Let's // add one. - s.insert(std::make_pair(idx, t)); - lean_trace("fo_match", tout << "match_var MATCHED : " << s << endl;); + s.insert(idx - o, t); + lean_trace("fo_match", tout << "match_var MATCHED : " << s << endl;); // LCOV_EXCL_LINE return true; } } bool fo_match::match_constant(expr const & p, expr const & t, unsigned, subst_map &) { - lean_trace("fo_match", tout << "match_constant : (" << p << ", " << t << ")" << endl;); + lean_trace("fo_match", tout << "match_constant : (" << p << ", " << t << ")" << endl;); // LCOV_EXCL_LINE return p == t; } bool fo_match::match_value(expr const & p, expr const & t, unsigned, subst_map &) { - lean_trace("fo_match", tout << "match_value : (" << p << ", " << t << ")" << endl;); + lean_trace("fo_match", tout << "match_value : (" << p << ", " << t << ")" << endl;); // LCOV_EXCL_LINE return p == t; } bool fo_match::match_app(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match_app : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + lean_trace("fo_match", tout << "match_app : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE if (!is_app(t)) return false; unsigned num_p = num_args(p); @@ -75,92 +65,92 @@ bool fo_match::match_app(expr const & p, expr const & t, unsigned o, subst_map & } for (unsigned i = 0; i < num_p; i++) { - if (!match(arg(p, i), arg(t, i), o, s)) + if (!match_main(arg(p, i), arg(t, i), o, s)) return false; } return true; } bool fo_match::match_lambda(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match_lambda : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); - lean_trace("fo_match", tout << "fun (" << abst_name(p) << " : " << abst_domain(p) << "), " << abst_body(p) << endl;); + lean_trace("fo_match", tout << "match_lambda : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE + lean_trace("fo_match", tout << "fun (" << abst_name(p) << " : " << abst_domain(p) << "), " << abst_body(p) << endl;); // LCOV_EXCL_LINE if (!is_lambda(t)) { return false; } else { // First match the domain part auto p_domain = abst_domain(p); auto t_domain = abst_domain(t); - if (!match(p_domain, t_domain, o, s)) + if (!match_main(p_domain, t_domain, o, s)) return false; // Then match the body part, increase offset by 1. auto p_body = abst_body(p); auto t_body = abst_body(t); - return match(p_domain, t_domain, o + 1, s); + return match_main(p_body, t_body, o + 1, s); } } bool fo_match::match_pi(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match_pi : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); - lean_trace("fo_match", tout << "Pi (" << abst_name(p) << " : " << abst_domain(p) << "), " << abst_body(p) << endl;); + lean_trace("fo_match", tout << "match_pi : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE + lean_trace("fo_match", tout << "Pi (" << abst_name(p) << " : " << abst_domain(p) << "), " << abst_body(p) << endl;); // LCOV_EXCL_LINE if (!is_pi(t)) { return false; } else { // First match the domain part auto p_domain = abst_domain(p); auto t_domain = abst_domain(t); - if (!match(p_domain, t_domain, o, s)) + if (!match_main(p_domain, t_domain, o, s)) return false; // Then match the body part, increase offset by 1. auto p_body = abst_body(p); auto t_body = abst_body(t); - return match(p_domain, t_domain, o + 1, s); + return match_main(p_body, t_body, o + 1, s); } } bool fo_match::match_type(expr const & p, expr const & t, unsigned, subst_map &) { - lean_trace("fo_match", tout << "match_type : (" << p << ", " << t << ")" << endl;); + lean_trace("fo_match", tout << "match_type : (" << p << ", " << t << ")" << endl;); // LCOV_EXCL_LINE return p == t; } bool fo_match::match_eq(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match_eq : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + lean_trace("fo_match", tout << "match_eq : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE if (!is_eq(t)) return false; - return match(eq_lhs(p), eq_lhs(t), o, s) && match(eq_rhs(p), eq_rhs(t), o, s); + return match_main(eq_lhs(p), eq_lhs(t), o, s) && match_main(eq_rhs(p), eq_rhs(t), o, s); } bool fo_match::match_let(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match_let : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); + lean_trace("fo_match", tout << "match_let : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE if (!is_let(t)) { return false; } else { // First match the type part auto p_type = let_type(p); auto t_type = let_type(t); - if (!match(p_type, t_type, o, s)) + if (!match_main(p_type, t_type, o, s)) return false; // then match the value part auto p_value = let_value(p); auto t_value = let_value(t); - if (!match(p_value, t_value, o, s)) + if (!match_main(p_value, t_value, o, s)) return false; // then match the value part auto p_body = let_body(p); auto t_body = let_body(t); - return match(p_body, t_body, o + 1, s); + return match_main(p_body, t_body, o + 1, s); } } bool fo_match::match_metavar(expr const & p, expr const & t, unsigned, subst_map &) { - lean_trace("fo_match", tout << "match_meta : (" << p << ", " << t << ")" << endl;); + lean_trace("fo_match", tout << "match_meta : (" << p << ", " << t << ")" << endl;); // LCOV_EXCL_LINE return p == t; } -bool fo_match::match(expr const & p, expr const & t, unsigned o, subst_map & s) { - lean_trace("fo_match", tout << "match : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); +bool fo_match::match_main(expr const & p, expr const & t, unsigned o, subst_map & s) { + lean_trace("fo_match", tout << "match : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE switch (p.kind()) { case expr_kind::Var: return match_var(p, t, o, s); @@ -183,6 +173,16 @@ bool fo_match::match(expr const & p, expr const & t, unsigned o, subst_map & s) case expr_kind::MetaVar: return match_metavar(p, t, o, s); } - lean_unreachable(); + lean_unreachable(); // LCOV_EXCL_LINE +} + +bool fo_match::match(expr const & p, expr const & t, unsigned o, subst_map & s) { + s.push(); + if (match_main(p, t, o, s)) { + return true; + } else { + s.pop(); + return false; + } } } diff --git a/src/library/rewriter/fo_match.h b/src/library/rewriter/fo_match.h index 404b61285..c1bf5c345 100644 --- a/src/library/rewriter/fo_match.h +++ b/src/library/rewriter/fo_match.h @@ -5,19 +5,13 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Soonho Kong */ #pragma once -#include +#include "library/printer.h" +#include "util/scoped_map.h" #include "kernel/expr.h" -#include "kernel/expr_maps.h" #include "kernel/context.h" -#include "util/list.h" -#include "library/all/all.h" -#include "library/expr_pair.h" -#include "library/arith/nat.h" -#include "library/arith/arith.h" namespace lean { - -using subst_map = std::unordered_map; +using subst_map = scoped_map; class fo_match { private: @@ -31,6 +25,7 @@ private: bool match_eq(expr const & p, expr const & t, unsigned o, subst_map & s); bool match_let(expr const & p, expr const & t, unsigned o, subst_map & s); bool match_metavar(expr const & p, expr const & t, unsigned o, subst_map & s); + bool match_main(expr const & p, expr const & t, unsigned o, subst_map & s); public: bool match(expr const & p, expr const & t, unsigned o, subst_map & s);