diff --git a/src/library/match.cpp b/src/library/match.cpp index 9f623237f..1136e1292 100644 --- a/src/library/match.cpp +++ b/src/library/match.cpp @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include "kernel/abstract.h" #include "kernel/instantiate.h" #include "kernel/for_each_fn.h" @@ -32,22 +33,57 @@ unsigned to_meta_idx(level const & l) { } class match_fn : public match_context { - buffer> & m_esubst; - buffer> & m_lsubst; - name_generator m_ngen; - name_map * m_name_subst; - match_plugin const * m_plugin; + buffer> & m_esubst; + buffer> & m_lsubst; + name_generator m_ngen; + name_map * m_name_subst; + match_plugin const * m_plugin; + buffer> m_stack; + buffer m_scopes; + + void push() { + m_scopes.push_back(m_stack.size()); + } + + void pop() { + lean_assert(!m_scopes.empty()); + unsigned old_sz = m_scopes.back(); + while (m_stack.size() > old_sz) { + auto p = m_stack.back(); + if (p.first) + m_esubst[p.second] = none_expr(); + else + m_lsubst[p.second] = none_level(); + m_stack.pop_back(); + } + m_scopes.pop_back(); + } + + void keep() { + m_scopes.back() = m_stack.size(); + } + + struct scope { + match_fn & m; + scope(match_fn & _m):m(_m) { m.push(); } + ~scope() { m.pop(); } + void keep() { m.keep(); } + }; void _assign(expr const & p, expr const & t) { lean_assert(var_idx(p) < m_esubst.size()); unsigned vidx = var_idx(p); unsigned sz = m_esubst.size(); - m_esubst[sz - vidx - 1] = t; + unsigned i = sz - vidx - 1; + m_stack.emplace_back(true, i); + m_esubst[i] = t; } void _assign(level const & p, level const & l) { lean_assert(to_meta_idx(p) < m_lsubst.size()); - m_lsubst[to_meta_idx(p)] = l; + unsigned i = to_meta_idx(p); + m_stack.emplace_back(false, i); + m_lsubst[i] = l; } void throw_exception() const { @@ -97,7 +133,7 @@ class match_fn : public match_context { return (*m_plugin)(p, t, *this); } - bool match_binding(expr p, expr t) { + bool match_binding_core(expr p, expr t) { lean_assert(is_binding(p) && is_binding(t)); buffer ls; expr_kind k = p.kind(); @@ -120,7 +156,18 @@ class match_fn : public match_context { return _match(p, t); } - bool match_macro(expr const & p, expr const & t) { + bool match_binding(expr const & p, expr const & t) { + { + scope s(*this); + if (match_binding_core(p, t)) { + s.keep(); + return true; + } + } + return try_plugin(p, t); + } + + bool match_macro_core(expr const & p, expr const & t) { if (macro_def(p) == macro_def(t) && macro_num_args(p) == macro_num_args(t)) { for (unsigned i = 0; i < macro_num_args(p); i++) { if (!_match(macro_arg(p, i), macro_arg(t, i))) @@ -131,10 +178,32 @@ class match_fn : public match_context { return false; } - bool match_app(expr const & p, expr const & t) { + bool match_macro(expr const & p, expr const & t) { + { + scope s(*this); + if (match_macro_core(p, t)) { + s.keep(); + return true; + } + } + return try_plugin(p, t); + } + + bool match_app_core(expr const & p, expr const & t) { return match_core(app_fn(p), app_fn(t)) && _match(app_arg(p), app_arg(t)); } + bool match_app(expr const & p, expr const & t) { + { + scope s(*this); + if (match_app_core(p, t)) { + s.keep(); + return true; + } + } + return try_plugin(p, t); + } + bool match_level_core(level const & p, level const & l) { if (p == l) return true; @@ -146,17 +215,29 @@ class match_fn : public match_context { return false; case level_kind::Succ: return match_level(succ_of(p), succ_of(l)); - case level_kind::Max: - return - match_level(max_lhs(p), max_lhs(l)) && - match_level(max_rhs(p), max_rhs(l)); - case level_kind::IMax: - return - match_level(imax_lhs(p), imax_lhs(l)) && - match_level(imax_rhs(p), imax_rhs(l)); + case level_kind::Max: { + scope s(*this); + if (match_level(max_lhs(p), max_lhs(l)) && match_level(max_rhs(p), max_rhs(l))) { + s.keep(); + return true; + } + break; } + case level_kind::IMax: { + scope s(*this); + if (match_level(imax_lhs(p), imax_lhs(l)) && match_level(imax_rhs(p), imax_rhs(l))) { + s.keep(); + return true; + } + break; + }} } - return false; + level p1 = normalize(p); + level l1 = normalize(l); + if (p1 != p || l1 != l) + return match_level(p1, l1); + else + return false; } bool match_level(level const & p, level const & l) { @@ -182,6 +263,13 @@ class match_fn : public match_context { return true; } + bool match_constant(expr const & p, expr const & t) { + if (const_name(p) == const_name(t)) + return match_levels(const_levels(p), const_levels(t)); + else + return try_plugin(p, t); + } + bool match_core(expr const & p, expr const & t) { if (p.kind() != t.kind()) return try_plugin(p, t); @@ -191,18 +279,15 @@ class match_fn : public match_context { case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: - if (const_name(p) == const_name(t)) - return match_levels(const_levels(p), const_levels(t)); - else - return try_plugin(p, t); + return match_constant(p, t); case expr_kind::Sort: return match_level(sort_level(p), sort_level(t)); case expr_kind::Lambda: case expr_kind::Pi: - return match_binding(p, t) || try_plugin(p, t); + return match_binding(p, t); case expr_kind::Macro: - return match_macro(p, t) || try_plugin(p, t); + return match_macro(p, t); case expr_kind::App: - return match_app(p, t) || try_plugin(p, t); + return match_app(p, t); } lean_unreachable(); // LCOV_EXCL_LINE } diff --git a/tests/lean/run/match2.lean b/tests/lean/run/match2.lean new file mode 100644 index 000000000..de86bb114 --- /dev/null +++ b/tests/lean/run/match2.lean @@ -0,0 +1,34 @@ +import data.nat +using nat + +definition two1 : nat := 2 +definition two2 : nat := succ (succ (zero)) +definition f (x : nat) (y : nat) := y +variable g : nat → nat → nat +variables a b : nat + +(* +local tc = type_checker_with_hints(get_env(), true) +local plugin = whnf_match_plugin(tc) +function tst_match(p, t) + local r1, r2 = match(p, t, plugin) + assert(r1) + print("--------------") + for i = 1, #r1 do + print(" expr:#" .. i .. " := " .. tostring(r1[i])) + end + for i = 1, #r2 do + print(" lvl:#" .. i .. " := " .. tostring(r2[i])) + end +end + +local f = Const("f") +local g = Const("g") +local a = Const("a") +local b = Const("b") +local x = mk_var(0) +local p = g(x, f(x, a)) +local t = g(a, f(b, a)) +tst_match(p, t) +tst_match(f(x, x), f(a, b)) +*)