feat(library/type_inferer): add support for metavariables at type_inferer
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
7f96c07a01
commit
dc51d35dc0
3 changed files with 84 additions and 13 deletions
|
@ -10,6 +10,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/normalizer.h"
|
||||
#include "kernel/builtin.h"
|
||||
#include "kernel/kernel_exception.h"
|
||||
#include "kernel/type_checker_trace.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/free_vars.h"
|
||||
#include "kernel/metavar.h"
|
||||
|
@ -17,6 +18,7 @@ Author: Leonardo de Moura
|
|||
#include "library/type_inferer.h"
|
||||
|
||||
namespace lean {
|
||||
static name g_x_name("x");
|
||||
class type_inferer::imp {
|
||||
typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache;
|
||||
typedef buffer<unification_constraint> unification_constraints;
|
||||
|
@ -30,15 +32,26 @@ class type_inferer::imp {
|
|||
cache m_cache;
|
||||
volatile bool m_interrupted;
|
||||
|
||||
expr normalize(expr const & e, context const & ctx) {
|
||||
return m_normalizer(e, ctx, m_menv);
|
||||
}
|
||||
|
||||
level infer_universe(expr const & t, context const & ctx) {
|
||||
expr u = m_normalizer(infer_type(t, ctx), ctx, m_menv);
|
||||
expr check_type(expr const & e, expr const & s, context const & ctx) {
|
||||
if (is_type(e))
|
||||
return e;
|
||||
if (e == Bool)
|
||||
return Type();
|
||||
expr u = normalize(e, ctx);
|
||||
if (is_type(u))
|
||||
return ty_level(u);
|
||||
return u;
|
||||
if (u == Bool)
|
||||
return level();
|
||||
// TODO(Leo): case when u has metavariables
|
||||
throw type_expected_exception(m_env, ctx, t);
|
||||
return Type();
|
||||
if (has_metavar(u) && m_menv) {
|
||||
trace tr = mk_type_expected_trace(ctx, s);
|
||||
m_uc->push_back(mk_convertible_constraint(ctx, u, TypeU, tr));
|
||||
return u;
|
||||
}
|
||||
throw type_expected_exception(m_env, ctx, s);
|
||||
}
|
||||
|
||||
expr get_range(expr t, expr const & e, context const & ctx) {
|
||||
|
@ -51,8 +64,16 @@ class type_inferer::imp {
|
|||
t = m_normalizer(t, ctx);
|
||||
if (is_pi(t)) {
|
||||
t = abst_body(t);
|
||||
} else if (has_metavar(t) && m_menv) {
|
||||
// Create two fresh variables A and B,
|
||||
// and assign r == (Pi(x : A), B x)
|
||||
expr A = m_menv->mk_metavar(ctx);
|
||||
expr B = m_menv->mk_metavar(ctx);
|
||||
expr p = mk_pi(g_x_name, A, B(Var(0)));
|
||||
trace tr = mk_function_expected_trace(ctx, e);
|
||||
m_uc->push_back(mk_eq_constraint(ctx, t, p, tr));
|
||||
t = abst_body(p);
|
||||
} else {
|
||||
// TODO(Leo): case when t has metavariables
|
||||
throw function_expected_exception(m_env, ctx, e);
|
||||
}
|
||||
}
|
||||
|
@ -151,13 +172,21 @@ class type_inferer::imp {
|
|||
break;
|
||||
}
|
||||
case expr_kind::Pi: {
|
||||
level l1 = infer_universe(abst_domain(e), ctx);
|
||||
level l2;
|
||||
expr t1 = check_type(infer_type(abst_domain(e), ctx), abst_domain(e), ctx);
|
||||
expr t2;
|
||||
{
|
||||
cache::mk_scope sc(m_cache);
|
||||
l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
|
||||
context new_ctx = extend(ctx, abst_name(e), abst_domain(e));
|
||||
t2 = check_type(infer_type(abst_body(e), new_ctx), abst_body(e), new_ctx);
|
||||
}
|
||||
if (is_type(t1) && is_type(t2)) {
|
||||
r = mk_type(max(ty_level(t1), ty_level(t2)));
|
||||
} else {
|
||||
lean_assert(m_uc);
|
||||
trace tr = mk_max_type_trace(ctx, e);
|
||||
r = m_menv->mk_metavar(ctx);
|
||||
m_uc->push_back(mk_max_constraint(ctx, t1, t2, r, tr));
|
||||
}
|
||||
r = mk_type(max(l1, l2));
|
||||
break;
|
||||
}
|
||||
case expr_kind::Let: {
|
||||
|
|
|
@ -24,7 +24,7 @@ Author: Leonardo de Moura
|
|||
#include "library/all/all.h"
|
||||
using namespace lean;
|
||||
|
||||
std::ostream & operator<<(std::ostream & out, substitution const & env) {
|
||||
static std::ostream & operator<<(std::ostream & out, substitution const & env) {
|
||||
bool first = true;
|
||||
env.for_each([&](name const & n, expr const & v) {
|
||||
if (first) first = false; else out << "\n";
|
||||
|
@ -33,7 +33,7 @@ std::ostream & operator<<(std::ostream & out, substitution const & env) {
|
|||
return out;
|
||||
}
|
||||
|
||||
std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
|
||||
static std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
|
||||
formatter fmt = mk_simple_formatter();
|
||||
for (auto c : uc) {
|
||||
out << c.pp(fmt, options(), nullptr, true) << "\n";
|
||||
|
|
|
@ -15,6 +15,14 @@ Author: Leonardo de Moura
|
|||
#include "library/all/all.h"
|
||||
using namespace lean;
|
||||
|
||||
static std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
|
||||
formatter fmt = mk_simple_formatter();
|
||||
for (auto c : uc) {
|
||||
out << c.pp(fmt, options(), nullptr, true) << "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static void tst1() {
|
||||
environment env = mk_toplevel();
|
||||
type_inferer type_of(env);
|
||||
|
@ -95,10 +103,44 @@ static void tst3() {
|
|||
lean_assert(is_eqp(r, infer(F, ctx1)));
|
||||
}
|
||||
|
||||
static void tst4() {
|
||||
environment env;
|
||||
import_all(env);
|
||||
metavar_env menv;
|
||||
buffer<unification_constraint> uc;
|
||||
type_inferer inferer(env);
|
||||
expr list = Const("list");
|
||||
expr nil = Const("nil");
|
||||
expr cons = Const("cons");
|
||||
expr A = Const("A");
|
||||
env.add_var("list", Type() >> Type());
|
||||
env.add_var("nil", Pi({A, Type()}, list(A)));
|
||||
env.add_var("cons", Pi({A, Type()}, A >> (list(A) >> list(A))));
|
||||
env.add_var("a", Int);
|
||||
env.add_var("b", Int);
|
||||
env.add_var("n", Nat);
|
||||
env.add_var("m", Nat);
|
||||
expr a = Const("a");
|
||||
expr b = Const("b");
|
||||
expr n = Const("n");
|
||||
expr m = Const("m");
|
||||
expr m1 = menv.mk_metavar();
|
||||
expr m2 = menv.mk_metavar();
|
||||
expr m3 = menv.mk_metavar();
|
||||
expr A1 = menv.mk_metavar();
|
||||
expr A2 = menv.mk_metavar();
|
||||
expr A3 = menv.mk_metavar();
|
||||
expr A4 = menv.mk_metavar();
|
||||
expr F = cons(A1, m1(a), cons(A2, m2(n), cons(A3, m3(b), nil(A4))));
|
||||
std::cout << F << "\n";
|
||||
std::cout << inferer(F, context(), &menv, uc) << "\n";
|
||||
std::cout << uc << "\n";
|
||||
}
|
||||
|
||||
int main() {
|
||||
tst1();
|
||||
tst2();
|
||||
tst3();
|
||||
tst4();
|
||||
return has_violations() ? 1 : 0;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue