refactor(library/match): use "special" meta-variables instead of free variables to represent placholders in the higher-order matcher

This commit is contained in:
Leonardo de Moura 2015-02-03 15:15:04 -08:00
parent fc6d9878c9
commit f79f43c702
7 changed files with 122 additions and 61 deletions

View file

@ -103,11 +103,16 @@ struct app_builder::imp {
buffer<expr> used_types;
buffer<bool> explicit_mask;
buffer<expr> domain_types;
unsigned idx = 0;
while (is_pi(type)) {
explicit_mask.push_back(is_explicit(binding_info(type)));
esubst.push_back(none_expr());
domain_types.push_back(binding_domain(type));
type = binding_body(type);
// TODO(Leo): perhaps, we should cache the result of this while-loop.
// The result of this computation can be reused in future calls.
expr meta = mk_idx_meta(idx, binding_domain(type));
idx++;
type = instantiate(binding_body(type), meta);
}
unsigned i = domain_types.size();
unsigned j = nargs;
@ -121,7 +126,7 @@ struct app_builder::imp {
if (cs)
return none_expr();
bool assigned = false;
if (!match(domain_types[i], arg_type, i, esubst.data(), lsubst.size(), lsubst.data(),
if (!match(domain_types[i], arg_type, lsubst, esubst,
nullptr, nullptr, &m_plugin, &assigned))
return none_expr();
if (assigned && use_cache) {
@ -135,7 +140,7 @@ struct app_builder::imp {
expr arg_type = m_tc.infer(*esubst[i], cs);
if (cs)
return none_expr();
if (!match(domain_types[i], arg_type, i, esubst.data(), lsubst.size(), lsubst.data(),
if (!match(domain_types[i], arg_type, lsubst, esubst,
nullptr, nullptr, &m_plugin))
return none_expr();
}

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <algorithm>
#include <utility>
#include "kernel/abstract.h"
#include "kernel/instantiate.h"
@ -28,6 +29,10 @@ level mk_idx_meta_univ(unsigned i) {
return mk_meta_univ(name(*g_tmp_prefix, i));
}
expr mk_idx_meta(unsigned i, expr const & type) {
return mk_metavar(name(*g_tmp_prefix, i), type);
}
bool is_idx_meta_univ(level const & l) {
if (!is_meta(l))
return false;
@ -40,6 +45,18 @@ unsigned to_meta_idx(level const & l) {
return meta_id(l).get_numeral();
}
bool is_idx_meta(expr const & e) {
if (!is_metavar(e))
return false;
name const & n = mlocal_name(e);
return !n.is_atomic() && n.is_numeral() && n.get_prefix() == *g_tmp_prefix;
}
unsigned to_meta_idx(expr const & e) {
lean_assert(is_idx_meta(e));
return mlocal_name(e).get_numeral();
}
class match_fn : public match_context {
unsigned m_esubst_sz;
optional<expr> * m_esubst;
@ -81,10 +98,8 @@ class match_fn : public match_context {
};
void _assign(expr const & p, expr const & t) {
lean_assert(var_idx(p) < m_esubst_sz);
unsigned vidx = var_idx(p);
unsigned sz = m_esubst_sz;
unsigned i = sz - vidx - 1;
lean_assert(to_meta_idx(p) < m_esubst_sz);
unsigned i = to_meta_idx(p);
m_stack.emplace_back(true, i);
m_esubst[i] = t;
if (m_assigned)
@ -105,11 +120,11 @@ class match_fn : public match_context {
}
optional<expr> _get_subst(expr const & x) const {
unsigned vidx = var_idx(x);
unsigned i = to_meta_idx(x);
unsigned sz = m_esubst_sz;
if (vidx >= sz)
if (i >= sz)
throw_exception();
return m_esubst[sz - vidx - 1];
return m_esubst[i];
}
optional<level> _get_subst(level const & x) const {
@ -307,7 +322,7 @@ class match_fn : public match_context {
}
bool _match(expr const & p, expr const & t) {
if (is_var(p)) {
if (is_idx_meta(p)) {
auto s = _get_subst(p);
if (s) {
return match_core(*s, t);
@ -318,7 +333,7 @@ class match_fn : public match_context {
} else if (is_app(p)) {
buffer<expr> args;
expr const & f = get_app_rev_args(p, args);
if (is_var(f)) {
if (is_idx_meta(f)) {
// higher-order pattern case
auto s = _get_subst(f);
if (s) {
@ -340,8 +355,8 @@ class match_fn : public match_context {
}
public:
match_fn(unsigned esubst_sz, optional<expr> * esubst,
unsigned lsubst_sz, optional<level> * lsubst,
match_fn(unsigned lsubst_sz, optional<level> * lsubst,
unsigned esubst_sz, optional<expr> * esubst,
name_generator const & ngen,
name_map<name> * name_subst, match_plugin const * plugin, bool * assigned):
m_esubst_sz(esubst_sz), m_esubst(esubst),
@ -353,21 +368,22 @@ public:
};
bool match(expr const & p, expr const & t,
unsigned esubst_sz, optional<expr> * esubst,
unsigned lsubst_sz, optional<level> * lsubst,
unsigned esubst_sz, optional<expr> * esubst,
name const * prefix, name_map<name> * name_subst, match_plugin const * plugin, bool * assigned) {
lean_assert(closed(t));
lean_assert(closed(p));
if (prefix)
return match_fn(esubst_sz, esubst, lsubst_sz, lsubst, name_generator(*prefix),
return match_fn(lsubst_sz, lsubst, esubst_sz, esubst, name_generator(*prefix),
name_subst, plugin, assigned).match(p, t);
else
return match_fn(esubst_sz, esubst, lsubst_sz, lsubst,
return match_fn(lsubst_sz, lsubst, esubst_sz, esubst,
name_generator(*g_tmp_prefix), name_subst, plugin, assigned).match(p, t);
}
bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst,
bool match(expr const & p, expr const & t, buffer<optional<level>> & lsubst, buffer<optional<expr>> & esubst,
name const * prefix, name_map<name> * name_subst, match_plugin const * plugin, bool * assigned) {
return match(p, t, esubst.size(), esubst.data(), lsubst.size(), lsubst.data(),
return match(p, t, lsubst.size(), lsubst.data(), esubst.size(), esubst.data(),
prefix, name_subst, plugin, assigned);
}
@ -410,20 +426,28 @@ static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
return r;
}
static unsigned get_idx_meta_univ_range(expr const & e) {
if (!has_univ_metavar(e))
return 0;
unsigned r = 0;
static pair<unsigned, unsigned> get_idx_meta_univ_ranges(expr const & e) {
if (!has_metavar(e))
return mk_pair(0, 0);
unsigned rlvl = 0;
unsigned rexp = 0;
for_each(e, [&](expr const & e, unsigned) {
if (!has_univ_metavar(e)) return false;
if (!has_metavar(e)) return false;
if (is_constant(e))
for (level const & l : const_levels(e))
r = updt_idx_meta_univ_range(l, r);
rlvl = updt_idx_meta_univ_range(l, rlvl);
if (is_sort(e))
r = updt_idx_meta_univ_range(sort_level(e), r);
rlvl = updt_idx_meta_univ_range(sort_level(e), rlvl);
if (is_idx_meta(e))
rexp = std::max(to_meta_idx(e) + 1, rexp);
return true;
});
return r;
return mk_pair(rlvl, rexp);
}
expr substitute(expr const & e, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst) {
// TODO(Leo)
return e;
}
DECL_UDATA(match_plugin)
@ -446,12 +470,14 @@ static int match(lua_State * L) {
plugin = &to_match_plugin(L, 3);
if (!closed(t))
throw exception("higher-order pattern matching failure, input term must not contain free variables");
unsigned r1 = get_free_var_range(p);
unsigned r2 = get_idx_meta_univ_range(p);
buffer<optional<expr>> esubst;
unsigned r1, r2;
auto r1_r2 = get_idx_meta_univ_ranges(p);
r1 = r1_r2.first;
r2 = r1_r2.second;
buffer<optional<level>> lsubst;
esubst.resize(r1); lsubst.resize(r2);
if (match(p, t, esubst, lsubst, nullptr, nullptr, plugin)) {
buffer<optional<expr>> esubst;
lsubst.resize(r1); esubst.resize(r2);
if (match(p, t, lsubst, esubst, nullptr, nullptr, plugin)) {
lua_newtable(L);
int i = 1;
for (auto s : esubst) {
@ -483,6 +509,10 @@ static int mk_idx_meta_univ(lua_State * L) {
return push_level(L, mk_idx_meta_univ(luaL_checkinteger(L, 1)));
}
static int mk_idx_meta(lua_State * L) {
return push_expr(L, mk_idx_meta(luaL_checkinteger(L, 1), to_expr(L, 2)));
}
void open_match(lua_State * L) {
luaL_newmetatable(L, match_plugin_mt);
lua_pushvalue(L, -1);
@ -492,6 +522,7 @@ void open_match(lua_State * L) {
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin");
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta");
SET_GLOBAL_FUN(match, "match");
}
}

View file

@ -14,13 +14,20 @@ Author: Leonardo de Moura
#include "kernel/environment.h"
namespace lean {
/** \brief Create a universe level metavariable that can be used as a placeholder in #hop_match.
/** \brief Create a universe level metavariable that can be used as a placeholder in #match.
\remark The index \c i is encoded in the hierarchical name, and can be quickly accessed.
In hop_match the substitution is also efficiently represented as an array (aka buffer).
In the match procedure the substitution is also efficiently represented as an array (aka buffer).
*/
level mk_idx_meta_univ(unsigned i);
/** \brief Create a special metavariable that can be used as a placeholder in #match.
\remark The index \c i is encoded in the hierarchical name, and can be quickly accessed.
In the match procedure the substitution is also efficiently represented as an array (aka buffer).
*/
expr mk_idx_meta(unsigned i, expr const & type);
/** \brief Context for match_plugins. */
class match_context {
public:
@ -54,19 +61,22 @@ match_plugin mk_whnf_match_plugin(type_checker & tc);
/**
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders
instead of meta-variables.
The substitution is stored in \c subst. Note that, this procedure treats "special" meta-variables
(created using #mk_idx_meta_univ and #mk_idx_meta) as placeholders. The substitution of these
metavariable can be quickly accessed using an index stored in them. The parameters
\c esubst and \c lsubst store the substitutions for them. There are just buffers.
\c subst is an assignment for the free variables occurring in \c p.
\pre \p and \c t must not contain free variables. Thus, free-variables must be replaced with local constants
before invoking this function.
\pre \c t must not contain free variables. If it does, they must be replaced with local constants
before invoking this functions.
Other (non special) meta-variables are treated as opaque constants.
\c p is a higher-order pattern when in all applications in \c p
1- A free variable is not the function OR
2- A free variable is the function, but all other arguments are distinct local constants.
1- A special meta-variable is not the function OR
2- A special meta-variable is the function, but all other arguments are distinct local constants.
\pre \c subst must be big enough to store all free variables occurring in subst
\pre \c esubst and \c lsubst must be big enough to store the substitution.
That is, their size should be > than the index of any special metavariable occuring in p.
If prefix is provided, then it is used for creating unique names.
@ -78,14 +88,19 @@ match_plugin mk_whnf_match_plugin(type_checker & tc);
If \c assigned is provided, then it is set to true if \c esubst or \c lsubst is updated.
*/
bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst,
bool match(expr const & p, expr const & t, buffer<optional<level>> & lsubst, buffer<optional<expr>> & esubst,
name const * prefix = nullptr, name_map<name> * name_subst = nullptr, match_plugin const * plugin = nullptr,
bool * assigned = nullptr);
bool match(expr const & p, expr const & t, unsigned esubst_sz, optional<expr> * esubst,
bool match(expr const & p, expr const & t,
unsigned lsubst_sz, optional<level> * lsubst,
unsigned esubst_sz, optional<expr> * esubst,
name const * prefix = nullptr, name_map<name> * name_subst = nullptr,
match_plugin const * plugin = nullptr, bool * assigned = nullptr);
/** \brief Replace special meta-variables (created using #mk_idx_meta_univ and #mk_idx_meta) with the values
provided in \c esubst and \c lsubst */
expr substitute(expr const & e, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst);
void open_match(lua_State * L);
void initialize_match();
void finalize_match();

View file

@ -1,4 +1,4 @@
import data.nat
import data.nat.basic
open nat
definition two1 : nat := 2
@ -20,9 +20,11 @@ function tst_match(p, t)
end
end
local nat = Const("nat")
local f = Const("f")
local two1 = Const("two1")
local two2 = Const("two2")
local succ = Const({"nat", "succ"})
tst_match(f(succ(mk_var(0)), two1), f(two2, two2))
local V0 = mk_idx_meta(0, nat)
tst_match(f(succ(V0), two1), f(two2, two2))
*)

View file

@ -1,4 +1,4 @@
import data.nat
import data.nat.basic
open nat
definition two1 : nat := 2
@ -22,11 +22,12 @@ function tst_match(p, t)
end
end
local nat = Const("nat")
local f = Const("f")
local g = Const("g")
local a = Const("a")
local b = Const("b")
local x = mk_var(0)
local x = mk_idx_meta(0, nat)
local p = g(x, f(x, a))
local t = g(a, f(b, a))
tst_match(p, t)

View file

@ -14,11 +14,15 @@ local a = Const("a")
local b = Const("b")
local x = Local("x", N)
local y = Local("y", N)
tst_match(f(Var(0), Var(0)), f(a, a))
tst_match(f(Var(0), Var(1)), f(a, b))
tst_match(Var(0)(x, y), f(x, f(x, y)))
assert(not match(Var(0)(x, x), f(x, f(x, y))))
assert(not match(Var(0)(x), f(x, y)))
tst_match(Pi(x, y, Var(2)(x)), Pi(x, y, f(f(x))))
tst_match(Fun(x, y, Var(2)(x)), Fun(x, y, f(f(x))))
assert(not match(Pi(x, Var(2)(x)), Pi(x, y, f(f(x)))))
local V0 = mk_idx_meta(0, N)
local V1 = mk_idx_meta(1, N)
tst_match(f(V0, V0), f(a, a))
tst_match(f(V0, V1), f(a, b))
local F0 = mk_idx_meta(0, Pi(x, y, N))
tst_match(F0(x, y), f(x, f(x, y)))
assert(not match(F0(x, x), f(x, f(x, y))))
assert(not match(F0(x), f(x, y)))
local F0 = mk_idx_meta(0, Pi(x, N))
tst_match(Pi(x, y, F0(x)), Pi(x, y, f(f(x))))
tst_match(Fun(x, y, F0(x)), Fun(x, y, f(f(x))))
assert(not match(Pi(x, F0(x)), Pi(x, y, f(f(x)))))

View file

@ -22,8 +22,11 @@ local z = level()
local f = Const("f", {u1, z})
local f2 = Const("f", {u1, u1+1})
local fp = Const("f", {mk_idx_meta_univ(0), mk_idx_meta_univ(1)})
tst_match(fp(Var(0), Var(0)), f(a, a))
tst_match(fp(Var(0), Var(1)), f2(a, b))
tst_match(Var(0)(x, y), f(x, f(x, y)))
assert(not match(Var(0)(x, x), f(x, f(x, y))))
assert(not match(Var(0)(x), f(x, y)))
local V0 = mk_idx_meta(0, N)
local V1 = mk_idx_meta(1, N)
tst_match(fp(V0, V0), f(a, a))
tst_match(fp(V0, V1), f2(a, b))
local F0 = mk_idx_meta(0, Pi(x, y, N))
tst_match(F0(x, y), f(x, f(x, y)))
assert(not match(F0(x, x), f(x, f(x, y))))
assert(not match(F0(x), f(x, y)))