diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt index 1b577ab27..cabaa7306 100644 --- a/src/kernel/CMakeLists.txt +++ b/src/kernel/CMakeLists.txt @@ -1,4 +1,4 @@ add_library(kernel expr.cpp max_sharing.cpp free_vars.cpp abstract.cpp instantiate.cpp deep_copy.cpp normalize.cpp level.cpp environment.cpp - type_check.cpp context.cpp) + type_check.cpp context.cpp builtin.cpp) target_link_libraries(kernel ${EXTRA_LIBS}) diff --git a/src/kernel/builtin.cpp b/src/kernel/builtin.cpp new file mode 100644 index 000000000..e85b08193 --- /dev/null +++ b/src/kernel/builtin.cpp @@ -0,0 +1,78 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "builtin.h" + +namespace lean { + +class bool_type_value : public value { +public: + static char const * g_kind; + virtual ~bool_type_value() {} + char const * kind() const { return g_kind; } + virtual expr get_type() const { return type(level()); } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } + virtual bool operator==(value const & other) const { return other.kind() == kind(); } + virtual void display(std::ostream & out) const { out << "bool"; } + virtual format pp() const { return format("bool"); } + virtual unsigned hash() const { return 17; } +}; + +char const * bool_type_value::g_kind = "bool"; + +expr bool_type() { + static thread_local expr r; + if (!r) + r = to_expr(*(new bool_type_value())); + return r; +} + +bool is_bool_type(expr const & e) { + return is_value(e) && to_value(e).kind() == bool_type_value::g_kind; +} + +class bool_value_value : public value { + bool m_val; +public: + static char const * g_kind; + bool_value_value(bool v):m_val(v) {} + virtual ~bool_value_value() {} + char const * kind() const { return g_kind; } + virtual expr get_type() const { return bool_type(); } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } + virtual bool operator==(value const & other) const { + return other.kind() == kind() && m_val == static_cast(other).m_val; + } + virtual void display(std::ostream & out) const { out << (m_val ? "true" : "false"); } + virtual format pp() const { return format(m_val ? "true" : "false"); } + virtual unsigned hash() const { return m_val ? 3 : 5; } + bool get_val() const { return m_val; } +}; + +char const * bool_value_value::g_kind = "bool_value"; + +expr bool_value(bool v) { + return to_expr(*(new bool_value_value(v))); +} + +bool is_bool_value(expr const & e) { + return is_value(e) && to_value(e).kind() == bool_value_value::g_kind; +} + +bool to_bool(expr const & e) { + lean_assert(is_bool_value(e)); + return static_cast(to_value(e)).get_val(); +} + +bool is_true(expr const & e) { + return is_bool_value(e) && to_bool(e); +} + +bool is_false(expr const & e) { + return is_bool_value(e) && !to_bool(e); +} + +} diff --git a/src/kernel/builtin.h b/src/kernel/builtin.h new file mode 100644 index 000000000..f61d64154 --- /dev/null +++ b/src/kernel/builtin.h @@ -0,0 +1,19 @@ +/* +Copyright (c) 2013 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "expr.h" + +namespace lean { +expr bool_type(); +bool is_bool_type(expr const & e); + +expr bool_value(bool v); +bool is_bool_value(expr const & e); +bool to_bool(expr const & e); +bool is_true(expr const & e); +bool is_false(expr const & e); +} diff --git a/src/kernel/deep_copy.cpp b/src/kernel/deep_copy.cpp index dcc116452..7ebacbcf7 100644 --- a/src/kernel/deep_copy.cpp +++ b/src/kernel/deep_copy.cpp @@ -32,8 +32,10 @@ class deep_copy_fn { r = app(new_args.size(), new_args.data()); break; } + case expr_kind::Eq: r = eq(apply(eq_lhs(a)), apply(eq_rhs(a))); break; case expr_kind::Lambda: r = lambda(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))); break; case expr_kind::Pi: r = pi(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))); break; + case expr_kind::Let: r = let(let_name(a), apply(let_value(a)), apply(let_body(a))); break; } if (sh) m_cache.insert(std::make_pair(a.raw(), r)); diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 8df5abc0c..7e2060dbc 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -71,7 +71,13 @@ expr app(unsigned n, expr const * as) { to_app(r)->m_hash = hash_args(new_n, m_args); return r; } - +expr_eq::expr_eq(expr const & lhs, expr const & rhs): + expr_cell(expr_kind::Eq, ::lean::hash(lhs.hash(), rhs.hash())), + m_lhs(lhs), + m_rhs(rhs) { +} +expr_eq::~expr_eq() { +} expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, expr const & b): expr_cell(k, ::lean::hash(t.hash(), b.hash())), m_name(n), @@ -89,7 +95,14 @@ expr_type::expr_type(level const & l): } expr_type::~expr_type() { } - +expr_let::expr_let(name const & n, expr const & v, expr const & b): + expr_cell(expr_kind::Let, ::lean::hash(v.hash(), b.hash())), + m_name(n), + m_value(v), + m_body(b) { +} +expr_let::~expr_let() { +} expr_value::expr_value(value & v): expr_cell(expr_kind::Value, v.hash()), m_val(v) { @@ -104,10 +117,12 @@ void expr_cell::dealloc() { case expr_kind::Var: delete static_cast(this); break; case expr_kind::Constant: delete static_cast(this); break; case expr_kind::App: static_cast(this)->~expr_app(); delete[] reinterpret_cast(this); break; + case expr_kind::Eq: delete static_cast(this); break; case expr_kind::Lambda: delete static_cast(this); break; case expr_kind::Pi: delete static_cast(this); break; case expr_kind::Type: delete static_cast(this); break; case expr_kind::Value: delete static_cast(this); break; + case expr_kind::Let: delete static_cast(this); break; } } @@ -142,6 +157,7 @@ class eq_fn { if (!apply(arg(a, i), arg(b, i))) return false; return true; + case expr_kind::Eq: return apply(eq_lhs(a), eq_lhs(b)) && apply(eq_rhs(a), eq_rhs(b)); case expr_kind::Lambda: case expr_kind::Pi: // Lambda and Pi @@ -149,6 +165,7 @@ class eq_fn { return apply(abst_domain(a), abst_domain(b)) && apply(abst_body(a), abst_body(b)); case expr_kind::Type: return ty_level(a) == ty_level(b); case expr_kind::Value: return to_value(a) == to_value(b); + case expr_kind::Let: return apply(let_value(a), let_value(b)) && apply(let_body(a), let_body(b)); } lean_unreachable(); return false; @@ -180,15 +197,17 @@ std::ostream & operator<<(std::ostream & out, expr const & a) { } out << ")"; break; - case expr_kind::Lambda: out << "(fun (" << abst_name(a) << " : " << abst_domain(a) << ") " << abst_body(a) << ")"; break; + case expr_kind::Eq: out << "(" << eq_lhs(a) << " = " << eq_rhs(a) << ")"; break; + case expr_kind::Lambda: out << "(fun " << abst_name(a) << " : " << abst_domain(a) << " => " << abst_body(a) << ")"; break; case expr_kind::Pi: if (!is_arrow(a)) - out << "(pi (" << abst_name(a) << " : " << abst_domain(a) << ") " << abst_body(a) << ")"; + out << "(pi " << abst_name(a) << " : " << abst_domain(a) << ", " << abst_body(a) << ")"; else if (!is_arrow(abst_domain(a))) out << abst_domain(a) << " -> " << abst_body(a); else out << "(" << abst_domain(a) << ") -> " << abst_body(a); break; + case expr_kind::Let: out << "(let " << let_name(a) << " := " << let_value(a) << " in " << let_body(a) << ")"; break; case expr_kind::Type: { level const & l = ty_level(a); if (is_uvar(l) && uvar_idx(l) == 0) @@ -209,8 +228,10 @@ expr copy(expr const & a) { case expr_kind::Type: return type(ty_level(a)); case expr_kind::Value: return to_expr(static_cast(a.raw())->m_val); case expr_kind::App: return app(num_args(a), begin_args(a)); + case expr_kind::Eq: return eq(eq_lhs(a), eq_rhs(a)); case expr_kind::Lambda: return lambda(abst_name(a), abst_domain(a), abst_body(a)); case expr_kind::Pi: return pi(abst_name(a), abst_domain(a), abst_body(a)); + case expr_kind::Let: return let(let_name(a), let_value(a), let_body(a)); } lean_unreachable(); return expr(); @@ -235,6 +256,17 @@ lean::format pp_aux(lean::expr const & a) { } return paren(r); } + case expr_kind::Eq: + return paren(format{pp_aux(eq_lhs(a)), format("="), pp_aux(eq_rhs(a))}); + case expr_kind::Let: + return paren(format{ + highlight(format("let "), format::format_color::PINK), /* Use unicode lambda */ + paren(format{ + format(let_name(a)), + format(" := "), + pp_aux(let_value(a))}), + format(" in "), + pp_aux(let_body(a))}); case expr_kind::Lambda: return paren(format{ highlight(format("\u03BB "), format::format_color::PINK), /* Use unicode lambda */ diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 2e9266b69..654dd5a27 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -25,8 +25,10 @@ class value; | Lambda name expr expr | Pi name expr expr | Type universe + | Eq expr expr (heterogeneous equality) + | Let name expr expr -TODO: add meta-variables, let, constructor references and match. +TODO: add meta-variables, and match expressions. The main API is divided in the following sections - Testers @@ -34,7 +36,7 @@ The main API is divided in the following sections - Accessors - Miscellaneous ======================================= */ -enum class expr_kind { Var, Constant, Value, App, Lambda, Pi, Type }; +enum class expr_kind { Var, Constant, Value, App, Lambda, Pi, Type, Eq, Let }; /** \brief Base class used to represent expressions. @@ -96,10 +98,11 @@ public: friend expr to_expr(value & v); friend expr app(unsigned num_args, expr const * args); friend expr app(std::initializer_list const & l); + friend expr eq(expr const & l, expr const & r); friend expr lambda(name const & n, expr const & t, expr const & e); friend expr pi(name const & n, expr const & t, expr const & e); - friend expr prop(); friend expr type(level const & l); + friend expr let(name const & n, expr const & v, expr const & e); friend bool eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; } @@ -139,6 +142,16 @@ public: expr const * begin_args() const { return m_args; } expr const * end_args() const { return m_args + m_num_args; } }; +/** \brief Heterogeneous equality */ +class expr_eq : public expr_cell { + expr m_lhs; + expr m_rhs; +public: + expr_eq(expr const & lhs, expr const & rhs); + ~expr_eq(); + expr const & get_lhs() const { return m_lhs; } + expr const & get_rhs() const { return m_rhs; } +}; /** \brief Super class for lambda abstraction and pi (functional spaces). */ class expr_abstraction : public expr_cell { name m_name; @@ -160,6 +173,18 @@ class expr_pi : public expr_abstraction { public: expr_pi(name const & n, expr const & t, expr const & e); }; +/** \brief Let expressions */ +class expr_let : public expr_cell { + name m_name; + expr m_value; + expr m_body; +public: + expr_let(name const & n, expr const & v, expr const & b); + ~expr_let(); + name const & get_name() const { return m_name; } + expr const & get_value() const { return m_value; } + expr const & get_body() const { return m_body; } +}; /** \brief Type */ class expr_type : public expr_cell { level m_level; @@ -205,18 +230,22 @@ inline bool is_var(expr_cell * e) { return e->kind() == expr_kind::Var; inline bool is_constant(expr_cell * e) { return e->kind() == expr_kind::Constant; } inline bool is_value(expr_cell * e) { return e->kind() == expr_kind::Value; } inline bool is_app(expr_cell * e) { return e->kind() == expr_kind::App; } +inline bool is_eq(expr_cell * e) { return e->kind() == expr_kind::Eq; } inline bool is_lambda(expr_cell * e) { return e->kind() == expr_kind::Lambda; } inline bool is_pi(expr_cell * e) { return e->kind() == expr_kind::Pi; } inline bool is_type(expr_cell * e) { return e->kind() == expr_kind::Type; } +inline bool is_let(expr_cell * e) { return e->kind() == expr_kind::Let; } inline bool is_abstraction(expr_cell * e) { return is_lambda(e) || is_pi(e); } inline bool is_var(expr const & e) { return e.kind() == expr_kind::Var; } inline bool is_constant(expr const & e) { return e.kind() == expr_kind::Constant; } inline bool is_value(expr const & e) { return e.kind() == expr_kind::Value; } inline bool is_app(expr const & e) { return e.kind() == expr_kind::App; } +inline bool is_eq(expr const & e) { return e.kind() == expr_kind::Eq; } inline bool is_lambda(expr const & e) { return e.kind() == expr_kind::Lambda; } inline bool is_pi(expr const & e) { return e.kind() == expr_kind::Pi; } inline bool is_type(expr const & e) { return e.kind() == expr_kind::Type; } +inline bool is_let(expr const & e) { return e.kind() == expr_kind::Let; } inline bool is_abstraction(expr const & e) { return is_lambda(e) || is_pi(e); } // ======================================= @@ -231,11 +260,14 @@ inline expr app(expr const & e1, expr const & e2) { expr args[2] = {e1, e2}; ret inline expr app(expr const & e1, expr const & e2, expr const & e3) { expr args[3] = {e1, e2, e3}; return app(3, args); } inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4) { expr args[4] = {e1, e2, e3, e4}; return app(4, args); } inline expr app(expr const & e1, expr const & e2, expr const & e3, expr const & e4, expr const & e5) { expr args[5] = {e1, e2, e3, e4, e5}; return app(5, args); } +inline expr eq(expr const & l, expr const & r) { return expr(new expr_eq(l, r)); } inline expr lambda(name const & n, expr const & t, expr const & e) { return expr(new expr_lambda(n, t, e)); } inline expr lambda(char const * n, expr const & t, expr const & e) { return lambda(name(n), t, e); } inline expr pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); } inline expr pi(char const * n, expr const & t, expr const & e) { return pi(name(n), t, e); } inline expr arrow(expr const & t, expr const & e) { return pi(name("_"), t, e); } +inline expr let(name const & n, expr const & v, expr const & e) { return expr(new expr_let(n, v, e)); } +inline expr let(char const * n, expr const & v, expr const & e) { return let(name(n), v, e); } inline expr type(level const & l) { return expr(new expr_type(l)); } expr type(); @@ -250,17 +282,21 @@ inline expr expr::operator()(expr const & a1, expr const & a2, expr const & a3, inline expr_var * to_var(expr_cell * e) { lean_assert(is_var(e)); return static_cast(e); } inline expr_const * to_constant(expr_cell * e) { lean_assert(is_constant(e)); return static_cast(e); } inline expr_app * to_app(expr_cell * e) { lean_assert(is_app(e)); return static_cast(e); } +inline expr_eq * to_eq(expr_cell * e) { lean_assert(is_eq(e)); return static_cast(e); } inline expr_abstraction * to_abstraction(expr_cell * e) { lean_assert(is_abstraction(e)); return static_cast(e); } inline expr_lambda * to_lambda(expr_cell * e) { lean_assert(is_lambda(e)); return static_cast(e); } inline expr_pi * to_pi(expr_cell * e) { lean_assert(is_pi(e)); return static_cast(e); } inline expr_type * to_type(expr_cell * e) { lean_assert(is_type(e)); return static_cast(e); } +inline expr_let * to_let(expr_cell * e) { lean_assert(is_let(e)); return static_cast(e); } inline expr_var * to_var(expr const & e) { return to_var(e.raw()); } inline expr_const * to_constant(expr const & e) { return to_constant(e.raw()); } inline expr_app * to_app(expr const & e) { return to_app(e.raw()); } +inline expr_eq * to_eq(expr const & e) { return to_eq(e.raw()); } inline expr_abstraction * to_abstraction(expr const & e) { return to_abstraction(e.raw()); } inline expr_lambda * to_lambda(expr const & e) { return to_lambda(e.raw()); } inline expr_pi * to_pi(expr const & e) { return to_pi(e.raw()); } +inline expr_let * to_let(expr const & e) { return to_let(e.raw()); } inline expr_type * to_type(expr const & e) { return to_type(e.raw()); } // ======================================= @@ -274,10 +310,15 @@ inline name const & const_name(expr_cell * e) { return to_constant(e) inline value const & to_value(expr_cell * e) { lean_assert(is_value(e)); return static_cast(e)->get_value(); } inline unsigned num_args(expr_cell * e) { return to_app(e)->get_num_args(); } inline expr const & arg(expr_cell * e, unsigned idx) { return to_app(e)->get_arg(idx); } +inline expr const & eq_lhs(expr_cell * e) { return to_eq(e)->get_lhs(); } +inline expr const & eq_rhs(expr_cell * e) { return to_eq(e)->get_rhs(); } inline name const & abst_name(expr_cell * e) { return to_abstraction(e)->get_name(); } inline expr const & abst_domain(expr_cell * e) { return to_abstraction(e)->get_domain(); } inline expr const & abst_body(expr_cell * e) { return to_abstraction(e)->get_body(); } inline level const & ty_level(expr_cell * e) { return to_type(e)->get_level(); } +inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); } +inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); } +inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); } inline unsigned get_rc(expr const & e) { return e.raw()->get_rc(); } inline bool is_shared(expr const & e) { return get_rc(e) > 1; } @@ -289,10 +330,15 @@ inline unsigned num_args(expr const & e) { return to_app(e)->ge inline expr const & arg(expr const & e, unsigned idx) { return to_app(e)->get_arg(idx); } inline expr const * begin_args(expr const & e) { return to_app(e)->begin_args(); } inline expr const * end_args(expr const & e) { return to_app(e)->end_args(); } +inline expr const & eq_lhs(expr const & e) { return to_eq(e)->get_lhs(); } +inline expr const & eq_rhs(expr const & e) { return to_eq(e)->get_rhs(); } inline name const & abst_name(expr const & e) { return to_abstraction(e)->get_name(); } inline expr const & abst_domain(expr const & e) { return to_abstraction(e)->get_domain(); } inline expr const & abst_body(expr const & e) { return to_abstraction(e)->get_body(); } inline level const & ty_level(expr const & e) { return to_type(e)->get_level(); } +inline name const & let_name(expr const & e) { return to_let(e)->get_name(); } +inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); } +inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } // ======================================= // ======================================= @@ -385,6 +431,30 @@ template expr update_abst(expr const & e, F f) { return e; } } +template expr update_let(expr const & e, F f) { + static_assert(std::is_same::type, + std::pair>::value, + "update_let: return type of f is not pair"); + expr const & old_v = let_value(e); + expr const & old_b = let_body(e); + std::pair p = f(old_v, old_b); + if (!eqp(p.first, old_v) || !eqp(p.second, old_b)) + return let(let_name(e), p.first, p.second); + else + return e; +} +template expr update_eq(expr const & e, F f) { + static_assert(std::is_same::type, + std::pair>::value, + "update_eq: return type of f is not pair"); + expr const & old_l = eq_lhs(e); + expr const & old_r = eq_rhs(e); + std::pair p = f(old_l, old_r); + if (!eqp(p.first, old_l) || !eqp(p.second, old_r)) + return eq(p.first, p.second); + else + return e; +} // ======================================= } void pp(lean::expr const & a); diff --git a/src/kernel/free_vars.cpp b/src/kernel/free_vars.cpp index 068be0b97..b1ddc37b0 100644 --- a/src/kernel/free_vars.cpp +++ b/src/kernel/free_vars.cpp @@ -27,7 +27,7 @@ protected: return false; case expr_kind::Var: return process_var(e, offset); - case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: + case expr_kind::App: case expr_kind::Eq: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Let: break; } @@ -50,10 +50,16 @@ protected: case expr_kind::App: result = std::any_of(begin_args(e), end_args(e), [=](expr const & arg){ return apply(arg, offset); }); break; + case expr_kind::Eq: + result = apply(eq_lhs(e), offset) || apply(eq_rhs(e), offset); + break; case expr_kind::Lambda: case expr_kind::Pi: result = apply(abst_domain(e), offset) || apply(abst_body(e), offset + 1); break; + case expr_kind::Let: + result = apply(let_value(e), offset) || apply(let_body(e), offset + 1); + break; } if (!result) diff --git a/src/kernel/max_sharing.cpp b/src/kernel/max_sharing.cpp index 033118803..d855fc3f4 100644 --- a/src/kernel/max_sharing.cpp +++ b/src/kernel/max_sharing.cpp @@ -40,11 +40,21 @@ struct max_sharing_fn::imp { cache(r); return r; } + case expr_kind::Eq : { + expr r = update_eq(a, [=](expr const & l, expr const & r){ return std::make_pair(apply(l), apply(r)); }); + cache(r); + return r; + } case expr_kind::Lambda: case expr_kind::Pi: { expr r = update_abst(a, [=](expr const & t, expr const & b) { return std::make_pair(apply(t), apply(b)); }); cache(r); return r; + } + case expr_kind::Let: { + expr r = update_let(a, [=](expr const & v, expr const & b) { return std::make_pair(apply(v), apply(b)); }); + cache(r); + return r; }} lean_unreachable(); return a; diff --git a/src/kernel/normalize.cpp b/src/kernel/normalize.cpp index 5543d0870..9dd87f2c3 100644 --- a/src/kernel/normalize.cpp +++ b/src/kernel/normalize.cpp @@ -9,6 +9,7 @@ Author: Leonardo de Moura #include "expr.h" #include "context.h" #include "environment.h" +#include "builtin.h" #include "free_vars.h" #include "list.h" #include "buffer.h" @@ -163,13 +164,26 @@ class normalize_fn { } } } + case expr_kind::Eq: { + expr new_l = reify(normalize(eq_lhs(a), s, k), k); + expr new_r = reify(normalize(eq_rhs(a), s, k), k); + if (new_l == new_r) { + return svalue(bool_value(true)); + } else { + // TODO: Invoke semantic attachments. + return svalue(eq(new_l, new_r)); + } + } case expr_kind::Lambda: return svalue(a, s); case expr_kind::Pi: { expr new_t = reify(normalize(abst_domain(a), s, k), k); expr new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1); return svalue(pi(abst_name(a), new_t, new_b)); - }} + } + case expr_kind::Let: + return normalize(let_body(a), extend(s, normalize(let_value(a), s, k)), k+1); + } lean_unreachable(); return svalue(a); } diff --git a/src/kernel/replace.h b/src/kernel/replace.h index 01d90e94d..e05cd8e7a 100644 --- a/src/kernel/replace.h +++ b/src/kernel/replace.h @@ -44,10 +44,16 @@ class replace_fn { case expr_kind::App: r = update_app(e, [=](expr const & c) { return apply(c, offset); }); break; + case expr_kind::Eq: + r = update_eq(e, [=](expr const & l, expr const & r) { return std::make_pair(apply(l, offset), apply(r, offset)); }); + break; case expr_kind::Lambda: case expr_kind::Pi: r = update_abst(e, [=](expr const & t, expr const & b) { return std::make_pair(apply(t, offset), apply(b, offset+1)); }); break; + case expr_kind::Let: + r = update_let(e, [=](expr const & v, expr const & b) { return std::make_pair(apply(v, offset), apply(b, offset+1)); }); + break; } } diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index e874ce0b4..23dda050d 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "type_check.h" #include "normalize.h" #include "instantiate.h" +#include "builtin.h" #include "free_vars.h" #include "exception.h" #include "trace.h" @@ -100,6 +101,10 @@ class infer_type_fn { check_pi(f_t, ctx); } } + case expr_kind::Eq: + infer_type(eq_lhs(e), ctx); + infer_type(eq_rhs(e), ctx); + return bool_type(); case expr_kind::Lambda: { infer_universe(abst_domain(e), ctx); expr t = infer_type(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); @@ -110,6 +115,8 @@ class infer_type_fn { level l2 = infer_universe(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); return type(max(l1, l2)); } + case expr_kind::Let: + return infer_type(let_body(e), extend(ctx, let_name(e), infer_type(let_value(e), ctx), let_value(e))); case expr_kind::Value: return to_value(e).get_type(); } diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index bed285ec3..b872b5a44 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -74,8 +74,12 @@ unsigned depth1(expr const & e) { m = std::max(m, depth1(a)); return m + 1; } + case expr_kind::Eq: + return std::max(depth1(eq_lhs(e)), depth1(eq_rhs(e))) + 1; case expr_kind::Lambda: case expr_kind::Pi: return std::max(depth1(abst_domain(e)), depth1(abst_body(e))) + 1; + case expr_kind::Let: + return std::max(depth1(let_value(e)), depth1(let_body(e))) + 1; } return 0; } @@ -90,8 +94,12 @@ unsigned depth2(expr const & e) { std::accumulate(begin_args(e), end_args(e), 0, [](unsigned m, expr const & arg){ return std::max(depth2(arg), m); }) + 1; + case expr_kind::Eq: + return std::max(depth2(eq_lhs(e)), depth2(eq_rhs(e))) + 1; case expr_kind::Lambda: case expr_kind::Pi: return std::max(depth2(abst_domain(e)), depth2(abst_body(e))) + 1; + case expr_kind::Let: + return std::max(depth2(let_value(e)), depth2(let_body(e))) + 1; } return 0; } @@ -116,10 +124,18 @@ unsigned depth3(expr const & e) { todo.push_back(std::make_pair(&arg(e, i), c)); break; } + case expr_kind::Eq: + todo.push_back(std::make_pair(&eq_lhs(e), c)); + todo.push_back(std::make_pair(&eq_rhs(e), c)); + break; case expr_kind::Lambda: case expr_kind::Pi: todo.push_back(std::make_pair(&abst_domain(e), c)); todo.push_back(std::make_pair(&abst_body(e), c)); break; + case expr_kind::Let: + todo.push_back(std::make_pair(&let_value(e), c)); + todo.push_back(std::make_pair(&let_body(e), c)); + break; } } return m; @@ -173,8 +189,12 @@ unsigned count_core(expr const & a, expr_set & s) { case expr_kind::App: return std::accumulate(begin_args(a), end_args(a), 1, [&](unsigned sum, expr const & arg){ return sum + count_core(arg, s); }); + case expr_kind::Eq: + return count_core(eq_lhs(a), s) + count_core(eq_rhs(a), s) + 1; case expr_kind::Lambda: case expr_kind::Pi: return count_core(abst_domain(a), s) + count_core(abst_body(a), s) + 1; + case expr_kind::Let: + return count_core(let_value(a), s) + count_core(let_body(a), s) + 1; } return 0; } @@ -343,6 +363,14 @@ void tst14() { std::cout << t0 << " " << t1 << "\n"; } +void tst15() { + expr t = eq(constant("a"), constant("b")); + std::cout << t << "\n"; + expr l = let("a", constant("b"), var(0)); + std::cout << l << "\n"; + lean_assert(closed(l)); +} + int main() { continue_on_violation(true); std::cout << "sizeof(expr): " << sizeof(expr) << "\n"; @@ -363,6 +391,7 @@ int main() { tst12(); tst13(); tst14(); + tst15(); std::cout << "done" << "\n"; return has_violations() ? 1 : 0; } diff --git a/src/tests/kernel/normalize.cpp b/src/tests/kernel/normalize.cpp index a5e3645dc..cd5db02c8 100644 --- a/src/tests/kernel/normalize.cpp +++ b/src/tests/kernel/normalize.cpp @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #include #include "normalize.h" +#include "builtin.h" #include "trace.h" #include "test.h" #include "sets.h" @@ -62,8 +63,12 @@ unsigned count_core(expr const & a, expr_set & s) { case expr_kind::App: return std::accumulate(begin_args(a), end_args(a), 1, [&](unsigned sum, expr const & arg){ return sum + count_core(arg, s); }); + case expr_kind::Eq: + return count_core(eq_lhs(a), s) + count_core(eq_rhs(a), s) + 1; case expr_kind::Lambda: case expr_kind::Pi: return count_core(abst_domain(a), s) + count_core(abst_body(a), s) + 1; + case expr_kind::Let: + return count_core(let_value(a), s) + count_core(let_body(a), s) + 1; } return 0; } @@ -159,10 +164,28 @@ static void tst2() { lean_assert(F6 == lambda("z1", t, lambda("z2", t, app(var(2), var(3), constant("a"))))); } +static void tst3() { + environment env; + expr t1 = constant("a"); + expr t2 = constant("a"); + expr e = eq(t1, t2); + std::cout << e << " --> " << normalize(e, env) << "\n"; + lean_assert(normalize(e, env) == bool_value(true)); +} + +static void tst4() { + environment env; + expr t1 = let("a", constant("b"), lambda("c", type(), var(1)(var(0)))); + std::cout << t1 << " --> " << normalize(t1, env) << "\n"; + lean_assert(normalize(t1, env) == lambda("c", type(), constant("b")(var(0)))); +} + int main() { continue_on_violation(true); tst_church_numbers(); tst1(); tst2(); + tst3(); + tst4(); return has_violations() ? 1 : 0; } diff --git a/src/tests/kernel/type_check.cpp b/src/tests/kernel/type_check.cpp index 397d3c88f..af0a93352 100644 --- a/src/tests/kernel/type_check.cpp +++ b/src/tests/kernel/type_check.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "type_check.h" #include "abstract.h" #include "exception.h" +#include "builtin.h" #include "trace.h" #include "test.h" using namespace lean; @@ -63,10 +64,20 @@ static void tst2() { } } +static void tst3() { + environment env; + expr f = fun("a", bool_type(), eq(constant("a"), bool_value(true))); + std::cout << infer_type(f, env) << "\n"; + lean_assert(infer_type(f, env) == arrow(bool_type(), bool_type())); + expr t = let("a", bool_value(true), var(0)); + std::cout << infer_type(t, env) << "\n"; +} + int main() { continue_on_violation(true); enable_trace("type_check"); tst1(); tst2(); + tst3(); return has_violations() ? 1 : 0; }