feat(library/type_inferer): add support for metavariables at type_inferer

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-10-18 10:22:40 -07:00
parent 7f96c07a01
commit dc51d35dc0
3 changed files with 84 additions and 13 deletions

View file

@ -10,6 +10,7 @@ Author: Leonardo de Moura
#include "kernel/normalizer.h" #include "kernel/normalizer.h"
#include "kernel/builtin.h" #include "kernel/builtin.h"
#include "kernel/kernel_exception.h" #include "kernel/kernel_exception.h"
#include "kernel/type_checker_trace.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
#include "kernel/free_vars.h" #include "kernel/free_vars.h"
#include "kernel/metavar.h" #include "kernel/metavar.h"
@ -17,6 +18,7 @@ Author: Leonardo de Moura
#include "library/type_inferer.h" #include "library/type_inferer.h"
namespace lean { namespace lean {
static name g_x_name("x");
class type_inferer::imp { class type_inferer::imp {
typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache; typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache;
typedef buffer<unification_constraint> unification_constraints; typedef buffer<unification_constraint> unification_constraints;
@ -30,15 +32,26 @@ class type_inferer::imp {
cache m_cache; cache m_cache;
volatile bool m_interrupted; volatile bool m_interrupted;
expr normalize(expr const & e, context const & ctx) {
return m_normalizer(e, ctx, m_menv);
}
level infer_universe(expr const & t, context const & ctx) { expr check_type(expr const & e, expr const & s, context const & ctx) {
expr u = m_normalizer(infer_type(t, ctx), ctx, m_menv); if (is_type(e))
return e;
if (e == Bool)
return Type();
expr u = normalize(e, ctx);
if (is_type(u)) if (is_type(u))
return ty_level(u); return u;
if (u == Bool) if (u == Bool)
return level(); return Type();
// TODO(Leo): case when u has metavariables if (has_metavar(u) && m_menv) {
throw type_expected_exception(m_env, ctx, t); trace tr = mk_type_expected_trace(ctx, s);
m_uc->push_back(mk_convertible_constraint(ctx, u, TypeU, tr));
return u;
}
throw type_expected_exception(m_env, ctx, s);
} }
expr get_range(expr t, expr const & e, context const & ctx) { expr get_range(expr t, expr const & e, context const & ctx) {
@ -51,8 +64,16 @@ class type_inferer::imp {
t = m_normalizer(t, ctx); t = m_normalizer(t, ctx);
if (is_pi(t)) { if (is_pi(t)) {
t = abst_body(t); t = abst_body(t);
} else if (has_metavar(t) && m_menv) {
// Create two fresh variables A and B,
// and assign r == (Pi(x : A), B x)
expr A = m_menv->mk_metavar(ctx);
expr B = m_menv->mk_metavar(ctx);
expr p = mk_pi(g_x_name, A, B(Var(0)));
trace tr = mk_function_expected_trace(ctx, e);
m_uc->push_back(mk_eq_constraint(ctx, t, p, tr));
t = abst_body(p);
} else { } else {
// TODO(Leo): case when t has metavariables
throw function_expected_exception(m_env, ctx, e); throw function_expected_exception(m_env, ctx, e);
} }
} }
@ -151,13 +172,21 @@ class type_inferer::imp {
break; break;
} }
case expr_kind::Pi: { case expr_kind::Pi: {
level l1 = infer_universe(abst_domain(e), ctx); expr t1 = check_type(infer_type(abst_domain(e), ctx), abst_domain(e), ctx);
level l2; expr t2;
{ {
cache::mk_scope sc(m_cache); cache::mk_scope sc(m_cache);
l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); context new_ctx = extend(ctx, abst_name(e), abst_domain(e));
t2 = check_type(infer_type(abst_body(e), new_ctx), abst_body(e), new_ctx);
}
if (is_type(t1) && is_type(t2)) {
r = mk_type(max(ty_level(t1), ty_level(t2)));
} else {
lean_assert(m_uc);
trace tr = mk_max_type_trace(ctx, e);
r = m_menv->mk_metavar(ctx);
m_uc->push_back(mk_max_constraint(ctx, t1, t2, r, tr));
} }
r = mk_type(max(l1, l2));
break; break;
} }
case expr_kind::Let: { case expr_kind::Let: {

View file

@ -24,7 +24,7 @@ Author: Leonardo de Moura
#include "library/all/all.h" #include "library/all/all.h"
using namespace lean; using namespace lean;
std::ostream & operator<<(std::ostream & out, substitution const & env) { static std::ostream & operator<<(std::ostream & out, substitution const & env) {
bool first = true; bool first = true;
env.for_each([&](name const & n, expr const & v) { env.for_each([&](name const & n, expr const & v) {
if (first) first = false; else out << "\n"; if (first) first = false; else out << "\n";
@ -33,7 +33,7 @@ std::ostream & operator<<(std::ostream & out, substitution const & env) {
return out; return out;
} }
std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) { static std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
formatter fmt = mk_simple_formatter(); formatter fmt = mk_simple_formatter();
for (auto c : uc) { for (auto c : uc) {
out << c.pp(fmt, options(), nullptr, true) << "\n"; out << c.pp(fmt, options(), nullptr, true) << "\n";

View file

@ -15,6 +15,14 @@ Author: Leonardo de Moura
#include "library/all/all.h" #include "library/all/all.h"
using namespace lean; using namespace lean;
static std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
formatter fmt = mk_simple_formatter();
for (auto c : uc) {
out << c.pp(fmt, options(), nullptr, true) << "\n";
}
return out;
}
static void tst1() { static void tst1() {
environment env = mk_toplevel(); environment env = mk_toplevel();
type_inferer type_of(env); type_inferer type_of(env);
@ -95,10 +103,44 @@ static void tst3() {
lean_assert(is_eqp(r, infer(F, ctx1))); lean_assert(is_eqp(r, infer(F, ctx1)));
} }
static void tst4() {
environment env;
import_all(env);
metavar_env menv;
buffer<unification_constraint> uc;
type_inferer inferer(env);
expr list = Const("list");
expr nil = Const("nil");
expr cons = Const("cons");
expr A = Const("A");
env.add_var("list", Type() >> Type());
env.add_var("nil", Pi({A, Type()}, list(A)));
env.add_var("cons", Pi({A, Type()}, A >> (list(A) >> list(A))));
env.add_var("a", Int);
env.add_var("b", Int);
env.add_var("n", Nat);
env.add_var("m", Nat);
expr a = Const("a");
expr b = Const("b");
expr n = Const("n");
expr m = Const("m");
expr m1 = menv.mk_metavar();
expr m2 = menv.mk_metavar();
expr m3 = menv.mk_metavar();
expr A1 = menv.mk_metavar();
expr A2 = menv.mk_metavar();
expr A3 = menv.mk_metavar();
expr A4 = menv.mk_metavar();
expr F = cons(A1, m1(a), cons(A2, m2(n), cons(A3, m3(b), nil(A4))));
std::cout << F << "\n";
std::cout << inferer(F, context(), &menv, uc) << "\n";
std::cout << uc << "\n";
}
int main() { int main() {
tst1(); tst1();
tst2(); tst2();
tst3(); tst3();
tst4();
return has_violations() ? 1 : 0; return has_violations() ? 1 : 0;
} }