feat(library/match): add basic match_plugin that just invokes whnf before failing

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-08-05 08:37:03 -07:00
parent a4b023a175
commit e6ffda0c51
6 changed files with 97 additions and 5 deletions

View file

@ -196,6 +196,8 @@ public:
}; };
}; };
typedef std::shared_ptr<type_checker> type_checker_ref;
/** /**
\brief Type check the given declaration, and return a certified declaration if it is type correct. \brief Type check the given declaration, and return a certified declaration if it is type correct.
Throw an exception if the declaration is type incorrect. Throw an exception if the declaration is type incorrect.

View file

@ -30,6 +30,7 @@ Author: Leonardo de Moura
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
#include "library/normalize.h" #include "library/normalize.h"
#include "library/module.h" #include "library/module.h"
#include "library/opaque_hints.h"
// Lua Bindings for the Kernel classes. We do not include the Lua // Lua Bindings for the Kernel classes. We do not include the Lua
// bindings in the kernel because we do not want to inflate the Kernel. // bindings in the kernel because we do not want to inflate the Kernel.
@ -1763,7 +1764,6 @@ static void open_substitution(lua_State * L) {
} }
// type_checker // type_checker
typedef std::shared_ptr<type_checker> type_checker_ref;
DECL_UDATA(type_checker_ref) DECL_UDATA(type_checker_ref)
static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> & mod_idx, bool & memoize, name_set & extra_opaque) { static void get_type_checker_args(lua_State * L, int idx, optional<module_idx> & mod_idx, bool & memoize, name_set & extra_opaque) {
@ -1826,6 +1826,22 @@ static int type_checker_keep(lua_State * L) {
static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); } static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); }
static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); } static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); }
static name g_tmp_prefix = name::mk_internal_unique_name();
static int mk_type_checker_with_hints(lua_State * L) {
environment const & env = to_environment(L, 1);
int nargs = lua_gettop(L);
if (nargs == 1) {
return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), false));
} else if (nargs == 2 && lua_isboolean(L, 2)) {
return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), lua_toboolean(L, 2)));
} else if (nargs == 2) {
return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), false));
} else {
return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), lua_toboolean(L, 3)));
}
}
static const struct luaL_Reg type_checker_ref_m[] = { static const struct luaL_Reg type_checker_ref_m[] = {
{"__gc", type_checker_ref_gc}, {"__gc", type_checker_ref_gc},
{"whnf", safe_function<type_checker_whnf>}, {"whnf", safe_function<type_checker_whnf>},
@ -1878,6 +1894,7 @@ static void open_type_checker(lua_State * L) {
setfuncs(L, type_checker_ref_m, 0); setfuncs(L, type_checker_ref_m, 0);
SET_GLOBAL_FUN(mk_type_checker, "type_checker"); SET_GLOBAL_FUN(mk_type_checker, "type_checker");
SET_GLOBAL_FUN(mk_type_checker_with_hints, "type_checker_with_hints");
SET_GLOBAL_FUN(type_checker_ref_pred, "is_type_checker"); SET_GLOBAL_FUN(type_checker_ref_pred, "is_type_checker");
SET_GLOBAL_FUN(type_check, "type_check"); SET_GLOBAL_FUN(type_check, "type_check");
SET_GLOBAL_FUN(type_check, "check"); SET_GLOBAL_FUN(type_check, "check");

View file

@ -7,6 +7,7 @@ Author: Leonardo de Moura
#pragma once #pragma once
#include "util/script_state.h" #include "util/script_state.h"
#include "kernel/environment.h" #include "kernel/environment.h"
#include "kernel/type_checker.h"
namespace lean { namespace lean {
void open_kernel_module(lua_State * L); void open_kernel_module(lua_State * L);
@ -21,6 +22,8 @@ UDATA_DEFS(justification)
UDATA_DEFS(constraint) UDATA_DEFS(constraint)
UDATA_DEFS(substitution) UDATA_DEFS(substitution)
UDATA_DEFS(io_state) UDATA_DEFS(io_state)
UDATA_DEFS_CORE(type_checker_ref)
int push_optional_level(lua_State * L, optional<level> const & e); int push_optional_level(lua_State * L, optional<level> const & e);
int push_optional_expr(lua_State * L, optional<expr> const & e); int push_optional_expr(lua_State * L, optional<expr> const & e);
int push_optional_justification(lua_State * L, optional<justification> const & j); int push_optional_justification(lua_State * L, optional<justification> const & j);

View file

@ -7,6 +7,7 @@ Author: Leonardo de Moura
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
#include "kernel/for_each_fn.h" #include "kernel/for_each_fn.h"
#include "kernel/type_checker.h"
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
#include "library/locals.h" #include "library/locals.h"
#include "library/match.h" #include "library/match.h"
@ -190,7 +191,10 @@ class match_fn : public match_context {
case expr_kind::Var: case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Constant: case expr_kind::Constant:
return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t)); if (const_name(p) == const_name(t))
return match_levels(const_levels(p), const_levels(t));
else
return try_plugin(p, t);
case expr_kind::Sort: case expr_kind::Sort:
return match_level(sort_level(p), sort_level(t)); return match_level(sort_level(p), sort_level(t));
case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Pi:
@ -255,6 +259,19 @@ bool match(expr const & p, expr const & t, buffer<optional<expr>> & esubst, buff
return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t); return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t);
} }
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> const & tc) {
return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT
try {
buffer<constraint> cs;
expr p1 = tc->whnf(p, cs);
expr t1 = tc->whnf(t, cs);
return cs.empty() && (p1 != p || t1 != t) && ctx.match(p1, t1);
} catch (exception&) {
return false;
}
};
}
static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) { static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) {
for_each(l, [&](level const & l) { for_each(l, [&](level const & l) {
if (!has_meta(l)) return false; if (!has_meta(l)) return false;
@ -284,9 +301,24 @@ static unsigned get_idx_meta_univ_range(expr const & e) {
return r; return r;
} }
DECL_UDATA(match_plugin)
static const struct luaL_Reg match_plugin_m[] = {
{"__gc", match_plugin_gc},
{0, 0}
};
static int mk_whnf_match_plugin(lua_State * L) {
return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1)));
}
static int match(lua_State * L) { static int match(lua_State * L) {
int nargs = lua_gettop(L);
expr p = to_expr(L, 1); expr p = to_expr(L, 1);
expr t = to_expr(L, 2); expr t = to_expr(L, 2);
match_plugin * plugin = nullptr;
if (nargs >= 3)
plugin = &to_match_plugin(L, 3);
if (!closed(t)) if (!closed(t))
throw exception("higher-order pattern matching failure, input term must not contain free variables"); throw exception("higher-order pattern matching failure, input term must not contain free variables");
unsigned r1 = get_free_var_range(p); unsigned r1 = get_free_var_range(p);
@ -294,7 +326,7 @@ static int match(lua_State * L) {
buffer<optional<expr>> esubst; buffer<optional<expr>> esubst;
buffer<optional<level>> lsubst; buffer<optional<level>> lsubst;
esubst.resize(r1); lsubst.resize(r2); esubst.resize(r1); lsubst.resize(r2);
if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) { if (match(p, t, esubst, lsubst, nullptr, nullptr, plugin)) {
lua_newtable(L); lua_newtable(L);
int i = 1; int i = 1;
for (auto s : esubst) { for (auto s : esubst) {
@ -327,7 +359,14 @@ static int mk_idx_meta_univ(lua_State * L) {
} }
void open_match(lua_State * L) { void open_match(lua_State * L) {
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); luaL_newmetatable(L, match_plugin_mt);
SET_GLOBAL_FUN(match, "match"); lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
setfuncs(L, match_plugin_m, 0);
SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin");
SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin");
SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ");
SET_GLOBAL_FUN(match, "match");
} }
} }

View file

@ -48,6 +48,9 @@ public:
*/ */
typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT typedef std::function<bool(expr const &, expr const &, match_context &)> match_plugin; // NOLINT
/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */
match_plugin mk_whnf_match_plugin(std::shared_ptr<type_checker> const & tc);
/** /**
\brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p. \brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p.
The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders

View file

@ -0,0 +1,28 @@
import data.nat
using nat
definition two1 : nat := 2
definition two2 : nat := succ (succ (zero))
variable f : nat → nat → nat
(*
local tc = type_checker_with_hints(get_env(), true)
local plugin = whnf_match_plugin(tc)
function tst_match(p, t)
local r1, r2 = match(p, t, plugin)
assert(r1)
print("--------------")
for i = 1, #r1 do
print(" expr:#" .. i .. " := " .. tostring(r1[i]))
end
for i = 1, #r2 do
print(" lvl:#" .. i .. " := " .. tostring(r2[i]))
end
end
local f = Const("f")
local two1 = Const("two1")
local two2 = Const("two2")
local succ = Const({"nat", "succ"})
tst_match(f(succ(mk_var(0)), two1), f(two2, two2))
*)