From 37bee8c8524feb6e9bb10666625ca8e0b32c8ae5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 21 Jun 2014 12:22:24 -0700 Subject: [PATCH] refactor(kernel/type_checker): simplify replace constraint_handler with closure Signed-off-by: Leonardo de Moura --- src/kernel/type_checker.cpp | 23 ++++++++-------- src/kernel/type_checker.h | 15 +++-------- src/library/kernel_bindings.cpp | 48 +++++++++------------------------ tests/lua/tc1.lua | 2 +- tests/lua/tc2.lua | 4 +-- tests/lua/tc7.lua | 2 +- 6 files changed, 33 insertions(+), 61 deletions(-) diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 7a7dfc3f4..20032199c 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -27,8 +27,8 @@ no_constraints_allowed_exception::no_constraints_allowed_exception():exception(" exception * no_constraints_allowed_exception::clone() const { return new no_constraints_allowed_exception(); } void no_constraints_allowed_exception::rethrow() const { throw *this; } -void no_constraint_handler::add_cnstr(constraint const &) { - throw no_constraints_allowed_exception(); +add_cnstr_fn mk_no_contranint_fn() { + return add_cnstr_fn([](constraint const &) { throw no_constraints_allowed_exception(); }); } /** \brief Auxiliary functional object used to implement type checker. */ @@ -58,7 +58,7 @@ struct type_checker::imp { environment m_env; name_generator m_gen; - constraint_handler & m_chandler; + add_cnstr_fn m_add_cnstr_fn; std::unique_ptr m_conv; // In the type checker cache, we must take into account binder information. // Examples: @@ -71,8 +71,8 @@ struct type_checker::imp { // temp flag level_param_names m_params; - imp(environment const & env, name_generator const & g, constraint_handler & h, std::unique_ptr && conv, bool memoize): - m_env(env), m_gen(g), m_chandler(h), m_conv(std::move(conv)), m_conv_ctx(*this), m_tc_ctx(*this), + imp(environment const & env, name_generator const & g, add_cnstr_fn const & h, std::unique_ptr && conv, bool memoize): + m_env(env), m_gen(g), m_add_cnstr_fn(h), m_conv(std::move(conv)), m_conv_ctx(*this), m_tc_ctx(*this), m_memoize(memoize) {} optional expand_macro(expr const & m) { @@ -89,9 +89,9 @@ struct type_checker::imp { return mk_pair(instantiate(binding_body(e), local), local); } - /** \brief Add given constraint to the constraint handler m_chandler. */ + /** \brief Add given constraint using m_add_cnstr_fn. */ void add_cnstr(constraint const & c) { - m_chandler.add_cnstr(c); + m_add_cnstr_fn(c); } /** \brief Return true iff \c t and \c s are definitionally equal */ @@ -403,18 +403,19 @@ struct type_checker::imp { expr whnf(expr const & t) { return m_conv->whnf(t, m_conv_ctx); } }; -no_constraint_handler g_no_constraint_handler; +static add_cnstr_fn g_no_constraint_fn = mk_no_contranint_fn(); -type_checker::type_checker(environment const & env, name_generator const & g, constraint_handler & h, std::unique_ptr && conv, bool memoize): +type_checker::type_checker(environment const & env, name_generator const & g, add_cnstr_fn const & h, + std::unique_ptr && conv, bool memoize): m_ptr(new imp(env, g, h, std::move(conv), memoize)) {} type_checker::type_checker(environment const & env, name_generator const & g, std::unique_ptr && conv, bool memoize): - type_checker(env, g, g_no_constraint_handler, std::move(conv), memoize) {} + type_checker(env, g, g_no_constraint_fn, std::move(conv), memoize) {} static name g_tmp_prefix = name::mk_internal_unique_name(); type_checker::type_checker(environment const & env): - type_checker(env, name_generator(g_tmp_prefix), g_no_constraint_handler, mk_default_converter(env), true) {} + type_checker(env, name_generator(g_tmp_prefix), g_no_constraint_fn, mk_default_converter(env), true) {} type_checker::~type_checker() {} expr type_checker::infer(expr const & t) { return m_ptr->infer_type(t); } diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index a539c6ba5..1c1d22534 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -15,17 +15,10 @@ Author: Leonardo de Moura #include "kernel/converter.h" namespace lean { -class constraint_handler { -public: - virtual ~constraint_handler() {} - virtual void add_cnstr(constraint const & c) = 0; -}; +typedef std::function add_cnstr_fn; /** \brief This handler always throw an exception (\c no_constraints_allowed_exception) when \c add_cnstr is invoked. */ -class no_constraint_handler : public constraint_handler { -public: - virtual void add_cnstr(constraint const & c); -}; +add_cnstr_fn mk_no_contranint_fn(); /** \brief Exception used in \c no_constraint_handler. */ class no_constraints_allowed_exception : public exception { @@ -51,9 +44,9 @@ public: memoize: if true, then inferred types are memoized/cached */ - type_checker(environment const & env, name_generator const & g, constraint_handler & h, std::unique_ptr && conv, + type_checker(environment const & env, name_generator const & g, add_cnstr_fn const & h, std::unique_ptr && conv, bool memoize = true); - type_checker(environment const & env, name_generator const & g, constraint_handler & h, bool memoize = true): + type_checker(environment const & env, name_generator const & g, add_cnstr_fn const & h, bool memoize = true): type_checker(env, g, h, mk_default_converter(env), memoize) {} /** \brief Similar to the previous constructor, but if a method tries to create a constraint, then an diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index cf3e5507f..7f7f413f1 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -1808,37 +1808,16 @@ static void open_substitution(lua_State * L) { SET_GLOBAL_FUN(substitution_pred, "is_substitution"); } -// constraint_handler -class lua_constraint_handler : public constraint_handler { - luaref m_f; -public: - lua_constraint_handler(luaref const & f):m_f(f) {} - virtual void add_cnstr(constraint const & c) { - lua_State * L = m_f.get_state(); - m_f.push(); - push_constraint(L, c); - pcall(L, 1, 0, 0); - } -}; -DECL_UDATA(lua_constraint_handler) -int mk_constraint_handler(lua_State * L) { - luaL_checktype(L, 1, LUA_TFUNCTION); // user-fun - return push_lua_constraint_handler(L, lua_constraint_handler(luaref(L, 1))); -} - -static const struct luaL_Reg lua_constraint_handler_m[] = { - {"__gc", lua_constraint_handler_gc}, - {0, 0} -}; - -static void open_constraint_handler(lua_State * L) { - luaL_newmetatable(L, lua_constraint_handler_mt); - lua_pushvalue(L, -1); - lua_setfield(L, -2, "__index"); - setfuncs(L, lua_constraint_handler_m, 0); - - SET_GLOBAL_FUN(mk_constraint_handler, "constraint_handler"); - SET_GLOBAL_FUN(lua_constraint_handler_pred, "is_constraint_handler"); +// add_cnstr_fn +add_cnstr_fn to_add_cnstr_fn(lua_State * L, int idx) { + luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun + luaref f(L, idx); + return add_cnstr_fn([=](constraint const & c) { + lua_State * L = f.get_state(); + f.push(); + push_constraint(L, c); + pcall(L, 1, 0, 0); + }); } // type_checker @@ -1857,9 +1836,9 @@ int mk_type_checker(lua_State * L) { return push_type_checker_ref(L, std::make_shared(to_environment(L, 1))); } else if (nargs == 2) { return push_type_checker_ref(L, std::make_shared(to_environment(L, 1), to_name_generator(L, 2))); - } else if (nargs == 3 && is_lua_constraint_handler(L, 3)) { + } else if (nargs == 3 && lua_isfunction(L, 3)) { return push_type_checker_ref(L, std::make_shared(to_environment(L, 1), to_name_generator(L, 2), - to_lua_constraint_handler(L, 3))); + to_add_cnstr_fn(L, 3))); } else { optional mod_idx; bool memoize; name_set extra_opaque; if (nargs == 3) { @@ -1871,7 +1850,7 @@ int mk_type_checker(lua_State * L) { } else { get_type_checker_args(L, 4, mod_idx, memoize, extra_opaque); auto t = std::make_shared(to_environment(L, 1), to_name_generator(L, 2), - to_lua_constraint_handler(L, 3), + to_add_cnstr_fn(L, 3), mk_default_converter(to_environment(L, 1), mod_idx, memoize, extra_opaque), memoize); return push_type_checker_ref(L, t); @@ -2046,7 +2025,6 @@ void open_kernel_module(lua_State * L) { open_justification(L); open_constraint(L); open_substitution(L); - open_constraint_handler(L); open_type_checker(L); open_inductive(L); } diff --git a/tests/lua/tc1.lua b/tests/lua/tc1.lua index 99041ab70..da0801458 100644 --- a/tests/lua/tc1.lua +++ b/tests/lua/tc1.lua @@ -8,7 +8,7 @@ local b = Const("b") print(t(b)) assert(tc:whnf(t(b)) == b) local cs = {} -local tc2 = type_checker(env, g, constraint_handler(function (c) print(c); cs[#cs+1] = c end)) +local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end) assert(tc:check(Bool) == mk_sort(mk_level_one())) print(tc:infer(t)) local m = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("u")))) diff --git a/tests/lua/tc2.lua b/tests/lua/tc2.lua index ea0a995d0..3021441cd 100644 --- a/tests/lua/tc2.lua +++ b/tests/lua/tc2.lua @@ -54,7 +54,7 @@ print(tc:check(Fun({{A, mk_sort(u)}, {a, A}, {b, A}, {c, A}, {d, A}, trans_u(A, b, c, d, H2, H3))))) local cs = {} local g = name_generator("tst") -local tc2 = type_checker(env, g, constraint_handler(function (c) print(c); cs[#cs+1] = c end)) +local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end) print("=================") local f = Const("f") local mf_ty = mk_metavar("f_ty", Pi(A, mk_sort(u), mk_sort(mk_meta_univ("l_f")))) @@ -64,7 +64,7 @@ print(tc2:check(Fun({{A, mk_sort(u)}, {f, mf_ty(A)}, {a, A}}, local cs = {} -local tc2 = type_checker(env, g, constraint_handler(function (c) print(c); cs[#cs+1] = c end)) +local tc2 = type_checker(env, g, function (c) print(c); cs[#cs+1] = c end) local scope = {{A, mk_sort(u)}, {a, A}, {b, A}, {c, A}, {d, A}, {H1, id_u(A, b, a)}, {H2, id_u(A, b, c)}, {H3, id_u(A, c, d)}} local mP = mk_metavar("P", Pi(scope, mk_metavar("P_ty", Pi(scope, mk_sort(mk_meta_univ("l_P"))))(A, a, b, c, d, H1, H2, H3)))(A, a, b, c, d, H1, H2, H3) diff --git a/tests/lua/tc7.lua b/tests/lua/tc7.lua index ef85276ea..77af56ffb 100644 --- a/tests/lua/tc7.lua +++ b/tests/lua/tc7.lua @@ -13,7 +13,7 @@ local a = Const("a") local m1 = mk_metavar("m1", N) local cs = {} local ngen = name_generator("tst") -local tc = type_checker(env, ngen, constraint_handler(function (c) print(c); cs[#cs+1] = c end)) +local tc = type_checker(env, ngen, function (c) print(c); cs[#cs+1] = c end) assert(tc:is_def_eq(f(m1), g(a))) assert(tc:is_def_eq(f(m1), a)) assert(not tc:is_def_eq(f(a), a))