feat(library/hop_match): optionally unfold constants when performing higher order matching
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
39c3b17eb7
commit
11719713ec
3 changed files with 73 additions and 22 deletions
|
@ -11,12 +11,14 @@ Author: Leonardo de Moura
|
|||
#include "kernel/kernel.h"
|
||||
#include "library/equality.h"
|
||||
#include "library/kernel_bindings.h"
|
||||
#include "library/hop_match.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
class hop_match_fn {
|
||||
buffer<optional<expr>> & m_subst;
|
||||
buffer<expr> m_vars;
|
||||
optional<ro_environment> m_env;
|
||||
|
||||
bool is_free_var(expr const & x, unsigned ctx_size) const {
|
||||
return is_var(x) && var_idx(x) >= ctx_size;
|
||||
|
@ -62,7 +64,6 @@ class hop_match_fn {
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
\brief Return t' when all locally bound variables in \c t occur in vars at positions [0, vars_size).
|
||||
The locally bound variables occurring in \c t are replaced using the following mapping:
|
||||
|
@ -150,6 +151,24 @@ class hop_match_fn {
|
|||
return some_expr(r);
|
||||
}
|
||||
|
||||
optional<expr> unfold_constant(expr const & c) {
|
||||
if (is_constant(c)) {
|
||||
auto obj = (*m_env)->find_object(const_name(c));
|
||||
if (obj && (obj->is_definition() || obj->is_builtin()))
|
||||
return some_expr(obj->get_value());
|
||||
}
|
||||
return none_expr();
|
||||
}
|
||||
|
||||
bool match_constant(expr const & p, expr const & t) {
|
||||
if (p == t)
|
||||
return true;
|
||||
auto new_p = unfold_constant(p);
|
||||
if (new_p)
|
||||
return match_constant(*new_p, t);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool match(expr const & p, expr const & t, context const & ctx, unsigned ctx_size) {
|
||||
lean_assert(ctx.size() == ctx_size);
|
||||
if (is_free_var(p, ctx_size)) {
|
||||
|
@ -184,6 +203,10 @@ class hop_match_fn {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (m_env && is_constant(p)) {
|
||||
return match_constant(p, t);
|
||||
}
|
||||
|
||||
if (is_equality(p) && is_equality(t) && (!is_eq(p) || !is_eq(t))) {
|
||||
// Remark: if p and t are homogeneous equality, then we handle as an application (in the else branch)
|
||||
// We do that because we can get more information. For example, the pattern
|
||||
|
@ -228,32 +251,32 @@ class hop_match_fn {
|
|||
lean_unreachable();
|
||||
}
|
||||
public:
|
||||
hop_match_fn(buffer<optional<expr>> & subst):m_subst(subst) {}
|
||||
hop_match_fn(buffer<optional<expr>> & subst, optional<ro_environment> const & env):m_subst(subst), m_env(env) {}
|
||||
|
||||
bool operator()(expr const & p, expr const & t) {
|
||||
return match(p, t, context(), 0);
|
||||
}
|
||||
};
|
||||
|
||||
bool hop_match(expr const & p, expr const & t, buffer<optional<expr>> & subst) {
|
||||
return hop_match_fn(subst)(p, t);
|
||||
bool hop_match(expr const & p, expr const & t, buffer<optional<expr>> & subst, optional<ro_environment> const & env) {
|
||||
return hop_match_fn(subst, env)(p, t);
|
||||
}
|
||||
|
||||
static int hop_match(lua_State * L) {
|
||||
static int hop_match_core(lua_State * L, optional<ro_environment> const & env) {
|
||||
int nargs = lua_gettop(L);
|
||||
expr p = to_expr(L, 1);
|
||||
expr t = to_expr(L, 2);
|
||||
int k = 0;
|
||||
if (nargs == 3) {
|
||||
k = luaL_checkinteger(L, 3);
|
||||
if (nargs >= 4) {
|
||||
k = luaL_checkinteger(L, 4);
|
||||
if (k < 0)
|
||||
throw exception("hop_match, arg #3 must be non-negative");
|
||||
throw exception("hop_match, arg #4 must be non-negative");
|
||||
} else {
|
||||
k = free_var_range(p);
|
||||
}
|
||||
buffer<optional<expr>> subst;
|
||||
subst.resize(k);
|
||||
if (hop_match(p, t, subst)) {
|
||||
if (hop_match(p, t, subst, env)) {
|
||||
lua_newtable(L);
|
||||
int i = 1;
|
||||
for (auto s : subst) {
|
||||
|
@ -271,6 +294,21 @@ static int hop_match(lua_State * L) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
static int hop_match(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs >= 3) {
|
||||
if (!lua_isnil(L, 3)) {
|
||||
ro_shared_environment env(L, 3);
|
||||
return hop_match_core(L, optional<ro_environment>(env));
|
||||
} else {
|
||||
return hop_match_core(L, optional<ro_environment>());
|
||||
}
|
||||
} else {
|
||||
ro_shared_environment env(L);
|
||||
return hop_match_core(L, optional<ro_environment>(env));
|
||||
}
|
||||
}
|
||||
|
||||
void open_hop_match(lua_State * L) {
|
||||
SET_GLOBAL_FUN(hop_match, "hop_match");
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include "util/lua.h"
|
||||
#include "kernel/expr.h"
|
||||
#include "kernel/environment.h"
|
||||
|
||||
namespace lean {
|
||||
/**
|
||||
|
@ -24,7 +25,11 @@ namespace lean {
|
|||
2- A free variable is the function, but all other arguments are distinct locally bound variables.
|
||||
|
||||
\pre \c subst must be big enough to store all free variables occurring in subst
|
||||
|
||||
If an environment is provided, then a constant \c c matches a term \c t if
|
||||
\c c is definitionally equal to \c t.
|
||||
*/
|
||||
bool hop_match(expr const & p, expr const & t, buffer<optional<expr>> & subst);
|
||||
bool hop_match(expr const & p, expr const & t, buffer<optional<expr>> & subst,
|
||||
optional<ro_environment> const & env = optional<ro_environment>());
|
||||
void open_hop_match(lua_State * L);
|
||||
}
|
||||
|
|
|
@ -16,14 +16,19 @@ function funbody(e)
|
|||
return e
|
||||
end
|
||||
|
||||
function hoptst(rule, target, expected)
|
||||
function hoptst(rule, target, expected, perfect_match, no_env)
|
||||
if expected == nil then
|
||||
expected = true
|
||||
end
|
||||
local th = parse_lean(rule)
|
||||
local p = pibody(th):arg(2)
|
||||
local t = funbody(parse_lean(target))
|
||||
local r = hop_match(p, t)
|
||||
local r
|
||||
if no_env then
|
||||
r = hop_match(p, t, nil)
|
||||
else
|
||||
r = hop_match(p, t)
|
||||
end
|
||||
-- print(p, t)
|
||||
if (r and not expected) or (not r and expected) then
|
||||
error("test failed: " .. tostring(rule) .. " === " .. tostring(target))
|
||||
|
@ -35,13 +40,15 @@ function hoptst(rule, target, expected)
|
|||
print("#" .. tostring(i) .. " <--- " .. tostring(r[i]))
|
||||
end
|
||||
print ""
|
||||
t = t:beta_reduce()
|
||||
if s ~= t then
|
||||
print("Mismatch")
|
||||
print(s)
|
||||
print(t)
|
||||
if perfect_match then
|
||||
t = t:beta_reduce()
|
||||
if s ~= t then
|
||||
print("Mismatch")
|
||||
print(s)
|
||||
print(t)
|
||||
end
|
||||
assert(s == t)
|
||||
end
|
||||
assert(s == t)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -118,8 +125,9 @@ hoptst([[forall (h : Nat -> Bool), (forall x y : Nat, h x) = true]],
|
|||
hoptst([[forall (h : Nat -> Bool), (forall x y : Nat, h y) = true]],
|
||||
[[fun (a b : Nat), forall x y : Nat, (fun z : Nat, z + y) (fun w1 w2 : Nat, w1 + w2 + y)]])
|
||||
|
||||
parse_lean_cmds([[
|
||||
definition ww := 0
|
||||
]])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
hoptst('ww = 0', '0', true, false)
|
||||
hoptst('ww = 0', '0', false, false, true)
|
||||
|
|
Loading…
Add table
Reference in a new issue