From 0e06f4aedc8dece73d60c25d268ab487b904db20 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 3 Feb 2015 18:10:38 -0800 Subject: [PATCH] feat(library/match): extend match_plugin interface --- src/library/app_builder.cpp | 4 +- src/library/match.cpp | 77 +++++++++++++++++++------------------ src/library/match.h | 33 ++++++++++++++-- 3 files changed, 71 insertions(+), 43 deletions(-) diff --git a/src/library/app_builder.cpp b/src/library/app_builder.cpp index 27a2a4aad..a34e52194 100644 --- a/src/library/app_builder.cpp +++ b/src/library/app_builder.cpp @@ -54,12 +54,12 @@ struct app_builder::imp { typedef scoped_map cache; type_checker & m_tc; - match_plugin m_plugin; + whnf_match_plugin m_plugin; name_map m_decl_info; cache m_cache; buffer m_levels; - imp(type_checker & tc):m_tc(tc), m_plugin(mk_whnf_match_plugin(tc)) { + imp(type_checker & tc):m_tc(tc), m_plugin(tc) { m_levels.push_back(levels()); } diff --git a/src/library/match.cpp b/src/library/match.cpp index c9cd4e37e..02437bcc8 100644 --- a/src/library/match.cpp +++ b/src/library/match.cpp @@ -159,7 +159,7 @@ class match_fn : public match_context { bool try_plugin(expr const & p, expr const & t) { if (!m_plugin) return false; - return (*m_plugin)(p, t, *this); + return m_plugin->on_failure(p, t, *this); } bool match_binding_core(expr p, expr t) { @@ -314,6 +314,15 @@ class match_fn : public match_context { bool match_core(expr const & p, expr const & t) { if (p.kind() != t.kind()) return try_plugin(p, t); + + if (m_plugin) { + switch (m_plugin->pre(p, t, *this)) { + case l_true: return true; + case l_false: return false; + case l_undef: break; + } + } + switch (p.kind()) { case expr_kind::Local: case expr_kind::Meta: return mlocal_name(p) == mlocal_name(t); @@ -399,30 +408,15 @@ bool match(expr const & p, expr const & t, buffer> & lsubst, buf prefix, name_subst, plugin, assigned); } -match_plugin mk_whnf_match_plugin(type_checker & tc) { - return [&](expr const & p, expr const & t, match_context & ctx) { // NOLINT - try { - constraint_seq cs; - expr p1 = tc.whnf(p, cs); - expr t1 = tc.whnf(t, cs); - return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1); - } catch (exception&) { - return false; - } - }; -} - -match_plugin mk_whnf_match_plugin(std::shared_ptr tc) { - return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT - try { - constraint_seq cs; - expr p1 = tc->whnf(p, cs); - expr t1 = tc->whnf(t, cs); - return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1); - } catch (exception&) { - return false; - } - }; +bool whnf_match_plugin::on_failure(expr const & p, expr const & t, match_context & ctx) const { + try { + constraint_seq cs; + expr p1 = m_tc.whnf(p, cs); + expr t1 = m_tc.whnf(t, cs); + return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1); + } catch (exception&) { + return false; + } } static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) { @@ -457,15 +451,24 @@ static pair get_idx_meta_univ_ranges(expr const & e) { return mk_pair(rlvl, rexp); } -DECL_UDATA(match_plugin) +typedef std::shared_ptr match_plugin_ref; +DECL_UDATA(match_plugin_ref) -static const struct luaL_Reg match_plugin_m[] = { - {"__gc", match_plugin_gc}, +static const struct luaL_Reg match_plugin_ref_m[] = { + {"__gc", match_plugin_ref_gc}, {0, 0} }; +// version of whnf_match_plugin for Lua +class whnf_match_plugin2 : public whnf_match_plugin { + std::shared_ptr m_tc_ref; +public: + whnf_match_plugin2(std::shared_ptr & tc): + whnf_match_plugin(*tc), m_tc_ref(tc) {} +}; + static int mk_whnf_match_plugin(lua_State * L) { - return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1))); + return push_match_plugin_ref(L, match_plugin_ref(new whnf_match_plugin2(to_type_checker_ref(L, 1)))); } static int match(lua_State * L) { @@ -474,7 +477,7 @@ static int match(lua_State * L) { expr t = to_expr(L, 2); match_plugin * plugin = nullptr; if (nargs >= 3) - plugin = &to_match_plugin(L, 3); + plugin = to_match_plugin_ref(L, 3).get(); if (!closed(t)) throw exception("higher-order pattern matching failure, input term must not contain free variables"); unsigned r1, r2; @@ -521,15 +524,15 @@ static int mk_idx_meta(lua_State * L) { } void open_match(lua_State * L) { - luaL_newmetatable(L, match_plugin_mt); + luaL_newmetatable(L, match_plugin_ref_mt); lua_pushvalue(L, -1); lua_setfield(L, -2, "__index"); - setfuncs(L, match_plugin_m, 0); + setfuncs(L, match_plugin_ref_m, 0); - SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin"); - SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin"); - SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); - SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta"); - SET_GLOBAL_FUN(match, "match"); + SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin"); + SET_GLOBAL_FUN(match_plugin_ref_pred, "is_match_plugin"); + SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); + SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta"); + SET_GLOBAL_FUN(match, "match"); } } diff --git a/src/library/match.h b/src/library/match.h index 9c4411869..b17d68072 100644 --- a/src/library/match.h +++ b/src/library/match.h @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include #include "util/lua.h" #include "util/optional.h" +#include "util/lbool.h" #include "util/buffer.h" #include "util/name_map.h" #include "kernel/expr.h" @@ -53,11 +54,35 @@ public: plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t. */ -typedef std::function match_plugin; // NOLINT +class match_plugin { +public: + virtual ~match_plugin() {} + /** \brief The following method is invoked before the matcher tries to process + \c p and \c t. The method is only invoked when \c p and \c t have the same kind, + \c p is not a special metavariable created using \c mk_idx_meta, and + \c p and \c t are not structurally identical. -/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */ -match_plugin mk_whnf_match_plugin(std::shared_ptr tc); -match_plugin mk_whnf_match_plugin(type_checker & tc); + The result should be: + - l_false : did not match + - l_true : matched + - l_undef : did not handled (i.e., default matcher should be used) + */ + virtual lbool pre(expr const & /*p*/, expr const & /*t*/, match_context & /*ctx*/) const { return l_undef; } + + /** \brief The following method is invoked when matcher doesn't have anything else to do. + This method is usually used to invoke expensive procedures such as the normalizer. + It should return true it the plugin did manage to match p and t. + */ + virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const = 0; +}; + +/** \brief Simple plugin that just puts terms in whnf and tries again */ +class whnf_match_plugin : public match_plugin { + type_checker & m_tc; +public: + whnf_match_plugin(type_checker & tc):m_tc(tc) {} + virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const; +}; /** \brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.