feat(library/unifier): implement flex-rigid case
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
68d55ef398
commit
7b188ea37e
2 changed files with 222 additions and 13 deletions
|
@ -12,6 +12,7 @@ Author: Leonardo de Moura
|
|||
#include "util/lazy_list_fn.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "library/unifier.h"
|
||||
#include "library/kernel_bindings.h"
|
||||
|
@ -255,6 +256,12 @@ struct unifier_fn {
|
|||
virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); }
|
||||
};
|
||||
|
||||
struct ho_case_split : public case_split {
|
||||
list<constraints> m_tail;
|
||||
ho_case_split(unifier_fn & u, list<constraints> const & tail):case_split(u), m_tail(tail) {}
|
||||
virtual bool next(unifier_fn & u) { return u.next_ho_case_split(*this); }
|
||||
};
|
||||
|
||||
case_split_stack m_case_splits;
|
||||
optional<justification> m_conflict; //!< if different from none, then there is a conflict.
|
||||
|
||||
|
@ -282,6 +289,79 @@ struct unifier_fn {
|
|||
void update_conflict(justification const & j) { m_conflict = j; }
|
||||
void reset_conflict() { m_conflict = optional<justification>(); lean_assert(!in_conflict()); }
|
||||
|
||||
/** \brief Given \c type of the form <tt>(Pi ctx, r)</tt>, return <tt>(Pi ctx, new_range)</tt> */
|
||||
static expr replace_range(expr const & type, expr const & new_range) {
|
||||
if (is_pi(type))
|
||||
return update_binding(type, binding_domain(type), replace_range(binding_body(type), new_range));
|
||||
else
|
||||
return new_range;
|
||||
}
|
||||
|
||||
/** \brief Return the "arity" of the given type. The arity is the number of nested pi-expressions. */
|
||||
static unsigned get_arity(expr type) {
|
||||
unsigned r = 0;
|
||||
while (is_pi(type)) {
|
||||
type = binding_body(type);
|
||||
r++;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
/** \brief Return the term (f #n-1 ... #0) */
|
||||
static expr mk_app_vars(expr const & f, unsigned n) {
|
||||
expr r = f;
|
||||
unsigned i = n;
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = r(mk_var(i));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Given a type \c t of the form
|
||||
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
|
||||
return a new metavariable \c m1 with type
|
||||
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), Type.{u}</tt>
|
||||
where \c u is a new universe metavariable.
|
||||
*/
|
||||
expr mk_aux_type_metavar_for(expr const & t) {
|
||||
expr new_type = replace_range(t, mk_sort(mk_meta_univ(m_ngen.next())));
|
||||
name n = m_ngen.next();
|
||||
return mk_metavar(n, new_type);
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Given a type \c t of the form
|
||||
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
|
||||
return a new metavariable \c m1 with type
|
||||
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), (m2 x_1 ... x_n)</tt>
|
||||
where \c m2 is a new metavariable with type
|
||||
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), Type.{u}</tt>
|
||||
where \c u is a new universe metavariable.
|
||||
*/
|
||||
expr mk_aux_metavar_for(expr const & t) {
|
||||
unsigned num = get_arity(t);
|
||||
expr r = mk_app_vars(mk_aux_type_metavar_for(t), num);
|
||||
expr new_type = replace_range(t, r);
|
||||
name n = m_ngen.next();
|
||||
return mk_metavar(n, new_type);
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Given t
|
||||
<tt>Pi (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), B[x_1, ..., x_n]</tt>
|
||||
return
|
||||
<tt>fun (x_1 : A_1) ... (x_n : A_n[x_1, ..., x_{n-1}]), v</tt>
|
||||
*/
|
||||
expr mk_lambda_for(expr const & t, expr const & v) {
|
||||
if (is_pi(t)) {
|
||||
return mk_lambda(binding_name(t), binding_domain(t), mk_lambda_for(binding_body(t), v), binding_info(t));
|
||||
} else {
|
||||
return v;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Update occurrence index with entry <tt>m -> cidx</tt>, where \c m is the name of a metavariable,
|
||||
and \c cidx is the index of a constraint that contains \c m.
|
||||
|
@ -306,6 +386,12 @@ struct unifier_fn {
|
|||
/** \see add_occ */
|
||||
void add_mlvl_occ(name const & m, unsigned cidx) { add_occ<false>(m, cidx); }
|
||||
|
||||
/**
|
||||
\brief Update the indices \c m_mvar_occs and \c m_mlvl_occs.
|
||||
For every metavariable name \c m in \c mlvl_occs and \c mvar_occs, add an entry to \c cidx.
|
||||
|
||||
\remark \c cidx is the index of some constraint in \c m_cnstrs.
|
||||
*/
|
||||
void add_occs(unsigned cidx, name_set const * mlvl_occs, name_set const * mvar_occs) {
|
||||
if (mlvl_occs) {
|
||||
mlvl_occs->for_each([=](name const & m) {
|
||||
|
@ -440,11 +526,9 @@ struct unifier_fn {
|
|||
st = process_metavar_eq(rhs, lhs, new_jst);
|
||||
if (st != Continue) return st == Assigned;
|
||||
|
||||
// Make sure the lhs/rhs are in weak-head-normal-form, when the other one is meta.
|
||||
if (is_meta(lhs))
|
||||
rhs = m_tc.whnf(rhs);
|
||||
else if (is_meta(rhs))
|
||||
lhs = m_tc.whnf(lhs);
|
||||
// Make sure lhs/rhs are in weak-head-normal-form
|
||||
rhs = m_tc.whnf(rhs);
|
||||
lhs = m_tc.whnf(lhs);
|
||||
|
||||
// If lhs or rhs were updated, then invoke is_def_eq again.
|
||||
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
|
||||
|
@ -678,14 +762,18 @@ struct unifier_fn {
|
|||
}
|
||||
}
|
||||
|
||||
bool process_flex_rigid(constraint const &) {
|
||||
// TODO(Leo):
|
||||
return true;
|
||||
}
|
||||
|
||||
bool process_flex_flex(constraint const &) {
|
||||
// TODO(Leo):
|
||||
return true;
|
||||
bool next_ho_case_split(ho_case_split & cs) {
|
||||
if (!is_nil(cs.m_tail)) {
|
||||
cs.restore_state(*this);
|
||||
lean_assert(!in_conflict());
|
||||
constraints c = head(cs.m_tail);
|
||||
cs.m_tail = tail(cs.m_tail);
|
||||
return process_constraints(c, mk_assumption_justification(cs.m_assumption_idx));
|
||||
} else {
|
||||
// update conflict
|
||||
update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c c is a flex-rigid constraint. */
|
||||
|
@ -702,6 +790,101 @@ struct unifier_fn {
|
|||
return is_eq_cnstr(c) && is_meta(cnstr_lhs_expr(c)) && is_meta(cnstr_rhs_expr(c));
|
||||
}
|
||||
|
||||
/** \brief Process a flex rigid constraint */
|
||||
bool process_flex_rigid(expr const & lhs, expr const & rhs, justification const & j) {
|
||||
lean_assert(is_meta(lhs));
|
||||
lean_assert(!is_meta(rhs));
|
||||
buffer<expr> margs;
|
||||
expr m = get_app_args(lhs, margs);
|
||||
expr mtype = mlocal_type(m);
|
||||
buffer<constraints> alts;
|
||||
lean_assert(!is_var(rhs)); // rhs can't be a free variable (this is an invariant of the approach we are using).
|
||||
lean_assert(!is_let(rhs)); // rhs can't be a let, since the rhs is in whnf.
|
||||
// Add Projections to alts
|
||||
unsigned vidx = margs.size() - 1;
|
||||
for (expr const & marg : margs) {
|
||||
if (!is_local(marg) && !is_local(rhs)) {
|
||||
// if rhs is not local, then we only add projections for the nonlocal arguments of lhs
|
||||
constraint c1 = mk_eq_cnstr(marg, rhs, j);
|
||||
constraint c2 = mk_eq_cnstr(m, mk_lambda_for(mtype, mk_var(vidx)), j);
|
||||
alts.push_back(constraints({c1, c2}));
|
||||
} else if (is_local(marg) && marg == rhs) {
|
||||
// if the argument is local, and rhs is equal to it, then we also add a projection
|
||||
constraint c1 = mk_eq_cnstr(m, mk_lambda_for(mtype, mk_var(vidx)), j);
|
||||
alts.push_back(constraints(c1));
|
||||
}
|
||||
vidx--;
|
||||
}
|
||||
// Add Imitation to alts
|
||||
buffer<constraint> cs;
|
||||
bool imitate = true;
|
||||
if (is_app(rhs)) {
|
||||
buffer<expr> rargs;
|
||||
expr f = get_app_args(rhs, rargs);
|
||||
// create an auxiliary metavariable for each rhs argument
|
||||
buffer<expr> sargs;
|
||||
for (expr const & rarg : rargs) {
|
||||
expr maux = mk_aux_metavar_for(mtype);
|
||||
cs.push_back(mk_eq_cnstr(mk_app(maux, margs), rarg, j));
|
||||
sargs.push_back(mk_app_vars(maux, margs.size()));
|
||||
}
|
||||
expr v = mk_app(f, sargs);
|
||||
v = mk_lambda_for(mtype, v);
|
||||
cs.push_back(mk_eq_cnstr(m, v, j));
|
||||
} else if (is_binding(rhs)) {
|
||||
expr maux1 = mk_aux_metavar_for(mtype);
|
||||
cs.push_back(mk_eq_cnstr(mk_app(maux1, margs), binding_domain(rhs), j));
|
||||
expr pi = mk_pi(binding_name(rhs), binding_domain(rhs), binding_body(rhs));
|
||||
expr mtype2 = replace_range(mtype, pi); // trick for "extending" the context
|
||||
expr maux2 = mk_aux_metavar_for(mtype2);
|
||||
expr new_local = mk_local(m_ngen.next(), binding_name(rhs), binding_domain(rhs));
|
||||
cs.push_back(mk_eq_cnstr(mk_app(mk_app(maux2, margs), new_local), instantiate(binding_body(rhs), new_local), j));
|
||||
expr v = update_binding(rhs, mk_app_vars(maux1, margs.size()), mk_app_vars(maux2, margs.size() + 1));
|
||||
v = mk_lambda_for(mtype, v);
|
||||
cs.push_back(mk_eq_cnstr(m, v, j));
|
||||
} else if (is_sort(rhs) || is_constant(rhs)) {
|
||||
expr v = mk_lambda_for(mtype, rhs);
|
||||
cs.push_back(mk_eq_cnstr(m, v, j));
|
||||
} else if (is_local(rhs)) {
|
||||
// We don't imitate when the right-hand-side is a local constant.
|
||||
// The term (fun (ctx), local) is not well-formed.
|
||||
imitate = false;
|
||||
} else {
|
||||
// we don't support macros
|
||||
lean_assert(is_macro(rhs));
|
||||
imitate = false;
|
||||
}
|
||||
if (imitate)
|
||||
alts.push_back(to_list(cs.begin(), cs.end()));
|
||||
|
||||
if (alts.empty()) {
|
||||
set_conflict(j);
|
||||
return false;
|
||||
} else if (alts.size() == 1) {
|
||||
// we don't need to create a backtracking point
|
||||
return process_constraints(alts[0], justification());
|
||||
} else {
|
||||
justification a = mk_assumption_justification(m_next_assumption_idx);
|
||||
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
|
||||
return process_constraints(alts[0], a);
|
||||
}
|
||||
}
|
||||
|
||||
/** \brief Process a flex rigid constraint */
|
||||
bool process_flex_rigid(constraint const & c) {
|
||||
lean_assert(is_flex_rigid(c));
|
||||
if (is_meta(cnstr_lhs_expr(c)))
|
||||
return process_flex_rigid(cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification());
|
||||
else
|
||||
return process_flex_rigid(cnstr_rhs_expr(c), cnstr_lhs_expr(c), c.get_justification());
|
||||
}
|
||||
|
||||
bool process_flex_flex(constraint const &) {
|
||||
// TODO(Leo):
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/** \brief Process the next constraint in the constraint queue m_cnstrs */
|
||||
bool process_next() {
|
||||
lean_assert(!m_cnstrs.empty());
|
||||
|
|
26
tests/lua/unify4.lua
Normal file
26
tests/lua/unify4.lua
Normal file
|
@ -0,0 +1,26 @@
|
|||
function test_unify(env, m, lhs, rhs, num_s)
|
||||
print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s))
|
||||
local ss = unify(env, lhs, rhs)
|
||||
local n = 0
|
||||
for s in ss do
|
||||
print("solution: " .. tostring(s:instantiate(m)))
|
||||
n = n + 1
|
||||
end
|
||||
if num_s ~= n then print("n: " .. n) end
|
||||
assert(num_s == n)
|
||||
end
|
||||
|
||||
local env = environment()
|
||||
env = add_decl(env, mk_var_decl("N", Type))
|
||||
local N = Const("N")
|
||||
env = add_decl(env, mk_var_decl("f", mk_arrow(N, N, N)))
|
||||
env = add_decl(env, mk_var_decl("a", N))
|
||||
local f = Const("f")
|
||||
local a = Const("a")
|
||||
local l1 = mk_local("l1", "x", N)
|
||||
local l2 = mk_local("l2", "y", N)
|
||||
local l3 = mk_local("l3", "z", N)
|
||||
local m = mk_metavar("m", mk_arrow(N, N, mk_metavar("m_type", mk_arrow(N, N, mk_sort(mk_meta_univ("u"))))(Var(1), Var(0))))
|
||||
test_unify(env, m, m(l1, l1), f(f(a, l1), l1), 4)
|
||||
print("-----------------")
|
||||
test_unify(env, m, m(l1, l1), mk_lambda("z", Bool, f(l1, f(Var(0), a))), 2)
|
Loading…
Reference in a new issue