diff --git a/src/library/match.cpp b/src/library/match.cpp index 0882323bb..b61b375f5 100644 --- a/src/library/match.cpp +++ b/src/library/match.cpp @@ -6,32 +6,74 @@ Author: Leonardo de Moura */ #include "kernel/abstract.h" #include "kernel/instantiate.h" +#include "kernel/for_each_fn.h" #include "library/kernel_bindings.h" #include "library/locals.h" #include "library/match.h" namespace lean { -class match_fn { - buffer> & m_subst; +static name g_tmp_prefix = name::mk_internal_unique_name(); + +level mk_idx_meta_univ(unsigned i) { + return mk_meta_univ(name(g_tmp_prefix, i)); +} + +bool is_idx_meta_univ(level const & l) { + if (!is_meta(l)) + return false; + name const & n = meta_id(l); + return !n.is_atomic() && n.is_numeral() && n.get_prefix() == g_tmp_prefix; +} + +unsigned to_meta_idx(level const & l) { + lean_assert(is_idx_meta_univ(l)); + return meta_id(l).get_numeral(); +} + +class match_fn : public match_context { + buffer> & m_esubst; + buffer> & m_lsubst; name_generator m_ngen; name_map * m_name_subst; - matcher_plugin const * m_plugin; + match_plugin const * m_plugin; - void assign(expr const & p, expr const & t) { - lean_assert(var_idx(p) < m_subst.size()); + void _assign(expr const & p, expr const & t) { + lean_assert(var_idx(p) < m_esubst.size()); unsigned vidx = var_idx(p); - unsigned sz = m_subst.size(); - m_subst[sz - vidx - 1] = t; + unsigned sz = m_esubst.size(); + m_esubst[sz - vidx - 1] = t; } - optional get_subst(expr const & x) const { - unsigned vidx = var_idx(x); - unsigned sz = m_subst.size(); - if (vidx >= sz) - throw exception("ill-formed higher-order matching problem"); - return m_subst[sz - vidx - 1]; + void _assign(level const & p, level const & l) { + lean_assert(to_meta_idx(p) < m_lsubst.size()); + m_lsubst[to_meta_idx(p)] = l; } + void throw_exception() const { + throw exception("ill-formed higher-order matching problem"); + } + + optional _get_subst(expr const & x) const { + unsigned vidx = var_idx(x); + unsigned sz = m_esubst.size(); + if (vidx >= sz) + throw_exception(); + return m_esubst[sz - vidx - 1]; + } + + optional _get_subst(level const & x) const { + unsigned i = to_meta_idx(x); + if (i > m_lsubst.size()) + throw_exception(); + return m_lsubst[i]; + } + + virtual void assign(expr const & p, expr const & t) { return _assign(p, t); } + virtual void assign(level const & p, level const & t) { return _assign(p, t); } + virtual optional get_subst(expr const & x) const { return _get_subst(x); } + virtual optional get_subst(level const & x) const { return _get_subst(x); } + virtual name mk_name() { return m_ngen.next(); } + bool args_are_distinct_locals(buffer const & args) { for (auto it = args.begin(); it != args.end(); it++) { if (!is_local(*it) || contains_local(*it, args.begin(), it)) @@ -48,10 +90,10 @@ class match_fn { return some_expr(r); } - bool match_plugin(expr const & p, expr const & t) { + bool try_plugin(expr const & p, expr const & t) { if (!m_plugin) return false; - return (*m_plugin)(p, t, m_subst, m_ngen.mk_child()); + return (*m_plugin)(p, t, *this); } bool match_binding(expr p, expr t) { @@ -92,43 +134,89 @@ class match_fn { return match_core(app_fn(p), app_fn(t)) && match(app_arg(p), app_arg(t)); } + bool match_level_core(level const & p, level const & l) { + if (p == l) + return true; + if (p.kind() == l.kind()) { + switch (p.kind()) { + case level_kind::Zero: + lean_unreachable(); // LCOV_EXCL_LINE + case level_kind::Param: case level_kind::Global: case level_kind::Meta: + return false; + case level_kind::Succ: + return match_level(succ_of(p), succ_of(l)); + case level_kind::Max: + return + match_level(max_lhs(p), max_lhs(l)) && + match_level(max_rhs(p), max_rhs(l)); + case level_kind::IMax: + return + match_level(imax_lhs(p), imax_lhs(l)) && + match_level(imax_rhs(p), imax_rhs(l)); + } + } + return false; + } + + bool match_level(level const & p, level const & l) { + if (is_idx_meta_univ(p)) { + auto s = _get_subst(p); + if (s) { + return match_level_core(*s, l); + } else { + _assign(p, l); + return true; + } + } + return match_level_core(p, l); + } + + bool match_levels(levels ps, levels ls) { + while (ps && ls) { + if (!match_level(head(ps), head(ls))) + return false; + ps = tail(ps); + ls = tail(ls); + } + return true; + } + bool match_core(expr const & p, expr const & t) { if (p.kind() != t.kind()) - return match_plugin(p, t); + return try_plugin(p, t); switch (p.kind()) { case expr_kind::Local: case expr_kind::Meta: return mlocal_name(p) == mlocal_name(t); case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: - // TODO(Leo): universe levels - return const_name(p) == const_name(t); + return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t)); case expr_kind::Sort: - // TODO(Leo): universe levels - return true; + return match_level(sort_level(p), sort_level(t)); case expr_kind::Lambda: case expr_kind::Pi: - return match_binding(p, t) || match_plugin(p, t); + return match_binding(p, t) || try_plugin(p, t); case expr_kind::Macro: - return match_macro(p, t) || match_plugin(p, t); + return match_macro(p, t) || try_plugin(p, t); case expr_kind::App: - return match_app(p, t) || match_plugin(p, t); + return match_app(p, t) || try_plugin(p, t); } lean_unreachable(); // LCOV_EXCL_LINE } public: - match_fn(buffer> & subst, name_generator const & ngen, name_map * name_subst, matcher_plugin const * plugin): - m_subst(subst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin) {} + match_fn(buffer> & esubst, buffer> & lsubst, name_generator const & ngen, + name_map * name_subst, match_plugin const * plugin): + m_esubst(esubst), m_lsubst(lsubst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin) {} bool match(expr const & p, expr const & t) { if (is_var(p)) { - auto s = get_subst(p); + auto s = _get_subst(p); if (s) { return match_core(*s, t); } else if (has_local(t)) { return false; } else { - assign(p, t); + _assign(p, t); return true; } } else if (is_app(p)) { @@ -136,7 +224,7 @@ public: expr const & f = get_app_rev_args(p, args); if (is_var(f)) { // higher-order pattern case - auto s = get_subst(f); + auto s = _get_subst(f); if (s) { expr new_p = apply_beta(*s, args.size(), args.data()); return match_core(new_p, t); @@ -144,7 +232,7 @@ public: if (args_are_distinct_locals(args)) { optional new_t = proj(t, args); if (new_t) { - assign(f, *new_t); + _assign(f, *new_t); return true; } } @@ -156,14 +244,42 @@ public: } }; -static name g_tmp_prefix = name::mk_internal_unique_name(); -bool match(expr const & p, expr const & t, buffer> & subst, name const * prefix, - name_map * name_subst, matcher_plugin const * plugin) { +bool match(expr const & p, expr const & t, buffer> & esubst, buffer> & lsubst, + name const * prefix, name_map * name_subst, match_plugin const * plugin) { lean_assert(closed(t)); if (prefix) - return match_fn(subst, name_generator(*prefix), name_subst, plugin).match(p, t); + return match_fn(esubst, lsubst, name_generator(*prefix), name_subst, plugin).match(p, t); else - return match_fn(subst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t); + return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t); +} + +static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) { + for_each(l, [&](level const & l) { + if (!has_meta(l)) return false; + if (is_idx_meta_univ(l)) { + unsigned new_r = to_meta_idx(l) + 1; + if (new_r > r) + r = new_r; + } + return true; + }); + return r; +} + +static unsigned get_idx_meta_univ_range(expr const & e) { + if (!has_univ_metavar(e)) + return 0; + unsigned r = 0; + for_each(e, [&](expr const & e, unsigned) { + if (!has_univ_metavar(e)) return false; + if (is_constant(e)) + for (level const & l : const_levels(e)) + r = updt_idx_meta_univ_range(l, r); + if (is_sort(e)) + r = updt_idx_meta_univ_range(sort_level(e), r); + return true; + }); + return r; } static int match(lua_State * L) { @@ -171,13 +287,15 @@ static int match(lua_State * L) { expr t = to_expr(L, 2); if (!closed(t)) throw exception("higher-order pattern matching failure, input term must not contain free variables"); - unsigned k = get_free_var_range(p); - buffer> subst; - subst.resize(k); - if (match(p, t, subst, nullptr, nullptr, nullptr)) { + unsigned r1 = get_free_var_range(p); + unsigned r2 = get_idx_meta_univ_range(p); + buffer> esubst; + buffer> lsubst; + esubst.resize(r1); lsubst.resize(r2); + if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) { lua_newtable(L); int i = 1; - for (auto s : subst) { + for (auto s : esubst) { if (s) push_expr(L, *s); else @@ -185,13 +303,29 @@ static int match(lua_State * L) { lua_rawseti(L, -2, i); i = i + 1; } + lua_newtable(L); + i = 1; + for (auto s : lsubst) { + if (s) + push_level(L, *s); + else + lua_pushboolean(L, false); + lua_rawseti(L, -2, i); + i = i + 1; + } } else { lua_pushnil(L); + lua_pushnil(L); } - return 1; + return 2; +} + +static int mk_idx_meta_univ(lua_State * L) { + return push_level(L, mk_idx_meta_univ(luaL_checkinteger(L, 1))); } void open_match(lua_State * L) { - SET_GLOBAL_FUN(match, "match"); + SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); + SET_GLOBAL_FUN(match, "match"); } } diff --git a/src/library/match.h b/src/library/match.h index e5fd01a23..b9985edfc 100644 --- a/src/library/match.h +++ b/src/library/match.h @@ -14,12 +14,38 @@ Author: Leonardo de Moura #include "kernel/environment.h" namespace lean { +/** \brief Create a universe level metavariable that can be used as a placeholder in #hop_match. + + \remark The index \c i is encoded in the hierarchical name, and can be quickly accessed. + In hop_match the substitution is also efficiently represented as an array (aka buffer). +*/ +level mk_idx_meta_univ(unsigned i); + +/** \brief Context for match_plugins. */ +class match_context { +public: + /** \brief Create a fresh name. */ + virtual name mk_name() = 0; + /** \brief Given a variable \c x, return its assignment (if available) */ + virtual optional get_subst(expr const & x) const = 0; + /** \brief Given a universe level meta-variable \c x (created using #mk_idx_meta_univ), return its assignment (if available) */ + virtual optional get_subst(level const & x) const = 0; + /** \brief Assign the variable \c x to \c e + \pre \c x is not assigned + */ + virtual void assign(expr const & x, expr const & e) = 0; + /** \brief Assign the variable \c x to \c l + \pre \c x is not assigned, \c x was created using #mk_idx_meta_univ. + */ + virtual void assign(level const & x, level const & l) = 0; +}; + /** \brief Callback for extending the higher-order pattern matching procedure. It is invoked before the matcher fails. plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t. */ -typedef std::function> &, name_generator const &)> matcher_plugin; // NOLINT +typedef std::function match_plugin; // NOLINT /** \brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p. @@ -45,7 +71,7 @@ typedef std::function> &, If the plugin is provided, then it is invoked before a failure. */ -bool match(expr const & p, expr const & t, buffer> & subst, name const * prefix = nullptr, - name_map * name_subst = nullptr, matcher_plugin const * plugin = nullptr); +bool match(expr const & p, expr const & t, buffer> & esubst, buffer> & lsubst, + name const * prefix = nullptr, name_map * name_subst = nullptr, match_plugin const * plugin = nullptr); void open_match(lua_State * L); } diff --git a/tests/lua/hop2.lua b/tests/lua/hop2.lua new file mode 100644 index 000000000..0277a358d --- /dev/null +++ b/tests/lua/hop2.lua @@ -0,0 +1,29 @@ +function tst_match(p, t) + local r1, r2 = match(p, t) + assert(r1) + print("--------------") + for i = 1, #r1 do + print(" expr:#" .. i .. " := " .. tostring(r1[i])) + end + for i = 1, #r2 do + print(" lvl:#" .. i .. " := " .. tostring(r2[i])) + end +end + +local env = environment() +local N = Const("N") +local a = Const("a") +local b = Const("b") +local x = Local("x", N) +local y = Local("y", N) +local u1 = mk_global_univ("u1") +local u2 = mk_global_univ("u2") +local z = level() +local f = Const("f", {u1, z}) +local f2 = Const("f", {u1, u1+1}) +local fp = Const("f", {mk_idx_meta_univ(0), mk_idx_meta_univ(1)}) +tst_match(fp(Var(0), Var(0)), f(a, a)) +tst_match(fp(Var(0), Var(1)), f2(a, b)) +tst_match(Var(0)(x, y), f(x, f(x, y))) +assert(not match(Var(0)(x, x), f(x, f(x, y)))) +assert(not match(Var(0)(x), f(x, y)))