feat(library/match): match universe levels

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-08-04 18:24:01 -07:00
parent b0a5ff7f93
commit c34c2f4f5c
3 changed files with 233 additions and 44 deletions

View file

@ -6,32 +6,74 @@ Author: Leonardo de Moura
*/ */
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
#include "kernel/for_each_fn.h"
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
#include "library/locals.h" #include "library/locals.h"
#include "library/match.h" #include "library/match.h"
namespace lean { namespace lean {
class match_fn { static name g_tmp_prefix = name::mk_internal_unique_name();
buffer<optional<expr>> & m_subst;
level mk_idx_meta_univ(unsigned i) {
return mk_meta_univ(name(g_tmp_prefix, i));
}
bool is_idx_meta_univ(level const & l) {
if (!is_meta(l))
return false;
name const & n = meta_id(l);
return !n.is_atomic() && n.is_numeral() && n.get_prefix() == g_tmp_prefix;
}
unsigned to_meta_idx(level const & l) {
lean_assert(is_idx_meta_univ(l));
return meta_id(l).get_numeral();
}
class match_fn : public match_context {
buffer<optional<expr>> & m_esubst;
buffer<optional<level>> & m_lsubst;
name_generator m_ngen; name_generator m_ngen;
name_map<name> * m_name_subst; name_map<name> * m_name_subst;
matcher_plugin const * m_plugin; match_plugin const * m_plugin;
void assign(expr const & p, expr const & t) { void _assign(expr const & p, expr const & t) {
lean_assert(var_idx(p) < m_subst.size()); lean_assert(var_idx(p) < m_esubst.size());
unsigned vidx = var_idx(p); unsigned vidx = var_idx(p);
unsigned sz = m_subst.size(); unsigned sz = m_esubst.size();
m_subst[sz - vidx - 1] = t; m_esubst[sz - vidx - 1] = t;
} }
optional<expr> get_subst(expr const & x) const { void _assign(level const & p, level const & l) {
unsigned vidx = var_idx(x); lean_assert(to_meta_idx(p) < m_lsubst.size());
unsigned sz = m_subst.size(); m_lsubst[to_meta_idx(p)] = l;
if (vidx >= sz)
throw exception("ill-formed higher-order matching problem");
return m_subst[sz - vidx - 1];
} }
void throw_exception() const {
throw exception("ill-formed higher-order matching problem");
}
optional<expr> _get_subst(expr const & x) const {
unsigned vidx = var_idx(x);
unsigned sz = m_esubst.size();
if (vidx >= sz)
throw_exception();
return m_esubst[sz - vidx - 1];
}
optional<level> _get_subst(level const & x) const {
unsigned i = to_meta_idx(x);
if (i > m_lsubst.size())
throw_exception();
return m_lsubst[i];
}
virtual void assign(expr const & p, expr const & t) { return _assign(p, t); }
virtual void assign(level const & p, level const & t) { return _assign(p, t); }
virtual optional<expr> get_subst(expr const & x) const { return _get_subst(x); }
virtual optional<level> get_subst(level const & x) const { return _get_subst(x); }
virtual name mk_name() { return m_ngen.next(); }
bool args_are_distinct_locals(buffer<expr> const & args) { bool args_are_distinct_locals(buffer<expr> const & args) {
for (auto it = args.begin(); it != args.end(); it++) { for (auto it = args.begin(); it != args.end(); it++) {
if (!is_local(*it) || contains_local(*it, args.begin(), it)) if (!is_local(*it) || contains_local(*it, args.begin(), it))
@ -48,10 +90,10 @@ class match_fn {
return some_expr(r); return some_expr(r);
} }
bool match_plugin(expr const & p, expr const & t) { bool try_plugin(expr const & p, expr const & t) {
if (!m_plugin) if (!m_plugin)
return false; return false;
return (*m_plugin)(p, t, m_subst, m_ngen.mk_child()); return (*m_plugin)(p, t, *this);
} }
bool match_binding(expr p, expr t) { bool match_binding(expr p, expr t) {
@ -92,43 +134,89 @@ class match_fn {
return match_core(app_fn(p), app_fn(t)) && match(app_arg(p), app_arg(t)); return match_core(app_fn(p), app_fn(t)) && match(app_arg(p), app_arg(t));
} }
bool match_level_core(level const & p, level const & l) {
if (p == l)
return true;
if (p.kind() == l.kind()) {
switch (p.kind()) {
case level_kind::Zero:
lean_unreachable(); // LCOV_EXCL_LINE
case level_kind::Param: case level_kind::Global: case level_kind::Meta:
return false;
case level_kind::Succ:
return match_level(succ_of(p), succ_of(l));
case level_kind::Max:
return
match_level(max_lhs(p), max_lhs(l)) &&
match_level(max_rhs(p), max_rhs(l));
case level_kind::IMax:
return
match_level(imax_lhs(p), imax_lhs(l)) &&
match_level(imax_rhs(p), imax_rhs(l));
}
}
return false;
}
bool match_level(level const & p, level const & l) {
if (is_idx_meta_univ(p)) {
auto s = _get_subst(p);
if (s) {
return match_level_core(*s, l);
} else {
_assign(p, l);
return true;
}
}
return match_level_core(p, l);
}
bool match_levels(levels ps, levels ls) {
while (ps && ls) {
if (!match_level(head(ps), head(ls)))
return false;
ps = tail(ps);
ls = tail(ls);
}
return true;
}
bool match_core(expr const & p, expr const & t) { bool match_core(expr const & p, expr const & t) {
if (p.kind() != t.kind()) if (p.kind() != t.kind())
return match_plugin(p, t); return try_plugin(p, t);
switch (p.kind()) { switch (p.kind()) {
case expr_kind::Local: case expr_kind::Meta: case expr_kind::Local: case expr_kind::Meta:
return mlocal_name(p) == mlocal_name(t); return mlocal_name(p) == mlocal_name(t);
case expr_kind::Var: case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Constant: case expr_kind::Constant:
// TODO(Leo): universe levels return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t));
return const_name(p) == const_name(t);
case expr_kind::Sort: case expr_kind::Sort:
// TODO(Leo): universe levels return match_level(sort_level(p), sort_level(t));
return true;
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Pi:
return match_binding(p, t) || match_plugin(p, t); return match_binding(p, t) || try_plugin(p, t);
case expr_kind::Macro: case expr_kind::Macro:
return match_macro(p, t) || match_plugin(p, t); return match_macro(p, t) || try_plugin(p, t);
case expr_kind::App: case expr_kind::App:
return match_app(p, t) || match_plugin(p, t); return match_app(p, t) || try_plugin(p, t);
} }
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
} }
public: public:
match_fn(buffer<optional<expr>> & subst, name_generator const & ngen, name_map<name> * name_subst, matcher_plugin const * plugin): match_fn(buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst, name_generator const & ngen,
m_subst(subst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin) {} name_map<name> * name_subst, match_plugin const * plugin):
m_esubst(esubst), m_lsubst(lsubst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin) {}
bool match(expr const & p, expr const & t) { bool match(expr const & p, expr const & t) {
if (is_var(p)) { if (is_var(p)) {
auto s = get_subst(p); auto s = _get_subst(p);
if (s) { if (s) {
return match_core(*s, t); return match_core(*s, t);
} else if (has_local(t)) { } else if (has_local(t)) {
return false; return false;
} else { } else {
assign(p, t); _assign(p, t);
return true; return true;
} }
} else if (is_app(p)) { } else if (is_app(p)) {
@ -136,7 +224,7 @@ public:
expr const & f = get_app_rev_args(p, args); expr const & f = get_app_rev_args(p, args);
if (is_var(f)) { if (is_var(f)) {
// higher-order pattern case // higher-order pattern case
auto s = get_subst(f); auto s = _get_subst(f);
if (s) { if (s) {
expr new_p = apply_beta(*s, args.size(), args.data()); expr new_p = apply_beta(*s, args.size(), args.data());
return match_core(new_p, t); return match_core(new_p, t);
@ -144,7 +232,7 @@ public:
if (args_are_distinct_locals(args)) { if (args_are_distinct_locals(args)) {
optional<expr> new_t = proj(t, args); optional<expr> new_t = proj(t, args);
if (new_t) { if (new_t) {
assign(f, *new_t); _assign(f, *new_t);
return true; return true;
} }
} }
@ -156,14 +244,42 @@ public:
} }
}; };
static name g_tmp_prefix = name::mk_internal_unique_name(); bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst,
bool match(expr const & p, expr const & t, buffer<optional<expr>> & subst, name const * prefix, name const * prefix, name_map<name> * name_subst, match_plugin const * plugin) {
name_map<name> * name_subst, matcher_plugin const * plugin) {
lean_assert(closed(t)); lean_assert(closed(t));
if (prefix) if (prefix)
return match_fn(subst, name_generator(*prefix), name_subst, plugin).match(p, t); return match_fn(esubst, lsubst, name_generator(*prefix), name_subst, plugin).match(p, t);
else else
return match_fn(subst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t); return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
}
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
for_each(l, [&](level const & l) {
if (!has_meta(l)) return false;
if (is_idx_meta_univ(l)) {
unsigned new_r = to_meta_idx(l) + 1;
if (new_r > r)
r = new_r;
}
return true;
});
return r;
}
static unsigned get_idx_meta_univ_range(expr const & e) {
if (!has_univ_metavar(e))
return 0;
unsigned r = 0;
for_each(e, [&](expr const & e, unsigned) {
if (!has_univ_metavar(e)) return false;
if (is_constant(e))
for (level const & l : const_levels(e))
r = updt_idx_meta_univ_range(l, r);
if (is_sort(e))
r = updt_idx_meta_univ_range(sort_level(e), r);
return true;
});
return r;
} }
static int match(lua_State * L) { static int match(lua_State * L) {
@ -171,13 +287,15 @@ static int match(lua_State * L) {
expr t = to_expr(L, 2); expr t = to_expr(L, 2);
if (!closed(t)) if (!closed(t))
throw exception("higher-order pattern matching failure, input term must not contain free variables"); throw exception("higher-order pattern matching failure, input term must not contain free variables");
unsigned k = get_free_var_range(p); unsigned r1 = get_free_var_range(p);
buffer<optional<expr>> subst; unsigned r2 = get_idx_meta_univ_range(p);
subst.resize(k); buffer<optional<expr>> esubst;
if (match(p, t, subst, nullptr, nullptr, nullptr)) { buffer<optional<level>> lsubst;
esubst.resize(r1); lsubst.resize(r2);
if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) {
lua_newtable(L); lua_newtable(L);
int i = 1; int i = 1;
for (auto s : subst) { for (auto s : esubst) {
if (s) if (s)
push_expr(L, *s); push_expr(L, *s);
else else
@ -185,13 +303,29 @@ static int match(lua_State * L) {
lua_rawseti(L, -2, i); lua_rawseti(L, -2, i);
i = i + 1; i = i + 1;
} }
lua_newtable(L);
i = 1;
for (auto s : lsubst) {
if (s)
push_level(L, *s);
else
lua_pushboolean(L, false);
lua_rawseti(L, -2, i);
i = i + 1;
}
} else { } else {
lua_pushnil(L); lua_pushnil(L);
lua_pushnil(L);
} }
return 1; return 2;
}
static int mk_idx_meta_univ(lua_State * L) {
return push_level(L, mk_idx_meta_univ(luaL_checkinteger(L, 1)));
} }
void open_match(lua_State * L) { void open_match(lua_State * L) {
SET_GLOBAL_FUN(match, "match"); SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
SET_GLOBAL_FUN(match, "match");
} }
} }

View file

@ -14,12 +14,38 @@ Author: Leonardo de Moura
#include "kernel/environment.h" #include "kernel/environment.h"
namespace lean { namespace lean {
/** \brief Create a universe level metavariable that can be used as a placeholder in #hop_match.
\remark The index \c i is encoded in the hierarchical name, and can be quickly accessed.
In hop_match the substitution is also efficiently represented as an array (aka buffer).
*/
level mk_idx_meta_univ(unsigned i);
/** \brief Context for match_plugins. */
class match_context {
public:
/** \brief Create a fresh name. */
virtual name mk_name() = 0;
/** \brief Given a variable \c x, return its assignment (if available) */
virtual optional<expr> get_subst(expr const & x) const = 0;
/** \brief Given a universe level meta-variable \c x (created using #mk_idx_meta_univ), return its assignment (if available) */
virtual optional<level> get_subst(level const & x) const = 0;
/** \brief Assign the variable \c x to \c e
\pre \c x is not assigned
*/
virtual void assign(expr const & x, expr const & e) = 0;
/** \brief Assign the variable \c x to \c l
\pre \c x is not assigned, \c x was created using #mk_idx_meta_univ.
*/
virtual void assign(level const & x, level const & l) = 0;
};
/** \brief Callback for extending the higher-order pattern matching procedure. /** \brief Callback for extending the higher-order pattern matching procedure.
It is invoked before the matcher fails. It is invoked before the matcher fails.
plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t. plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t.
*/ */
typedef std::function<bool(expr const &, expr const &, buffer<optional<expr>> &, name_generator const &)> matcher_plugin; // NOLINT typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
/** /**
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p. \brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
@ -45,7 +71,7 @@ typedef std::function<bool(expr const &, expr const &, buffer<optional<expr>> &,
If the plugin is provided, then it is invoked before a failure. If the plugin is provided, then it is invoked before a failure.
*/ */
bool match(expr const & p, expr const & t, buffer<optional<expr>> & subst, name const * prefix = nullptr, bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buffer<optional<level>> & lsubst,
name_map<name> * name_subst = nullptr, matcher_plugin const * plugin = nullptr); name const * prefix = nullptr, name_map<name> * name_subst = nullptr, match_plugin const * plugin = nullptr);
void open_match(lua_State * L); void open_match(lua_State * L);
} }

29
tests/lua/hop2.lua Normal file
View file

@ -0,0 +1,29 @@
function tst_match(p, t)
local r1, r2 = match(p, t)
assert(r1)
print("--------------")
for i = 1, #r1 do
print(" expr:#" .. i .. " := " .. tostring(r1[i]))
end
for i = 1, #r2 do
print(" lvl:#" .. i .. " := " .. tostring(r2[i]))
end
end
local env = environment()
local N = Const("N")
local a = Const("a")
local b = Const("b")
local x = Local("x", N)
local y = Local("y", N)
local u1 = mk_global_univ("u1")
local u2 = mk_global_univ("u2")
local z = level()
local f = Const("f", {u1, z})
local f2 = Const("f", {u1, u1+1})
local fp = Const("f", {mk_idx_meta_univ(0), mk_idx_meta_univ(1)})
tst_match(fp(Var(0), Var(0)), f(a, a))
tst_match(fp(Var(0), Var(1)), f2(a, b))
tst_match(Var(0)(x, y), f(x, f(x, y)))
assert(not match(Var(0)(x, x), f(x, f(x, y))))
assert(not match(Var(0)(x), f(x, y)))