diff --git a/src/library/ho_unifier.cpp b/src/library/ho_unifier.cpp index d64cbbd6a..f149f10fd 100644 --- a/src/library/ho_unifier.cpp +++ b/src/library/ho_unifier.cpp @@ -218,6 +218,10 @@ class ho_unifier::imp { bool failed() const { return m_failed; } }; + expr mk_lambda(name const & n, expr const & d, expr const & b) { + return ::lean::mk_lambda(n, d, b); + } + /** \brief Create the term (fun (x_0 : types[0]) ... (x_{n-1} : types[n-1]) body) */ @@ -226,7 +230,7 @@ class ho_unifier::imp { unsigned i = types.size(); while (i > 0) { --i; - r = ::lean::mk_lambda(name(g_x_name, i), types[i], r); + r = mk_lambda(name(g_x_name, i), types[i], r); } return r; } @@ -288,6 +292,7 @@ class ho_unifier::imp { expr f_b = arg(b, 0); unsigned num_b = num_args(b); // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), f_b (h_1 x_1 ... x_{num_a-1}) ... (h_{num_b-1} x_1 ... x_{num_a-1}) + // New constraints (h_i a_1 ... a_{num_a-1}) == arg(b, i) buffer imitation_args; // arguments for the imitation imitation_args.push_back(f_b); for (unsigned i = 1; i < num_b; i++) { @@ -300,14 +305,27 @@ class ho_unifier::imp { } else if (is_eq(b)) { // Imitation for equality // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), (h_1 x_1 ... x_{num_a-1}) = (h_2 x_1 ... x_{num_a-1}) + // New constraints (h_1 a_1 ... a_{num_a-1}) == eq_lhs(b) + // (h_2 a_1 ... a_{num_a-1}) == eq_rhs(b) expr h_1 = new_s.mk_metavar(ctx); expr h_2 = new_s.mk_metavar(ctx); expr imitation = mk_lambda(arg_types, mk_eq(mk_app_vars(h_1, num_a - 1), mk_app_vars(h_2, num_a - 1))); new_s.assign(midx, imitation); new_q.push_front(constraint(ctx, update_app(a, 0, h_1), eq_lhs(b))); new_q.push_front(constraint(ctx, update_app(a, 0, h_2), eq_rhs(b))); + } else if (is_abstraction(b)) { + // Imitation for lambdas and Pis + // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), fun (x_b : (?h_1 x_1 ... x_{num_a-1})), (?h_2 x_1 ... x_{num_a-1} x_b) + // New constraints (h_1 a_1 ... a_{num_a-1}) == abst_domain(x) + // (h_2 a_1 ... a_{num_a-1} x_b) == abst_body(x) + expr h_1 = new_s.mk_metavar(ctx); + expr h_2 = new_s.mk_metavar(ctx); + expr imitation = mk_lambda(arg_types, mk_lambda(abst_name(b), mk_app_vars(h_1, num_a - 1), mk_app_vars(h_2, num_a))); + new_s.assign(midx, imitation); + new_q.push_front(constraint(ctx, update_app(a, 0, h_1), abst_domain(b))); + new_q.push_front(constraint(extend(ctx, abst_name(b), abst_domain(b)), mk_app(update_app(a, 0, h_2), Var(0)), abst_body(b))); } else { - // "Dump imitation" aka the constant function + // "Dumb imitation" aka the constant function // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), b expr imitation = mk_lambda(arg_types, lift_free_vars(b, 0, num_a - 1)); new_s.assign(midx, imitation); diff --git a/src/tests/library/ho_unifier.cpp b/src/tests/library/ho_unifier.cpp index a7e85ac70..cbeedf2f7 100644 --- a/src/tests/library/ho_unifier.cpp +++ b/src/tests/library/ho_unifier.cpp @@ -6,6 +6,8 @@ Author: Leonardo de Moura */ #include "util/test.h" #include "kernel/environment.h" +#include "kernel/builtin.h" +#include "kernel/abstract.h" #include "library/ho_unifier.h" #include "library/printer.h" #include "library/reduce.h" @@ -27,7 +29,6 @@ void tst1() { expr a = Const("a"); expr b = Const("b"); expr m1 = menv.mk_metavar(); - expr m2 = menv.mk_metavar(); expr l = m1(b, a); expr r = f(b, f(a, b)); for (auto sol : unify(context(), l, r, menv)) { @@ -38,7 +39,34 @@ void tst1() { } } +void tst2() { + environment env; + import_basic(env); + metavar_env menv; + ho_unifier unify(env); + expr N = Const("N"); + expr M = Const("M"); + env.add_var("N", Type()); + env.add_var("f", N >> (Bool >> N)); + env.add_var("a", N); + env.add_var("b", N); + expr f = Const("f"); + expr x = Const("x"); + expr a = Const("a"); + expr b = Const("b"); + expr m1 = menv.mk_metavar(); + expr l = m1(b, a); + expr r = Fun({x, N}, f(x, Eq(a, b))); + for (auto sol : unify(context(), l, r, menv)) { + std::cout << m1 << " -> " << beta_reduce(sol.first.get_subst(m1)) << "\n"; + std::cout << beta_reduce(instantiate_metavars(l, sol.first)) << "\n"; + lean_assert(beta_reduce(instantiate_metavars(l, sol.first)) == r); + std::cout << "--------------\n"; + } +} + int main() { tst1(); + tst2(); return has_violations() ? 1 : 0; }