feat(library/type_inferer): add support for metavariables at type_inferer
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
7f96c07a01
commit
dc51d35dc0
3 changed files with 84 additions and 13 deletions
|
@ -10,6 +10,7 @@ Author: Leonardo de Moura
|
||||||
#include "kernel/normalizer.h"
|
#include "kernel/normalizer.h"
|
||||||
#include "kernel/builtin.h"
|
#include "kernel/builtin.h"
|
||||||
#include "kernel/kernel_exception.h"
|
#include "kernel/kernel_exception.h"
|
||||||
|
#include "kernel/type_checker_trace.h"
|
||||||
#include "kernel/instantiate.h"
|
#include "kernel/instantiate.h"
|
||||||
#include "kernel/free_vars.h"
|
#include "kernel/free_vars.h"
|
||||||
#include "kernel/metavar.h"
|
#include "kernel/metavar.h"
|
||||||
|
@ -17,6 +18,7 @@ Author: Leonardo de Moura
|
||||||
#include "library/type_inferer.h"
|
#include "library/type_inferer.h"
|
||||||
|
|
||||||
namespace lean {
|
namespace lean {
|
||||||
|
static name g_x_name("x");
|
||||||
class type_inferer::imp {
|
class type_inferer::imp {
|
||||||
typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache;
|
typedef scoped_map<expr, expr, expr_hash, expr_eqp> cache;
|
||||||
typedef buffer<unification_constraint> unification_constraints;
|
typedef buffer<unification_constraint> unification_constraints;
|
||||||
|
@ -30,15 +32,26 @@ class type_inferer::imp {
|
||||||
cache m_cache;
|
cache m_cache;
|
||||||
volatile bool m_interrupted;
|
volatile bool m_interrupted;
|
||||||
|
|
||||||
|
expr normalize(expr const & e, context const & ctx) {
|
||||||
|
return m_normalizer(e, ctx, m_menv);
|
||||||
|
}
|
||||||
|
|
||||||
level infer_universe(expr const & t, context const & ctx) {
|
expr check_type(expr const & e, expr const & s, context const & ctx) {
|
||||||
expr u = m_normalizer(infer_type(t, ctx), ctx, m_menv);
|
if (is_type(e))
|
||||||
|
return e;
|
||||||
|
if (e == Bool)
|
||||||
|
return Type();
|
||||||
|
expr u = normalize(e, ctx);
|
||||||
if (is_type(u))
|
if (is_type(u))
|
||||||
return ty_level(u);
|
return u;
|
||||||
if (u == Bool)
|
if (u == Bool)
|
||||||
return level();
|
return Type();
|
||||||
// TODO(Leo): case when u has metavariables
|
if (has_metavar(u) && m_menv) {
|
||||||
throw type_expected_exception(m_env, ctx, t);
|
trace tr = mk_type_expected_trace(ctx, s);
|
||||||
|
m_uc->push_back(mk_convertible_constraint(ctx, u, TypeU, tr));
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
throw type_expected_exception(m_env, ctx, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
expr get_range(expr t, expr const & e, context const & ctx) {
|
expr get_range(expr t, expr const & e, context const & ctx) {
|
||||||
|
@ -51,8 +64,16 @@ class type_inferer::imp {
|
||||||
t = m_normalizer(t, ctx);
|
t = m_normalizer(t, ctx);
|
||||||
if (is_pi(t)) {
|
if (is_pi(t)) {
|
||||||
t = abst_body(t);
|
t = abst_body(t);
|
||||||
|
} else if (has_metavar(t) && m_menv) {
|
||||||
|
// Create two fresh variables A and B,
|
||||||
|
// and assign r == (Pi(x : A), B x)
|
||||||
|
expr A = m_menv->mk_metavar(ctx);
|
||||||
|
expr B = m_menv->mk_metavar(ctx);
|
||||||
|
expr p = mk_pi(g_x_name, A, B(Var(0)));
|
||||||
|
trace tr = mk_function_expected_trace(ctx, e);
|
||||||
|
m_uc->push_back(mk_eq_constraint(ctx, t, p, tr));
|
||||||
|
t = abst_body(p);
|
||||||
} else {
|
} else {
|
||||||
// TODO(Leo): case when t has metavariables
|
|
||||||
throw function_expected_exception(m_env, ctx, e);
|
throw function_expected_exception(m_env, ctx, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -151,13 +172,21 @@ class type_inferer::imp {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case expr_kind::Pi: {
|
case expr_kind::Pi: {
|
||||||
level l1 = infer_universe(abst_domain(e), ctx);
|
expr t1 = check_type(infer_type(abst_domain(e), ctx), abst_domain(e), ctx);
|
||||||
level l2;
|
expr t2;
|
||||||
{
|
{
|
||||||
cache::mk_scope sc(m_cache);
|
cache::mk_scope sc(m_cache);
|
||||||
l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
|
context new_ctx = extend(ctx, abst_name(e), abst_domain(e));
|
||||||
|
t2 = check_type(infer_type(abst_body(e), new_ctx), abst_body(e), new_ctx);
|
||||||
|
}
|
||||||
|
if (is_type(t1) && is_type(t2)) {
|
||||||
|
r = mk_type(max(ty_level(t1), ty_level(t2)));
|
||||||
|
} else {
|
||||||
|
lean_assert(m_uc);
|
||||||
|
trace tr = mk_max_type_trace(ctx, e);
|
||||||
|
r = m_menv->mk_metavar(ctx);
|
||||||
|
m_uc->push_back(mk_max_constraint(ctx, t1, t2, r, tr));
|
||||||
}
|
}
|
||||||
r = mk_type(max(l1, l2));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case expr_kind::Let: {
|
case expr_kind::Let: {
|
||||||
|
|
|
@ -24,7 +24,7 @@ Author: Leonardo de Moura
|
||||||
#include "library/all/all.h"
|
#include "library/all/all.h"
|
||||||
using namespace lean;
|
using namespace lean;
|
||||||
|
|
||||||
std::ostream & operator<<(std::ostream & out, substitution const & env) {
|
static std::ostream & operator<<(std::ostream & out, substitution const & env) {
|
||||||
bool first = true;
|
bool first = true;
|
||||||
env.for_each([&](name const & n, expr const & v) {
|
env.for_each([&](name const & n, expr const & v) {
|
||||||
if (first) first = false; else out << "\n";
|
if (first) first = false; else out << "\n";
|
||||||
|
@ -33,7 +33,7 @@ std::ostream & operator<<(std::ostream & out, substitution const & env) {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
|
static std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
|
||||||
formatter fmt = mk_simple_formatter();
|
formatter fmt = mk_simple_formatter();
|
||||||
for (auto c : uc) {
|
for (auto c : uc) {
|
||||||
out << c.pp(fmt, options(), nullptr, true) << "\n";
|
out << c.pp(fmt, options(), nullptr, true) << "\n";
|
||||||
|
|
|
@ -15,6 +15,14 @@ Author: Leonardo de Moura
|
||||||
#include "library/all/all.h"
|
#include "library/all/all.h"
|
||||||
using namespace lean;
|
using namespace lean;
|
||||||
|
|
||||||
|
static std::ostream & operator<<(std::ostream & out, buffer<unification_constraint> const & uc) {
|
||||||
|
formatter fmt = mk_simple_formatter();
|
||||||
|
for (auto c : uc) {
|
||||||
|
out << c.pp(fmt, options(), nullptr, true) << "\n";
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
static void tst1() {
|
static void tst1() {
|
||||||
environment env = mk_toplevel();
|
environment env = mk_toplevel();
|
||||||
type_inferer type_of(env);
|
type_inferer type_of(env);
|
||||||
|
@ -95,10 +103,44 @@ static void tst3() {
|
||||||
lean_assert(is_eqp(r, infer(F, ctx1)));
|
lean_assert(is_eqp(r, infer(F, ctx1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void tst4() {
|
||||||
|
environment env;
|
||||||
|
import_all(env);
|
||||||
|
metavar_env menv;
|
||||||
|
buffer<unification_constraint> uc;
|
||||||
|
type_inferer inferer(env);
|
||||||
|
expr list = Const("list");
|
||||||
|
expr nil = Const("nil");
|
||||||
|
expr cons = Const("cons");
|
||||||
|
expr A = Const("A");
|
||||||
|
env.add_var("list", Type() >> Type());
|
||||||
|
env.add_var("nil", Pi({A, Type()}, list(A)));
|
||||||
|
env.add_var("cons", Pi({A, Type()}, A >> (list(A) >> list(A))));
|
||||||
|
env.add_var("a", Int);
|
||||||
|
env.add_var("b", Int);
|
||||||
|
env.add_var("n", Nat);
|
||||||
|
env.add_var("m", Nat);
|
||||||
|
expr a = Const("a");
|
||||||
|
expr b = Const("b");
|
||||||
|
expr n = Const("n");
|
||||||
|
expr m = Const("m");
|
||||||
|
expr m1 = menv.mk_metavar();
|
||||||
|
expr m2 = menv.mk_metavar();
|
||||||
|
expr m3 = menv.mk_metavar();
|
||||||
|
expr A1 = menv.mk_metavar();
|
||||||
|
expr A2 = menv.mk_metavar();
|
||||||
|
expr A3 = menv.mk_metavar();
|
||||||
|
expr A4 = menv.mk_metavar();
|
||||||
|
expr F = cons(A1, m1(a), cons(A2, m2(n), cons(A3, m3(b), nil(A4))));
|
||||||
|
std::cout << F << "\n";
|
||||||
|
std::cout << inferer(F, context(), &menv, uc) << "\n";
|
||||||
|
std::cout << uc << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
tst1();
|
tst1();
|
||||||
tst2();
|
tst2();
|
||||||
tst3();
|
tst3();
|
||||||
|
tst4();
|
||||||
return has_violations() ? 1 : 0;
|
return has_violations() ? 1 : 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue