diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index 913ae236d..e25f5ddcb 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -196,6 +196,8 @@ public: }; }; +typedef std::shared_ptr type_checker_ref; + /** \brief Type check the given declaration, and return a certified declaration if it is type correct. Throw an exception if the declaration is type incorrect. diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 36202e9ea..8186a92fa 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -30,6 +30,7 @@ Author: Leonardo de Moura #include "library/kernel_bindings.h" #include "library/normalize.h" #include "library/module.h" +#include "library/opaque_hints.h" // Lua Bindings for the Kernel classes. We do not include the Lua // bindings in the kernel because we do not want to inflate the Kernel. @@ -1763,7 +1764,6 @@ static void open_substitution(lua_State * L) { } // type_checker -typedef std::shared_ptr type_checker_ref; DECL_UDATA(type_checker_ref) static void get_type_checker_args(lua_State * L, int idx, optional & mod_idx, bool & memoize, name_set & extra_opaque) { @@ -1826,6 +1826,22 @@ static int type_checker_keep(lua_State * L) { static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); } static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); } +static name g_tmp_prefix = name::mk_internal_unique_name(); + +static int mk_type_checker_with_hints(lua_State * L) { + environment const & env = to_environment(L, 1); + int nargs = lua_gettop(L); + if (nargs == 1) { + return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), false)); + } else if (nargs == 2 && lua_isboolean(L, 2)) { + return push_type_checker_ref(L, mk_type_checker_with_hints(env, name_generator(g_tmp_prefix), lua_toboolean(L, 2))); + } else if (nargs == 2) { + return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), false)); + } else { + return push_type_checker_ref(L, mk_type_checker_with_hints(env, to_name_generator(L, 2), lua_toboolean(L, 3))); + } +} + static const struct luaL_Reg type_checker_ref_m[] = { {"__gc", type_checker_ref_gc}, {"whnf", safe_function}, @@ -1878,6 +1894,7 @@ static void open_type_checker(lua_State * L) { setfuncs(L, type_checker_ref_m, 0); SET_GLOBAL_FUN(mk_type_checker, "type_checker"); + SET_GLOBAL_FUN(mk_type_checker_with_hints, "type_checker_with_hints"); SET_GLOBAL_FUN(type_checker_ref_pred, "is_type_checker"); SET_GLOBAL_FUN(type_check, "type_check"); SET_GLOBAL_FUN(type_check, "check"); diff --git a/src/library/kernel_bindings.h b/src/library/kernel_bindings.h index 39a7ed487..5880221bb 100644 --- a/src/library/kernel_bindings.h +++ b/src/library/kernel_bindings.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include "util/script_state.h" #include "kernel/environment.h" +#include "kernel/type_checker.h" namespace lean { void open_kernel_module(lua_State * L); @@ -21,6 +22,8 @@ UDATA_DEFS(justification) UDATA_DEFS(constraint) UDATA_DEFS(substitution) UDATA_DEFS(io_state) +UDATA_DEFS_CORE(type_checker_ref) + int push_optional_level(lua_State * L, optional const & e); int push_optional_expr(lua_State * L, optional const & e); int push_optional_justification(lua_State * L, optional const & j); diff --git a/src/library/match.cpp b/src/library/match.cpp index 9f92bfd60..9f623237f 100644 --- a/src/library/match.cpp +++ b/src/library/match.cpp @@ -7,6 +7,7 @@ Author: Leonardo de Moura #include "kernel/abstract.h" #include "kernel/instantiate.h" #include "kernel/for_each_fn.h" +#include "kernel/type_checker.h" #include "library/kernel_bindings.h" #include "library/locals.h" #include "library/match.h" @@ -190,7 +191,10 @@ class match_fn : public match_context { case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: - return const_name(p) == const_name(t) && match_levels(const_levels(p), const_levels(t)); + if (const_name(p) == const_name(t)) + return match_levels(const_levels(p), const_levels(t)); + else + return try_plugin(p, t); case expr_kind::Sort: return match_level(sort_level(p), sort_level(t)); case expr_kind::Lambda: case expr_kind::Pi: @@ -255,6 +259,19 @@ bool match(expr const & p, expr const & t, buffer> & esubst, buff return match_fn(esubst, lsubst, name_generator(g_tmp_prefix), name_subst, plugin).match(p, t); } +match_plugin mk_whnf_match_plugin(std::shared_ptr const & tc) { + return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT + try { + buffer cs; + expr p1 = tc->whnf(p, cs); + expr t1 = tc->whnf(t, cs); + return cs.empty() && (p1 != p || t1 != t) && ctx.match(p1, t1); + } catch (exception&) { + return false; + } + }; +} + static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) { for_each(l, [&](level const & l) { if (!has_meta(l)) return false; @@ -284,9 +301,24 @@ static unsigned get_idx_meta_univ_range(expr const & e) { return r; } +DECL_UDATA(match_plugin) + +static const struct luaL_Reg match_plugin_m[] = { + {"__gc", match_plugin_gc}, + {0, 0} +}; + +static int mk_whnf_match_plugin(lua_State * L) { + return push_match_plugin(L, mk_whnf_match_plugin(to_type_checker_ref(L, 1))); +} + static int match(lua_State * L) { + int nargs = lua_gettop(L); expr p = to_expr(L, 1); expr t = to_expr(L, 2); + match_plugin * plugin = nullptr; + if (nargs >= 3) + plugin = &to_match_plugin(L, 3); if (!closed(t)) throw exception("higher-order pattern matching failure, input term must not contain free variables"); unsigned r1 = get_free_var_range(p); @@ -294,7 +326,7 @@ static int match(lua_State * L) { buffer> esubst; buffer> lsubst; esubst.resize(r1); lsubst.resize(r2); - if (match(p, t, esubst, lsubst, nullptr, nullptr, nullptr)) { + if (match(p, t, esubst, lsubst, nullptr, nullptr, plugin)) { lua_newtable(L); int i = 1; for (auto s : esubst) { @@ -327,7 +359,14 @@ static int mk_idx_meta_univ(lua_State * L) { } void open_match(lua_State * L) { - SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); - SET_GLOBAL_FUN(match, "match"); + luaL_newmetatable(L, match_plugin_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, match_plugin_m, 0); + + SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin"); + SET_GLOBAL_FUN(match_plugin_pred, "is_match_plugin"); + SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); + SET_GLOBAL_FUN(match, "match"); } } diff --git a/src/library/match.h b/src/library/match.h index 72a75fd5a..1b213c225 100644 --- a/src/library/match.h +++ b/src/library/match.h @@ -48,6 +48,9 @@ public: */ typedef std::function match_plugin; // NOLINT +/** \brief Create a match_plugin that puts terms in weak-head-normal-form before failing */ +match_plugin mk_whnf_match_plugin(std::shared_ptr const & tc); + /** \brief Matching for higher-order patterns. Return true iff \c t matches the higher-order pattern \c p. The substitution is stored in \c subst. Note that, this procedure treats free-variables as placholders diff --git a/tests/lean/run/match1.lean b/tests/lean/run/match1.lean new file mode 100644 index 000000000..bc1f66753 --- /dev/null +++ b/tests/lean/run/match1.lean @@ -0,0 +1,28 @@ +import data.nat +using nat + +definition two1 : nat := 2 +definition two2 : nat := succ (succ (zero)) +variable f : nat → nat → nat + +(* +local tc = type_checker_with_hints(get_env(), true) +local plugin = whnf_match_plugin(tc) +function tst_match(p, t) + local r1, r2 = match(p, t, plugin) + assert(r1) + print("--------------") + for i = 1, #r1 do + print(" expr:#" .. i .. " := " .. tostring(r1[i])) + end + for i = 1, #r2 do + print(" lvl:#" .. i .. " := " .. tostring(r2[i])) + end +end + +local f = Const("f") +local two1 = Const("two1") +local two2 = Const("two2") +local succ = Const({"nat", "succ"}) +tst_match(f(succ(mk_var(0)), two1), f(two2, two2)) +*)