feat(library/match): extend match_plugin interface

This commit is contained in:
Leonardo de Moura 2015-02-03 18:10:38 -08:00
parent 44e575c895
commit 0e06f4aedc
3 changed files with 71 additions and 43 deletions

View file

@ -54,12 +54,12 @@ struct app_builder::imp {
typedef scoped_map<cache_key, expr, cache_key_hash_fn, cache_key_equal_fn> cache;
type_checker & m_tc;
match_plugin m_plugin;
whnf_match_plugin m_plugin;
name_map<decl_info> m_decl_info;
cache m_cache;
buffer<levels> m_levels;
imp(type_checker & tc):m_tc(tc), m_plugin(mk_whnf_match_plugin(tc)) {
imp(type_checker & tc):m_tc(tc), m_plugin(tc) {
m_levels.push_back(levels());
}

View file

@ -159,7 +159,7 @@ class match_fn : public match_context {
bool try_plugin(expr const & p, expr const & t) {
if (!m_plugin)
return false;
return (*m_plugin)(p, t, *this);
return m_plugin->on_failure(p, t, *this);
}
bool match_binding_core(expr p, expr t) {
@ -314,6 +314,15 @@ class match_fn : public match_context {
bool match_core(expr const & p, expr const & t) {
if (p.kind() != t.kind())
return try_plugin(p, t);
if (m_plugin) {
switch (m_plugin->pre(p, t, *this)) {
case l_true: return true;
case l_false: return false;
case l_undef: break;
}
}
switch (p.kind()) {
case expr_kind::Local: case expr_kind::Meta:
return mlocal_name(p) == mlocal_name(t);
@ -399,30 +408,15 @@ bool match(expr const & p, expr const & t, buffer<optional<level>> & lsubst, buf
prefix, name_subst, plugin, assigned);
}
match_plugin mk_whnf_match_plugin(type_checker & tc) {
return [&](expr const & p, expr const & t, match_context & ctx) { // NOLINT
try {
constraint_seq cs;
expr p1 = tc.whnf(p, cs);
expr t1 = tc.whnf(t, cs);
return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1);
} catch (exception&) {
return false;
}
};
}
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> tc) {
return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT
try {
constraint_seq cs;
expr p1 = tc->whnf(p, cs);
expr t1 = tc->whnf(t, cs);
return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1);
} catch (exception&) {
return false;
}
};
bool whnf_match_plugin::on_failure(expr const & p, expr const & t, match_context & ctx) const {
try {
constraint_seq cs;
expr p1 = m_tc.whnf(p, cs);
expr t1 = m_tc.whnf(t, cs);
return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1);
} catch (exception&) {
return false;
}
}
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
@ -457,15 +451,24 @@ static pair<unsigned, unsigned> get_idx_meta_univ_ranges(expr const & e) {
return mk_pair(rlvl, rexp);
}
DECL_UDATA(match_plugin)
typedef std::shared_ptr<match_plugin> match_plugin_ref;
DECL_UDATA(match_plugin_ref)
static const struct luaL_Reg match_plugin_m[] = {
{"__gc", match_plugin_gc},
static const struct luaL_Reg match_plugin_ref_m[] = {
{"__gc", match_plugin_ref_gc},
{0, 0}
};
// version of whnf_match_plugin for Lua
class whnf_match_plugin2 : public whnf_match_plugin {
std::shared_ptr<type_checker> m_tc_ref;
public:
whnf_match_plugin2(std::shared_ptr<type_checker> & tc):
whnf_match_plugin(*tc), m_tc_ref(tc) {}
};
static int mk_whnf_match_plugin(lua_State * L) {
return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1)));
return push_match_plugin_ref(L, match_plugin_ref(new whnf_match_plugin2(to_type_checker_ref(L, 1))));
}
static int match(lua_State * L) {
@ -474,7 +477,7 @@ static int match(lua_State * L) {
expr t = to_expr(L, 2);
match_plugin * plugin = nullptr;
if (nargs >= 3)
plugin = &to_match_plugin(L, 3);
plugin = to_match_plugin_ref(L, 3).get();
if (!closed(t))
throw exception("higher-order pattern matching failure, input term must not contain free variables");
unsigned r1, r2;
@ -521,15 +524,15 @@ static int mk_idx_meta(lua_State * L) {
}
void open_match(lua_State * L) {
luaL_newmetatable(L, match_plugin_mt);
luaL_newmetatable(L, match_plugin_ref_mt);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, match_plugin_m, 0);
setfuncs(L, match_plugin_ref_m, 0);
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin");
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta");
SET_GLOBAL_FUN(match, "match");
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
SET_GLOBAL_FUN(match_plugin_ref_pred, "is_match_plugin");
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta");
SET_GLOBAL_FUN(match, "match");
}
}

View file

@ -8,6 +8,7 @@ Author: Leonardo de Moura
#include <functional>
#include "util/lua.h"
#include "util/optional.h"
#include "util/lbool.h"
#include "util/buffer.h"
#include "util/name_map.h"
#include "kernel/expr.h"
@ -53,11 +54,35 @@ public:
plugin(p, t, s) must return true iff for updated substitution s', s'(p) is definitionally equal to t.
*/
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
class match_plugin {
public:
virtual ~match_plugin() {}
/** \brief The following method is invoked before the matcher tries to process
\c p and \c t. The method is only invoked when \c p and \c t have the same kind,
\c p is not a special metavariable created using \c mk_idx_meta, and
\c p and \c t are not structurally identical.
/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> tc);
match_plugin mk_whnf_match_plugin(type_checker & tc);
The result should be:
- l_false : did not match
- l_true : matched
- l_undef : did not handled (i.e., default matcher should be used)
*/
virtual lbool pre(expr const & /*p*/, expr const & /*t*/, match_context & /*ctx*/) const { return l_undef; }
/** \brief The following method is invoked when matcher doesn't have anything else to do.
This method is usually used to invoke expensive procedures such as the normalizer.
It should return true it the plugin did manage to match p and t.
*/
virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const = 0;
};
/** \brief Simple plugin that just puts terms in whnf and tries again */
class whnf_match_plugin : public match_plugin {
type_checker & m_tc;
public:
whnf_match_plugin(type_checker & tc):m_tc(tc) {}
virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const;
};
/**
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.