/* Copyright (c) 2013-2014 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include #include #include "kernel/abstract.h" #include "kernel/instantiate.h" #include "kernel/for_each_fn.h" #include "kernel/type_checker.h" #include "library/kernel_bindings.h" #include "library/locals.h" #include "library/match.h" namespace lean { static name * g_tmp_prefix = nullptr; void initialize_match() { g_tmp_prefix = new name(name::mk_internal_unique_name()); } void finalize_match() { delete g_tmp_prefix; } level mk_idx_meta_univ(unsigned i) { return mk_meta_univ(name(*g_tmp_prefix, i)); } expr mk_idx_meta(unsigned i, expr const & type) { return mk_metavar(name(*g_tmp_prefix, i), type); } bool is_idx_meta_univ(level const & l) { if (!is_meta(l)) return false; name const & n = meta_id(l); return !n.is_atomic() && n.is_numeral() && n.get_prefix() == *g_tmp_prefix; } unsigned to_meta_idx(level const & l) { lean_assert(is_idx_meta_univ(l)); return meta_id(l).get_numeral(); } bool is_idx_meta(expr const & e) { if (!is_metavar(e)) return false; name const & n = mlocal_name(e); return !n.is_atomic() && n.is_numeral() && n.get_prefix() == *g_tmp_prefix; } unsigned to_meta_idx(expr const & e) { lean_assert(is_idx_meta(e)); return mlocal_name(e).get_numeral(); } class match_fn : public match_context { unsigned m_esubst_sz; optional * m_esubst; unsigned m_lsubst_sz; optional * m_lsubst; name_generator m_ngen; name_map * m_name_subst; match_plugin const * m_plugin; buffer> m_stack; buffer m_scopes; bool * m_assigned; // mark if matcher assigned anything void push() { m_scopes.push_back(m_stack.size()); } void pop() { lean_assert(!m_scopes.empty()); unsigned old_sz = m_scopes.back(); while (m_stack.size() > old_sz) { auto p = m_stack.back(); if (p.first) m_esubst[p.second] = none_expr(); else m_lsubst[p.second] = none_level(); m_stack.pop_back(); } m_scopes.pop_back(); } void keep() { m_scopes.back() = m_stack.size(); } struct scope { match_fn & m; scope(match_fn & _m):m(_m) { m.push(); } ~scope() { m.pop(); } void keep() { m.keep(); } }; void _assign(expr const & p, expr const & t) { lean_assert(to_meta_idx(p) < m_esubst_sz); unsigned i = to_meta_idx(p); m_stack.emplace_back(true, i); m_esubst[i] = t; if (m_assigned) *m_assigned = true; } void _assign(level const & p, level const & l) { lean_assert(to_meta_idx(p) < m_lsubst_sz); unsigned i = to_meta_idx(p); m_stack.emplace_back(false, i); m_lsubst[i] = l; if (m_assigned) *m_assigned = true; } void throw_exception() const { throw exception("ill-formed higher-order matching problem"); } optional _get_subst(expr const & x) const { unsigned i = to_meta_idx(x); unsigned sz = m_esubst_sz; if (i >= sz) throw_exception(); return m_esubst[i]; } optional _get_subst(level const & x) const { unsigned i = to_meta_idx(x); if (i >= m_lsubst_sz) throw_exception(); return m_lsubst[i]; } virtual void assign(expr const & p, expr const & t) { return _assign(p, t); } virtual void assign(level const & p, level const & t) { return _assign(p, t); } virtual optional get_subst(expr const & x) const { return _get_subst(x); } virtual optional get_subst(level const & x) const { return _get_subst(x); } virtual name mk_name() { return m_ngen.next(); } bool args_are_distinct_locals(buffer const & args) { for (auto it = args.begin(); it != args.end(); it++) { if (!is_local(*it) || contains_local(*it, args.begin(), it)) return false; } return true; } optional proj(expr const & t, buffer const & args) { expr r = Fun(args, t); if (has_local(r)) return none_expr(); else return some_expr(r); } bool try_plugin(expr const & p, expr const & t) { if (!m_plugin) return false; return m_plugin->on_failure(p, t, *this); } bool match_binding_core(expr p, expr t) { lean_assert(is_binding(p) && is_binding(t)); buffer ls; expr_kind k = p.kind(); while (p.kind() == k && t.kind() == k) { if (m_name_subst) (*m_name_subst)[binding_name(p)] = binding_name(t); expr p_d = instantiate_rev(binding_domain(p), ls.size(), ls.data()); expr t_d = instantiate_rev(binding_domain(t), ls.size(), ls.data()); if (!_match(p_d, t_d)) return false; expr l = mk_local(m_ngen.next(), binding_name(t), t_d, binding_info(t)); ls.push_back(l); p = binding_body(p); t = binding_body(t); } p = instantiate_rev(p, ls.size(), ls.data()); t = instantiate_rev(t, ls.size(), ls.data()); return _match(p, t); } bool match_binding(expr const & p, expr const & t) { { scope s(*this); if (match_binding_core(p, t)) { s.keep(); return true; } } return try_plugin(p, t); } bool match_macro_core(expr const & p, expr const & t) { if (macro_def(p) == macro_def(t) && macro_num_args(p) == macro_num_args(t)) { for (unsigned i = 0; i < macro_num_args(p); i++) { if (!_match(macro_arg(p, i), macro_arg(t, i))) return false; } return true; } return false; } bool match_macro(expr const & p, expr const & t) { { scope s(*this); if (match_macro_core(p, t)) { s.keep(); return true; } } return try_plugin(p, t); } bool match_app_core(expr const & p, expr const & t) { buffer p_args; buffer t_args; expr const & p_fn = get_app_args(p, p_args); expr const & t_fn = get_app_args(t, t_args); if (p_args.size() != t_args.size()) return false; if (!match_core(p_fn, t_fn)) return false; for (unsigned i = 0; i < p_args.size(); i++) { if (!_match(p_args[i], t_args[i])) return false; } return true; } bool match_app(expr const & p, expr const & t) { { scope s(*this); if (match_app_core(p, t)) { s.keep(); return true; } } return try_plugin(p, t); } bool match_level_core(level const & p, level const & l) { if (p == l) return true; if (p.kind() == l.kind()) { switch (p.kind()) { case level_kind::Zero: lean_unreachable(); // LCOV_EXCL_LINE case level_kind::Param: case level_kind::Global: case level_kind::Meta: return false; case level_kind::Succ: return match_level(succ_of(p), succ_of(l)); case level_kind::Max: { scope s(*this); if (match_level(max_lhs(p), max_lhs(l)) && match_level(max_rhs(p), max_rhs(l))) { s.keep(); return true; } break; } case level_kind::IMax: { scope s(*this); if (match_level(imax_lhs(p), imax_lhs(l)) && match_level(imax_rhs(p), imax_rhs(l))) { s.keep(); return true; } break; }} } level p1 = normalize(p); level l1 = normalize(l); if (p1 != p || l1 != l) return match_level(p1, l1); else return false; } bool match_level(level const & p, level const & l) { if (is_idx_meta_univ(p)) { auto s = _get_subst(p); if (s) { return match_level_core(*s, l); } else { _assign(p, l); return true; } } return match_level_core(p, l); } bool match_levels(levels ps, levels ls) { while (ps && ls) { if (!match_level(head(ps), head(ls))) return false; ps = tail(ps); ls = tail(ls); } return true; } bool match_constant(expr const & p, expr const & t) { if (const_name(p) == const_name(t)) return match_levels(const_levels(p), const_levels(t)); else return try_plugin(p, t); } static expr eta(expr const & e) { unsigned num = 0; expr it = e; while (is_lambda(it)) { num++; it = binding_body(it); } if (num == 0) return e; for (unsigned i = 0; i < num; i++) { if (!is_app(it)) return e; if (!is_var(app_arg(it), i)) { return e; } it = app_fn(it); } if (!closed(it)) return e; return it; } bool match_core(expr const & p, expr const & t) { if (p.kind() != t.kind()) { if (is_lambda(p) != is_lambda(t)) { expr new_p = eta(p); expr new_t = eta(t); if (is_lambda(new_p) == is_lambda(new_t)) return _match(new_p, new_t); } return try_plugin(p, t); } if (m_plugin) { switch (m_plugin->pre(p, t, *this)) { case l_true: return true; case l_false: return false; case l_undef: break; } } switch (p.kind()) { case expr_kind::Local: case expr_kind::Meta: return mlocal_name(p) == mlocal_name(t); case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: return match_constant(p, t); case expr_kind::Sort: return match_level(sort_level(p), sort_level(t)); case expr_kind::Lambda: case expr_kind::Pi: return match_binding(p, t); case expr_kind::Macro: return match_macro(p, t); case expr_kind::App: return match_app(p, t); } lean_unreachable(); // LCOV_EXCL_LINE } bool _match(expr const & p, expr const & t) { if (is_idx_meta(p)) { auto s = _get_subst(p); if (s) { return match_core(*s, t); } else { _assign(p, t); return true; } } else if (is_app(p)) { buffer args; expr const & f = get_app_rev_args(p, args); if (is_idx_meta(f)) { // higher-order pattern case auto s = _get_subst(f); if (s) { expr new_p = apply_beta(*s, args.size(), args.data()); return match_core(new_p, t); } if (args_are_distinct_locals(args)) { optional new_t = proj(t, args); if (new_t) { _assign(f, *new_t); return true; } } } // fallback to the first-order case } return match_core(p, t); } public: match_fn(unsigned lsubst_sz, optional * lsubst, unsigned esubst_sz, optional * esubst, name_generator const & ngen, name_map * name_subst, match_plugin const * plugin, bool * assigned): m_esubst_sz(esubst_sz), m_esubst(esubst), m_lsubst_sz(lsubst_sz), m_lsubst(lsubst), m_ngen(ngen), m_name_subst(name_subst), m_plugin(plugin), m_assigned(assigned) {} virtual bool match(expr const & p, expr const & t) { return _match(p, t); } }; bool match(expr const & p, expr const & t, unsigned lsubst_sz, optional * lsubst, unsigned esubst_sz, optional * esubst, name const * prefix, name_map * name_subst, match_plugin const * plugin, bool * assigned) { lean_assert(closed(t)); lean_assert(closed(p)); if (prefix) return match_fn(lsubst_sz, lsubst, esubst_sz, esubst, name_generator(*prefix), name_subst, plugin, assigned).match(p, t); else return match_fn(lsubst_sz, lsubst, esubst_sz, esubst, name_generator(*g_tmp_prefix), name_subst, plugin, assigned).match(p, t); } bool match(expr const & p, expr const & t, buffer> & lsubst, buffer> & esubst, name const * prefix, name_map * name_subst, match_plugin const * plugin, bool * assigned) { return match(p, t, lsubst.size(), lsubst.data(), esubst.size(), esubst.data(), prefix, name_subst, plugin, assigned); } bool whnf_match_plugin::on_failure(expr const & p, expr const & t, match_context & ctx) const { try { constraint_seq cs; expr p1 = m_tc.whnf(p, cs); expr t1 = m_tc.whnf(t, cs); return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1); } catch (exception&) { return false; } } static unsigned updt_idx_meta_univ_range(level const & l, unsigned r) { for_each(l, [&](level const & l) { if (!has_meta(l)) return false; if (is_idx_meta_univ(l)) { unsigned new_r = to_meta_idx(l) + 1; if (new_r > r) r = new_r; } return true; }); return r; } static pair get_idx_meta_univ_ranges(expr const & e) { if (!has_metavar(e)) return mk_pair(0, 0); unsigned rlvl = 0; unsigned rexp = 0; for_each(e, [&](expr const & e, unsigned) { if (!has_metavar(e)) return false; if (is_constant(e)) for (level const & l : const_levels(e)) rlvl = updt_idx_meta_univ_range(l, rlvl); if (is_sort(e)) rlvl = updt_idx_meta_univ_range(sort_level(e), rlvl); if (is_idx_meta(e)) rexp = std::max(to_meta_idx(e) + 1, rexp); return true; }); return mk_pair(rlvl, rexp); } typedef std::shared_ptr match_plugin_ref; DECL_UDATA(match_plugin_ref) static const struct luaL_Reg match_plugin_ref_m[] = { {"__gc", match_plugin_ref_gc}, {0, 0} }; // version of whnf_match_plugin for Lua class whnf_match_plugin2 : public whnf_match_plugin { std::shared_ptr m_tc_ref; public: whnf_match_plugin2(std::shared_ptr & tc): whnf_match_plugin(*tc), m_tc_ref(tc) {} }; static int mk_whnf_match_plugin(lua_State * L) { return push_match_plugin_ref(L, match_plugin_ref(new whnf_match_plugin2(to_type_checker_ref(L, 1)))); } static int match(lua_State * L) { int nargs = lua_gettop(L); expr p = to_expr(L, 1); expr t = to_expr(L, 2); match_plugin * plugin = nullptr; if (nargs >= 3) plugin = to_match_plugin_ref(L, 3).get(); if (!closed(t)) throw exception("higher-order pattern matching failure, input term must not contain free variables"); unsigned r1, r2; auto r1_r2 = get_idx_meta_univ_ranges(p); r1 = r1_r2.first; r2 = r1_r2.second; buffer> lsubst; buffer> esubst; lsubst.resize(r1); esubst.resize(r2); if (match(p, t, lsubst, esubst, nullptr, nullptr, plugin)) { lua_newtable(L); int i = 1; for (auto s : esubst) { if (s) push_expr(L, *s); else lua_pushboolean(L, false); lua_rawseti(L, -2, i); i = i + 1; } lua_newtable(L); i = 1; for (auto s : lsubst) { if (s) push_level(L, *s); else lua_pushboolean(L, false); lua_rawseti(L, -2, i); i = i + 1; } } else { lua_pushnil(L); lua_pushnil(L); } return 2; } static int mk_idx_meta_univ(lua_State * L) { return push_level(L, mk_idx_meta_univ(luaL_checkinteger(L, 1))); } static int mk_idx_meta(lua_State * L) { return push_expr(L, mk_idx_meta(luaL_checkinteger(L, 1), to_expr(L, 2))); } void open_match(lua_State * L) { luaL_newmetatable(L, match_plugin_ref_mt); lua_pushvalue(L, -1); lua_setfield(L, -2, "__index"); setfuncs(L, match_plugin_ref_m, 0); SET_GLOBAL_FUN(mk_whnf_match_plugin, "whnf_match_plugin"); SET_GLOBAL_FUN(match_plugin_ref_pred, "is_match_plugin"); SET_GLOBAL_FUN(mk_idx_meta_univ, "mk_idx_meta_univ"); SET_GLOBAL_FUN(mk_idx_meta, "mk_idx_meta"); SET_GLOBAL_FUN(match, "match"); } }