feat(library/match): add basic match_plugin that just invokes whnf before failing
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
a4b023a175
commit
e6ffda0c51
6 changed files with 97 additions and 5 deletions
|
@ -196,6 +196,8 @@ public:
|
|||
};
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<type_checker> type_checker_ref;
|
||||
|
||||
/**
|
||||
\brief Type check the given declaration, and return a certified declaration if it is type correct.
|
||||
Throw an exception if the declaration is type incorrect.
|
||||
|
|
|
@ -30,6 +30,7 @@ Author: Leonardo de Moura
|
|||
#include "library/kernel_bindings.h"
|
||||
#include "library/normalize.h"
|
||||
#include "library/module.h"
|
||||
#include "library/opaque_hints.h"
|
||||
|
||||
// Lua Bindings for the Kernel classes. We do not include the Lua
|
||||
// bindings in the kernel because we do not want to inflate the Kernel.
|
||||
|
@ -1763,7 +1764,6 @@ static void open_substitution(lua_State * L) {
|
|||
}
|
||||
|
||||
// type_checker
|
||||
typedef std::shared_ptr<type_checker> type_checker_ref;
|
||||
DECL_UDATA(type_checker_ref)
|
||||
|
||||
static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> & mod_idx, bool & memoize, name_set & extra_opaque) {
|
||||
|
@ -1826,6 +1826,22 @@ static int type_checker_keep(lua_State * L) {
|
|||
static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); }
|
||||
static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); }
|
||||
|
||||
static name g_tmp_prefix = name::mk_internal_unique_name();
|
||||
|
||||
static int mk_type_checker_with_hints(lua_State * L) {
|
||||
environment const & env = to_environment(L, 1);
|
||||
int nargs = lua_gettop(L);
|
||||
if (nargs == 1) {
|
||||
return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), false));
|
||||
} else if (nargs == 2 && lua_isboolean(L, 2)) {
|
||||
return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), lua_toboolean(L, 2)));
|
||||
} else if (nargs == 2) {
|
||||
return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), false));
|
||||
} else {
|
||||
return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), lua_toboolean(L, 3)));
|
||||
}
|
||||
}
|
||||
|
||||
static const struct luaL_Reg type_checker_ref_m[] = {
|
||||
{"__gc", type_checker_ref_gc},
|
||||
{"whnf", safe_function<type_checker_whnf>},
|
||||
|
@ -1878,6 +1894,7 @@ static void open_type_checker(lua_State * L) {
|
|||
setfuncs(L, type_checker_ref_m, 0);
|
||||
|
||||
SET_GLOBAL_FUN(mk_type_checker, "type_checker");
|
||||
SET_GLOBAL_FUN(mk_type_checker_with_hints, "type_checker_with_hints");
|
||||
SET_GLOBAL_FUN(type_checker_ref_pred, "is_type_checker");
|
||||
SET_GLOBAL_FUN(type_check, "type_check");
|
||||
SET_GLOBAL_FUN(type_check, "check");
|
||||
|
|
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include "util/script_state.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "kernel/type_checker.h"
|
||||
|
||||
namespace lean {
|
||||
void open_kernel_module(lua_State * L);
|
||||
|
@ -21,6 +22,8 @@ UDATA_DEFS(justification)
|
|||
UDATA_DEFS(constraint)
|
||||
UDATA_DEFS(substitution)
|
||||
UDATA_DEFS(io_state)
|
||||
UDATA_DEFS_CORE(type_checker_ref)
|
||||
|
||||
int push_optional_level(lua_State * L, optional<level> const & e);
|
||||
int push_optional_expr(lua_State * L, optional<expr> const & e);
|
||||
int push_optional_justification(lua_State * L, optional<justification> const & j);
|
||||
|
|
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/abstract.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "library/kernel_bindings.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/match.h"
|
||||
|
@ -190,7 +191,10 @@ class match_fn : public match_context {
|
|||
case expr_kind::Var:
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
case expr_kind::Constant:
|
||||
return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t));
|
||||
if (const_name(p) == const_name(t))
|
||||
return match_levels(const_levels(p), const_levels(t));
|
||||
else
|
||||
return try_plugin(p, t);
|
||||
case expr_kind::Sort:
|
||||
return match_level(sort_level(p), sort_level(t));
|
||||
case expr_kind::Lambda: case expr_kind::Pi:
|
||||
|
@ -255,6 +259,19 @@ bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buff
|
|||
return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
|
||||
}
|
||||
|
||||
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> const & tc) {
|
||||
return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT
|
||||
try {
|
||||
buffer<constraint> cs;
|
||||
expr p1 = tc->whnf(p, cs);
|
||||
expr t1 = tc->whnf(t, cs);
|
||||
return cs.empty() && (p1 != p || t1 != t) && ctx.match(p1, t1);
|
||||
} catch (exception&) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
|
||||
for_each(l, [&](level const & l) {
|
||||
if (!has_meta(l)) return false;
|
||||
|
@ -284,9 +301,24 @@ static unsigned get_idx_meta_univ_range(expr const & e) {
|
|||
return r;
|
||||
}
|
||||
|
||||
DECL_UDATA(match_plugin)
|
||||
|
||||
static const struct luaL_Reg match_plugin_m[] = {
|
||||
{"__gc", match_plugin_gc},
|
||||
{0, 0}
|
||||
};
|
||||
|
||||
static int mk_whnf_match_plugin(lua_State * L) {
|
||||
return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1)));
|
||||
}
|
||||
|
||||
static int match(lua_State * L) {
|
||||
int nargs = lua_gettop(L);
|
||||
expr p = to_expr(L, 1);
|
||||
expr t = to_expr(L, 2);
|
||||
match_plugin * plugin = nullptr;
|
||||
if (nargs >= 3)
|
||||
plugin = &to_match_plugin(L, 3);
|
||||
if (!closed(t))
|
||||
throw exception("higher-order pattern matching failure, input term must not contain free variables");
|
||||
unsigned r1 = get_free_var_range(p);
|
||||
|
@ -294,7 +326,7 @@ static int match(lua_State * L) {
|
|||
buffer<optional<expr>> esubst;
|
||||
buffer<optional<level>> lsubst;
|
||||
esubst.resize(r1); lsubst.resize(r2);
|
||||
if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) {
|
||||
if (match(p, t, esubst, lsubst, nullptr, nullptr, plugin)) {
|
||||
lua_newtable(L);
|
||||
int i = 1;
|
||||
for (auto s : esubst) {
|
||||
|
@ -327,6 +359,13 @@ static int mk_idx_meta_univ(lua_State * L) {
|
|||
}
|
||||
|
||||
void open_match(lua_State * L) {
|
||||
luaL_newmetatable(L, match_plugin_mt);
|
||||
lua_pushvalue(L, -1);
|
||||
lua_setfield(L, -2, "__index");
|
||||
setfuncs(L, match_plugin_m, 0);
|
||||
|
||||
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
|
||||
SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin");
|
||||
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
|
||||
SET_GLOBAL_FUN(match, "match");
|
||||
}
|
||||
|
|
|
@ -48,6 +48,9 @@ public:
|
|||
*/
|
||||
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
|
||||
|
||||
/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */
|
||||
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> const & tc);
|
||||
|
||||
/**
|
||||
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
|
||||
The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders
|
||||
|
|
28
tests/lean/run/match1.lean
Normal file
28
tests/lean/run/match1.lean
Normal file
|
@ -0,0 +1,28 @@
|
|||
import data.nat
|
||||
using nat
|
||||
|
||||
definition two1 : nat := 2
|
||||
definition two2 : nat := succ (succ (zero))
|
||||
variable f : nat → nat → nat
|
||||
|
||||
(*
|
||||
local tc = type_checker_with_hints(get_env(), true)
|
||||
local plugin = whnf_match_plugin(tc)
|
||||
function tst_match(p, t)
|
||||
local r1, r2 = match(p, t, plugin)
|
||||
assert(r1)
|
||||
print("--------------")
|
||||
for i = 1, #r1 do
|
||||
print(" expr:#" .. i .. " := " .. tostring(r1[i]))
|
||||
end
|
||||
for i = 1, #r2 do
|
||||
print(" lvl:#" .. i .. " := " .. tostring(r2[i]))
|
||||
end
|
||||
end
|
||||
|
||||
local f = Const("f")
|
||||
local two1 = Const("two1")
|
||||
local two2 = Const("two2")
|
||||
local succ = Const({"nat", "succ"})
|
||||
tst_match(f(succ(mk_var(0)), two1), f(two2, two2))
|
||||
*)
|
Loading…
Reference in a new issue