feat(library/match): extend match_plugin interface
This commit is contained in:
parent
44e575c895
commit
0e06f4aedc
3 changed files with 71 additions and 43 deletions
|
@ -54,12 +54,12 @@ struct app_builder::imp {
|
|||
typedef scoped_map<cache_key, expr, cache_key_hash_fn, cache_key_equal_fn> cache;
|
||||
|
||||
type_checker & m_tc;
|
||||
match_plugin m_plugin;
|
||||
whnf_match_plugin m_plugin;
|
||||
name_map<decl_info> m_decl_info;
|
||||
cache m_cache;
|
||||
buffer<levels> m_levels;
|
||||
|
||||
imp(type_checker & tc):m_tc(tc), m_plugin(mk_whnf_match_plugin(tc)) {
|
||||
imp(type_checker & tc):m_tc(tc), m_plugin(tc) {
|
||||
m_levels.push_back(levels());
|
||||
}
|
||||
|
||||
|
|
|
@ -159,7 +159,7 @@ class match_fn : public match_context {
|
|||
bool try_plugin(expr const & p, expr const & t) {
|
||||
if (!m_plugin)
|
||||
return false;
|
||||
return (*m_plugin)(p, t, *this);
|
||||
return m_plugin->on_failure(p, t, *this);
|
||||
}
|
||||
|
||||
bool match_binding_core(expr p, expr t) {
|
||||
|
@ -314,6 +314,15 @@ class match_fn : public match_context {
|
|||
bool match_core(expr const & p, expr const & t) {
|
||||
if (p.kind() != t.kind())
|
||||
return try_plugin(p, t);
|
||||
|
||||
if (m_plugin) {
|
||||
switch (m_plugin->pre(p, t, *this)) {
|
||||
case l_true: return true;
|
||||
case l_false: return false;
|
||||
case l_undef: break;
|
||||
}
|
||||
}
|
||||
|
||||
switch (p.kind()) {
|
||||
case expr_kind::Local: case expr_kind::Meta:
|
||||
return mlocal_name(p) == mlocal_name(t);
|
||||
|
@ -399,30 +408,15 @@ bool match(expr const & p, expr const & t, buffer<optional<level>> & lsubst, buf
|
|||
prefix, name_subst, plugin, assigned);
|
||||
}
|
||||
|
||||
match_plugin mk_whnf_match_plugin(type_checker & tc) {
|
||||
return [&](expr const & p, expr const & t, match_context & ctx) { // NOLINT
|
||||
bool whnf_match_plugin::on_failure(expr const & p, expr const & t, match_context & ctx) const {
|
||||
try {
|
||||
constraint_seq cs;
|
||||
expr p1 = tc.whnf(p, cs);
|
||||
expr t1 = tc.whnf(t, cs);
|
||||
expr p1 = m_tc.whnf(p, cs);
|
||||
expr t1 = m_tc.whnf(t, cs);
|
||||
return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1);
|
||||
} catch (exception&) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> tc) {
|
||||
return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT
|
||||
try {
|
||||
constraint_seq cs;
|
||||
expr p1 = tc->whnf(p, cs);
|
||||
expr t1 = tc->whnf(t, cs);
|
||||
return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1);
|
||||
} catch (exception&) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
|
||||
|
@ -457,15 +451,24 @@ static pair<unsigned, unsigned> get_idx_meta_univ_ranges(expr const & e) {
|
|||
return mk_pair(rlvl, rexp);
|
||||
}
|
||||
|
||||
DECL_UDATA(match_plugin)
|
||||
typedef std::shared_ptr<match_plugin> match_plugin_ref;
|
||||
DECL_UDATA(match_plugin_ref)
|
||||
|
||||
static const struct luaL_Reg match_plugin_m[] = {
|
||||
{"__gc", match_plugin_gc},
|
||||
static const struct luaL_Reg match_plugin_ref_m[] = {
|
||||
{"__gc", match_plugin_ref_gc},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
// version of whnf_match_plugin for Lua
|
||||
class whnf_match_plugin2 : public whnf_match_plugin {
|
||||
std::shared_ptr<type_checker> m_tc_ref;
|
||||
public:
|
||||
whnf_match_plugin2(std::shared_ptr<type_checker> & tc):
|
||||
whnf_match_plugin(*tc), m_tc_ref(tc) {}
|
||||
};
|
||||
|
||||
static int mk_whnf_match_plugin(lua_State * L) {
|
||||
return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1)));
|
||||
return push_match_plugin_ref(L, match_plugin_ref(new whnf_match_plugin2(to_type_checker_ref(L, 1))));
|
||||
}
|
||||
|
||||
static int match(lua_State * L) {
|
||||
|
@ -474,7 +477,7 @@ static int match(lua_State * L) {
|
|||
expr t = to_expr(L, 2);
|
||||
match_plugin * plugin = nullptr;
|
||||
if (nargs >= 3)
|
||||
plugin = &to_match_plugin(L, 3);
|
||||
plugin = to_match_plugin_ref(L, 3).get();
|
||||
if (!closed(t))
|
||||
throw exception("higher-order pattern matching failure, input term must not contain free variables");
|
||||
unsigned r1, r2;
|
||||
|
@ -521,13 +524,13 @@ static int mk_idx_meta(lua_State * L) {
|
|||
}
|
||||
|
||||
void open_match(lua_State * L) {
|
||||
luaL_newmetatable(L, match_plugin_mt);
|
||||
luaL_newmetatable(L, match_plugin_ref_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, match_plugin_m, 0);
|
||||
setfuncs(L, match_plugin_ref_m, 0);
|
||||
|
||||
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
|
||||
SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin");
|
||||
SET_GLOBAL_FUN(match_plugin_ref_pred, "is_match_plugin");
|
||||
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
|
||||
SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta");
|
||||
SET_GLOBAL_FUN(match, "match");
|
||||
|
|
|
@ -8,6 +8,7 @@ Author: Leonardo de Moura
|
|||
#include <functional>
|
||||
#include "util/lua.h"
|
||||
#include "util/optional.h"
|
||||
#include "util/lbool.h"
|
||||
#include "util/buffer.h"
|
||||
#include "util/name_map.h"
|
||||
#include "kernel/expr.h"
|
||||
|
@ -53,11 +54,35 @@ public:
|
|||
|
||||
plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t.
|
||||
*/
|
||||
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
|
||||
class match_plugin {
|
||||
public:
|
||||
virtual ~match_plugin() {}
|
||||
/** \brief The following method is invoked before the matcher tries to process
|
||||
\c p and \c t. The method is only invoked when \c p and \c t have the same kind,
|
||||
\c p is not a special metavariable created using \c mk_idx_meta, and
|
||||
\c p and \c t are not structurally identical.
|
||||
|
||||
/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */
|
||||
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> tc);
|
||||
match_plugin mk_whnf_match_plugin(type_checker & tc);
|
||||
The result should be:
|
||||
- l_false : did not match
|
||||
- l_true : matched
|
||||
- l_undef : did not handled (i.e., default matcher should be used)
|
||||
*/
|
||||
virtual lbool pre(expr const & /*p*/, expr const & /*t*/, match_context & /*ctx*/) const { return l_undef; }
|
||||
|
||||
/** \brief The following method is invoked when matcher doesn't have anything else to do.
|
||||
This method is usually used to invoke expensive procedures such as the normalizer.
|
||||
It should return true it the plugin did manage to match p and t.
|
||||
*/
|
||||
virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const = 0;
|
||||
};
|
||||
|
||||
/** \brief Simple plugin that just puts terms in whnf and tries again */
|
||||
class whnf_match_plugin : public match_plugin {
|
||||
type_checker & m_tc;
|
||||
public:
|
||||
whnf_match_plugin(type_checker & tc):m_tc(tc) {}
|
||||
virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const;
|
||||
};
|
||||
|
||||
/**
|
||||
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
|
||||
|
|
Loading…
Reference in a new issue