feat(elaborator): add support for constraints of the form ?m[inst, ...] == t, fix bugs, add more tests

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-10-22 16:39:22 -07:00
parent 891d22b3de
commit 7ad256131e
2 changed files with 376 additions and 35 deletions

View file

@ -296,6 +296,15 @@ class elaborator::imp {
push_new_constraint(q, true, new_ctx, new_a, new_b, new_tr);
}
/**
\brief Auxiliary method for pushing a new constraint to the current constraint queue.
If \c is_eq is true, then a equality constraint is created, otherwise a convertability constraint is created.
*/
void push_new_constraint(bool is_eq, context const & new_ctx, expr const & new_a, expr const & new_b, trace const & new_tr) {
reset_quota();
push_new_constraint(m_state.m_queue, is_eq, new_ctx, new_a, new_b, new_tr);
}
/**
\brief Auxiliary method for pushing a new constraint to the current constraint queue.
The new constraint is based on the constraint \c c. The constraint \c c may be a equality or convertability constraint.
@ -405,15 +414,23 @@ class elaborator::imp {
}
} else {
local_entry const & me = head(metavar_lctx(a));
if (me.is_lift() && !has_free_var(b, me.s(), me.s() + me.n())) {
// Case 3
trace new_tr(new normalize_trace(c));
expr new_a = pop_meta_context(a);
expr new_b = lower_free_vars(b, me.s() + me.n(), me.n());
if (!is_lhs)
swap(new_a, new_b);
push_updated_constraint(c, new_a, new_b, new_tr);
return Processed;
if (me.is_lift()) {
if (!has_free_var(b, me.s(), me.s() + me.n())) {
// Case 3
trace new_tr(new normalize_trace(c));
expr new_a = pop_meta_context(a);
expr new_b = lower_free_vars(b, me.s() + me.n(), me.n());
context new_ctx = get_context(c).remove(me.s(), me.n());
if (!is_lhs)
swap(new_a, new_b);
push_new_constraint(is_eq(c), new_ctx, new_a, new_b, new_tr);
return Processed;
} else if (is_var(b)) {
// Failure, there is no way to unify
// ?m[lift:s:n, ...] with a variable in [s, s+n]
m_conflict = trace(new unification_failure_trace(c));
return Failed;
}
}
}
}
@ -691,7 +708,7 @@ class elaborator::imp {
expr proj = mk_lambda(arg_types, mk_var(num_a - i - 1));
expr new_a = arg(a, i);
expr new_b = b;
if (is_lhs)
if (!is_lhs)
swap(new_a, new_b);
push_new_constraint(new_state.m_queue, is_eq(c), ctx, new_a, new_b, new_assumption);
push_new_eq_constraint(new_state.m_queue, ctx, f_a, proj, new_assumption);
@ -755,6 +772,82 @@ class elaborator::imp {
}
}
/** \brief Return true if \c a is of the form ?m[inst:i t, ...] */
bool is_metavar_inst(expr const & a) const {
return is_metavar(a) && has_local_context(a) && head(metavar_lctx(a)).is_inst();
}
/**
\brief Process a constraint <tt>ctx |- a == b</tt> where \c a is of the form <tt>?m[(inst:i t), ...]</tt>.
We perform a "case split",
Case 1) ?m[...] == #i and t == b
Case 2) imitate b
*/
bool process_metavar_inst(expr const & a, expr const & b, bool is_lhs, unification_constraint const & c) {
if (is_metavar_inst(a) && !is_metavar_inst(b) && !is_meta_app(b)) {
context const & ctx = get_context(c);
local_context lctx = metavar_lctx(a);
unsigned i = head(lctx).s();
expr t = head(lctx).v();
std::unique_ptr<generic_case_split> new_cs(new generic_case_split(c, m_state));
{
// Case 1
state new_state(m_state);
trace new_assumption = mk_assumption();
// add ?m[...] == #1
push_new_eq_constraint(new_state.m_queue, ctx, pop_meta_context(a), mk_var(i), new_assumption);
// add t == b (t << b)
expr new_a = t;
expr new_b = b;
if (!is_lhs)
swap(new_a, new_b);
push_new_constraint(new_state.m_queue, is_eq(c), ctx, new_a, new_b, new_assumption);
new_cs->push_back(new_state, new_assumption);
}
{
// Case 2
state new_state(m_state);
trace new_assumption = mk_assumption();
expr imitation;
if (is_app(b)) {
// Imitation for applications b == f(s_1, ..., s_k)
// mname <- f(?h_1, ..., ?h_k)
expr f_b = arg(b, 0);
unsigned num_b = num_args(b);
buffer<expr> imitation_args;
imitation_args.push_back(f_b);
for (unsigned i = 1; i < num_b; i++)
imitation_args.push_back(new_state.m_menv.mk_metavar(ctx));
imitation = mk_app(imitation_args.size(), imitation_args.data());
} else if (is_eq(b)) {
// Imitation for equality b == Eq(s1, s2)
// mname <- Eq(?h_1, ?h_2)
expr h_1 = new_state.m_menv.mk_metavar(ctx);
expr h_2 = new_state.m_menv.mk_metavar(ctx);
imitation = mk_eq(h_1, h_2);
} else if (is_abstraction(b)) {
// Lambdas and Pis
// Imitation for Lambdas and Pis, b == Fun(x:T) B
// mname <- Fun (x:?h_1) ?h_2 x)
expr h_1 = new_state.m_menv.mk_metavar(ctx);
expr h_2 = new_state.m_menv.mk_metavar(ctx);
imitation = update_abstraction(b, h_1, mk_app(h_2, Var(0)));
} else {
imitation = lift_free_vars(b, i, 1);
}
push_new_eq_constraint(new_state.m_queue, ctx, pop_meta_context(a), imitation, new_assumption);
new_cs->push_back(new_state, new_assumption);
}
bool r = new_cs->next(*this);
lean_assert(r);
m_case_splits.push_back(std::move(new_cs));
reset_quota();
return r;
} else {
return false;
}
}
/** \brief Process constraint of the form ctx |- a << ?m, where \c a is Type of Bool */
bool process_lower(expr const & a, expr const & b, unification_constraint const & c) {
if (is_convertible(c) && is_metavar(b) && (a == Bool || is_type(a))) {
@ -776,24 +869,6 @@ class elaborator::imp {
}
}
/**
\brief Process a constraints of the form:
- true == (t1 = t2)
- true << (t1 = t2)
\remark This method should be removed if we remove T == T ==> true normalization rule from the
kernel.
*/
bool process_true_eq(expr const & a, expr const & b, unification_constraint const & c) {
if (a == True && is_eq(b)) {
trace new_tr(new normalize_trace(c));
push_front(mk_eq_constraint(get_context(c), eq_lhs(b), eq_rhs(b), new_tr));
return true;
} else {
return false;
}
}
bool process_eq_convertible(context const & ctx, expr const & a, expr const & b, unification_constraint const & c) {
bool eq = is_eq(c);
if (a == b) {
@ -818,8 +893,7 @@ class elaborator::imp {
process_simple_ho_match(ctx, b, a, false, c))
return true;
if (process_true_eq(a, b, c) ||
process_true_eq(b, a, c))
if (!eq && a == Bool && is_type(b))
return true;
if (a.kind() == b.kind()) {
@ -895,6 +969,8 @@ class elaborator::imp {
// process expensive cases
if (process_meta_app(a, b, true, c) || process_meta_app(b, a, false, c))
return true;
if (process_metavar_inst(a, b, true, c) || process_metavar_inst(b, a, false, c))
return true;
}
if (m_quota < - static_cast<int>(m_state.m_queue.size())) {
@ -903,7 +979,7 @@ class elaborator::imp {
return true;
}
std::cout << "Postponed: "; display(std::cout, c);
// std::cout << "Postponed: "; display(std::cout, c);
push_back(c);
return true;
@ -932,7 +1008,7 @@ class elaborator::imp {
while (!m_case_splits.empty()) {
std::unique_ptr<case_split> & d = m_case_splits.back();
std::cout << "Assumption " << d->m_curr_assumption.pp(fmt, options(), nullptr, true) << "\n";
// std::cout << "Assumption " << d->m_curr_assumption.pp(fmt, options(), nullptr, true) << "\n";
if (depends_on(m_conflict, d->m_curr_assumption)) {
d->m_failed_traces.push_back(m_conflict);
if (d->next(*this)) {
@ -1035,7 +1111,7 @@ public:
}
} else {
unification_constraint c = q.front();
std::cout << "Processing, quota: " << m_quota << ", depth: " << m_case_splits.size() << " "; display(std::cout, c);
// std::cout << "Processing, quota: " << m_quota << ", depth: " << m_case_splits.size() << " "; display(std::cout, c);
q.pop_front();
if (!process(c)) {
resolve_conflict();

View file

@ -280,12 +280,29 @@ static expr elaborate(expr const & e, environment const & env) {
// Check elaborator success
static void success(expr const & e, expr const & expected, environment const & env) {
std::cout << "\n" << e << "\n------>\n";
std::cout << "\n" << e << "\n\n";
expr r = elaborate(e, env);
std::cout << r << "\n";
std::cout << "\n" << e << "\n------>\n" << r << "\n";
lean_assert(r == expected);
}
// Check elaborator failure
static void fails(expr const & e, environment const & env) {
try {
expr new_e = elaborate(e, env);
std::cout << "new_e: " << new_e << std::endl;
lean_unreachable();
} catch (exception &) {
}
}
// Check elaborator partial success (i.e., result still contain some metavariables */
static void unsolved(expr const & e, environment const & env) {
expr r = elaborate(e, env);
std::cout << "\n" << e << "\n------>\n" << r << "\n";
lean_assert(has_metavar(r));
}
static void tst7() {
environment env;
import_all(env);
@ -362,6 +379,243 @@ static void tst9() {
success(Refl(_, a), Refl(Nat, a), env);
}
static void tst10() {
environment env;
import_all(env);
expr Nat = Const("N");
env.add_var("N", Type());
expr R = Const("R");
env.add_var("R", Type());
env.add_var("a", Nat);
expr a = Const("a");
expr f = Const("f");
env.add_var("f", Nat >> ((R >> Nat) >> R));
expr x = Const("x");
expr y = Const("y");
expr z = Const("z");
success(Fun({{x, _}, {y, _}}, f(x, y)),
Fun({{x, Nat}, {y, R >> Nat}}, f(x, y)), env);
success(Fun({{x, _}, {y, _}, {z, _}}, Eq(f(x, y), f(x, z))),
Fun({{x, Nat}, {y, R >> Nat}, {z, R >> Nat}}, Eq(f(x, y), f(x, z))), env);
expr A = Const("A");
success(Fun({{A, Type()}, {x, _}, {y, _}, {z, _}}, Eq(f(x, y), f(x, z))),
Fun({{A, Type()}, {x, Nat}, {y, R >> Nat}, {z, R >> Nat}}, Eq(f(x, y), f(x, z))), env);
}
static void tst11() {
environment env;
import_all(env);
expr A = Const("A");
expr B = Const("B");
expr a = Const("a");
expr b = Const("b");
expr f = Const("f");
expr g = Const("g");
expr Nat = Const("N");
env.add_var("N", Type());
env.add_var("f", Pi({{A, Type()}, {a, A}, {b, A}}, A));
env.add_var("g", Nat >> Nat);
success(Fun({{a, _}, {b, _}}, g(f(_, a, b))),
Fun({{a, Nat}, {b, Nat}}, g(f(Nat, a, b))), env);
}
static void tst12() {
environment env;
import_all(env);
expr lst = Const("list");
expr nil = Const("nil");
expr cons = Const("cons");
expr N = Const("N");
expr A = Const("A");
expr f = Const("f");
expr l = Const("l");
expr a = Const("a");
env.add_var("N", Type());
env.add_var("list", Type() >> Type());
env.add_var("nil", Pi({A, Type()}, lst(A)));
env.add_var("cons", Pi({{A, Type()}, {a, A}, {l, lst(A)}}, lst(A)));
env.add_var("f", lst(N >> N) >> Bool);
success(Fun({a, _}, f(cons(_, a, cons(_, a, nil(_))))),
Fun({a, N >> N}, f(cons(N >> N, a, cons(N >> N, a, nil(N >> N))))), env);
}
static void tst13() {
environment env;
import_all(env);
expr B = Const("B");
expr A = Const("A");
expr x = Const("x");
expr f = Const("f");
env.add_var("f", Pi({B, Type()}, B >> B));
success(Fun({{A, Type()}, {B, Type()}, {x, _}}, f(B, x)),
Fun({{A, Type()}, {B, Type()}, {x, B}}, f(B, x)), env);
fails(Fun({{x, _}, {A, Type()}}, f(A, x)), env);
success(Fun({{A, Type()}, {x, _}}, f(A, x)),
Fun({{A, Type()}, {x, A}}, f(A, x)), env);
success(Fun({{A, Type()}, {B, Type()}, {x, _}}, f(A, x)),
Fun({{A, Type()}, {B, Type()}, {x, A}}, f(A, x)), env);
success(Fun({{A, Type()}, {B, Type()}, {x, _}}, Eq(f(B, x), f(_, x))),
Fun({{A, Type()}, {B, Type()}, {x, B}}, Eq(f(B, x), f(B, x))), env);
success(Fun({{A, Type()}, {B, Type()}, {x, _}}, Eq(f(B, x), f(_, x))),
Fun({{A, Type()}, {B, Type()}, {x, B}}, Eq(f(B, x), f(B, x))), env);
unsolved(Fun({{A, _}, {B, _}, {x, _}}, Eq(f(B, x), f(_, x))), env);
}
static void tst14() {
environment env;
import_all(env);
expr A = Const("A");
expr B = Const("B");
expr f = Const("f");
expr g = Const("g");
expr x = Const("x");
expr y = Const("y");
env.add_var("N", Type());
env.add_var("f", Pi({A, Type()}, A >> A));
expr N = Const("N");
success(Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, g(_, True, False)),
Fun({g, Pi({A, Type()}, A >> (A >> Bool))}, g(Bool, True, False)),
env);
success(Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(_, Bool, Bool)),
Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(Type(), Bool, Bool)),
env);
success(Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(_, Bool, N)),
Fun({g, Pi({A, TypeU}, A >> (A >> Bool))}, g(Type(), Bool, N)),
env);
success(Fun({g, Pi({A, Type()}, A >> (A >> Bool))},
g(_,
Fun({{x, _}, {y, _}}, Eq(f(_, x), f(_, y))),
Fun({{x, N}, {y, Bool}}, True))),
Fun({g, Pi({A, Type()}, A >> (A >> Bool))},
g((N >> (Bool >> Bool)),
Fun({{x, N}, {y, Bool}}, Eq(f(N, x), f(Bool, y))),
Fun({{x, N}, {y, Bool}}, True))), env);
success(Fun({g, Pi({A, Type()}, A >> (A >> Bool))},
g(_,
Fun({{x, N}, {y, _}}, Eq(f(_, x), f(_, y))),
Fun({{x, _}, {y, Bool}}, True))),
Fun({g, Pi({A, Type()}, A >> (A >> Bool))},
g((N >> (Bool >> Bool)),
Fun({{x, N}, {y, Bool}}, Eq(f(N, x), f(Bool, y))),
Fun({{x, N}, {y, Bool}}, True))), env);
}
static void tst15() {
environment env;
import_all(env);
expr A = Const("A");
expr B = Const("B");
expr C = Const("C");
expr a = Const("a");
expr b = Const("b");
expr eq = Const("eq");
env.add_var("eq", Pi({A, Type()}, A >> (A >> Bool)));
success(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, B}}, eq(_, a, b)),
Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}}, eq(B, a, b)), env);
success(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, A}}, eq(_, a, b)),
Fun({{A, Type()}, {B, Type()}, {a, A}, {b, A}}, eq(A, a, b)), env);
success(Fun({{A, Type()}, {B, Type()}, {a, A}, {b, _}}, eq(_, a, b)),
Fun({{A, Type()}, {B, Type()}, {a, A}, {b, A}}, eq(A, a, b)), env);
success(Fun({{A, Type()}, {B, Type()}, {a, B}, {b, _}}, eq(_, a, b)),
Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}}, eq(B, a, b)), env);
success(Fun({{A, Type()}, {B, Type()}, {a, B}, {b, _}, {C, Type()}}, eq(_, a, b)),
Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}, {C, Type()}}, eq(B, a, b)), env);
fails(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, _}, {C, Type()}}, eq(C, a, b)), env);
success(Fun({{A, Type()}, {B, Type()}, {a, _}, {b, _}, {C, Type()}}, eq(B, a, b)),
Fun({{A, Type()}, {B, Type()}, {a, B}, {b, B}, {C, Type()}}, eq(B, a, b)), env);
}
static void tst16() {
environment env;
import_all(env);
expr a = Const("a");
expr b = Const("b");
expr c = Const("c");
expr H1 = Const("H1");
expr H2 = Const("H2");
env.add_var("a", Bool);
env.add_var("b", Bool);
env.add_var("c", Bool);
success(Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}},
Trans(_, _, _, _, H1, H2)),
Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}},
Trans(Bool, a, b, c, H1, H2)),
env);
expr H3 = Const("H3");
success(Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}},
EqTIntro(_, EqMP(_, _, Symm(_, _, _, Trans(_, _, _, _, Symm(_, _, _, H2), Symm(_, _, _, H1))), H3))),
Fun({{H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}},
EqTIntro(c, EqMP(a, c, Symm(Bool, c, a, Trans(Bool, c, b, a, Symm(Bool, b, c, H2), Symm(Bool, a, b, H1))), H3))),
env);
environment env2;
import_all(env2);
success(Fun({{a, Bool}, {b, Bool}, {c, Bool}, {H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}},
EqTIntro(_, EqMP(_, _, Symm(_, _, _, Trans(_, _, _, _, Symm(_, _, _, H2), Symm(_, _, _, H1))), H3))),
Fun({{a, Bool}, {b, Bool}, {c, Bool}, {H1, Eq(a, b)}, {H2, Eq(b, c)}, {H3, a}},
EqTIntro(c, EqMP(a, c, Symm(Bool, c, a, Trans(Bool, c, b, a, Symm(Bool, b, c, H2), Symm(Bool, a, b, H1))), H3))),
env2);
expr A = Const("A");
success(Fun({{A, Type()}, {a, A}, {b, A}, {c, A}, {H1, Eq(a, b)}, {H2, Eq(b, c)}},
Symm(_, _, _, Trans(_, _, _, _, Symm(_, _, _, H2), Symm(_, _, _, H1)))),
Fun({{A, Type()}, {a, A}, {b, A}, {c, A}, {H1, Eq(a, b)}, {H2, Eq(b, c)}},
Symm(A, c, a, Trans(A, c, b, a, Symm(A, b, c, H2), Symm(A, a, b, H1)))),
env2);
}
void tst17() {
environment env;
import_all(env);
expr A = Const("A");
expr B = Const("B");
expr a = Const("a");
expr b = Const("b");
expr eq = Const("eq");
env.add_var("eq", Pi({A, Type(level()+1)}, A >> (A >> Bool)));
success(eq(_, Fun({{A, Type()}, {a, _}}, a), Fun({{B, Type()}, {b, B}}, b)),
eq(Pi({A, Type()}, A >> A), Fun({{A, Type()}, {a, A}}, a), Fun({{B, Type()}, {b, B}}, b)),
env);
}
void tst18() {
environment env;
import_all(env);
expr A = Const("A");
expr h = Const("h");
expr f = Const("f");
expr a = Const("a");
env.add_var("h", Pi({A, Type()}, A) >> Bool);
success(Fun({{f, Pi({A, Type()}, _)}, {a, Bool}}, h(f)),
Fun({{f, Pi({A, Type()}, A)}, {a, Bool}}, h(f)),
env);
}
void tst19() {
environment env;
import_all(env);
expr R = Const("R");
expr A = Const("A");
expr r = Const("r");
expr eq = Const("eq");
expr f = Const("f");
expr g = Const("g");
expr h = Const("h");
expr D = Const("D");
env.add_var("R", Type() >> Bool);
env.add_var("r", Pi({A, Type()}, R(A)));
env.add_var("h", Pi({A, Type()}, R(A)) >> Bool);
env.add_var("eq", Pi({A, Type(level()+1)}, A >> (A >> Bool)));
success(Let({{f, Fun({A, Type()}, r(_))},
{g, Fun({A, Type()}, r(_))},
{D, Fun({A, Type()}, eq(_, f(A), g(_)))}},
h(f)),
Let({{f, Fun({A, Type()}, r(A))},
{g, Fun({A, Type()}, r(A))},
{D, Fun({A, Type()}, eq(R(A), f(A), g(A)))}},
h(f)),
env);
}
int main() {
tst1();
tst2();
@ -371,6 +625,17 @@ int main() {
tst6();
tst7();
tst8();
tst9();
tst10();
tst11();
tst12();
tst13();
tst14();
tst15();
tst16();
tst17();
tst18();
tst19();
return has_violations() ? 1 : 0;
}