feat(library/unifier): implement flex-rigid case

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-23 11:00:35 -07:00
parent 68d55ef398
commit 7b188ea37e
2 changed files with 222 additions and 13 deletions

View file

@ -12,6 +12,7 @@ Author: Leonardo de Moura
#include "util/lazy_list_fn.h" #include "util/lazy_list_fn.h"
#include "kernel/for_each_fn.h" #include "kernel/for_each_fn.h"
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "kernel/instantiate.h"
#include "kernel/type_checker.h" #include "kernel/type_checker.h"
#include "library/unifier.h" #include "library/unifier.h"
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
@ -255,6 +256,12 @@ struct unifier_fn {
virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); } virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); }
}; };
struct ho_case_split : public case_split {
list<constraints> m_tail;
ho_case_split(unifier_fn & u, list<constraints> const & tail):case_split(u), m_tail(tail) {}
virtual bool next(unifier_fn & u) { return u.next_ho_case_split(*this); }
};
case_split_stack m_case_splits; case_split_stack m_case_splits;
optional<justification> m_conflict; //!< if different from none, then there is a conflict. optional<justification> m_conflict; //!< if different from none, then there is a conflict.
@ -282,6 +289,79 @@ struct unifier_fn {
void update_conflict(justification const & j) { m_conflict = j; } void update_conflict(justification const & j) { m_conflict = j; }
void reset_conflict() { m_conflict = optional<justification>(); lean_assert(!in_conflict()); } void reset_conflict() { m_conflict = optional<justification>(); lean_assert(!in_conflict()); }
/** \brief Given \c type of the form <tt>(Pi ctx, r)</tt>, return <tt>(Pi ctx, new_range)</tt> */
static expr replace_range(expr const & type, expr const & new_range) {
if (is_pi(type))
return update_binding(type, binding_domain(type), replace_range(binding_body(type), new_range));
else
return new_range;
}
/** \brief Return the "arity" of the given type. The arity is the number of nested pi-expressions. */
static unsigned get_arity(expr type) {
unsigned r = 0;
while (is_pi(type)) {
type = binding_body(type);
r++;
}
return r;
}
/** \brief Return the term (f #n-1 ... #0) */
static expr mk_app_vars(expr const & f, unsigned n) {
expr r = f;
unsigned i = n;
while (i > 0) {
--i;
r = r(mk_var(i));
}
return r;
}
/**
\brief Given a type \c t of the form
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
return a new metavariable \c m1 with type
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), Type.{u}</tt>
where \c u is a new universe metavariable.
*/
expr mk_aux_type_metavar_for(expr const & t) {
expr new_type = replace_range(t, mk_sort(mk_meta_univ(m_ngen.next())));
name n = m_ngen.next();
return mk_metavar(n, new_type);
}
/**
\brief Given a type \c t of the form
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
return a new metavariable \c m1 with type
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), (m2 x_1 ... x_n)</tt>
where \c m2 is a new metavariable with type
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), Type.{u}</tt>
where \c u is a new universe metavariable.
*/
expr mk_aux_metavar_for(expr const & t) {
unsigned num = get_arity(t);
expr r = mk_app_vars(mk_aux_type_metavar_for(t), num);
expr new_type = replace_range(t, r);
name n = m_ngen.next();
return mk_metavar(n, new_type);
}
/**
\brief Given t
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
return
<tt>fun (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), v</tt>
*/
expr mk_lambda_for(expr const & t, expr const & v) {
if (is_pi(t)) {
return mk_lambda(binding_name(t), binding_domain(t), mk_lambda_for(binding_body(t), v), binding_info(t));
} else {
return v;
}
}
/** /**
\brief Update occurrence index with entry <tt>m -> cidx</tt>, where \c m is the name of a metavariable, \brief Update occurrence index with entry <tt>m -> cidx</tt>, where \c m is the name of a metavariable,
and \c cidx is the index of a constraint that contains \c m. and \c cidx is the index of a constraint that contains \c m.
@ -306,6 +386,12 @@ struct unifier_fn {
/** \see add_occ */ /** \see add_occ */
void add_mlvl_occ(name const & m, unsigned cidx) { add_occ<false>(m, cidx); } void add_mlvl_occ(name const & m, unsigned cidx) { add_occ<false>(m, cidx); }
/**
\brief Update the indices \c m_mvar_occs and \c m_mlvl_occs.
For every metavariable name \c m in \c mlvl_occs and \c mvar_occs, add an entry to \c cidx.
\remark \c cidx is the index of some constraint in \c m_cnstrs.
*/
void add_occs(unsigned cidx, name_set const * mlvl_occs, name_set const * mvar_occs) { void add_occs(unsigned cidx, name_set const * mlvl_occs, name_set const * mvar_occs) {
if (mlvl_occs) { if (mlvl_occs) {
mlvl_occs->for_each([=](name const & m) { mlvl_occs->for_each([=](name const & m) {
@ -440,11 +526,9 @@ struct unifier_fn {
st = process_metavar_eq(rhs, lhs, new_jst); st = process_metavar_eq(rhs, lhs, new_jst);
if (st != Continue) return st == Assigned; if (st != Continue) return st == Assigned;
// Make sure the lhs/rhs are in weak-head-normal-form, when the other one is meta. // Make sure lhs/rhs are in weak-head-normal-form
if (is_meta(lhs)) rhs = m_tc.whnf(rhs);
rhs = m_tc.whnf(rhs); lhs = m_tc.whnf(lhs);
else if (is_meta(rhs))
lhs = m_tc.whnf(lhs);
// If lhs or rhs were updated, then invoke is_def_eq again. // If lhs or rhs were updated, then invoke is_def_eq again.
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) { if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
@ -678,14 +762,18 @@ struct unifier_fn {
} }
} }
bool process_flex_rigid(constraint const &) { bool next_ho_case_split(ho_case_split & cs) {
// TODO(Leo): if (!is_nil(cs.m_tail)) {
return true; cs.restore_state(*this);
} lean_assert(!in_conflict());
constraints c = head(cs.m_tail);
bool process_flex_flex(constraint const &) { cs.m_tail = tail(cs.m_tail);
// TODO(Leo): return process_constraints(c, mk_assumption_justification(cs.m_assumption_idx));
return true; } else {
// update conflict
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
return false;
}
} }
/** \brief Return true iff \c c is a flex-rigid constraint. */ /** \brief Return true iff \c c is a flex-rigid constraint. */
@ -702,6 +790,101 @@ struct unifier_fn {
return is_eq_cnstr(c) && is_meta(cnstr_lhs_expr(c)) && is_meta(cnstr_rhs_expr(c)); return is_eq_cnstr(c) && is_meta(cnstr_lhs_expr(c)) && is_meta(cnstr_rhs_expr(c));
} }
/** \brief Process a flex rigid constraint */
bool process_flex_rigid(expr const & lhs, expr const & rhs, justification const & j) {
lean_assert(is_meta(lhs));
lean_assert(!is_meta(rhs));
buffer<expr> margs;
expr m = get_app_args(lhs, margs);
expr mtype = mlocal_type(m);
buffer<constraints> alts;
lean_assert(!is_var(rhs)); // rhs can't be a free variable (this is an invariant of the approach we are using).
lean_assert(!is_let(rhs)); // rhs can't be a let, since the rhs is in whnf.
// Add Projections to alts
unsigned vidx = margs.size() - 1;
for (expr const & marg : margs) {
if (!is_local(marg) && !is_local(rhs)) {
// if rhs is not local, then we only add projections for the nonlocal arguments of lhs
constraint c1 = mk_eq_cnstr(marg, rhs, j);
constraint c2 = mk_eq_cnstr(m, mk_lambda_for(mtype, mk_var(vidx)), j);
alts.push_back(constraints({c1, c2}));
} else if (is_local(marg) && marg == rhs) {
// if the argument is local, and rhs is equal to it, then we also add a projection
constraint c1 = mk_eq_cnstr(m, mk_lambda_for(mtype, mk_var(vidx)), j);
alts.push_back(constraints(c1));
}
vidx--;
}
// Add Imitation to alts
buffer<constraint> cs;
bool imitate = true;
if (is_app(rhs)) {
buffer<expr> rargs;
expr f = get_app_args(rhs, rargs);
// create an auxiliary metavariable for each rhs argument
buffer<expr> sargs;
for (expr const & rarg : rargs) {
expr maux = mk_aux_metavar_for(mtype);
cs.push_back(mk_eq_cnstr(mk_app(maux, margs), rarg, j));
sargs.push_back(mk_app_vars(maux, margs.size()));
}
expr v = mk_app(f, sargs);
v = mk_lambda_for(mtype, v);
cs.push_back(mk_eq_cnstr(m, v, j));
} else if (is_binding(rhs)) {
expr maux1 = mk_aux_metavar_for(mtype);
cs.push_back(mk_eq_cnstr(mk_app(maux1, margs), binding_domain(rhs), j));
expr pi = mk_pi(binding_name(rhs), binding_domain(rhs), binding_body(rhs));
expr mtype2 = replace_range(mtype, pi); // trick for "extending" the context
expr maux2 = mk_aux_metavar_for(mtype2);
expr new_local = mk_local(m_ngen.next(), binding_name(rhs), binding_domain(rhs));
cs.push_back(mk_eq_cnstr(mk_app(mk_app(maux2, margs), new_local), instantiate(binding_body(rhs), new_local), j));
expr v = update_binding(rhs, mk_app_vars(maux1, margs.size()), mk_app_vars(maux2, margs.size() + 1));
v = mk_lambda_for(mtype, v);
cs.push_back(mk_eq_cnstr(m, v, j));
} else if (is_sort(rhs) || is_constant(rhs)) {
expr v = mk_lambda_for(mtype, rhs);
cs.push_back(mk_eq_cnstr(m, v, j));
} else if (is_local(rhs)) {
// We don't imitate when the right-hand-side is a local constant.
// The term (fun (ctx), local) is not well-formed.
imitate = false;
} else {
// we don't support macros
lean_assert(is_macro(rhs));
imitate = false;
}
if (imitate)
alts.push_back(to_list(cs.begin(), cs.end()));
if (alts.empty()) {
set_conflict(j);
return false;
} else if (alts.size() == 1) {
// we don't need to create a backtracking point
return process_constraints(alts[0], justification());
} else {
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
return process_constraints(alts[0], a);
}
}
/** \brief Process a flex rigid constraint */
bool process_flex_rigid(constraint const & c) {
lean_assert(is_flex_rigid(c));
if (is_meta(cnstr_lhs_expr(c)))
return process_flex_rigid(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification());
else
return process_flex_rigid(cnstr_rhs_expr(c), cnstr_lhs_expr(c), c.get_justification());
}
bool process_flex_flex(constraint const &) {
// TODO(Leo):
return true;
}
/** \brief Process the next constraint in the constraint queue m_cnstrs */ /** \brief Process the next constraint in the constraint queue m_cnstrs */
bool process_next() { bool process_next() {
lean_assert(!m_cnstrs.empty()); lean_assert(!m_cnstrs.empty());

26
tests/lua/unify4.lua Normal file
View file

@ -0,0 +1,26 @@
function test_unify(env, m, lhs, rhs, num_s)
print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s))
local ss = unify(env, lhs, rhs)
local n = 0
for s in ss do
print("solution: " .. tostring(s:instantiate(m)))
n = n + 1
end
if num_s ~= n then print("n: " .. n) end
assert(num_s == n)
end
local env = environment()
env = add_decl(env, mk_var_decl("N", Type))
local N = Const("N")
env = add_decl(env, mk_var_decl("f", mk_arrow(N, N, N)))
env = add_decl(env, mk_var_decl("a", N))
local f = Const("f")
local a = Const("a")
local l1 = mk_local("l1", "x", N)
local l2 = mk_local("l2", "y", N)
local l3 = mk_local("l3", "z", N)
local m = mk_metavar("m", mk_arrow(N, N, mk_metavar("m_type", mk_arrow(N, N, mk_sort(mk_meta_univ("u"))))(Var(1), Var(0))))
test_unify(env, m, m(l1, l1), f(f(a, l1), l1), 4)
print("-----------------")
test_unify(env, m, m(l1, l1), mk_lambda("z", Bool, f(l1, f(Var(0), a))), 2)