diff --git a/src/library/hop_match.cpp b/src/library/hop_match.cpp index 5956a1fb1..e939839c8 100644 --- a/src/library/hop_match.cpp +++ b/src/library/hop_match.cpp @@ -11,12 +11,14 @@ Author: Leonardo de Moura #include "kernel/kernel.h" #include "library/equality.h" #include "library/kernel_bindings.h" +#include "library/hop_match.h" namespace lean { class hop_match_fn { buffer> & m_subst; buffer m_vars; + optional m_env; bool is_free_var(expr const & x, unsigned ctx_size) const { return is_var(x) && var_idx(x) >= ctx_size; @@ -62,7 +64,6 @@ class hop_match_fn { return true; } - /** \brief Return t' when all locally bound variables in \c t occur in vars at positions [0, vars_size). The locally bound variables occurring in \c t are replaced using the following mapping: @@ -150,6 +151,24 @@ class hop_match_fn { return some_expr(r); } + optional unfold_constant(expr const & c) { + if (is_constant(c)) { + auto obj = (*m_env)->find_object(const_name(c)); + if (obj && (obj->is_definition() || obj->is_builtin())) + return some_expr(obj->get_value()); + } + return none_expr(); + } + + bool match_constant(expr const & p, expr const & t) { + if (p == t) + return true; + auto new_p = unfold_constant(p); + if (new_p) + return match_constant(*new_p, t); + return false; + } + bool match(expr const & p, expr const & t, context const & ctx, unsigned ctx_size) { lean_assert(ctx.size() == ctx_size); if (is_free_var(p, ctx_size)) { @@ -184,6 +203,10 @@ class hop_match_fn { return true; } + if (m_env && is_constant(p)) { + return match_constant(p, t); + } + if (is_equality(p) && is_equality(t) && (!is_eq(p) || !is_eq(t))) { // Remark: if p and t are homogeneous equality, then we handle as an application (in the else branch) // We do that because we can get more information. For example, the pattern @@ -228,32 +251,32 @@ class hop_match_fn { lean_unreachable(); } public: - hop_match_fn(buffer> & subst):m_subst(subst) {} + hop_match_fn(buffer> & subst, optional const & env):m_subst(subst), m_env(env) {} bool operator()(expr const & p, expr const & t) { return match(p, t, context(), 0); } }; -bool hop_match(expr const & p, expr const & t, buffer> & subst) { - return hop_match_fn(subst)(p, t); +bool hop_match(expr const & p, expr const & t, buffer> & subst, optional const & env) { + return hop_match_fn(subst, env)(p, t); } -static int hop_match(lua_State * L) { +static int hop_match_core(lua_State * L, optional const & env) { int nargs = lua_gettop(L); expr p = to_expr(L, 1); expr t = to_expr(L, 2); int k = 0; - if (nargs == 3) { - k = luaL_checkinteger(L, 3); + if (nargs >= 4) { + k = luaL_checkinteger(L, 4); if (k < 0) - throw exception("hop_match, arg #3 must be non-negative"); + throw exception("hop_match, arg #4 must be non-negative"); } else { k = free_var_range(p); } buffer> subst; subst.resize(k); - if (hop_match(p, t, subst)) { + if (hop_match(p, t, subst, env)) { lua_newtable(L); int i = 1; for (auto s : subst) { @@ -271,6 +294,21 @@ static int hop_match(lua_State * L) { return 1; } +static int hop_match(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs >= 3) { + if (!lua_isnil(L, 3)) { + ro_shared_environment env(L, 3); + return hop_match_core(L, optional(env)); + } else { + return hop_match_core(L, optional()); + } + } else { + ro_shared_environment env(L); + return hop_match_core(L, optional(env)); + } +} + void open_hop_match(lua_State * L) { SET_GLOBAL_FUN(hop_match, "hop_match"); } diff --git a/src/library/hop_match.h b/src/library/hop_match.h index 1cfabf271..ada6a9b39 100644 --- a/src/library/hop_match.h +++ b/src/library/hop_match.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include "util/lua.h" #include "kernel/expr.h" +#include "kernel/environment.h" namespace lean { /** @@ -24,7 +25,11 @@ namespace lean { 2- A free variable is the function, but all other arguments are distinct locally bound variables. \pre \c subst must be big enough to store all free variables occurring in subst + + If an environment is provided, then a constant \c c matches a term \c t if + \c c is definitionally equal to \c t. */ -bool hop_match(expr const & p, expr const & t, buffer> & subst); +bool hop_match(expr const & p, expr const & t, buffer> & subst, + optional const & env = optional()); void open_hop_match(lua_State * L); } diff --git a/tests/lua/hop2.lua b/tests/lua/hop2.lua index cb9abdb45..4f9685604 100644 --- a/tests/lua/hop2.lua +++ b/tests/lua/hop2.lua @@ -16,14 +16,19 @@ function funbody(e) return e end -function hoptst(rule, target, expected) +function hoptst(rule, target, expected, perfect_match, no_env) if expected == nil then expected = true end local th = parse_lean(rule) local p = pibody(th):arg(2) local t = funbody(parse_lean(target)) - local r = hop_match(p, t) + local r + if no_env then + r = hop_match(p, t, nil) + else + r = hop_match(p, t) + end -- print(p, t) if (r and not expected) or (not r and expected) then error("test failed: " .. tostring(rule) .. " === " .. tostring(target)) @@ -35,13 +40,15 @@ function hoptst(rule, target, expected) print("#" .. tostring(i) .. " <--- " .. tostring(r[i])) end print "" - t = t:beta_reduce() - if s ~= t then - print("Mismatch") - print(s) - print(t) + if perfect_match then + t = t:beta_reduce() + if s ~= t then + print("Mismatch") + print(s) + print(t) + end + assert(s == t) end - assert(s == t) end end @@ -118,8 +125,9 @@ hoptst([[forall (h : Nat -> Bool), (forall x y : Nat, h x) = true]], hoptst([[forall (h : Nat -> Bool), (forall x y : Nat, h y) = true]], [[fun (a b : Nat), forall x y : Nat, (fun z : Nat, z + y) (fun w1 w2 : Nat, w1 + w2 + y)]]) +parse_lean_cmds([[ + definition ww := 0 +]]) - - - - +hoptst('ww = 0', '0', true, false) +hoptst('ww = 0', '0', false, false, true)