feat(library/match): add 'local' backtracking
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
e6ffda0c51
commit
d1924097d5
2 changed files with 145 additions and 26 deletions
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
||||||
|
|
||||||
Author: Leonardo de Moura
|
Author: Leonardo de Moura
|
||||||
*/
|
*/
|
||||||
|
#include <utility>
|
||||||
#include "kernel/abstract.h"
|
#include "kernel/abstract.h"
|
||||||
#include "kernel/instantiate.h"
|
#include "kernel/instantiate.h"
|
||||||
#include "kernel/for_each_fn.h"
|
#include "kernel/for_each_fn.h"
|
||||||
|
@ -32,22 +33,57 @@ unsigned to_meta_idx(level const & l) {
|
||||||
}
|
}
|
||||||
|
|
||||||
class match_fn : public match_context {
|
class match_fn : public match_context {
|
||||||
buffer<optional<expr>> & m_esubst;
|
buffer<optional<expr>> & m_esubst;
|
||||||
buffer<optional<level>> & m_lsubst;
|
buffer<optional<level>> & m_lsubst;
|
||||||
name_generator m_ngen;
|
name_generator m_ngen;
|
||||||
name_map<name> * m_name_subst;
|
name_map<name> * m_name_subst;
|
||||||
match_plugin const * m_plugin;
|
match_plugin const * m_plugin;
|
||||||
|
buffer<std::pair<bool, unsigned>> m_stack;
|
||||||
|
buffer<unsigned> m_scopes;
|
||||||
|
|
||||||
|
void push() {
|
||||||
|
m_scopes.push_back(m_stack.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void pop() {
|
||||||
|
lean_assert(!m_scopes.empty());
|
||||||
|
unsigned old_sz = m_scopes.back();
|
||||||
|
while (m_stack.size() > old_sz) {
|
||||||
|
auto p = m_stack.back();
|
||||||
|
if (p.first)
|
||||||
|
m_esubst[p.second] = none_expr();
|
||||||
|
else
|
||||||
|
m_lsubst[p.second] = none_level();
|
||||||
|
m_stack.pop_back();
|
||||||
|
}
|
||||||
|
m_scopes.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void keep() {
|
||||||
|
m_scopes.back() = m_stack.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct scope {
|
||||||
|
match_fn & m;
|
||||||
|
scope(match_fn & _m):m(_m) { m.push(); }
|
||||||
|
~scope() { m.pop(); }
|
||||||
|
void keep() { m.keep(); }
|
||||||
|
};
|
||||||
|
|
||||||
void _assign(expr const & p, expr const & t) {
|
void _assign(expr const & p, expr const & t) {
|
||||||
lean_assert(var_idx(p) < m_esubst.size());
|
lean_assert(var_idx(p) < m_esubst.size());
|
||||||
unsigned vidx = var_idx(p);
|
unsigned vidx = var_idx(p);
|
||||||
unsigned sz = m_esubst.size();
|
unsigned sz = m_esubst.size();
|
||||||
m_esubst[sz - vidx - 1] = t;
|
unsigned i = sz - vidx - 1;
|
||||||
|
m_stack.emplace_back(true, i);
|
||||||
|
m_esubst[i] = t;
|
||||||
}
|
}
|
||||||
|
|
||||||
void _assign(level const & p, level const & l) {
|
void _assign(level const & p, level const & l) {
|
||||||
lean_assert(to_meta_idx(p) < m_lsubst.size());
|
lean_assert(to_meta_idx(p) < m_lsubst.size());
|
||||||
m_lsubst[to_meta_idx(p)] = l;
|
unsigned i = to_meta_idx(p);
|
||||||
|
m_stack.emplace_back(false, i);
|
||||||
|
m_lsubst[i] = l;
|
||||||
}
|
}
|
||||||
|
|
||||||
void throw_exception() const {
|
void throw_exception() const {
|
||||||
|
@ -97,7 +133,7 @@ class match_fn : public match_context {
|
||||||
return (*m_plugin)(p, t, *this);
|
return (*m_plugin)(p, t, *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match_binding(expr p, expr t) {
|
bool match_binding_core(expr p, expr t) {
|
||||||
lean_assert(is_binding(p) && is_binding(t));
|
lean_assert(is_binding(p) && is_binding(t));
|
||||||
buffer<expr> ls;
|
buffer<expr> ls;
|
||||||
expr_kind k = p.kind();
|
expr_kind k = p.kind();
|
||||||
|
@ -120,7 +156,18 @@ class match_fn : public match_context {
|
||||||
return _match(p, t);
|
return _match(p, t);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match_macro(expr const & p, expr const & t) {
|
bool match_binding(expr const & p, expr const & t) {
|
||||||
|
{
|
||||||
|
scope s(*this);
|
||||||
|
if (match_binding_core(p, t)) {
|
||||||
|
s.keep();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return try_plugin(p, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_macro_core(expr const & p, expr const & t) {
|
||||||
if (macro_def(p) == macro_def(t) && macro_num_args(p) == macro_num_args(t)) {
|
if (macro_def(p) == macro_def(t) && macro_num_args(p) == macro_num_args(t)) {
|
||||||
for (unsigned i = 0; i < macro_num_args(p); i++) {
|
for (unsigned i = 0; i < macro_num_args(p); i++) {
|
||||||
if (!_match(macro_arg(p, i), macro_arg(t, i)))
|
if (!_match(macro_arg(p, i), macro_arg(t, i)))
|
||||||
|
@ -131,10 +178,32 @@ class match_fn : public match_context {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match_app(expr const & p, expr const & t) {
|
bool match_macro(expr const & p, expr const & t) {
|
||||||
|
{
|
||||||
|
scope s(*this);
|
||||||
|
if (match_macro_core(p, t)) {
|
||||||
|
s.keep();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return try_plugin(p, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_app_core(expr const & p, expr const & t) {
|
||||||
return match_core(app_fn(p), app_fn(t)) && _match(app_arg(p), app_arg(t));
|
return match_core(app_fn(p), app_fn(t)) && _match(app_arg(p), app_arg(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool match_app(expr const & p, expr const & t) {
|
||||||
|
{
|
||||||
|
scope s(*this);
|
||||||
|
if (match_app_core(p, t)) {
|
||||||
|
s.keep();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return try_plugin(p, t);
|
||||||
|
}
|
||||||
|
|
||||||
bool match_level_core(level const & p, level const & l) {
|
bool match_level_core(level const & p, level const & l) {
|
||||||
if (p == l)
|
if (p == l)
|
||||||
return true;
|
return true;
|
||||||
|
@ -146,17 +215,29 @@ class match_fn : public match_context {
|
||||||
return false;
|
return false;
|
||||||
case level_kind::Succ:
|
case level_kind::Succ:
|
||||||
return match_level(succ_of(p), succ_of(l));
|
return match_level(succ_of(p), succ_of(l));
|
||||||
case level_kind::Max:
|
case level_kind::Max: {
|
||||||
return
|
scope s(*this);
|
||||||
match_level(max_lhs(p), max_lhs(l)) &&
|
if (match_level(max_lhs(p), max_lhs(l)) && match_level(max_rhs(p), max_rhs(l))) {
|
||||||
match_level(max_rhs(p), max_rhs(l));
|
s.keep();
|
||||||
case level_kind::IMax:
|
return true;
|
||||||
return
|
}
|
||||||
match_level(imax_lhs(p), imax_lhs(l)) &&
|
break;
|
||||||
match_level(imax_rhs(p), imax_rhs(l));
|
|
||||||
}
|
}
|
||||||
|
case level_kind::IMax: {
|
||||||
|
scope s(*this);
|
||||||
|
if (match_level(imax_lhs(p), imax_lhs(l)) && match_level(imax_rhs(p), imax_rhs(l))) {
|
||||||
|
s.keep();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}}
|
||||||
}
|
}
|
||||||
return false;
|
level p1 = normalize(p);
|
||||||
|
level l1 = normalize(l);
|
||||||
|
if (p1 != p || l1 != l)
|
||||||
|
return match_level(p1, l1);
|
||||||
|
else
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match_level(level const & p, level const & l) {
|
bool match_level(level const & p, level const & l) {
|
||||||
|
@ -182,6 +263,13 @@ class match_fn : public match_context {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool match_constant(expr const & p, expr const & t) {
|
||||||
|
if (const_name(p) == const_name(t))
|
||||||
|
return match_levels(const_levels(p), const_levels(t));
|
||||||
|
else
|
||||||
|
return try_plugin(p, t);
|
||||||
|
}
|
||||||
|
|
||||||
bool match_core(expr const & p, expr const & t) {
|
bool match_core(expr const & p, expr const & t) {
|
||||||
if (p.kind() != t.kind())
|
if (p.kind() != t.kind())
|
||||||
return try_plugin(p, t);
|
return try_plugin(p, t);
|
||||||
|
@ -191,18 +279,15 @@ class match_fn : public match_context {
|
||||||
case expr_kind::Var:
|
case expr_kind::Var:
|
||||||
lean_unreachable(); // LCOV_EXCL_LINE
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
case expr_kind::Constant:
|
case expr_kind::Constant:
|
||||||
if (const_name(p) == const_name(t))
|
return match_constant(p, t);
|
||||||
return match_levels(const_levels(p), const_levels(t));
|
|
||||||
else
|
|
||||||
return try_plugin(p, t);
|
|
||||||
case expr_kind::Sort:
|
case expr_kind::Sort:
|
||||||
return match_level(sort_level(p), sort_level(t));
|
return match_level(sort_level(p), sort_level(t));
|
||||||
case expr_kind::Lambda: case expr_kind::Pi:
|
case expr_kind::Lambda: case expr_kind::Pi:
|
||||||
return match_binding(p, t) || try_plugin(p, t);
|
return match_binding(p, t);
|
||||||
case expr_kind::Macro:
|
case expr_kind::Macro:
|
||||||
return match_macro(p, t) || try_plugin(p, t);
|
return match_macro(p, t);
|
||||||
case expr_kind::App:
|
case expr_kind::App:
|
||||||
return match_app(p, t) || try_plugin(p, t);
|
return match_app(p, t);
|
||||||
}
|
}
|
||||||
lean_unreachable(); // LCOV_EXCL_LINE
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
}
|
}
|
||||||
|
|
34
tests/lean/run/match2.lean
Normal file
34
tests/lean/run/match2.lean
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import data.nat
|
||||||
|
using nat
|
||||||
|
|
||||||
|
definition two1 : nat := 2
|
||||||
|
definition two2 : nat := succ (succ (zero))
|
||||||
|
definition f (x : nat) (y : nat) := y
|
||||||
|
variable g : nat → nat → nat
|
||||||
|
variables a b : nat
|
||||||
|
|
||||||
|
(*
|
||||||
|
local tc = type_checker_with_hints(get_env(), true)
|
||||||
|
local plugin = whnf_match_plugin(tc)
|
||||||
|
function tst_match(p, t)
|
||||||
|
local r1, r2 = match(p, t, plugin)
|
||||||
|
assert(r1)
|
||||||
|
print("--------------")
|
||||||
|
for i = 1, #r1 do
|
||||||
|
print(" expr:#" .. i .. " := " .. tostring(r1[i]))
|
||||||
|
end
|
||||||
|
for i = 1, #r2 do
|
||||||
|
print(" lvl:#" .. i .. " := " .. tostring(r2[i]))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local f = Const("f")
|
||||||
|
local g = Const("g")
|
||||||
|
local a = Const("a")
|
||||||
|
local b = Const("b")
|
||||||
|
local x = mk_var(0)
|
||||||
|
local p = g(x, f(x, a))
|
||||||
|
local t = g(a, f(b, a))
|
||||||
|
tst_match(p, t)
|
||||||
|
tst_match(f(x, x), f(a, b))
|
||||||
|
*)
|
Loading…
Reference in a new issue