feat(library/match): add basic match_plugin that just invokes whnf before failing
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
a4b023a175
commit
e6ffda0c51
6 changed files with 97 additions and 5 deletions
|
@ -196,6 +196,8 @@ public:
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef std::shared_ptr<type_checker> type_checker_ref;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
\brief Type check the given declaration, and return a certified declaration if it is type correct.
|
\brief Type check the given declaration, and return a certified declaration if it is type correct.
|
||||||
Throw an exception if the declaration is type incorrect.
|
Throw an exception if the declaration is type incorrect.
|
||||||
|
|
|
@ -30,6 +30,7 @@ Author: Leonardo de Moura
|
||||||
#include "library/kernel_bindings.h"
|
#include "library/kernel_bindings.h"
|
||||||
#include "library/normalize.h"
|
#include "library/normalize.h"
|
||||||
#include "library/module.h"
|
#include "library/module.h"
|
||||||
|
#include "library/opaque_hints.h"
|
||||||
|
|
||||||
// Lua Bindings for the Kernel classes. We do not include the Lua
|
// Lua Bindings for the Kernel classes. We do not include the Lua
|
||||||
// bindings in the kernel because we do not want to inflate the Kernel.
|
// bindings in the kernel because we do not want to inflate the Kernel.
|
||||||
|
@ -1763,7 +1764,6 @@ static void open_substitution(lua_State * L) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// type_checker
|
// type_checker
|
||||||
typedef std::shared_ptr<type_checker> type_checker_ref;
|
|
||||||
DECL_UDATA(type_checker_ref)
|
DECL_UDATA(type_checker_ref)
|
||||||
|
|
||||||
static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> & mod_idx, bool & memoize, name_set & extra_opaque) {
|
static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> & mod_idx, bool & memoize, name_set & extra_opaque) {
|
||||||
|
@ -1826,6 +1826,22 @@ static int type_checker_keep(lua_State * L) {
|
||||||
static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); }
|
static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); }
|
||||||
static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); }
|
static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); }
|
||||||
|
|
||||||
|
static name g_tmp_prefix = name::mk_internal_unique_name();
|
||||||
|
|
||||||
|
static int mk_type_checker_with_hints(lua_State * L) {
|
||||||
|
environment const & env = to_environment(L, 1);
|
||||||
|
int nargs = lua_gettop(L);
|
||||||
|
if (nargs == 1) {
|
||||||
|
return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), false));
|
||||||
|
} else if (nargs == 2 && lua_isboolean(L, 2)) {
|
||||||
|
return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), lua_toboolean(L, 2)));
|
||||||
|
} else if (nargs == 2) {
|
||||||
|
return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), false));
|
||||||
|
} else {
|
||||||
|
return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), lua_toboolean(L, 3)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static const struct luaL_Reg type_checker_ref_m[] = {
|
static const struct luaL_Reg type_checker_ref_m[] = {
|
||||||
{"__gc", type_checker_ref_gc},
|
{"__gc", type_checker_ref_gc},
|
||||||
{"whnf", safe_function<type_checker_whnf>},
|
{"whnf", safe_function<type_checker_whnf>},
|
||||||
|
@ -1878,6 +1894,7 @@ static void open_type_checker(lua_State * L) {
|
||||||
setfuncs(L, type_checker_ref_m, 0);
|
setfuncs(L, type_checker_ref_m, 0);
|
||||||
|
|
||||||
SET_GLOBAL_FUN(mk_type_checker, "type_checker");
|
SET_GLOBAL_FUN(mk_type_checker, "type_checker");
|
||||||
|
SET_GLOBAL_FUN(mk_type_checker_with_hints, "type_checker_with_hints");
|
||||||
SET_GLOBAL_FUN(type_checker_ref_pred, "is_type_checker");
|
SET_GLOBAL_FUN(type_checker_ref_pred, "is_type_checker");
|
||||||
SET_GLOBAL_FUN(type_check, "type_check");
|
SET_GLOBAL_FUN(type_check, "type_check");
|
||||||
SET_GLOBAL_FUN(type_check, "check");
|
SET_GLOBAL_FUN(type_check, "check");
|
||||||
|
|
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "util/script_state.h"
|
#include "util/script_state.h"
|
||||||
#include "kernel/environment.h"
|
#include "kernel/environment.h"
|
||||||
|
#include "kernel/type_checker.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
void open_kernel_module(lua_State * L);
|
void open_kernel_module(lua_State * L);
|
||||||
|
@ -21,6 +22,8 @@ UDATA_DEFS(justification)
|
||||||
UDATA_DEFS(constraint)
|
UDATA_DEFS(constraint)
|
||||||
UDATA_DEFS(substitution)
|
UDATA_DEFS(substitution)
|
||||||
UDATA_DEFS(io_state)
|
UDATA_DEFS(io_state)
|
||||||
|
UDATA_DEFS_CORE(type_checker_ref)
|
||||||
|
|
||||||
int push_optional_level(lua_State * L, optional<level> const & e);
|
int push_optional_level(lua_State * L, optional<level> const & e);
|
||||||
int push_optional_expr(lua_State * L, optional<expr> const & e);
|
int push_optional_expr(lua_State * L, optional<expr> const & e);
|
||||||
int push_optional_justification(lua_State * L, optional<justification> const & j);
|
int push_optional_justification(lua_State * L, optional<justification> const & j);
|
||||||
|
|
|
@ -7,6 +7,7 @@ Author: Leonardo de Moura
|
||||||
#include "kernel/abstract.h"
|
#include "kernel/abstract.h"
|
||||||
#include "kernel/instantiate.h"
|
#include "kernel/instantiate.h"
|
||||||
#include "kernel/for_each_fn.h"
|
#include "kernel/for_each_fn.h"
|
||||||
|
#include "kernel/type_checker.h"
|
||||||
#include "library/kernel_bindings.h"
|
#include "library/kernel_bindings.h"
|
||||||
#include "library/locals.h"
|
#include "library/locals.h"
|
||||||
#include "library/match.h"
|
#include "library/match.h"
|
||||||
|
@ -190,7 +191,10 @@ class match_fn : public match_context {
|
||||||
case expr_kind::Var:
|
case expr_kind::Var:
|
||||||
lean_unreachable(); // LCOV_EXCL_LINE
|
lean_unreachable(); // LCOV_EXCL_LINE
|
||||||
case expr_kind::Constant:
|
case expr_kind::Constant:
|
||||||
return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t));
|
if (const_name(p) == const_name(t))
|
||||||
|
return match_levels(const_levels(p), const_levels(t));
|
||||||
|
else
|
||||||
|
return try_plugin(p, t);
|
||||||
case expr_kind::Sort:
|
case expr_kind::Sort:
|
||||||
return match_level(sort_level(p), sort_level(t));
|
return match_level(sort_level(p), sort_level(t));
|
||||||
case expr_kind::Lambda: case expr_kind::Pi:
|
case expr_kind::Lambda: case expr_kind::Pi:
|
||||||
|
@ -255,6 +259,19 @@ bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buff
|
||||||
return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
|
return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> const & tc) {
|
||||||
|
return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT
|
||||||
|
try {
|
||||||
|
buffer<constraint> cs;
|
||||||
|
expr p1 = tc->whnf(p, cs);
|
||||||
|
expr t1 = tc->whnf(t, cs);
|
||||||
|
return cs.empty() && (p1 != p || t1 != t) && ctx.match(p1, t1);
|
||||||
|
} catch (exception&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
|
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
|
||||||
for_each(l, [&](level const & l) {
|
for_each(l, [&](level const & l) {
|
||||||
if (!has_meta(l)) return false;
|
if (!has_meta(l)) return false;
|
||||||
|
@ -284,9 +301,24 @@ static unsigned get_idx_meta_univ_range(expr const & e) {
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECL_UDATA(match_plugin)
|
||||||
|
|
||||||
|
static const struct luaL_Reg match_plugin_m[] = {
|
||||||
|
{"__gc", match_plugin_gc},
|
||||||
|
{0, 0}
|
||||||
|
};
|
||||||
|
|
||||||
|
static int mk_whnf_match_plugin(lua_State * L) {
|
||||||
|
return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1)));
|
||||||
|
}
|
||||||
|
|
||||||
static int match(lua_State * L) {
|
static int match(lua_State * L) {
|
||||||
|
int nargs = lua_gettop(L);
|
||||||
expr p = to_expr(L, 1);
|
expr p = to_expr(L, 1);
|
||||||
expr t = to_expr(L, 2);
|
expr t = to_expr(L, 2);
|
||||||
|
match_plugin * plugin = nullptr;
|
||||||
|
if (nargs >= 3)
|
||||||
|
plugin = &to_match_plugin(L, 3);
|
||||||
if (!closed(t))
|
if (!closed(t))
|
||||||
throw exception("higher-order pattern matching failure, input term must not contain free variables");
|
throw exception("higher-order pattern matching failure, input term must not contain free variables");
|
||||||
unsigned r1 = get_free_var_range(p);
|
unsigned r1 = get_free_var_range(p);
|
||||||
|
@ -294,7 +326,7 @@ static int match(lua_State * L) {
|
||||||
buffer<optional<expr>> esubst;
|
buffer<optional<expr>> esubst;
|
||||||
buffer<optional<level>> lsubst;
|
buffer<optional<level>> lsubst;
|
||||||
esubst.resize(r1); lsubst.resize(r2);
|
esubst.resize(r1); lsubst.resize(r2);
|
||||||
if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) {
|
if (match(p, t, esubst, lsubst, nullptr, nullptr, plugin)) {
|
||||||
lua_newtable(L);
|
lua_newtable(L);
|
||||||
int i = 1;
|
int i = 1;
|
||||||
for (auto s : esubst) {
|
for (auto s : esubst) {
|
||||||
|
@ -327,6 +359,13 @@ static int mk_idx_meta_univ(lua_State * L) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void open_match(lua_State * L) {
|
void open_match(lua_State * L) {
|
||||||
|
luaL_newmetatable(L, match_plugin_mt);
|
||||||
|
lua_pushvalue(L, -1);
|
||||||
|
lua_setfield(L, -2, "__index");
|
||||||
|
setfuncs(L, match_plugin_m, 0);
|
||||||
|
|
||||||
|
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
|
||||||
|
SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin");
|
||||||
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
|
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
|
||||||
SET_GLOBAL_FUN(match, "match");
|
SET_GLOBAL_FUN(match, "match");
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,6 +48,9 @@ public:
|
||||||
*/
|
*/
|
||||||
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
|
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
|
||||||
|
|
||||||
|
/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */
|
||||||
|
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> const & tc);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
|
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
|
||||||
The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders
|
The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders
|
||||||
|
|
28
tests/lean/run/match1.lean
Normal file
28
tests/lean/run/match1.lean
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import data.nat
|
||||||
|
using nat
|
||||||
|
|
||||||
|
definition two1 : nat := 2
|
||||||
|
definition two2 : nat := succ (succ (zero))
|
||||||
|
variable f : nat → nat → nat
|
||||||
|
|
||||||
|
(*
|
||||||
|
local tc = type_checker_with_hints(get_env(), true)
|
||||||
|
local plugin = whnf_match_plugin(tc)
|
||||||
|
function tst_match(p, t)
|
||||||
|
local r1, r2 = match(p, t, plugin)
|
||||||
|
assert(r1)
|
||||||
|
print("--------------")
|
||||||
|
for i = 1, #r1 do
|
||||||
|
print(" expr:#" .. i .. " := " .. tostring(r1[i]))
|
||||||
|
end
|
||||||
|
for i = 1, #r2 do
|
||||||
|
print(" lvl:#" .. i .. " := " .. tostring(r2[i]))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local f = Const("f")
|
||||||
|
local two1 = Const("two1")
|
||||||
|
local two2 = Const("two2")
|
||||||
|
local succ = Const({"nat", "succ"})
|
||||||
|
tst_match(f(succ(mk_var(0)), two1), f(two2, two2))
|
||||||
|
*)
|
Loading…
Reference in a new issue