diff --git a/src/kernel/environment.cpp b/src/kernel/environment.cpp index b2de355fc..568beb907 100644 --- a/src/kernel/environment.cpp +++ b/src/kernel/environment.cpp @@ -9,19 +9,58 @@ Author: Leonardo de Moura #include #include #include +#include #include "environment.h" +#include "type_check.h" #include "exception.h" #include "debug.h" namespace lean { constexpr unsigned uninit = std::numeric_limits::max(); +environment::definition::definition(name const & n, expr const & t, expr const & v, bool opaque): + m_name(n), + m_type(t), + m_value(v), + m_opaque(opaque) { +} + +environment::definition::~definition() { +} + +environment::object_kind environment::definition::kind() const { + return object_kind::Definition; +} + +void environment::definition::display(std::ostream & out) const { + out << "Definition " << m_name << " : " << m_type << " := " << m_value << "\n"; +} + +environment::fact::fact(name const & n, expr const & t): + m_name(n), + m_type(t) { +} + +environment::fact::~fact() { +} + +environment::object_kind environment::fact::kind() const { + return object_kind::Fact; +} + +void environment::fact::display(std::ostream & out) const { + out << "Fact " << m_name << " : " << m_type << "\n"; +} + /** \brief Implementation of the Lean environment. */ struct environment::imp { - std::vector> m_uvar_distances; - std::vector m_uvars; - std::atomic m_num_children; - std::shared_ptr m_parent; + typedef std::unordered_map object_dictionary; + std::vector> m_uvar_distances; + std::vector m_uvars; + std::atomic m_num_children; + std::shared_ptr m_parent; + std::vector m_objects; + object_dictionary m_object_dictionary; bool has_children() const { return m_num_children > 0; } void inc_children() { m_num_children++; } @@ -171,6 +210,52 @@ struct environment::imp { }); } + void check_no_children() { + if (has_children()) + throw exception("invalid object declaration, environment has children environments"); + } + + void check_name(name const & n) { + if (m_object_dictionary.find(n) != m_object_dictionary.end()) { + std::ostringstream s; + s << "environment already contains an object with name '" << n << "'"; + throw exception (s.str()); + } + } + + void add_definition(name const & n, expr const & t, expr const & v, bool opaque) { + m_objects.push_back(new definition(n, t, v, opaque)); + m_object_dictionary.insert(std::make_pair(n, m_objects.back())); + } + + void add_fact(name const & n, expr const & t) { + m_objects.push_back(new fact(n, t)); + m_object_dictionary.insert(std::make_pair(n, m_objects.back())); + } + + object const * get_object_ptr(name const & n) const { + auto it = m_object_dictionary.find(n); + if (it == m_object_dictionary.end()) { + if (has_parent()) + return m_parent->get_object_ptr(n); + else + return nullptr; + } else { + return it->second; + } + } + + object const & get_object(name const & n) const { + object const * ptr = get_object_ptr(n); + if (ptr) { + return *ptr; + } else { + std::ostringstream s; + s << "unknown object '" << n << "'"; + throw exception (s.str()); + } + } + imp(): m_num_children(0) { init_uvars(); @@ -185,6 +270,7 @@ struct environment::imp { ~imp() { if (m_parent) m_parent->dec_children(); + std::for_each(m_objects.begin(), m_objects.end(), [](object * obj) { delete obj; }); } }; @@ -236,4 +322,40 @@ level environment::get_uvar(name const & n) const { return m_imp->get_uvar(n); } +void environment::add_definition(name const & n, expr const & t, expr const & v, bool opaque) { + m_imp->check_no_children(); + m_imp->check_name(n); + infer_universe(t, *this); + expr v_t = infer_type(v, *this); + if (!is_convertible(t, v_t, *this)) { + std::ostringstream buffer; + buffer << "type mismatch when defining '" << n << "'\n" + << "expected type:\n" << t << "\n" + << "given type:\n" << v_t; + throw exception(buffer.str()); + } + m_imp->add_definition(n, t, v, opaque); +} + +void environment::add_definition(name const & n, expr const & v, bool opaque) { + m_imp->check_no_children(); + m_imp->check_name(n); + expr v_t = infer_type(v, *this); + m_imp->add_definition(n, v_t, v, opaque); +} + +void environment::add_fact(name const & n, expr const & t) { + m_imp->check_no_children(); + m_imp->check_name(n); + infer_universe(t, *this); + m_imp->add_fact(n, t); +} + +environment::object const & environment::get_object(name const & n) const { + return m_imp->get_object(n); +} + +environment::object const * environment::get_object_ptr(name const & n) const { + return m_imp->get_object_ptr(n); +} } diff --git a/src/kernel/environment.h b/src/kernel/environment.h index 35b33b191..d8d40db76 100644 --- a/src/kernel/environment.h +++ b/src/kernel/environment.h @@ -69,21 +69,92 @@ public: */ environment parent() const; + enum class object_kind { Definition, Fact }; + + /** + \brief Base class for environment objects + It is just a place holder at this point. + */ + class object { + public: + object() {} + object(object const & o) = delete; + object & operator=(object const & o) = delete; + + virtual ~object() {} + virtual object_kind kind() const = 0; + virtual void display(std::ostream & out) const = 0; + virtual expr const & get_type() const = 0; + }; + + class definition : public object { + name m_name; + expr m_type; + expr m_value; + bool m_opaque; + public: + definition(name const & n, expr const & t, expr const & v, bool opaque); + virtual ~definition(); + virtual object_kind kind() const; + name const & get_name() const { return m_name; } + virtual expr const & get_type() const { return m_type; } + expr const & get_value() const { return m_value; } + bool is_opaque() const { return m_opaque; } + virtual void display(std::ostream & out) const; + }; + + class fact : public object { + name m_name; + expr m_type; + public: + fact(name const & n, expr const & t); + virtual ~fact(); + virtual object_kind kind() const; + name const & get_name() const { return m_name; } + virtual expr const & get_type() const { return m_type; } + virtual void display(std::ostream & out) const; + }; + + friend bool is_definition(object const & o) { return o.kind() == object_kind::Definition; } + friend bool is_fact(object const & o) { return o.kind() == object_kind::Fact; } + + friend definition const & to_definition(object const & o) { lean_assert(is_definition(o)); return static_cast(o); } + friend fact const & to_fact(object const & o) { lean_assert(is_fact(o)); return static_cast(o); } + /** \brief Add a new definition n : t := v. It throws an exception if v does not have type t. + It throws an exception if there is already an object with the given name. If opaque == true, then definition is not used by normalizer. */ void add_definition(name const & n, expr const & t, expr const & v, bool opaque = false); + void add_definition(char const * n, expr const & t, expr const & v, bool opaque = false) { add_definition(name(n), t, v, opaque); } /** \brief Add a new definition n : infer_type(v) := v. + It throws an exception if there is already an object with the given name. If opaque == true, then definition is not used by normalizer. */ void add_definition(name const & n, expr const & v, bool opaque = false); + void add_definition(char const * n, expr const & v, bool opaque = false) { add_definition(name(n), v, opaque); } /** + \brief Add a new fact to the environment. + It throws an exception if there is already an object with the given name. */ void add_fact(name const & n, expr const & t); + void add_fact(char const * n, expr const & t) { add_fact(name(n), t); } + + /** + \brief Return the object with the given name. + It throws an exception if the environment does not have an object with the given name. + */ + object const & get_object(name const & n) const; + + /** + \brief Return the object with the given name. + Return nullptr if there is no object with the given name. + */ + object const * get_object_ptr(name const & n) const; }; } diff --git a/src/kernel/normalize.cpp b/src/kernel/normalize.cpp index 325182146..0fe181369 100644 --- a/src/kernel/normalize.cpp +++ b/src/kernel/normalize.cpp @@ -132,7 +132,16 @@ class normalize_fn { switch (a.kind()) { case expr_kind::Var: return lookup(s, var_idx(a), k); - case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: + case expr_kind::Constant: { + environment::object const & obj = m_env.get_object(const_name(a)); + if (is_definition(obj) && !to_definition(obj).is_opaque()) { + return normalize(to_definition(obj).get_value(), value_stack(), 0); + } + else { + return svalue(a); + } + } + case expr_kind::Type: case expr_kind::Value: return svalue(a); case expr_kind::App: { svalue f = normalize(arg(a, 0), s, k); diff --git a/src/kernel/type_check.cpp b/src/kernel/type_check.cpp index 7bcb0ad07..6af4fc770 100644 --- a/src/kernel/type_check.cpp +++ b/src/kernel/type_check.cpp @@ -32,7 +32,7 @@ bool is_convertible(expr const & expected, expr const & given, environment const return is_convertible_core(e_n, g_n, env); } -class infer_type_fn { +struct infer_type_fn { environment const & m_env; expr lookup(context const & c, unsigned i) { @@ -50,10 +50,10 @@ class infer_type_fn { if (is_bool_type(u)) return level(); std::ostringstream buffer; - buffer << "type expected"; + buffer << "type expected, "; if (!empty(ctx)) - buffer << ", in context:\n" << ctx; - buffer << "\ngiven:\n" << t; + buffer << "in context:\n" << ctx << "\n"; + buffer << "got:\n" << t; throw exception(buffer.str()); } @@ -64,10 +64,10 @@ class infer_type_fn { if (is_pi(r)) return r; std::ostringstream buffer; - buffer << "function expected"; + buffer << "function expected, "; if (!empty(ctx)) - buffer << ", in context:\n" << ctx; - buffer << "\ngiven:\n" << e; + buffer << "in context:\n" << ctx << "\n"; + buffer << "got:\n" << e; throw exception(buffer.str()); } @@ -79,8 +79,7 @@ class infer_type_fn { lean_trace("type_check", tout << "infer type\n" << e << "\n" << ctx << "\n";); switch (e.kind()) { case expr_kind::Constant: - // TODO - return e; + return m_env.get_object(const_name(e)).get_type(); case expr_kind::Var: return lookup(ctx, var_idx(e)); case expr_kind::Type: return type(ty_level(e) + 1); case expr_kind::App: { @@ -128,7 +127,7 @@ class infer_type_fn { lean_unreachable(); return e; } -public: + infer_type_fn(environment const & env): m_env(env) { } @@ -141,4 +140,8 @@ public: expr infer_type(expr const & e, environment const & env, context const & ctx) { return infer_type_fn(env)(e, ctx); } + +level infer_universe(expr const & t, environment const & env, context const & ctx) { + return infer_type_fn(env).infer_universe(t, ctx); +} } diff --git a/src/kernel/type_check.h b/src/kernel/type_check.h index a7057d693..52c269b71 100644 --- a/src/kernel/type_check.h +++ b/src/kernel/type_check.h @@ -11,5 +11,6 @@ Author: Leonardo de Moura namespace lean { expr infer_type(expr const & e, environment const & env, context const & ctx = context()); +level infer_universe(expr const & t, environment const & env, context const & ctx = context()); bool is_convertible(expr const & t1, expr const & t2, environment const & env, context const & ctx = context()); } diff --git a/src/tests/kernel/arith.cpp b/src/tests/kernel/arith.cpp index c94676d65..a11f5eda7 100644 --- a/src/tests/kernel/arith.cpp +++ b/src/tests/kernel/arith.cpp @@ -77,6 +77,7 @@ static void tst4() { static void tst5() { environment env; + env.add_fact(name("a"), int_type()); expr e = eq(int_value(3), int_value(4)); std::cout << e << " --> " << normalize(e, env) << "\n"; lean_assert(normalize(e, env) == bool_value(false)); diff --git a/src/tests/kernel/environment.cpp b/src/tests/kernel/environment.cpp index 02002cbb6..9be08118f 100644 --- a/src/tests/kernel/environment.cpp +++ b/src/tests/kernel/environment.cpp @@ -31,8 +31,7 @@ static void tst1() { try { level o = env.define_uvar("o", w + 1); lean_unreachable(); - } - catch (exception const & ex) { + } catch (exception const & ex) { std::cout << "expected error: " << ex.what() << "\n"; } } @@ -58,9 +57,56 @@ static void tst2() { std::cout << "uvar: " << child.get_uvar("u") << "\n"; } +static void tst3() { + environment env; + try { + env.add_definition("a", int_type(), constant("a")); + lean_unreachable(); + } catch (exception ex) { + std::cout << "expected error: " << ex.what() << "\n"; + } + env.add_definition("a", int_type(), app(int_add(), int_value(1), int_value(2))); + expr t = app(int_add(), constant("a"), int_value(1)); + std::cout << t << " --> " << normalize(t, env) << "\n"; + lean_assert(normalize(t, env) == int_value(4)); + env.add_definition("b", int_type(), app(int_mul(), int_value(2), constant("a"))); + std::cout << "b --> " << normalize(constant("b"), env) << "\n"; + lean_assert(normalize(constant("b"), env) == int_value(6)); + try { + env.add_definition("c", arrow(int_type(), int_type()), constant("a")); + lean_unreachable(); + } catch (exception ex) { + std::cout << "expected error: " << ex.what() << "\n"; + } + try { + env.add_definition("a", int_type(), int_value(10)); + lean_unreachable(); + } catch (exception ex) { + std::cout << "expected error: " << ex.what() << "\n"; + } + environment c_env = env.mk_child(); + try { + env.add_definition("c", int_type(), constant("a")); + lean_unreachable(); + } catch (exception ex) { + std::cout << "expected error: " << ex.what() << "\n"; + } + lean_assert(normalize(constant("b"), env) == int_value(6)); + lean_assert(normalize(constant("b"), c_env) == int_value(6)); + c_env.add_definition("c", int_type(), constant("a")); + lean_assert(normalize(constant("c"), c_env) == int_value(3)); + try { + lean_assert(normalize(constant("c"), env) == int_value(3)); + lean_unreachable(); + } catch (exception ex) { + std::cout << "expected error: " << ex.what() << "\n"; + } +} + int main() { continue_on_violation(true); tst1(); tst2(); + tst3(); return has_violations() ? 1 : 0; } diff --git a/src/tests/kernel/normalize.cpp b/src/tests/kernel/normalize.cpp index cd5db02c8..bf338644f 100644 --- a/src/tests/kernel/normalize.cpp +++ b/src/tests/kernel/normalize.cpp @@ -17,7 +17,7 @@ expr normalize(expr const & e) { return normalize(e, env); } -static void eval(expr const & e) { std::cout << e << " --> " << normalize(e) << "\n"; } +static void eval(expr const & e, environment & env) { std::cout << e << " --> " << normalize(e, env) << "\n"; } static expr t() { return constant("t"); } static expr lam(expr const & e) { return lambda("_", t(), e); } static expr lam(expr const & t, expr const & e) { return lambda("_", t, e); } @@ -79,63 +79,77 @@ unsigned count(expr const & a) { } static void tst_church_numbers() { + environment env; + env.add_fact("t", type(level())); + env.add_fact("N", type(level())); + env.add_fact("z", constant("N")); + env.add_fact("s", constant("N")); expr N = constant("N"); expr z = constant("z"); expr s = constant("s"); - std::cout << normalize(app(zero(), N, s, z)) << "\n"; - std::cout << normalize(app(one(), N, s, z)) << "\n"; - std::cout << normalize(app(two(), N, s, z)) << "\n"; - std::cout << normalize(app(four(), N, s, z)) << "\n"; - std::cout << count(normalize(app(four(), N, s, z))) << "\n"; - lean_assert(count(normalize(app(four(), N, s, z))) == 4 + 2); - std::cout << normalize(app(app(times(), four(), four()), N, s, z)) << "\n"; - std::cout << normalize(app(app(power(), two(), four()), N, s, z)) << "\n"; - lean_assert(count(normalize(app(app(power(), two(), four()), N, s, z))) == 16 + 2); - std::cout << normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z)) << "\n"; - std::cout << count(normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z))) << "\n"; - std::cout << count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z))) << "\n"; - lean_assert(count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z))) == 64 + 2); - expr big = normalize(app(app(power(), two(), app(power(), two(), three())), N, s, z)); + std::cout << normalize(app(zero(), N, s, z), env) << "\n"; + std::cout << normalize(app(one(), N, s, z), env) << "\n"; + std::cout << normalize(app(two(), N, s, z), env) << "\n"; + std::cout << normalize(app(four(), N, s, z), env) << "\n"; + std::cout << count(normalize(app(four(), N, s, z), env)) << "\n"; + lean_assert(count(normalize(app(four(), N, s, z), env)) == 4 + 2); + std::cout << normalize(app(app(times(), four(), four()), N, s, z), env) << "\n"; + std::cout << normalize(app(app(power(), two(), four()), N, s, z), env) << "\n"; + lean_assert(count(normalize(app(app(power(), two(), four()), N, s, z), env)) == 16 + 2); + std::cout << normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z), env) << "\n"; + std::cout << count(normalize(app(app(times(), two(), app(power(), two(), four())), N, s, z), env)) << "\n"; + std::cout << count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z), env)) << "\n"; + lean_assert(count(normalize(app(app(times(), four(), app(power(), two(), four())), N, s, z), env)) == 64 + 2); + expr big = normalize(app(app(power(), two(), app(power(), two(), three())), N, s, z), env); std::cout << count(big) << "\n"; lean_assert(count(big) == 256 + 2); expr three = app(plus(), two(), one()); - lean_assert(count(normalize(app(app(power(), three, three), N, s, z))) == 27 + 2); - // expr big2 = normalize(app(app(power(), two(), app(times(), app(plus(), four(), one()), four())), N, s, z)); + lean_assert(count(normalize(app(app(power(), three, three), N, s, z), env)) == 27 + 2); + // expr big2 = normalize(app(app(power(), two(), app(times(), app(plus(), four(), one()), four())), N, s, z), env); // std::cout << count(big2) << "\n"; - std::cout << normalize(lam(lam(app(app(times(), four(), four()), N, var(0), z)))) << "\n"; + std::cout << normalize(lam(lam(app(app(times(), four(), four()), N, var(0), z))), env) << "\n"; } static void tst1() { + environment env; + env.add_fact("t", type(level())); + expr t = type(level()); + env.add_fact("f", arrow(t, t)); expr f = constant("f"); + env.add_fact("a", t); expr a = constant("a"); + env.add_fact("b", t); expr b = constant("b"); expr x = var(0); expr y = var(1); - expr t = type(level()); - eval(app(lambda("x", t, x), a)); - eval(app(lambda("x", t, x), a, b)); - eval(lambda("x", t, f(x))); - eval(lambda("y", t, lambda("x", t, f(y, x)))); + eval(app(lambda("x", t, x), a), env); + eval(app(lambda("x", t, x), a, b), env); + eval(lambda("x", t, f(x)), env); + eval(lambda("y", t, lambda("x", t, f(y, x))), env); eval(app(lambda("x", t, app(lambda("f", t, app(var(0), b)), lambda("g", t, f(var(1))))), - a)); + a), env); expr l01 = lam(v(0)(v(1))); expr l12 = lam(lam(v(1)(v(2)))); - eval(lam(l12(l01))); - lean_assert(normalize(lam(l12(l01))) == lam(lam(v(1)(v(1))))); + eval(lam(l12(l01)), env); + lean_assert(normalize(lam(l12(l01)), env) == lam(lam(v(1)(v(1))))); } static void tst2() { environment env; + expr t = type(level()); + env.add_fact("f", arrow(t, t)); expr f = constant("f"); - expr h = constant("h"); + env.add_fact("a", t); expr a = constant("a"); + env.add_fact("b", t); expr b = constant("b"); + env.add_fact("h", arrow(t, t)); + expr h = constant("h"); expr x = var(0); expr y = var(1); - expr t = type(level()); lean_assert(normalize(f(x,x), env, extend(context(), name("f"), t, f(a))) == f(f(a), f(a))); context c1 = extend(extend(context(), name("f"), t, f(a)), name("h"), t, h(x)); expr F1 = normalize(f(x,f(x)), env, c1); @@ -166,6 +180,7 @@ static void tst2() { static void tst3() { environment env; + env.add_fact("a", bool_type()); expr t1 = constant("a"); expr t2 = constant("a"); expr e = eq(t1, t2); @@ -175,13 +190,14 @@ static void tst3() { static void tst4() { environment env; + env.add_fact("b", type(level())); expr t1 = let("a", constant("b"), lambda("c", type(), var(1)(var(0)))); std::cout << t1 << " --> " << normalize(t1, env) << "\n"; lean_assert(normalize(t1, env) == lambda("c", type(), constant("b")(var(0)))); } int main() { - continue_on_violation(true); + // continue_on_violation(true); tst_church_numbers(); tst1(); tst2(); diff --git a/src/tests/kernel/threads.cpp b/src/tests/kernel/threads.cpp index 84ead6bb1..e48b39bdd 100644 --- a/src/tests/kernel/threads.cpp +++ b/src/tests/kernel/threads.cpp @@ -13,11 +13,16 @@ Author: Leonardo de Moura #include "deep_copy.h" #include "abstract.h" #include "normalize.h" +#include "arith.h" #include "test.h" using namespace lean; expr normalize(expr const & e) { environment env; + env.add_fact("a", int_type()); + env.add_fact("b", int_type()); + env.add_fact("f", arrow(int_type(), arrow(int_type(), int_type()))); + env.add_fact("h", arrow(int_type(), arrow(int_type(), int_type()))); return normalize(e, env); } diff --git a/src/util/name.h b/src/util/name.h index 62abcc3a4..524e7f4f8 100644 --- a/src/util/name.h +++ b/src/util/name.h @@ -62,5 +62,6 @@ public: }; friend std::ostream & operator<<(std::ostream & out, sep const & s); }; - +struct name_hash { unsigned operator()(name const & n) const { return n.hash(); } }; +struct name_eq { bool operator()(name const & n1, name const & n2) const { return n1 == n2; } }; }