feat(library/match): match universe levels
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
b0a5ff7f93
commit
c34c2f4f5c
3 changed files with 233 additions and 44 deletions
|
@ -6,32 +6,74 @@ Author: Leonardo de Moura
|
||||||
*/
|
*/
|
||||||
#include "kernel/abstract.h"
|
#include "kernel/abstract.h"
|
||||||
#include "kernel/instantiate.h"
|
#include "kernel/instantiate.h"
|
||||||
|
#include "kernel/for_each_fn.h"
|
||||||
#include "library/kernel_bindings.h"
|
#include "library/kernel_bindings.h"
|
||||||
#include "library/locals.h"
|
#include "library/locals.h"
|
||||||
#include "library/match.h"
|
#include "library/match.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
class match_fn {
|
static name g_tmp_prefix = name::mk_internal_unique_name();
|
||||||
buffer<optional<expr>> & m_subst;
|
|
||||||
|
level mk_idx_meta_univ(unsigned i) {
|
||||||
|
return mk_meta_univ(name(g_tmp_prefix, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_idx_meta_univ(level const & l) {
|
||||||
|
if (!is_meta(l))
|
||||||
|
return false;
|
||||||
|
name const & n = meta_id(l);
|
||||||
|
return !n.is_atomic() && n.is_numeral() && n.get_prefix() == g_tmp_prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned to_meta_idx(level const & l) {
|
||||||
|
lean_assert(is_idx_meta_univ(l));
|
||||||
|
return meta_id(l).get_numeral();
|
||||||
|
}
|
||||||
|
|
||||||
|
class match_fn : public match_context {
|
||||||
|
buffer<optional<expr>> & m_esubst;
|
||||||
|
buffer<optional<level>> & m_lsubst;
|
||||||
name_generator m_ngen;
|
name_generator m_ngen;
|
||||||
name_map<name> * m_name_subst;
|
name_map<name> * m_name_subst;
|
||||||
matcher_plugin const * m_plugin;
|
match_plugin const * m_plugin;
|
||||||
|
|
||||||
void assign(expr const & p, expr const & t) {
|
void _assign(expr const & p, expr const & t) {
|
||||||
lean_assert(var_idx(p) < m_subst.size());
|
lean_assert(var_idx(p) < m_esubst.size());
|
||||||
unsigned vidx = var_idx(p);
|
unsigned vidx = var_idx(p);
|
||||||
unsigned sz = m_subst.size();
|
unsigned sz = m_esubst.size();
|
||||||
m_subst[sz - vidx - 1] = t;
|
m_esubst[sz - vidx - 1] = t;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<expr> get_subst(expr const & x) const {
|
void _assign(level const & p, level const & l) {
|
||||||
unsigned vidx = var_idx(x);
|
lean_assert(to_meta_idx(p) < m_lsubst.size());
|
||||||
unsigned sz = m_subst.size();
|
m_lsubst[to_meta_idx(p)] = l;
|
||||||
if (vidx >= sz)
|
|
||||||
throw exception("ill-formed higher-order matching problem");
|
|
||||||
return m_subst[sz - vidx - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void throw_exception() const {
|
||||||
|
throw exception("ill-formed higher-order matching problem");
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<expr> _get_subst(expr const & x) const {
|
||||||
|
unsigned vidx = var_idx(x);
|
||||||
|
unsigned sz = m_esubst.size();
|
||||||
|
if (vidx >= sz)
|
||||||
|
throw_exception();
|
||||||
|
return m_esubst[sz - vidx - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<level> _get_subst(level const & x) const {
|
||||||
|
unsigned i = to_meta_idx(x);
|
||||||
|
if (i > m_lsubst.size())
|
||||||
|
throw_exception();
|
||||||
|
return m_lsubst[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void assign(expr const & p, expr const & t) { return _assign(p, t); }
|
||||||
|
virtual void assign(level const & p, level const & t) { return _assign(p, t); }
|
||||||
|
virtual optional<expr> get_subst(expr const & x) const { return _get_subst(x); }
|
||||||
|
virtual optional<level> get_subst(level const & x) const { return _get_subst(x); }
|
||||||
|
virtual name mk_name() { return m_ngen.next(); }
|
||||||
|
|
||||||
bool args_are_distinct_locals(buffer<expr> const & args) {
|
bool args_are_distinct_locals(buffer<expr> const & args) {
|
||||||
for (auto it = args.begin(); it != args.end(); it++) {
|
for (auto it = args.begin(); it != args.end(); it++) {
|
||||||
if (!is_local(*it) || contains_local(*it, args.begin(), it))
|
if (!is_local(*it) || contains_local(*it, args.begin(), it))
|
||||||
|
@ -48,10 +90,10 @@ class match_fn {
|
||||||
return some_expr(r);
|
return some_expr(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match_plugin(expr const & p, expr const & t) {
|
bool try_plugin(expr const & p, expr const & t) {
|
||||||
if (!m_plugin)
|
if (!m_plugin)
|
||||||
return false;
|
return false;
|
||||||
return (*m_plugin)(p, t, m_subst, m_ngen.mk_child());
|
return (*m_plugin)(p, t, *this);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool match_binding(expr p, expr t) {
|
bool match_binding(expr p, expr t) {
|
||||||
|
@ -92,43 +134,89 @@ class match_fn {
|
||||||
return match_core(app_fn(p), app_fn(t)) && match(app_arg(p), app_arg(t));
|
return match_core(app_fn(p), app_fn(t)) && match(app_arg(p), app_arg(t));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool match_level_core(level const & p, level const & l) {
|
||||||
|
if (p == l)
|
||||||
|
return true;
|
||||||
|
if (p.kind() == l.kind()) {
|
||||||
|
switch (p.kind()) {
|
||||||
|
case level_kind::Zero:
|
||||||
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
|
case level_kind::Param: case level_kind::Global: case level_kind::Meta:
|
||||||
|
return false;
|
||||||
|
case level_kind::Succ:
|
||||||
|
return match_level(succ_of(p), succ_of(l));
|
||||||
|
case level_kind::Max:
|
||||||
|
return
|
||||||
|
match_level(max_lhs(p), max_lhs(l)) &&
|
||||||
|
match_level(max_rhs(p), max_rhs(l));
|
||||||
|
case level_kind::IMax:
|
||||||
|
return
|
||||||
|
match_level(imax_lhs(p), imax_lhs(l)) &&
|
||||||
|
match_level(imax_rhs(p), imax_rhs(l));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_level(level const & p, level const & l) {
|
||||||
|
if (is_idx_meta_univ(p)) {
|
||||||
|
auto s = _get_subst(p);
|
||||||
|
if (s) {
|
||||||
|
return match_level_core(*s, l);
|
||||||
|
} else {
|
||||||
|
_assign(p, l);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return match_level_core(p, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool match_levels(levels ps, levels ls) {
|
||||||
|
while (ps && ls) {
|
||||||
|
if (!match_level(head(ps), head(ls)))
|
||||||
|
return false;
|
||||||
|
ps = tail(ps);
|
||||||
|
ls = tail(ls);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool match_core(expr const & p, expr const & t) {
|
bool match_core(expr const & p, expr const & t) {
|
||||||
if (p.kind() != t.kind())
|
if (p.kind() != t.kind())
|
||||||
return match_plugin(p, t);
|
return try_plugin(p, t);
|
||||||
switch (p.kind()) {
|
switch (p.kind()) {
|
||||||
case expr_kind::Local: case expr_kind::Meta:
|
case expr_kind::Local: case expr_kind::Meta:
|
||||||
return mlocal_name(p) == mlocal_name(t);
|
return mlocal_name(p) == mlocal_name(t);
|
||||||
case expr_kind::Var:
|
case expr_kind::Var:
|
||||||
lean_unreachable(); // LCOV_EXCL_LINE
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
case expr_kind::Constant:
|
case expr_kind::Constant:
|
||||||
// TODO(Leo): universe levels
|
return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t));
|
||||||
return const_name(p) == const_name(t);
|
|
||||||
case expr_kind::Sort:
|
case expr_kind::Sort:
|
||||||
// TODO(Leo): universe levels
|
return match_level(sort_level(p), sort_level(t));
|
||||||
return true;
|
|
||||||
case expr_kind::Lambda: case expr_kind::Pi:
|
case expr_kind::Lambda: case expr_kind::Pi:
|
||||||
return match_binding(p, t) || match_plugin(p, t);
|
return match_binding(p, t) || try_plugin(p, t);
|
||||||
case expr_kind::Macro:
|
case expr_kind::Macro:
|
||||||
return match_macro(p, t) || match_plugin(p, t);
|
return match_macro(p, t) || try_plugin(p, t);
|
||||||
case expr_kind::App:
|
case expr_kind::App:
|
||||||
return match_app(p, t) || match_plugin(p, t);
|
return match_app(p, t) || try_plugin(p, t);
|
||||||
}
|
}
|
||||||
lean_unreachable(); // LCOV_EXCL_LINE
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
match_fn(buffer<optional<expr>> & subst, name_generator const & ngen, name_map<name> * name_subst, matcher_plugin const * plugin):
|
match_fn(buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst, name_generator const & ngen,
|
||||||
m_subst(subst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin) {}
|
name_map<name> * name_subst, match_plugin const * plugin):
|
||||||
|
m_esubst(esubst), m_lsubst(lsubst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin) {}
|
||||||
|
|
||||||
bool match(expr const & p, expr const & t) {
|
bool match(expr const & p, expr const & t) {
|
||||||
if (is_var(p)) {
|
if (is_var(p)) {
|
||||||
auto s = get_subst(p);
|
auto s = _get_subst(p);
|
||||||
if (s) {
|
if (s) {
|
||||||
return match_core(*s, t);
|
return match_core(*s, t);
|
||||||
} else if (has_local(t)) {
|
} else if (has_local(t)) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
assign(p, t);
|
_assign(p, t);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} else if (is_app(p)) {
|
} else if (is_app(p)) {
|
||||||
|
@ -136,7 +224,7 @@ public:
|
||||||
expr const & f = get_app_rev_args(p, args);
|
expr const & f = get_app_rev_args(p, args);
|
||||||
if (is_var(f)) {
|
if (is_var(f)) {
|
||||||
// higher-order pattern case
|
// higher-order pattern case
|
||||||
auto s = get_subst(f);
|
auto s = _get_subst(f);
|
||||||
if (s) {
|
if (s) {
|
||||||
expr new_p = apply_beta(*s, args.size(), args.data());
|
expr new_p = apply_beta(*s, args.size(), args.data());
|
||||||
return match_core(new_p, t);
|
return match_core(new_p, t);
|
||||||
|
@ -144,7 +232,7 @@ public:
|
||||||
if (args_are_distinct_locals(args)) {
|
if (args_are_distinct_locals(args)) {
|
||||||
optional<expr> new_t = proj(t, args);
|
optional<expr> new_t = proj(t, args);
|
||||||
if (new_t) {
|
if (new_t) {
|
||||||
assign(f, *new_t);
|
_assign(f, *new_t);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -156,14 +244,42 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static name g_tmp_prefix = name::mk_internal_unique_name();
|
bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst,
|
||||||
bool match(expr const & p, expr const & t, buffer<optional<expr>> & subst, name const * prefix,
|
name const * prefix, name_map<name> * name_subst, match_plugin const * plugin) {
|
||||||
name_map<name> * name_subst, matcher_plugin const * plugin) {
|
|
||||||
lean_assert(closed(t));
|
lean_assert(closed(t));
|
||||||
if (prefix)
|
if (prefix)
|
||||||
return match_fn(subst, name_generator(*prefix), name_subst, plugin).match(p, t);
|
return match_fn(esubst, lsubst, name_generator(*prefix), name_subst, plugin).match(p, t);
|
||||||
else
|
else
|
||||||
return match_fn(subst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
|
return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
|
||||||
|
for_each(l, [&](level const & l) {
|
||||||
|
if (!has_meta(l)) return false;
|
||||||
|
if (is_idx_meta_univ(l)) {
|
||||||
|
unsigned new_r = to_meta_idx(l) + 1;
|
||||||
|
if (new_r > r)
|
||||||
|
r = new_r;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned get_idx_meta_univ_range(expr const & e) {
|
||||||
|
if (!has_univ_metavar(e))
|
||||||
|
return 0;
|
||||||
|
unsigned r = 0;
|
||||||
|
for_each(e, [&](expr const & e, unsigned) {
|
||||||
|
if (!has_univ_metavar(e)) return false;
|
||||||
|
if (is_constant(e))
|
||||||
|
for (level const & l : const_levels(e))
|
||||||
|
r = updt_idx_meta_univ_range(l, r);
|
||||||
|
if (is_sort(e))
|
||||||
|
r = updt_idx_meta_univ_range(sort_level(e), r);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int match(lua_State * L) {
|
static int match(lua_State * L) {
|
||||||
|
@ -171,13 +287,15 @@ static int match(lua_State * L) {
|
||||||
expr t = to_expr(L, 2);
|
expr t = to_expr(L, 2);
|
||||||
if (!closed(t))
|
if (!closed(t))
|
||||||
throw exception("higher-order pattern matching failure, input term must not contain free variables");
|
throw exception("higher-order pattern matching failure, input term must not contain free variables");
|
||||||
unsigned k = get_free_var_range(p);
|
unsigned r1 = get_free_var_range(p);
|
||||||
buffer<optional<expr>> subst;
|
unsigned r2 = get_idx_meta_univ_range(p);
|
||||||
subst.resize(k);
|
buffer<optional<expr>> esubst;
|
||||||
if (match(p, t, subst, nullptr, nullptr, nullptr)) {
|
buffer<optional<level>> lsubst;
|
||||||
|
esubst.resize(r1); lsubst.resize(r2);
|
||||||
|
if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) {
|
||||||
lua_newtable(L);
|
lua_newtable(L);
|
||||||
int i = 1;
|
int i = 1;
|
||||||
for (auto s : subst) {
|
for (auto s : esubst) {
|
||||||
if (s)
|
if (s)
|
||||||
push_expr(L, *s);
|
push_expr(L, *s);
|
||||||
else
|
else
|
||||||
|
@ -185,13 +303,29 @@ static int match(lua_State * L) {
|
||||||
lua_rawseti(L, -2, i);
|
lua_rawseti(L, -2, i);
|
||||||
i = i + 1;
|
i = i + 1;
|
||||||
}
|
}
|
||||||
|
lua_newtable(L);
|
||||||
|
i = 1;
|
||||||
|
for (auto s : lsubst) {
|
||||||
|
if (s)
|
||||||
|
push_level(L, *s);
|
||||||
|
else
|
||||||
|
lua_pushboolean(L, false);
|
||||||
|
lua_rawseti(L, -2, i);
|
||||||
|
i = i + 1;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
lua_pushnil(L);
|
lua_pushnil(L);
|
||||||
|
lua_pushnil(L);
|
||||||
}
|
}
|
||||||
return 1;
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int mk_idx_meta_univ(lua_State * L) {
|
||||||
|
return push_level(L, mk_idx_meta_univ(luaL_checkinteger(L, 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void open_match(lua_State * L) {
|
void open_match(lua_State * L) {
|
||||||
SET_GLOBAL_FUN(match, "match");
|
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
|
||||||
|
SET_GLOBAL_FUN(match, "match");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,12 +14,38 @@ Author: Leonardo de Moura
|
||||||
#include "kernel/environment.h"
|
#include "kernel/environment.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
|
/** \brief Create a universe level metavariable that can be used as a placeholder in #hop_match.
|
||||||
|
|
||||||
|
\remark The index \c i is encoded in the hierarchical name, and can be quickly accessed.
|
||||||
|
In hop_match the substitution is also efficiently represented as an array (aka buffer).
|
||||||
|
*/
|
||||||
|
level mk_idx_meta_univ(unsigned i);
|
||||||
|
|
||||||
|
/** \brief Context for match_plugins. */
|
||||||
|
class match_context {
|
||||||
|
public:
|
||||||
|
/** \brief Create a fresh name. */
|
||||||
|
virtual name mk_name() = 0;
|
||||||
|
/** \brief Given a variable \c x, return its assignment (if available) */
|
||||||
|
virtual optional<expr> get_subst(expr const & x) const = 0;
|
||||||
|
/** \brief Given a universe level meta-variable \c x (created using #mk_idx_meta_univ), return its assignment (if available) */
|
||||||
|
virtual optional<level> get_subst(level const & x) const = 0;
|
||||||
|
/** \brief Assign the variable \c x to \c e
|
||||||
|
\pre \c x is not assigned
|
||||||
|
*/
|
||||||
|
virtual void assign(expr const & x, expr const & e) = 0;
|
||||||
|
/** \brief Assign the variable \c x to \c l
|
||||||
|
\pre \c x is not assigned, \c x was created using #mk_idx_meta_univ.
|
||||||
|
*/
|
||||||
|
virtual void assign(level const & x, level const & l) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
/** \brief Callback for extending the higher-order pattern matching procedure.
|
/** \brief Callback for extending the higher-order pattern matching procedure.
|
||||||
It is invoked before the matcher fails.
|
It is invoked before the matcher fails.
|
||||||
|
|
||||||
plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t.
|
plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t.
|
||||||
*/
|
*/
|
||||||
typedef std::function<bool(expr const &, expr const &, buffer<optional<expr>> &, name_generator const &)> matcher_plugin; // NOLINT
|
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
|
||||||
|
|
||||||
/**
|
/**
|
||||||
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
|
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
|
||||||
|
@ -45,7 +71,7 @@ typedef std::function<bool(expr const &, expr const &, buffer<optional<expr>> &,
|
||||||
|
|
||||||
If the plugin is provided, then it is invoked before a failure.
|
If the plugin is provided, then it is invoked before a failure.
|
||||||
*/
|
*/
|
||||||
bool match(expr const & p, expr const & t, buffer<optional<expr>> & subst, name const * prefix = nullptr,
|
bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst,
|
||||||
name_map<name> * name_subst = nullptr, matcher_plugin const * plugin = nullptr);
|
name const * prefix = nullptr, name_map<name> * name_subst = nullptr, match_plugin const * plugin = nullptr);
|
||||||
void open_match(lua_State * L);
|
void open_match(lua_State * L);
|
||||||
}
|
}
|
||||||
|
|
29
tests/lua/hop2.lua
Normal file
29
tests/lua/hop2.lua
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
function tst_match(p, t)
|
||||||
|
local r1, r2 = match(p, t)
|
||||||
|
assert(r1)
|
||||||
|
print("--------------")
|
||||||
|
for i = 1, #r1 do
|
||||||
|
print(" expr:#" .. i .. " := " .. tostring(r1[i]))
|
||||||
|
end
|
||||||
|
for i = 1, #r2 do
|
||||||
|
print(" lvl:#" .. i .. " := " .. tostring(r2[i]))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local env = environment()
|
||||||
|
local N = Const("N")
|
||||||
|
local a = Const("a")
|
||||||
|
local b = Const("b")
|
||||||
|
local x = Local("x", N)
|
||||||
|
local y = Local("y", N)
|
||||||
|
local u1 = mk_global_univ("u1")
|
||||||
|
local u2 = mk_global_univ("u2")
|
||||||
|
local z = level()
|
||||||
|
local f = Const("f", {u1, z})
|
||||||
|
local f2 = Const("f", {u1, u1+1})
|
||||||
|
local fp = Const("f", {mk_idx_meta_univ(0), mk_idx_meta_univ(1)})
|
||||||
|
tst_match(fp(Var(0), Var(0)), f(a, a))
|
||||||
|
tst_match(fp(Var(0), Var(1)), f2(a, b))
|
||||||
|
tst_match(Var(0)(x, y), f(x, f(x, y)))
|
||||||
|
assert(not match(Var(0)(x, x), f(x, f(x, y))))
|
||||||
|
assert(not match(Var(0)(x), f(x, y)))
|
Loading…
Reference in a new issue