Add (optional) type to let declarations

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2013-09-06 10:06:26 -07:00
parent 6da194334e
commit 2459c4ae7c
26 changed files with 186 additions and 45 deletions

View file

@ -425,10 +425,30 @@ class elaborator::imp {
return expr_pair(new_e, t);
}
case expr_kind::Let: {
expr_pair t_p;
if (let_type(e))
t_p = process(let_type(e), ctx);
auto v_p = process(let_value(e), ctx);
auto b_p = process(let_body(e), extend(ctx, let_name(e), v_p.second, v_p.first));
if (let_type(e)) {
expr const & expected = t_p.first;
expr const & given = v_p.second;
if (has_metavar(expected) || has_metavar(given)) {
info_ref r = mk_expected_type_info(let_value(e), v_p.first, expected, given, ctx);
m_constraints.push_back(constraint(expected, given, ctx, r));
} else {
if (!is_convertible(expected, given, ctx)) {
expr coercion = m_frontend.get_coercion(given, expected);
if (coercion) {
v_p.first = mk_app(coercion, v_p.first);
} else {
throw def_type_mismatch_exception(m_env, ctx, let_name(e), let_type(e), v_p.first, v_p.second);
}
}
}
}
auto b_p = process(let_body(e), extend(ctx, let_name(e), t_p.first ? t_p.first : v_p.second, v_p.first));
expr t = lower_free_vars_mmv(b_p.second, 1, 1);
expr new_e = update_let(e, v_p.first, b_p.first);
expr new_e = update_let(e, t_p.first, v_p.first, b_p.first);
add_trace(e, new_e);
return expr_pair(new_e, t);
}}

View file

@ -853,14 +853,19 @@ class parser::imp {
expr parse_let() {
next();
mk_scope scope(*this);
buffer<std::tuple<pos_info, name, expr>> bindings;
buffer<std::tuple<pos_info, name, expr, expr>> bindings;
while (true) {
auto p = pos();
name id = check_identifier_next("invalid let expression, identifier expected");
expr type;
if (curr_is_colon()) {
next();
type = parse_expr();
}
check_assign_next("invalid let expression, ':=' expected");
expr val = parse_expr();
register_binding(id);
bindings.push_back(std::make_tuple(p, id, val));
bindings.push_back(std::make_tuple(p, id, type, val));
if (curr_is_in()) {
next();
expr r = parse_expr();
@ -868,7 +873,7 @@ class parser::imp {
while (i > 0) {
--i;
auto p = std::get<0>(bindings[i]);
r = save(mk_let(std::get<1>(bindings[i]), std::get<2>(bindings[i]), r), p);
r = save(mk_let(std::get<1>(bindings[i]), std::get<2>(bindings[i]), std::get<3>(bindings[i]), r), p);
}
return r;
} else {

View file

@ -842,11 +842,11 @@ class pp_fn {
return pp_abstraction_core(e, depth, expr());
}
expr collect_nested_let(expr const & e, buffer<std::pair<name, expr>> & bindings) {
expr collect_nested_let(expr const & e, buffer<std::tuple<name, expr, expr>> & bindings) {
if (is_let(e)) {
name n1 = get_unused_name(e);
m_local_names.insert(n1);
bindings.push_back(mk_pair(n1, let_value(e)));
bindings.push_back(std::make_tuple(n1, let_type(e), let_value(e)));
expr b = instantiate_with_closed(let_body(e), mk_constant(n1));
return collect_nested_let(b, bindings);
} else {
@ -856,19 +856,28 @@ class pp_fn {
result pp_let(expr const & e, unsigned depth) {
local_names::mk_scope mk(m_local_names);
buffer<std::pair<name, expr>> bindings;
buffer<std::tuple<name, expr, expr>> bindings;
expr body = collect_nested_let(e, bindings);
unsigned r_weight = 2;
format r_format = g_let_fmt;
unsigned sz = bindings.size();
for (unsigned i = 0; i < sz; i++) {
auto b = bindings[i];
name const & n = b.first;
result p_def = pp(b.second, depth+1);
name const & n = std::get<0>(b);
format beg = i == 0 ? space() : line();
format sep = i < sz - 1 ? comma() : format();
r_format += nest(3 + 1, format{beg, format(n), space(), g_assign_fmt, nest(n.size() + 1 + 2 + 1, format{space(), p_def.first, sep})});
r_weight += p_def.second;
result p_def = pp(std::get<2>(b), depth+1);
expr type = std::get<1>(b);
if (type) {
result p_type = pp(type, depth+1);
r_format += nest(3 + 1, compose(beg, group(format{format(n), space(),
colon(), nest(n.size() + 1 + 1 + 1, compose(space(), p_type.first)), space(), g_assign_fmt,
nest(m_indent, format{line(), p_def.first, sep})})));
r_weight += p_type.second + p_def.second;
} else {
r_format += nest(3 + 1, format{beg, format(n), space(), g_assign_fmt, nest(n.size() + 1 + 2 + 1, format{space(), p_def.first, sep})});
r_weight += p_def.second;
}
}
result p_body = pp(body, depth+1);
r_weight += p_body.second;

View file

@ -47,8 +47,14 @@ inline expr Pi(std::pair<expr const &, expr const &> const & p, expr const & b)
/**
\brief Create a Let expression (Let x := v in b), the term b is abstracted using abstract(b, x).
*/
inline expr Let(name const & x, expr const & v, expr const & b) { return mk_let(x, v, abstract(b, mk_constant(x))); }
inline expr Let(expr const & x, expr const & v, expr const & b) { return mk_let(const_name(x), v, abstract(b, x)); }
inline expr Let(name const & x, expr const & v, expr const & b) { return mk_let(x, expr(), v, abstract(b, mk_constant(x))); }
inline expr Let(expr const & x, expr const & v, expr const & b) { return mk_let(const_name(x), expr(), v, abstract(b, x)); }
inline expr Let(std::pair<expr const &, expr const &> const & p, expr const & b) { return Let(p.first, p.second, b); }
expr Let(std::initializer_list<std::pair<expr const &, expr const &>> const & l, expr const & b);
/**
\brief Create a Let expression (Let x : t := v in b), the term b is abstracted using abstract(b, x).
*/
inline expr Let(name const & x, expr const & t, expr const & v, expr const & b) { return mk_let(x, t, v, abstract(b, mk_constant(x))); }
inline expr Let(expr const & x, expr const & t, expr const & v, expr const & b) { return mk_let(const_name(x), t, v, abstract(b, x)); }
}

View file

@ -94,9 +94,10 @@ expr_type::expr_type(level const & l):
m_level(l) {
}
expr_type::~expr_type() {}
expr_let::expr_let(name const & n, expr const & v, expr const & b):
expr_let::expr_let(name const & n, expr const & t, expr const & v, expr const & b):
expr_cell(expr_kind::Let, ::lean::hash(v.hash(), b.hash())),
m_name(n),
m_type(t),
m_value(v),
m_body(b) {
}
@ -154,7 +155,7 @@ expr copy(expr const & a) {
case expr_kind::Eq: return mk_eq(eq_lhs(a), eq_rhs(a));
case expr_kind::Lambda: return mk_lambda(abst_name(a), abst_domain(a), abst_body(a));
case expr_kind::Pi: return mk_pi(abst_name(a), abst_domain(a), abst_body(a));
case expr_kind::Let: return mk_let(let_name(a), let_value(a), let_body(a));
case expr_kind::Let: return mk_let(let_name(a), let_type(a), let_value(a), let_body(a));
}
lean_unreachable();
return expr();

View file

@ -26,7 +26,7 @@ class value;
| Pi name expr expr
| Type universe
| Eq expr expr (heterogeneous equality)
| Let name expr expr
| Let name expr expr expr
TODO: match expressions.
@ -102,7 +102,7 @@ public:
friend expr mk_lambda(name const & n, expr const & t, expr const & e);
friend expr mk_pi(name const & n, expr const & t, expr const & e);
friend expr mk_type(level const & l);
friend expr mk_let(name const & n, expr const & v, expr const & e);
friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e);
friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; }
@ -179,12 +179,14 @@ public:
/** \brief Let expressions */
class expr_let : public expr_cell {
name m_name;
expr m_type;
expr m_value;
expr m_body;
public:
expr_let(name const & n, expr const & v, expr const & b);
expr_let(name const & n, expr const & t, expr const & v, expr const & b);
~expr_let();
name const & get_name() const { return m_name; }
expr const & get_type() const { return m_type; }
expr const & get_value() const { return m_value; }
expr const & get_body() const { return m_body; }
};
@ -271,7 +273,7 @@ inline expr mk_lambda(name const & n, expr const & t, expr const & e) { return e
inline expr mk_pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); }
inline expr mk_arrow(expr const & t, expr const & e) { return mk_pi(name("_"), t, e); }
inline expr operator>>(expr const & t, expr const & e) { return mk_arrow(t, e); }
inline expr mk_let(name const & n, expr const & v, expr const & e) { return expr(new expr_let(n, v, e)); }
inline expr mk_let(name const & n, expr const & t, expr const & v, expr const & e) { return expr(new expr_let(n, t, v, e)); }
inline expr mk_type(level const & l) { return expr(new expr_type(l)); }
expr mk_type();
inline expr Type(level const & l) { return mk_type(l); }
@ -327,6 +329,7 @@ inline expr const & abst_body(expr_cell * e) { return to_abstraction
inline level const & ty_level(expr_cell * e) { return to_type(e)->get_level(); }
inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); }
inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); }
inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); }
inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); }
/** \brief Return the reference counter of the given expression. */
@ -354,6 +357,7 @@ inline expr const & abst_domain(expr const & e) { return to_abstractio
inline expr const & abst_body(expr const & e) { return to_abstraction(e)->get_body(); }
inline level const & ty_level(expr const & e) { return to_type(e)->get_level(); }
inline name const & let_name(expr const & e) { return to_let(e)->get_name(); }
inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); }
inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); }
inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); }
// =======================================
@ -447,14 +451,15 @@ template<typename F> expr update_abst(expr const & e, F f) {
}
}
template<typename F> expr update_let(expr const & e, F f) {
static_assert(std::is_same<typename std::result_of<F(expr const &, expr const &)>::type,
std::pair<expr, expr>>::value,
static_assert(std::is_same<typename std::result_of<F(expr const &, expr const &, expr const &)>::type,
std::tuple<expr, expr, expr>>::value,
"update_let: return type of f is not pair<expr, expr>");
expr const & old_t = let_type(e);
expr const & old_v = let_value(e);
expr const & old_b = let_body(e);
std::pair<expr, expr> p = f(old_v, old_b);
if (!is_eqp(p.first, old_v) || !is_eqp(p.second, old_b))
return mk_let(let_name(e), p.first, p.second);
std::tuple<expr, expr, expr> t = f(old_t, old_v, old_b);
if (!is_eqp(std::get<0>(t), old_t) || !is_eqp(std::get<1>(t), old_v) || !is_eqp(std::get<2>(t), old_b))
return mk_let(let_name(e), std::get<0>(t), std::get<1>(t), std::get<2>(t));
else
return e;
}

View file

@ -55,7 +55,14 @@ class expr_eq_fn {
case expr_kind::Pi: return apply(abst_domain(a), abst_domain(b)) && apply(abst_body(a), abst_body(b));
case expr_kind::Type: return ty_level(a) == ty_level(b);
case expr_kind::Value: return to_value(a) == to_value(b);
case expr_kind::Let: return apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b));
case expr_kind::Let:
if (let_type(a) && let_type(b)) {
if (!apply(let_type(a), let_type(b)))
return false;
} else if (let_type(a) || let_type(b)) {
return false;
}
return apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b));
}
lean_unreachable(); // LCOV_EXCL_LINE
return false; // LCOV_EXCL_LINE

View file

@ -42,6 +42,8 @@ class for_each_fn {
apply(abst_body(e), offset + 1);
return;
case expr_kind::Let:
if (let_type(e))
apply(let_type(e), offset);
apply(let_value(e), offset);
apply(let_body(e), offset + 1);
return;

View file

@ -65,7 +65,7 @@ protected:
result = apply(abst_domain(e), offset) || apply(abst_body(e), offset + 1);
break;
case expr_kind::Let:
result = apply(let_value(e), offset) || apply(let_body(e), offset + 1);
result = (let_type(e) && apply(let_type(e), offset)) || apply(let_value(e), offset) || apply(let_body(e), offset + 1);
break;
}
@ -133,7 +133,7 @@ protected:
result = apply(abst_domain(e), offset) || apply(abst_body(e), offset + 1);
break;
case expr_kind::Let:
result = apply(let_value(e), offset) || apply(let_body(e), offset + 1);
result = (let_type(e) && apply(let_type(e), offset)) || apply(let_value(e), offset) || apply(let_body(e), offset + 1);
break;
}

View file

@ -154,17 +154,24 @@ public:
get_name(), type get_type() and value get_value() is incorrect
because the value has type get_value_type() and it not matches
the given type get_type().
This exception is also used to sign declaration mismatches in
let declarations.
*/
class def_type_mismatch_exception : public type_checker_exception {
name m_name;
expr m_type;
expr m_value;
expr m_value_type;
context m_context;
name m_name;
expr m_type;
expr m_value;
expr m_value_type;
public:
def_type_mismatch_exception(environment const & env, context const & ctx, name const & n, expr const & type, expr const & val, expr const & val_type):
type_checker_exception(env), m_context(ctx), m_name(n), m_type(type), m_value(val), m_value_type(val_type) {}
def_type_mismatch_exception(environment const & env, name const & n, expr const & type, expr const & val, expr const & val_type):
type_checker_exception(env), m_name(n), m_type(type), m_value(val), m_value_type(val_type) {}
virtual ~def_type_mismatch_exception() {}
name const & get_name() const { return m_name; }
context const & get_context() const { return m_context; }
expr const & get_type() const { return m_type; }
expr const & get_value() const { return m_value; }
expr const & get_value_type() const { return m_value_type; }

View file

@ -65,7 +65,10 @@ class replace_fn {
r = update_abst(e, [=](expr const & t, expr const & b) { return std::make_pair(apply(t, offset), apply(b, offset+1)); });
break;
case expr_kind::Let:
r = update_let(e, [=](expr const & v, expr const & b) { return std::make_pair(apply(v, offset), apply(b, offset+1)); });
r = update_let(e, [=](expr const & t, expr const & v, expr const & b) {
expr new_t = t ? apply(t, offset) : expr();
return std::make_tuple(new_t, apply(v, offset), apply(b, offset+1));
});
break;
}
}

View file

@ -133,6 +133,11 @@ public:
}
case expr_kind::Let: {
expr lt = infer_type(let_value(e), ctx);
if (let_type(e)) {
infer_universe(let_type(e), ctx); // check if it is really a type
if (!m_normalizer.is_convertible(let_type(e), lt, ctx))
throw def_type_mismatch_exception(m_env, ctx, let_name(e), let_type(e), let_value(e), lt);
}
{
cache::mk_scope sc(m_cache);
r = lower_free_vars(infer_type(let_body(e), extend(ctx, let_name(e), lt, let_value(e))), 1);

View file

@ -35,7 +35,10 @@ class deep_copy_fn {
case expr_kind::Eq: r = mk_eq(apply(eq_lhs(a)), apply(eq_rhs(a))); break;
case expr_kind::Lambda: r = mk_lambda(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))); break;
case expr_kind::Pi: r = mk_pi(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))); break;
case expr_kind::Let: r = mk_let(let_name(a), apply(let_value(a)), apply(let_body(a))); break;
case expr_kind::Let: {
expr new_t = let_type(a) ? apply(let_type(a)) : expr();
r = mk_let(let_name(a), new_t, apply(let_value(a)), apply(let_body(a))); break;
}
}
if (sh)
m_cache.insert(std::make_pair(a.raw(), r));

View file

@ -55,7 +55,10 @@ struct max_sharing_fn::imp {
return r;
}
case expr_kind::Let: {
expr r = update_let(a, [=](expr const & v, expr const & b) { return std::make_pair(apply(v), apply(b)); });
expr r = update_let(a, [=](expr const & t, expr const & v, expr const & b) {
expr new_t = t ? apply(t) : expr();
return std::make_tuple(new_t, apply(v), apply(b));
});
cache(r);
return r;
}}

View file

@ -124,7 +124,10 @@ struct print_expr_fn {
}
break;
case expr_kind::Let:
out() << "let " << let_name(a) << " := ";
out() << "let " << let_name(a);
if (let_type(a))
out() << " : " << let_type(a);
out() << " := ";
print(let_value(a), c);
out() << " in ";
print_child(let_body(a), extend(c, let_name(a), let_value(a)));

View file

@ -38,11 +38,11 @@ expr update_pi(expr const & pi, expr const & d, expr const & b) {
return mk_pi(abst_name(pi), d, b);
}
expr update_let(expr const & let, expr const & v, expr const & b) {
if (is_eqp(let_value(let), v) && is_eqp(let_body(let), b))
expr update_let(expr const & let, expr const & t, expr const & v, expr const & b) {
if (is_eqp(let_type(let), t) && is_eqp(let_value(let), v) && is_eqp(let_body(let), b))
return let;
else
return mk_let(let_name(let), v, b);
return mk_let(let_name(let), t, v, b);
}
expr update_eq(expr const & eq, expr const & l, expr const & r) {

View file

@ -27,11 +27,11 @@ expr update_lambda(expr const & lambda, expr const & d, expr const & b);
*/
expr update_pi(expr const & pi, expr const & d, expr const & b);
/**
\brief Return a let expression based on \c let with value \c v and \c body b.
\brief Return a let expression based on \c let with type \c t value \c v and \c body b.
\remark Return \c let if the given value and body are (pointer) equal to the ones in \c let.
*/
expr update_let(expr const & let, expr const & v, expr const & b);
expr update_let(expr const & let, expr const & t, expr const & v, expr const & b);
/**
\brief Return a new equality with lhs \c l and rhs \c r.

View file

@ -345,7 +345,7 @@ void tst14() {
void tst15() {
expr t = Eq(Const("a"), Const("b"));
std::cout << t << "\n";
expr l = mk_let("a", Const("b"), Var(0));
expr l = mk_let("a", expr(), Const("b"), Var(0));
std::cout << l << "\n";
lean_assert(closed(l));
}

View file

@ -198,7 +198,7 @@ static void tst3() {
static void tst4() {
environment env;
env.add_var("b", Type());
expr t1 = mk_let("a", Const("b"), mk_lambda("c", Type(), Var(1)(Var(0))));
expr t1 = mk_let("a", expr(), Const("b"), mk_lambda("c", Type(), Var(1)(Var(0))));
std::cout << t1 << " --> " << normalize(t1, env) << "\n";
lean_assert(normalize(t1, env) == mk_lambda("c", Type(), Const("b")(Var(0))));
}

View file

@ -77,7 +77,7 @@ static void tst3() {
expr f = Fun("a", Bool, Eq(Const("a"), True));
std::cout << infer_type(f, env) << "\n";
lean_assert(infer_type(f, env) == mk_arrow(Bool, Bool));
expr t = mk_let("a", True, Var(0));
expr t = mk_let("a", expr(), True, Var(0));
std::cout << infer_type(t, env) << "\n";
}
@ -200,7 +200,7 @@ static void tst11() {
expr t3 = f(b,b);
for (unsigned i = 0; i < n; i++) {
t1 = f(t1,t1);
t2 = mk_let("x", t2, f(Var(0), Var(0)));
t2 = mk_let("x", expr(), t2, f(Var(0), Var(0)));
t3 = f(t3,t3);
}
lean_assert(t1 != t2);

11
tests/lean/let1.lean Normal file
View file

@ -0,0 +1,11 @@
Show let a : Nat := 10, b : Nat := 20, c : Nat := 30, d : Nat := 10 in a + b + c + d
Show let a : Nat := 1000000000000000000, b : Nat := 20000000000000000000, c : Nat := 3000000000000000000, d : Nat := 4000000000000000000 in a + b + c + d
Check let a : Nat := 10 in a + 1
Eval let a : Nat := 20 in a + 10
Eval let a := 20 in a + 10
Check let a : Int := 20 in a + 10
Set pp::coercion true
Show let a : Int := 20 in a + 10

View file

@ -0,0 +1,14 @@
Set: pp::colors
Set: pp::unicode
let a : := 10, b : := 20, c : := 30, d : := 10 in a + b + c + d
let a : := 1000000000000000000,
b : := 20000000000000000000,
c : := 3000000000000000000,
d : := 4000000000000000000
in a + b + c + d
30
30
Set: lean::pp::coercion
let a : := nat_to_int 20 in a + (nat_to_int 10)

11
tests/lean/let2.lean Normal file
View file

@ -0,0 +1,11 @@
(* Annotating lemmas *)
Theorem simple (p q r : Bool) : (p ⇒ q) ∧ (q ⇒ r) ⇒ p ⇒ r :=
Discharge (λ H_pq_qr, Discharge (λ H_p,
let P_pq : (p ⇒ q) := Conjunct1 H_pq_qr,
P_qr : (q ⇒ r) := Conjunct2 H_pq_qr,
P_q : q := MP P_pq H_p
in MP P_qr P_q))
Show Environment 1

View file

@ -0,0 +1,12 @@
Set: pp::colors
Set: pp::unicode
Proved: simple
Theorem simple (p q r : Bool) : (p ⇒ q) ∧ (q ⇒ r) ⇒ p ⇒ r :=
Discharge
(λ H_pq_qr : (p ⇒ q) ∧ (q ⇒ r),
Discharge
(λ H_p : p,
let P_pq : p ⇒ q := Conjunct1 H_pq_qr,
P_qr : q ⇒ r := Conjunct2 H_pq_qr,
P_q : q := MP P_pq H_p
in MP P_qr P_q))

8
tests/lean/let3.lean Normal file
View file

@ -0,0 +1,8 @@
Variable magic : Pi (H : Bool), H
Set pp::notation false
Set pp::coercion true
Show let a : Int := 1,
H : a > 0 := magic (a > 0)
in H

View file

@ -0,0 +1,6 @@
Set: pp::colors
Set: pp::unicode
Assumed: magic
Set: lean::pp::notation
Set: lean::pp::coercion
let a : := nat_to_int 1, H : Int::gt a (nat_to_int 0) := magic (Int::gt a (nat_to_int 0)) in H