From 7b188ea37e470af086ebb997804ad413735ec6ea Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 23 Jun 2014 11:00:35 -0700 Subject: [PATCH] feat(library/unifier): implement flex-rigid case Signed-off-by: Leonardo de Moura --- src/library/unifier.cpp | 209 +++++++++++++++++++++++++++++++++++++--- tests/lua/unify4.lua | 26 +++++ 2 files changed, 222 insertions(+), 13 deletions(-) create mode 100644 tests/lua/unify4.lua diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index a2271b7d9..6beed3549 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -12,6 +12,7 @@ Author: Leonardo de Moura #include "util/lazy_list_fn.h" #include "kernel/for_each_fn.h" #include "kernel/abstract.h" +#include "kernel/instantiate.h" #include "kernel/type_checker.h" #include "library/unifier.h" #include "library/kernel_bindings.h" @@ -255,6 +256,12 @@ struct unifier_fn { virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); } }; + struct ho_case_split : public case_split { + list m_tail; + ho_case_split(unifier_fn & u, list const & tail):case_split(u), m_tail(tail) {} + virtual bool next(unifier_fn & u) { return u.next_ho_case_split(*this); } + }; + case_split_stack m_case_splits; optional m_conflict; //!< if different from none, then there is a conflict. @@ -282,6 +289,79 @@ struct unifier_fn { void update_conflict(justification const & j) { m_conflict = j; } void reset_conflict() { m_conflict = optional(); lean_assert(!in_conflict()); } + /** \brief Given \c type of the form (Pi ctx, r), return (Pi ctx, new_range) */ + static expr replace_range(expr const & type, expr const & new_range) { + if (is_pi(type)) + return update_binding(type, binding_domain(type), replace_range(binding_body(type), new_range)); + else + return new_range; + } + + /** \brief Return the "arity" of the given type. The arity is the number of nested pi-expressions. */ + static unsigned get_arity(expr type) { + unsigned r = 0; + while (is_pi(type)) { + type = binding_body(type); + r++; + } + return r; + } + + /** \brief Return the term (f #n-1 ... #0) */ + static expr mk_app_vars(expr const & f, unsigned n) { + expr r = f; + unsigned i = n; + while (i > 0) { + --i; + r = r(mk_var(i)); + } + return r; + } + + /** + \brief Given a type \c t of the form + Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n] + return a new metavariable \c m1 with type + Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), Type.{u} + where \c u is a new universe metavariable. + */ + expr mk_aux_type_metavar_for(expr const & t) { + expr new_type = replace_range(t, mk_sort(mk_meta_univ(m_ngen.next()))); + name n = m_ngen.next(); + return mk_metavar(n, new_type); + } + + /** + \brief Given a type \c t of the form + Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n] + return a new metavariable \c m1 with type + Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), (m2 x_1 ... x_n) + where \c m2 is a new metavariable with type + Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), Type.{u} + where \c u is a new universe metavariable. + */ + expr mk_aux_metavar_for(expr const & t) { + unsigned num = get_arity(t); + expr r = mk_app_vars(mk_aux_type_metavar_for(t), num); + expr new_type = replace_range(t, r); + name n = m_ngen.next(); + return mk_metavar(n, new_type); + } + + /** + \brief Given t + Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n] + return + fun (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), v + */ + expr mk_lambda_for(expr const & t, expr const & v) { + if (is_pi(t)) { + return mk_lambda(binding_name(t), binding_domain(t), mk_lambda_for(binding_body(t), v), binding_info(t)); + } else { + return v; + } + } + /** \brief Update occurrence index with entry m -> cidx, where \c m is the name of a metavariable, and \c cidx is the index of a constraint that contains \c m. @@ -306,6 +386,12 @@ struct unifier_fn { /** \see add_occ */ void add_mlvl_occ(name const & m, unsigned cidx) { add_occ(m, cidx); } + /** + \brief Update the indices \c m_mvar_occs and \c m_mlvl_occs. + For every metavariable name \c m in \c mlvl_occs and \c mvar_occs, add an entry to \c cidx. + + \remark \c cidx is the index of some constraint in \c m_cnstrs. + */ void add_occs(unsigned cidx, name_set const * mlvl_occs, name_set const * mvar_occs) { if (mlvl_occs) { mlvl_occs->for_each([=](name const & m) { @@ -440,11 +526,9 @@ struct unifier_fn { st = process_metavar_eq(rhs, lhs, new_jst); if (st != Continue) return st == Assigned; - // Make sure the lhs/rhs are in weak-head-normal-form, when the other one is meta. - if (is_meta(lhs)) - rhs = m_tc.whnf(rhs); - else if (is_meta(rhs)) - lhs = m_tc.whnf(lhs); + // Make sure lhs/rhs are in weak-head-normal-form + rhs = m_tc.whnf(rhs); + lhs = m_tc.whnf(lhs); // If lhs or rhs were updated, then invoke is_def_eq again. if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) { @@ -678,14 +762,18 @@ struct unifier_fn { } } - bool process_flex_rigid(constraint const &) { - // TODO(Leo): - return true; - } - - bool process_flex_flex(constraint const &) { - // TODO(Leo): - return true; + bool next_ho_case_split(ho_case_split & cs) { + if (!is_nil(cs.m_tail)) { + cs.restore_state(*this); + lean_assert(!in_conflict()); + constraints c = head(cs.m_tail); + cs.m_tail = tail(cs.m_tail); + return process_constraints(c, mk_assumption_justification(cs.m_assumption_idx)); + } else { + // update conflict + update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications)); + return false; + } } /** \brief Return true iff \c c is a flex-rigid constraint. */ @@ -702,6 +790,101 @@ struct unifier_fn { return is_eq_cnstr(c) && is_meta(cnstr_lhs_expr(c)) && is_meta(cnstr_rhs_expr(c)); } + /** \brief Process a flex rigid constraint */ + bool process_flex_rigid(expr const & lhs, expr const & rhs, justification const & j) { + lean_assert(is_meta(lhs)); + lean_assert(!is_meta(rhs)); + buffer margs; + expr m = get_app_args(lhs, margs); + expr mtype = mlocal_type(m); + buffer alts; + lean_assert(!is_var(rhs)); // rhs can't be a free variable (this is an invariant of the approach we are using). + lean_assert(!is_let(rhs)); // rhs can't be a let, since the rhs is in whnf. + // Add Projections to alts + unsigned vidx = margs.size() - 1; + for (expr const & marg : margs) { + if (!is_local(marg) && !is_local(rhs)) { + // if rhs is not local, then we only add projections for the nonlocal arguments of lhs + constraint c1 = mk_eq_cnstr(marg, rhs, j); + constraint c2 = mk_eq_cnstr(m, mk_lambda_for(mtype, mk_var(vidx)), j); + alts.push_back(constraints({c1, c2})); + } else if (is_local(marg) && marg == rhs) { + // if the argument is local, and rhs is equal to it, then we also add a projection + constraint c1 = mk_eq_cnstr(m, mk_lambda_for(mtype, mk_var(vidx)), j); + alts.push_back(constraints(c1)); + } + vidx--; + } + // Add Imitation to alts + buffer cs; + bool imitate = true; + if (is_app(rhs)) { + buffer rargs; + expr f = get_app_args(rhs, rargs); + // create an auxiliary metavariable for each rhs argument + buffer sargs; + for (expr const & rarg : rargs) { + expr maux = mk_aux_metavar_for(mtype); + cs.push_back(mk_eq_cnstr(mk_app(maux, margs), rarg, j)); + sargs.push_back(mk_app_vars(maux, margs.size())); + } + expr v = mk_app(f, sargs); + v = mk_lambda_for(mtype, v); + cs.push_back(mk_eq_cnstr(m, v, j)); + } else if (is_binding(rhs)) { + expr maux1 = mk_aux_metavar_for(mtype); + cs.push_back(mk_eq_cnstr(mk_app(maux1, margs), binding_domain(rhs), j)); + expr pi = mk_pi(binding_name(rhs), binding_domain(rhs), binding_body(rhs)); + expr mtype2 = replace_range(mtype, pi); // trick for "extending" the context + expr maux2 = mk_aux_metavar_for(mtype2); + expr new_local = mk_local(m_ngen.next(), binding_name(rhs), binding_domain(rhs)); + cs.push_back(mk_eq_cnstr(mk_app(mk_app(maux2, margs), new_local), instantiate(binding_body(rhs), new_local), j)); + expr v = update_binding(rhs, mk_app_vars(maux1, margs.size()), mk_app_vars(maux2, margs.size() + 1)); + v = mk_lambda_for(mtype, v); + cs.push_back(mk_eq_cnstr(m, v, j)); + } else if (is_sort(rhs) || is_constant(rhs)) { + expr v = mk_lambda_for(mtype, rhs); + cs.push_back(mk_eq_cnstr(m, v, j)); + } else if (is_local(rhs)) { + // We don't imitate when the right-hand-side is a local constant. + // The term (fun (ctx), local) is not well-formed. + imitate = false; + } else { + // we don't support macros + lean_assert(is_macro(rhs)); + imitate = false; + } + if (imitate) + alts.push_back(to_list(cs.begin(), cs.end())); + + if (alts.empty()) { + set_conflict(j); + return false; + } else if (alts.size() == 1) { + // we don't need to create a backtracking point + return process_constraints(alts[0], justification()); + } else { + justification a = mk_assumption_justification(m_next_assumption_idx); + add_case_split(std::unique_ptr(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end())))); + return process_constraints(alts[0], a); + } + } + + /** \brief Process a flex rigid constraint */ + bool process_flex_rigid(constraint const & c) { + lean_assert(is_flex_rigid(c)); + if (is_meta(cnstr_lhs_expr(c))) + return process_flex_rigid(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification()); + else + return process_flex_rigid(cnstr_rhs_expr(c), cnstr_lhs_expr(c), c.get_justification()); + } + + bool process_flex_flex(constraint const &) { + // TODO(Leo): + return true; + } + + /** \brief Process the next constraint in the constraint queue m_cnstrs */ bool process_next() { lean_assert(!m_cnstrs.empty()); diff --git a/tests/lua/unify4.lua b/tests/lua/unify4.lua new file mode 100644 index 000000000..2d68d4654 --- /dev/null +++ b/tests/lua/unify4.lua @@ -0,0 +1,26 @@ +function test_unify(env, m, lhs, rhs, num_s) + print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s)) + local ss = unify(env, lhs, rhs) + local n = 0 + for s in ss do + print("solution: " .. tostring(s:instantiate(m))) + n = n + 1 + end + if num_s ~= n then print("n: " .. n) end + assert(num_s == n) +end + +local env = environment() +env = add_decl(env, mk_var_decl("N", Type)) +local N = Const("N") +env = add_decl(env, mk_var_decl("f", mk_arrow(N, N, N))) +env = add_decl(env, mk_var_decl("a", N)) +local f = Const("f") +local a = Const("a") +local l1 = mk_local("l1", "x", N) +local l2 = mk_local("l2", "y", N) +local l3 = mk_local("l3", "z", N) +local m = mk_metavar("m", mk_arrow(N, N, mk_metavar("m_type", mk_arrow(N, N, mk_sort(mk_meta_univ("u"))))(Var(1), Var(0)))) +test_unify(env, m, m(l1, l1), f(f(a, l1), l1), 4) +print("-----------------") +test_unify(env, m, m(l1, l1), mk_lambda("z", Bool, f(l1, f(Var(0), a))), 2)