diff --git a/src/frontends/lean/frontend.cpp b/src/frontends/lean/frontend.cpp index 0fd492fdd..d17f787a5 100644 --- a/src/frontends/lean/frontend.cpp +++ b/src/frontends/lean/frontend.cpp @@ -315,16 +315,16 @@ struct lean_extension : public environment::extension { env.add_neutral_object(new coercion_declaration(f)); } - expr get_coercion(expr const & from_type, expr const & to_type) const { + optional get_coercion(expr const & from_type, expr const & to_type) const { expr_pair p(from_type, to_type); auto it = m_coercion_map.find(p); if (it != m_coercion_map.end()) - return it->second; + return optional(it->second); lean_extension const * parent = get_parent(); if (parent) return parent->get_coercion(from_type, to_type); else - return expr(); + return optional(); } list get_coercions(expr const & from_type) const { @@ -444,7 +444,7 @@ bool frontend::is_explicit(name const & n) const { void frontend::add_coercion(expr const & f) { to_ext(m_env).add_coercion(f, m_env); } -expr frontend::get_coercion(expr const & from_type, expr const & to_type) const { +optional frontend::get_coercion(expr const & from_type, expr const & to_type) const { return to_ext(m_env).get_coercion(from_type, to_type); } list frontend::get_coercions(expr const & from_type) const { diff --git a/src/frontends/lean/frontend.h b/src/frontends/lean/frontend.h index c63a7562f..5cd0ae9f3 100644 --- a/src/frontends/lean/frontend.h +++ b/src/frontends/lean/frontend.h @@ -157,10 +157,8 @@ public: /** \brief Return a coercion from given_type to expected_type if it exists. - Return the null expression if there is no coercion from \c from_type to - \c to_type. */ - expr get_coercion(expr const & from_type, expr const & to_type) const; + optional get_coercion(expr const & from_type, expr const & to_type) const; /** \brief Return the list of coercions for the given type. diff --git a/src/frontends/lean/frontend_elaborator.cpp b/src/frontends/lean/frontend_elaborator.cpp index db6bd41af..f2b9782c9 100644 --- a/src/frontends/lean/frontend_elaborator.cpp +++ b/src/frontends/lean/frontend_elaborator.cpp @@ -63,7 +63,7 @@ public: return r; } virtual void get_children(buffer &) const {} - virtual expr const & get_main_expr() const { return m_src; } + virtual optional get_main_expr() const { return some(m_src); } context const & get_context() const { return m_ctx; } }; @@ -82,7 +82,7 @@ public: return r; } virtual void get_children(buffer &) const {} - virtual expr const & get_main_expr() const { return m_app; } + virtual optional get_main_expr() const { return some(m_app); } context const & get_context() const { return m_ctx; } expr const & get_app() const { return m_app; } }; @@ -133,30 +133,29 @@ class frontend_elaborator::imp { /** \brief Return the type of \c e if possible. - Return null expression if it was not possible to infer the type of \c e. The idea is to use the type to catch the easy cases where we can solve overloads (aka choices) and coercions during preprocessing. */ - expr get_type(expr const & e, context const & ctx) { + optional get_type(expr const & e, context const & ctx) { try { - return m_ref.m_type_inferer(e, ctx); + return some(m_ref.m_type_inferer(e, ctx)); } catch (exception &) { - return expr(); + return optional(); } } /** - \brief Make sure f_t is a Pi, if it is not, then return the null expression. + \brief Make sure f_t is a Pi, if it is not, then return optional() */ - expr check_pi(expr const & f_t, context const & ctx) { - if (!f_t || is_pi(f_t)) { + optional check_pi(optional const & f_t, context const & ctx) { + if (!f_t || is_pi(*f_t)) { return f_t; } else { - expr r = m_ref.m_normalizer(f_t, ctx); + expr r = m_ref.m_normalizer(*f_t, ctx); if (is_pi(r)) - return r; + return optional(r); else - return expr(); + return optional(); } } @@ -174,20 +173,20 @@ class frontend_elaborator::imp { return mk_app(mvar, a); } - expr find_coercion(list const & l, expr const & to_type) { + optional find_coercion(list const & l, expr const & to_type) { for (auto p : l) { if (p.first == to_type) { - return p.second; + return optional(p.second); } } - return expr(); + return optional(); } /** \brief Try to solve overload at preprocessing time. */ - void choose(buffer & f_choices, buffer & f_choice_types, - buffer const & args, buffer const & arg_types, + void choose(buffer & f_choices, buffer> & f_choice_types, + buffer const & args, buffer> const & arg_types, context const & ctx) { unsigned best_num_coercions = std::numeric_limits::max(); unsigned num_choices = f_choices.size(); @@ -196,7 +195,7 @@ class frontend_elaborator::imp { buffer matched; for (unsigned j = 0; j < num_choices; j++) { - expr f_t = f_choice_types[j]; + optional f_t = f_choice_types[j]; unsigned num_coercions = 0; // number of coercions needed by current choice unsigned num_skipped_args = 0; unsigned i = 1; @@ -206,27 +205,28 @@ class frontend_elaborator::imp { // can't process this choice at preprocessing time delayed.push_back(j); break; - } - expr expected = abst_domain(f_t); - expr given = arg_types[i]; - if (!given) { - num_skipped_args++; } else { - if (!has_metavar(expected) && !has_metavar(given)) { - if (m_ref.m_type_checker.is_convertible(given, expected, ctx)) { - // compatible - } else if (m_ref.m_frontend.get_coercion(given, expected)) { - // compatible if using coercion - num_coercions++; - } else { - // failed, this choice does not work - break; - } - } else { + expr expected = abst_domain(*f_t); + optional given = arg_types[i]; + if (!given) { num_skipped_args++; + } else { + if (!has_metavar(expected) && !has_metavar(*given)) { + if (m_ref.m_type_checker.is_convertible(*given, expected, ctx)) { + // compatible + } else if (m_ref.m_frontend.get_coercion(*given, expected)) { + // compatible if using coercion + num_coercions++; + } else { + // failed, this choice does not work + break; + } + } else { + num_skipped_args++; + } } + f_t = some(::lean::instantiate(abst_body(*f_t), args[i])); } - f_t = ::lean::instantiate(abst_body(f_t), args[i]); } if (i == num_args) { if (num_skipped_args > 0) { @@ -250,7 +250,7 @@ class frontend_elaborator::imp { // We currently do nothing, and let the elaborator to sign the error } else { buffer to_keep; - buffer to_keep_types; + buffer> to_keep_types; for (unsigned i : matched) { to_keep.push_back(f_choices[i]); to_keep_types.push_back(f_choice_types[i]); @@ -275,27 +275,27 @@ class frontend_elaborator::imp { virtual expr visit_app(expr const & e, context const & ctx) { expr f = arg(e, 0); - expr f_t; - buffer args; - buffer arg_types; + optional f_t; + buffer args; + buffer> arg_types; args.push_back(expr()); // placeholder - arg_types.push_back(expr()); // placeholder + arg_types.push_back(optional()); // placeholder for (unsigned i = 1; i < num_args(e); i++) { expr a = arg(e, i); expr new_a = visit(a, ctx); - expr new_a_t = get_type(new_a, ctx); + optional new_a_t = get_type(new_a, ctx); args.push_back(new_a); arg_types.push_back(new_a_t); } if (is_choice(f)) { buffer f_choices; - buffer f_choice_types; + buffer> f_choice_types; unsigned num_alts = get_num_choices(f); for (unsigned i = 0; i < num_alts; i++) { expr c = get_choice(f, i); expr new_c = visit(c, ctx); - expr new_c_t = get_type(new_c, ctx); + optional new_c_t = get_type(new_c, ctx); f_choices.push_back(new_c); f_choice_types.push_back(new_c_t); } @@ -304,9 +304,9 @@ class frontend_elaborator::imp { args[0] = mk_overload_mvar(f_choices, ctx, e); for (unsigned i = 1; i < args.size(); i++) { if (arg_types[i]) { - list coercions = m_ref.m_frontend.get_coercions(arg_types[i]); + list coercions = m_ref.m_frontend.get_coercions(*(arg_types[i])); if (coercions) - args[i] = add_coercion_mvar_app(coercions, args[i], arg_types[i], ctx, arg(e, i)); + args[i] = add_coercion_mvar_app(coercions, args[i], *(arg_types[i]), ctx, arg(e, i)); } } return mk_app(args); @@ -326,20 +326,20 @@ class frontend_elaborator::imp { f_t = check_pi(f_t, ctx); expr a = arg(e, i); expr new_a = args[i]; - expr new_a_t = arg_types[i]; + optional new_a_t = arg_types[i]; if (new_a_t) { - list coercions = m_ref.m_frontend.get_coercions(new_a_t); + list coercions = m_ref.m_frontend.get_coercions(*new_a_t); if (coercions) { if (!f_t) { - new_a = add_coercion_mvar_app(coercions, new_a, new_a_t, ctx, a); + new_a = add_coercion_mvar_app(coercions, new_a, *new_a_t, ctx, a); } else { - expr expected = abst_domain(f_t); - if (expected != new_a_t) { - expr c = find_coercion(coercions, expected); + expr expected = abst_domain(*f_t); + if (expected != *new_a_t) { + optional c = find_coercion(coercions, expected); if (c) { - new_a = mk_app(c, new_a); // apply coercion + new_a = mk_app(*c, new_a); // apply coercion } else { - new_a = add_coercion_mvar_app(coercions, new_a, new_a_t, ctx, a); + new_a = add_coercion_mvar_app(coercions, new_a, *new_a_t, ctx, a); } } } @@ -347,34 +347,37 @@ class frontend_elaborator::imp { } new_args.push_back(new_a); if (f_t) - f_t = ::lean::instantiate(abst_body(f_t), new_a); + f_t = some(::lean::instantiate(abst_body(*f_t), new_a)); } return mk_app(new_args); } virtual expr visit_let(expr const & e, context const & ctx) { lean_assert(is_let(e)); - return update_let(e, [&](expr const & t, expr const & v, expr const & b) { - expr new_t = t ? visit(t, ctx) : expr(); + return update_let(e, [&](optional const & t, expr const & v, expr const & b) { + optional new_t = visit(t, ctx); expr new_v = visit(v, ctx); if (new_t) { - expr new_v_t = get_type(new_v, ctx); - if (new_v_t && new_t != new_v_t) { - list coercions = m_ref.m_frontend.get_coercions(new_v_t); + optional new_v_t = get_type(new_v, ctx); + if (new_v_t && *new_t != *new_v_t) { + list coercions = m_ref.m_frontend.get_coercions(*new_v_t); if (coercions) { - new_v = add_coercion_mvar_app(coercions, new_v, new_v_t, ctx, v); + new_v = add_coercion_mvar_app(coercions, new_v, *new_v_t, ctx, v); } } } - expr new_b; { cache::mk_scope sc(m_cache); - new_b = visit(b, extend(ctx, let_name(e), new_t, new_v)); + expr new_b = visit(b, extend(ctx, let_name(e), new_t, new_v)); + return std::make_tuple(new_t, new_v, new_b); } - return std::make_tuple(new_t, new_v, new_b); }); } + optional visit(optional const & e, context const & ctx) { + return replace_visitor::visit(e, ctx); + } + virtual expr visit(expr const & e, context const & ctx) { check_interrupted(); expr r = replace_visitor::visit(e, ctx); diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index fde2da8a2..42caf0de6 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -324,9 +324,9 @@ class parser::imp { void display_error_pos(pos_info const & p) { display_error_pos(p.first, p.second); } - void display_error_pos(expr const & e) { + void display_error_pos(optional const & e) { if (e) { - auto it = m_expr_pos_info.find(e); + auto it = m_expr_pos_info.find(*e); if (it == m_expr_pos_info.end()) return display_error_pos(m_last_cmd_pos); else @@ -346,7 +346,11 @@ class parser::imp { } void display_error(kernel_exception const & ex) { - display_error_pos(m_elaborator.get_original(ex.get_main_expr())); + optional main_expr = ex.get_main_expr(); + if (main_expr) + display_error_pos(some(m_elaborator.get_original(*main_expr))); + else + display_error_pos(main_expr); regular(m_frontend) << " " << ex << endl; } @@ -912,7 +916,7 @@ class parser::imp { void parse_simple_bindings(bindings_buffer & result, bool implicit_decl, bool suppress_type) { buffer> names; parse_names(names); - expr type; + optional type; if (!suppress_type) { check_colon_next("invalid binder, ':' expected"); type = parse_expr(); @@ -928,7 +932,7 @@ class parser::imp { --i; expr arg_type; if (type) - arg_type = lift_free_vars(type, i); + arg_type = lift_free_vars(*type, i); else arg_type = save(mk_placeholder(), names[i].first); result[sz + i] = std::make_tuple(names[i].first, names[i].second, arg_type, implicit_decl); @@ -1063,11 +1067,11 @@ class parser::imp { expr parse_let() { next(); mk_scope scope(*this); - buffer> bindings; + buffer, expr>> bindings; while (true) { auto p = pos(); name id = check_identifier_next("invalid let expression, identifier expected"); - expr type; + optional type; if (curr_is_colon()) { next(); type = parse_expr(); @@ -1181,7 +1185,7 @@ class parser::imp { expr t = parse_expr(); check_next(scanner::token::By, "invalid 'show _ by _' expression, 'by' expected"); tactic tac = parse_tactic_expr(); - expr r = mk_placeholder(t); + expr r = mk_placeholder(optional(t)); m_tactic_hints[r] = tac; return save(r, p); } @@ -1440,7 +1444,7 @@ class parser::imp { */ expr parse_tactic_cmds(proof_state s, context const & ctx, expr const & expected_type) { proof_state_seq_stack stack; - expr pr; + optional pr; enum class status { Continue, Done, Eof, Abort }; status st = status::Continue; while (st == status::Continue) { @@ -1497,7 +1501,7 @@ class parser::imp { }); } switch (st) { - case status::Done: return pr; + case status::Done: return *pr; case status::Eof: throw parser_error("invalid tactic command, unexpected end of file", pos()); case status::Abort: throw parser_error("failed to prove theorem, proof has been aborted", pos()); default: lean_unreachable(); // LCOV_EXCL_LINE @@ -1542,9 +1546,14 @@ class parser::imp { throw exception("failed to synthesize metavar, its type contains metavariables"); buffer new_entries; for (auto e : menv.get_context(mvar)) { - new_entries.emplace_back(e.get_name(), - instantiate_metavars(e.get_domain(), menv), - instantiate_metavars(e.get_body(), menv)); + optional d = e.get_domain(); + optional b = e.get_body(); + if (d) d = instantiate_metavars(*d, menv); + if (b) b = instantiate_metavars(*b, menv); + if (d) + new_entries.emplace_back(e.get_name(), *d, b); + else + new_entries.emplace_back(e.get_name(), d, *b); } context mvar_ctx(to_list(new_entries.begin(), new_entries.end())); if (!m_type_inferer.is_proposition(mvar_type, mvar_ctx)) @@ -1562,8 +1571,7 @@ class parser::imp { next(); } expr mvar_val = parse_tactic_cmds(s, mvar_ctx, mvar_type); - if (mvar_val) - menv.assign(mvar, mvar_val); + menv.assign(mvar, mvar_val); } } return instantiate_metavars(val, menv); diff --git a/src/frontends/lean/pp.cpp b/src/frontends/lean/pp.cpp index de71bd16d..d215c87f2 100644 --- a/src/frontends/lean/pp.cpp +++ b/src/frontends/lean/pp.cpp @@ -725,14 +725,14 @@ class pp_fn { \remark The argument B is only relevant when processing condensed definitions. \see pp_abstraction_core. */ - std::pair collect_nested(expr const & e, expr T, expr_kind k, buffer> & r) { - if (e.kind() == k && (!T || is_abstraction(T))) { + std::pair> collect_nested(expr const & e, optional T, expr_kind k, buffer> & r) { + if (e.kind() == k && (!T || is_abstraction(*T))) { name n1 = get_unused_name(e); m_local_names.insert(n1); r.emplace_back(n1, abst_domain(e)); expr b = replace_var_with_name(abst_body(e), n1); if (T) - T = replace_var_with_name(abst_body(T), n1); + T = some(replace_var_with_name(abst_body(*T), n1)); return collect_nested(b, T, k, r); } else { return mk_pair(e, T); @@ -849,7 +849,7 @@ class pp_fn { \remark if T != 0, then T is Pi(x : A), B */ - result pp_abstraction_core(expr const & e, unsigned depth, expr T, std::vector const * implicit_args = nullptr) { + result pp_abstraction_core(expr const & e, unsigned depth, optional T, std::vector const * implicit_args = nullptr) { local_names::mk_scope mk(m_local_names); if (is_arrow(e) && !implicit_args) { lean_assert(!T); @@ -876,7 +876,7 @@ class pp_fn { } format body_sep; if (T) { - format T_f = pp(T, 0).first; + format T_f = pp(*T, 0).first; body_sep = format{space(), colon(), space(), T_f, space(), g_assign_fmt}; } else if (implicit_args) { // This is a little hack to pretty print Variable and @@ -957,10 +957,10 @@ class pp_fn { } result pp_abstraction(expr const & e, unsigned depth) { - return pp_abstraction_core(e, depth, expr()); + return pp_abstraction_core(e, depth, optional()); } - expr collect_nested_let(expr const & e, buffer> & bindings) { + expr collect_nested_let(expr const & e, buffer, expr>> & bindings) { if (is_let(e)) { name n1 = get_unused_name(e); m_local_names.insert(n1); @@ -974,7 +974,7 @@ class pp_fn { result pp_let(expr const & e, unsigned depth) { local_names::mk_scope mk(m_local_names); - buffer> bindings; + buffer, expr>> bindings; expr body = collect_nested_let(e, bindings); unsigned r_weight = 2; format r_format = g_let_fmt; @@ -985,9 +985,9 @@ class pp_fn { format beg = i == 0 ? space() : line(); format sep = i < sz - 1 ? comma() : format(); result p_def = pp_scoped_child(std::get<2>(b), depth+1); - expr type = std::get<1>(b); + optional const & type = std::get<1>(b); if (type) { - result p_type = pp_scoped_child(type, depth+1); + result p_type = pp_scoped_child(*type, depth+1); r_format += nest(3 + 1, compose(beg, group(format{format(n), space(), colon(), nest(n.size() + 1 + 1 + 1, compose(space(), p_type.first)), space(), g_assign_fmt, nest(m_indent, format{line(), p_def.first, sep})}))); @@ -1166,13 +1166,12 @@ public: format pp_definition(expr const & v, expr const & t, std::vector const * implicit_args) { init(mk_app(v, t)); - expr T(t); - return pp_abstraction_core(v, 0, T, implicit_args).first; + return pp_abstraction_core(v, 0, optional(t), implicit_args).first; } format pp_pi_with_implicit_args(expr const & e, std::vector const & implicit_args) { init(e); - return pp_abstraction_core(e, 0, expr(), &implicit_args).first; + return pp_abstraction_core(e, 0, optional(), &implicit_args).first; } void register_local(name const & n) { @@ -1198,10 +1197,13 @@ class pp_formatter_cell : public formatter_cell { check_interrupted(); name n1 = get_unused_name(c2); fn.register_local(n1); - format entry = format{format(n1), space(), colon(), space(), fn(fake_context_domain(c2))}; - expr val = fake_context_value(c2); + format entry = format(n1); + optional domain = fake_context_domain(c2); + optional val = fake_context_value(c2); + if (domain) + entry += format{space(), colon(), space(), fn(*domain)}; if (val) - entry += format{space(), g_assign_fmt, nest(indent, format{line(), fn(val)})}; + entry += format{space(), g_assign_fmt, nest(indent, format{line(), fn(*val)})}; if (first) { r = group(entry); first = false; diff --git a/src/kernel/abstract.h b/src/kernel/abstract.h index b04b5b902..587ff0ad1 100644 --- a/src/kernel/abstract.h +++ b/src/kernel/abstract.h @@ -48,8 +48,8 @@ inline expr Pi(std::pair const & p, expr const & b) /** \brief Create a Let expression (Let x := v in b), the term b is abstracted using abstract(b, x). */ -inline expr Let(name const & x, expr const & v, expr const & b) { return mk_let(x, expr(), v, abstract(b, mk_constant(x))); } -inline expr Let(expr const & x, expr const & v, expr const & b) { return mk_let(const_name(x), expr(), v, abstract(b, x)); } +inline expr Let(name const & x, expr const & v, expr const & b) { return mk_let(x, v, abstract(b, mk_constant(x))); } +inline expr Let(expr const & x, expr const & v, expr const & b) { return mk_let(const_name(x), v, abstract(b, x)); } inline expr Let(std::pair const & p, expr const & b) { return Let(p.first, p.second, b); } expr Let(std::initializer_list> const & l, expr const & b); /** diff --git a/src/kernel/builtin.cpp b/src/kernel/builtin.cpp index 2fc764160..39144c71d 100644 --- a/src/kernel/builtin.cpp +++ b/src/kernel/builtin.cpp @@ -126,27 +126,26 @@ static format g_if_fmt(g_if_name); */ class if_fn_value : public value { expr m_type; -public: - if_fn_value() { + static expr mk_type() { expr A = Const("A"); // Pi (A: Type), bool -> A -> A -> A - m_type = Pi({A, TypeU}, Bool >> (A >> (A >> A))); + return Pi({A, TypeU}, Bool >> (A >> (A >> A))); } +public: + if_fn_value():m_type(mk_type()) {} virtual ~if_fn_value() {} virtual expr get_type() const { return m_type; } virtual name get_name() const { return g_if_name; } - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 5 && is_bool_value(args[2])) { if (to_bool(args[2])) - r = args[3]; // if A true a b --> a + return some(args[3]); // if A true a b --> a else - r = args[4]; // if A false a b --> b - return true; + return some(args[4]); // if A false a b --> b } else if (num_args == 5 && args[3] == args[4]) { - r = args[3]; // if A c a a --> a - return true; + return some(args[3]); // if A c a a --> a } else { - return false; + return optional(); } } }; diff --git a/src/kernel/context.h b/src/kernel/context.h index a3e14b82a..803efdf49 100644 --- a/src/kernel/context.h +++ b/src/kernel/context.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include #include "util/list.h" +#include "util/optional.h" #include "kernel/expr.h" namespace lean { @@ -15,15 +16,17 @@ namespace lean { \see context */ class context_entry { - name m_name; - expr m_domain; - expr m_body; + name m_name; + optional m_domain; + optional m_body; public: + context_entry(name const & n, optional const & d, expr const & b):m_name(n), m_domain(d), m_body(b) {} + context_entry(name const & n, expr const & d, optional const & b):m_name(n), m_domain(d), m_body(b) {} context_entry(name const & n, expr const & d, expr const & b):m_name(n), m_domain(d), m_body(b) {} context_entry(name const & n, expr const & d):m_name(n), m_domain(d) {} name const & get_name() const { return m_name; } - expr const & get_domain() const { return m_domain; } - expr const & get_body() const { return m_body; } + optional const & get_domain() const { return m_domain; } + optional const & get_body() const { return m_body; } }; /** @@ -33,8 +36,11 @@ class context { list m_list; public: context() {} - context(context const & c, name const & n, expr const & d):m_list(context_entry(n, d), c.m_list) {} + context(context const & c, name const & n, optional const & d, expr const & b):m_list(context_entry(n, d, b), c.m_list) {} + context(context const & c, name const & n, expr const & d, optional const & b):m_list(context_entry(n, d, b), c.m_list) {} context(context const & c, name const & n, expr const & d, expr const & b):m_list(context_entry(n, d, b), c.m_list) {} + context(context const & c, name const & n, expr const & d):m_list(context_entry(n, d), c.m_list) {} + context(context const & c, context_entry const & e):m_list(e, c.m_list) {} explicit context(list const & l):m_list(l) {} context_entry const & lookup(unsigned vidx) const; std::pair lookup_ext(unsigned vidx) const; @@ -69,6 +75,8 @@ inline std::pair lookup_ext(context const & c, u Bruijn index \c i. */ inline context_entry const & lookup(context const & c, unsigned i) { return c.lookup(i); } +inline context extend(context const & c, name const & n, optional const & d, expr const & b) { return context(c, n, d, b); } +inline context extend(context const & c, name const & n, expr const & d, optional const & b) { return context(c, n, d, b); } inline context extend(context const & c, name const & n, expr const & d, expr const & b) { return context(c, n, d, b); } inline context extend(context const & c, name const & n, expr const & d) { return context(c, n, d); } inline bool empty(context const & c) { return c.empty(); } diff --git a/src/kernel/environment.cpp b/src/kernel/environment.cpp index 58c4b3d8b..43293cb33 100644 --- a/src/kernel/environment.cpp +++ b/src/kernel/environment.cpp @@ -159,10 +159,11 @@ struct environment::imp { object const & get_object(name const & n, environment const & env) const { object const & obj = get_object_core(n); - if (obj) + if (obj) { return obj; - else + } else { throw unknown_object_exception(env, n); + } } /** diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index ceb92a2cf..7f0c7292d 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -15,6 +15,9 @@ Author: Leonardo de Moura #include "kernel/metavar.h" namespace lean { +static expr g_dummy(mk_var(0)); +expr::expr():expr(g_dummy) {} + local_entry::local_entry(unsigned s, unsigned n):m_kind(local_entry_kind::Lift), m_s(s), m_n(n) {} local_entry::local_entry(unsigned s, expr const & v):m_kind(local_entry_kind::Inst), m_s(s), m_v(v) {} local_entry::~local_entry() {} @@ -31,13 +34,6 @@ unsigned hash_args(unsigned size, expr const * args) { return hash(size, [&args](unsigned i){ return args[i].hash(); }); } -static expr g_null; - -expr const & expr::null() { - lean_assert(!g_null); - return g_null; -} - expr_cell::expr_cell(expr_kind k, unsigned h, bool has_mv): m_kind(static_cast(k)), m_flags(has_mv ? 4 : 0), @@ -55,20 +51,25 @@ expr_cell::expr_cell(expr_kind k, unsigned h, bool has_mv): } void expr_cell::dec_ref(expr & e, buffer & todelete) { - if (e) { + if (e.m_ptr) { expr_cell * c = e.steal_ptr(); - lean_assert(!e); + lean_assert(!(e.m_ptr)); if (c->dec_ref_core()) todelete.push_back(c); } } +void expr_cell::dec_ref(optional & c, buffer & todelete) { + if (c) + dec_ref(*c, todelete); +} + expr_var::expr_var(unsigned idx): expr_cell(expr_kind::Var, idx, false), m_vidx(idx) {} -expr_const::expr_const(name const & n, expr const & t): - expr_cell(expr_kind::Constant, n.hash(), t && t.has_metavar()), +expr_const::expr_const(name const & n, optional const & t): + expr_cell(expr_kind::Constant, n.hash(), t && t->has_metavar()), m_name(n), m_type(t) {} void expr_const::dealloc(buffer & todelete) { @@ -86,7 +87,6 @@ void expr_app::dealloc(buffer & todelete) { while (i > 0) { --i; dec_ref(m_args[i], todelete); - lean_assert(!m_args[i]); } delete[] reinterpret_cast(this); } @@ -140,8 +140,6 @@ expr_abstraction::expr_abstraction(expr_kind k, name const & n, expr const & t, void expr_abstraction::dealloc(buffer & todelete) { dec_ref(m_body, todelete); dec_ref(m_domain, todelete); - lean_assert(!m_body); - lean_assert(!m_domain); delete(this); } expr_lambda::expr_lambda(name const & n, expr const & t, expr const & e):expr_abstraction(expr_kind::Lambda, n, t, e) {} @@ -151,8 +149,8 @@ expr_type::expr_type(level const & l): m_level(l) { } expr_type::~expr_type() {} -expr_let::expr_let(name const & n, expr const & t, expr const & v, expr const & b): - expr_cell(expr_kind::Let, ::lean::hash(v.hash(), b.hash()), v.has_metavar() || b.has_metavar() || (t && t.has_metavar())), +expr_let::expr_let(name const & n, optional const & t, expr const & v, expr const & b): + expr_cell(expr_kind::Let, ::lean::hash(v.hash(), b.hash()), v.has_metavar() || b.has_metavar() || (t && t->has_metavar())), m_name(n), m_type(t), m_value(v), @@ -166,7 +164,7 @@ void expr_let::dealloc(buffer & todelete) { } expr_let::~expr_let() {} name value::get_unicode_name() const { return get_name(); } -bool value::normalize(unsigned, expr const *, expr &) const { return false; } +optional value::normalize(unsigned, expr const *) const { return optional(); } void value::display(std::ostream & out) const { out << get_name(); } bool value::operator==(value const & other) const { return typeid(*this) == typeid(other); } bool value::operator<(value const & other) const { diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 7fe13d27a..828bc0bca 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -17,6 +17,7 @@ Author: Leonardo de Moura #include "util/hash.h" #include "util/buffer.h" #include "util/list_fn.h" +#include "util/optional.h" #include "util/sexpr/format.h" #include "kernel/level.h" @@ -85,6 +86,7 @@ protected: void set_closed() { m_flags |= 2; } friend class has_free_var_fn; static void dec_ref(expr & c, buffer & todelete); + static void dec_ref(optional & c, buffer & todelete); public: expr_cell(expr_kind k, unsigned h, bool has_mv); expr_kind kind() const { return static_cast(m_kind); } @@ -102,13 +104,11 @@ private: friend class expr_cell; expr_cell * steal_ptr() { expr_cell * r = m_ptr; m_ptr = nullptr; return r; } public: - expr():m_ptr(nullptr) {} + expr(); expr(expr const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } expr(expr && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; } ~expr() { if (m_ptr) m_ptr->dec_ref(); } - static expr const & null(); - friend void swap(expr & a, expr & b) { std::swap(a.m_ptr, b.m_ptr); } void release() { if (m_ptr) m_ptr->dec_ref(); m_ptr = nullptr; } @@ -123,20 +123,21 @@ public: expr_cell * raw() const { return m_ptr; } - explicit operator bool() const { return m_ptr != nullptr; } - friend expr mk_var(unsigned idx); - friend expr mk_constant(name const & n, expr const & t); + friend expr mk_constant(name const & n, optional const & t); friend expr mk_value(value & v); friend expr mk_app(unsigned num_args, expr const * args); friend expr mk_eq(expr const & l, expr const & r); friend expr mk_lambda(name const & n, expr const & t, expr const & e); friend expr mk_pi(name const & n, expr const & t, expr const & e); friend expr mk_type(level const & l); - friend expr mk_let(name const & n, expr const & t, expr const & v, expr const & e); + friend expr mk_let(name const & n, optional const & t, expr const & v, expr const & e); friend expr mk_metavar(name const & n, local_context const & ctx); friend bool is_eqp(expr const & a, expr const & b) { return a.m_ptr == b.m_ptr; } + friend bool is_eqp(optional const & a, optional const & b) { + return static_cast(a) == static_cast(b) && (!a || is_eqp(*a, *b)); + } // Overloaded operator() can be used to create applications expr operator()(expr const & a1) const; @@ -160,16 +161,16 @@ public: }; /** \brief Constants. */ class expr_const : public expr_cell { - name m_name; - expr m_type; // (optional) cached type + name m_name; + optional m_type; // Remark: we do *not* perform destructive updates on m_type // This field is used to efficiently implement the tactic framework friend class expr_cell; void dealloc(buffer & to_delete); public: - expr_const(name const & n, expr const & type); + expr_const(name const & n, optional const & type); name const & get_name() const { return m_name; } - expr const & get_type() const { return m_type; } + optional const & get_type() const { return m_type; } }; /** \brief Function Applications */ class expr_app : public expr_cell { @@ -223,17 +224,17 @@ public: }; /** \brief Let expressions */ class expr_let : public expr_cell { - name m_name; - expr m_type; - expr m_value; - expr m_body; + name m_name; + optional m_type; + expr m_value; + expr m_body; friend class expr_cell; void dealloc(buffer & todelete); public: - expr_let(name const & n, expr const & t, expr const & v, expr const & b); + expr_let(name const & n, optional const & t, expr const & v, expr const & b); ~expr_let(); name const & get_name() const { return m_name; } - expr const & get_type() const { return m_type; } + optional const & get_type() const { return m_type; } expr const & get_value() const { return m_value; } expr const & get_body() const { return m_body; } }; @@ -262,7 +263,7 @@ public: virtual expr get_type() const = 0; virtual name get_name() const = 0; virtual name get_unicode_name() const; - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const; + virtual optional normalize(unsigned num_args, expr const * args) const; virtual bool operator==(value const & other) const; bool operator<(value const & other) const; virtual void display(std::ostream & out) const; @@ -320,7 +321,7 @@ class local_entry { local_entry_kind m_kind; unsigned m_s; unsigned m_n; - expr m_v; + optional m_v; local_entry(unsigned s, unsigned n); local_entry(unsigned s, expr const & v); public: @@ -334,7 +335,7 @@ public: unsigned n() const { lean_assert(is_lift()); return m_n; } bool operator==(local_entry const & e) const; bool operator!=(local_entry const & e) const { return !operator==(e); } - expr const & v() const { lean_assert(is_inst()); return m_v; } + expr const & v() const { lean_assert(is_inst()); return *m_v; } }; inline local_entry mk_lift(unsigned s, unsigned n) { return local_entry(s, n); } inline local_entry mk_inst(unsigned s, expr const & v) { return local_entry(s, v); } @@ -383,7 +384,9 @@ inline bool is_abstraction(expr const & e) { return is_lambda(e) || is_pi(e); } // Constructors inline expr mk_var(unsigned idx) { return expr(new expr_var(idx)); } inline expr Var(unsigned idx) { return mk_var(idx); } -inline expr mk_constant(name const & n, expr const & t = expr()) { return expr(new expr_const(n, t)); } +inline expr mk_constant(name const & n, optional const & t) { return expr(new expr_const(n, t)); } +inline expr mk_constant(name const & n, expr const & t) { return mk_constant(n, optional(t)); } +inline expr mk_constant(name const & n) { return mk_constant(n, optional()); } inline expr Const(name const & n) { return mk_constant(n); } inline expr mk_value(value & v) { return expr(new expr_value(v)); } inline expr to_expr(value & v) { return mk_value(v); } @@ -400,7 +403,9 @@ inline expr mk_lambda(name const & n, expr const & t, expr const & e) { return e inline expr mk_pi(name const & n, expr const & t, expr const & e) { return expr(new expr_pi(n, t, e)); } inline expr mk_arrow(expr const & t, expr const & e) { return mk_pi(name("_"), t, e); } inline expr operator>>(expr const & t, expr const & e) { return mk_arrow(t, e); } -inline expr mk_let(name const & n, expr const & t, expr const & v, expr const & e) { return expr(new expr_let(n, t, v, e)); } +inline expr mk_let(name const & n, optional const & t, expr const & v, expr const & e) { return expr(new expr_let(n, t, v, e)); } +inline expr mk_let(name const & n, expr const & t, expr const & v, expr const & e) { return mk_let(n, optional(t), v, e); } +inline expr mk_let(name const & n, expr const & v, expr const & e) { return mk_let(n, optional(), v, e); } inline expr mk_type(level const & l) { return expr(new expr_type(l)); } expr mk_type(); inline expr Type(level const & l) { return mk_type(l); } @@ -453,7 +458,7 @@ inline unsigned var_idx(expr_cell * e) { return to_var(e)->get inline bool is_var(expr_cell * e, unsigned i) { return is_var(e) && var_idx(e) == i; } inline name const & const_name(expr_cell * e) { return to_constant(e)->get_name(); } // Remark: the following function should not be exposed in the internal API. -inline expr const & const_type(expr_cell * e) { return to_constant(e)->get_type(); } +inline optional const & const_type(expr_cell * e) { return to_constant(e)->get_type(); } inline value const & to_value(expr_cell * e) { lean_assert(is_value(e)); return static_cast(e)->get_value(); } inline unsigned num_args(expr_cell * e) { return to_app(e)->get_num_args(); } inline expr const & arg(expr_cell * e, unsigned idx) { return to_app(e)->get_arg(idx); } @@ -465,7 +470,7 @@ inline expr const & abst_body(expr_cell * e) { return to_abstraction inline level const & ty_level(expr_cell * e) { return to_type(e)->get_level(); } inline name const & let_name(expr_cell * e) { return to_let(e)->get_name(); } inline expr const & let_value(expr_cell * e) { return to_let(e)->get_value(); } -inline expr const & let_type(expr_cell * e) { return to_let(e)->get_type(); } +inline optional const & let_type(expr_cell * e) { return to_let(e)->get_type(); } inline expr const & let_body(expr_cell * e) { return to_let(e)->get_body(); } inline name const & metavar_name(expr_cell * e) { return to_metavar(e)->get_name(); } inline local_context const & metavar_lctx(expr_cell * e) { return to_metavar(e)->get_lctx(); } @@ -480,7 +485,7 @@ inline unsigned var_idx(expr const & e) { return to_var(e)->ge inline bool is_var(expr const & e, unsigned i) { return is_var(e) && var_idx(e) == i; } inline name const & const_name(expr const & e) { return to_constant(e)->get_name(); } // Remark: the following function should not be exposed in the internal API. -inline expr const & const_type(expr const & e) { return to_constant(e)->get_type(); } +inline optional const & const_type(expr const & e) { return to_constant(e)->get_type(); } /** \brief Return true iff the given expression is a constant with name \c n. */ inline bool is_constant(expr const & e, name const & n) { return is_constant(e) && const_name(e) == n; @@ -497,7 +502,7 @@ inline expr const & abst_domain(expr const & e) { return to_abstractio inline expr const & abst_body(expr const & e) { return to_abstraction(e)->get_body(); } inline level const & ty_level(expr const & e) { return to_type(e)->get_level(); } inline name const & let_name(expr const & e) { return to_let(e)->get_name(); } -inline expr const & let_type(expr const & e) { return to_let(e)->get_type(); } +inline optional const & let_type(expr const & e) { return to_let(e)->get_type(); } inline expr const & let_value(expr const & e) { return to_let(e)->get_value(); } inline expr const & let_body(expr const & e) { return to_let(e)->get_body(); } inline name const & metavar_name(expr const & e) { return to_metavar(e)->get_name(); } @@ -604,15 +609,15 @@ template expr update_abst(expr const & e, F f) { } } template expr update_let(expr const & e, F f) { - static_assert(std::is_same::type, - std::tuple>::value, - "update_let: return type of f is not pair"); - expr const & old_t = let_type(e); - expr const & old_v = let_value(e); - expr const & old_b = let_body(e); - std::tuple t = f(old_t, old_v, old_b); - if (!is_eqp(std::get<0>(t), old_t) || !is_eqp(std::get<1>(t), old_v) || !is_eqp(std::get<2>(t), old_b)) - return mk_let(let_name(e), std::get<0>(t), std::get<1>(t), std::get<2>(t)); + static_assert(std::is_same const &, expr const &, expr const &)>::type, + std::tuple, expr, expr>>::value, + "update_let: return type of f is not tuple, expr, expr>"); + optional const & old_t = let_type(e); + expr const & old_v = let_value(e); + expr const & old_b = let_body(e); + std::tuple, expr, expr> r = f(old_t, old_v, old_b); + if (!is_eqp(std::get<0>(r), old_t) || !is_eqp(std::get<1>(r), old_v) || !is_eqp(std::get<2>(r), old_b)) + return mk_let(let_name(e), std::get<0>(r), std::get<1>(r), std::get<2>(r)); else return e; } @@ -661,7 +666,7 @@ inline expr update_metavar(expr const & e, local_context const & lctx) { else return e; } -inline expr update_const(expr const & e, expr const & t) { +inline expr update_const(expr const & e, optional const & t) { if (!is_eqp(const_type(e), t)) return mk_constant(const_name(e), t); else diff --git a/src/kernel/expr_eq.h b/src/kernel/expr_eq.h index d036b5c9c..32a3870b1 100644 --- a/src/kernel/expr_eq.h +++ b/src/kernel/expr_eq.h @@ -29,10 +29,18 @@ class expr_eq_fn { std::unique_ptr m_eq_visited; N m_norm; + bool apply(optional const & a0, optional const & b0) { + if (is_eqp(a0, b0)) + return true; + else if (!a0 || !b0) + return false; + else + return apply(*a0, *b0); + } + bool apply(expr const & a0, expr const & b0) { check_system("expression equality test"); if (is_eqp(a0, b0)) return true; - if (!a0 || !b0) return false; if (UseHash && a0.hash() != b0.hash()) return false; expr const & a = m_norm(a0); expr const & b = m_norm(b0); diff --git a/src/kernel/find_fn.h b/src/kernel/find_fn.h index 37302a115..6071d9bf7 100644 --- a/src/kernel/find_fn.h +++ b/src/kernel/find_fn.h @@ -12,9 +12,9 @@ namespace lean { template class find_fn { struct pred_fn { - expr & m_result; + optional & m_result; P m_p; - pred_fn(expr & result, P const & p):m_result(result), m_p(p) {} + pred_fn(optional & result, P const & p):m_result(result), m_p(p) {} bool operator()(expr const & e, unsigned) { if (m_result) { return false; // already found result, stop the search @@ -26,19 +26,18 @@ class find_fn { } } }; - expr m_result; + optional m_result; for_each_fn m_proc; public: find_fn(P const & p):m_proc(pred_fn(m_result, p)) {} - expr operator()(expr const & e) { m_proc(e); return m_result; } + optional operator()(expr const & e) { m_proc(e); return m_result; } }; /** \brief Return a subexpression of \c e that satisfies the predicate \c p. - If there is none, then return the null expression. */ template -expr find(expr const & e, P p) { +optional find(expr const & e, P p) { return find_fn

(p)(e); } @@ -46,22 +45,26 @@ expr find(expr const & e, P p) { \brief Return an expression \c e that satisfies \c p and occurs in \c c or \c es. */ template -expr find(context const * c, unsigned sz, expr const * es, P p) { +optional find(context const * c, unsigned sz, expr const * es, P p) { find_fn

finder(p); if (c) { for (auto const & e : *c) { - if (expr r = finder(e.get_domain())) { - return r; - } else if (e.get_body()) { - if (expr r = finder(e.get_body())) + auto const & d = e.get_domain(); + if (d) { + if (optional r = finder(*d)) + return r; + } + auto const & b = e.get_body(); + if (b) { + if (optional r = finder(*b)) return r; } } } for (unsigned i = 0; i < sz; i++) { - if (expr r = finder(es[i])) + if (optional r = finder(es[i])) return r; } - return expr(); + return optional(); } } diff --git a/src/kernel/for_each_fn.h b/src/kernel/for_each_fn.h index 8a6b02efd..d584a9456 100644 --- a/src/kernel/for_each_fn.h +++ b/src/kernel/for_each_fn.h @@ -71,7 +71,7 @@ class for_each_fn { switch (e.kind()) { case expr_kind::Constant: if (const_type(e)) - todo.emplace_back(const_type(e), offset); + todo.emplace_back(*const_type(e), offset); goto begin_loop; case expr_kind::Type: case expr_kind::Value: case expr_kind::Var: case expr_kind::MetaVar: @@ -98,7 +98,7 @@ class for_each_fn { todo.emplace_back(let_body(e), offset + 1); todo.emplace_back(let_value(e), offset); if (let_type(e)) - todo.emplace_back(let_type(e), offset); + todo.emplace_back(*let_type(e), offset); goto begin_loop; } } diff --git a/src/kernel/free_vars.cpp b/src/kernel/free_vars.cpp index 3738b780d..5ee3e87a0 100644 --- a/src/kernel/free_vars.cpp +++ b/src/kernel/free_vars.cpp @@ -21,6 +21,10 @@ class has_free_vars_fn { protected: expr_cell_offset_set m_cached; + bool apply(optional const & e, unsigned offset) { + return e && apply(*e, offset); + } + bool apply(expr const & e, unsigned offset) { // handle easy cases switch (e.kind()) { @@ -80,7 +84,7 @@ protected: result = apply(abst_domain(e), offset) || apply(abst_body(e), offset + 1); break; case expr_kind::Let: - result = (let_type(e) && apply(let_type(e), offset)) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); + result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); break; } @@ -99,7 +103,7 @@ public: }; bool has_free_vars(expr const & e) { - return e && has_free_vars_fn()(e); + return has_free_vars_fn()(e); } /** @@ -114,6 +118,10 @@ protected: unsigned m_high; expr_cell_offset_set m_cached; + bool apply(optional const & e, unsigned offset) { + return e && apply(*e, offset); + } + bool apply(expr const & e, unsigned offset) { // handle easy cases switch (e.kind()) { @@ -163,7 +171,7 @@ protected: result = apply(abst_domain(e), offset) || apply(abst_body(e), offset + 1); break; case expr_kind::Let: - result = (let_type(e) && apply(let_type(e), offset)) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); + result = apply(let_type(e), offset) || apply(let_value(e), offset) || apply(let_body(e), offset + 1); break; } diff --git a/src/kernel/justification.cpp b/src/kernel/justification.cpp index e64dafe4b..3fe51a28f 100644 --- a/src/kernel/justification.cpp +++ b/src/kernel/justification.cpp @@ -9,10 +9,10 @@ Author: Leonardo de Moura #include "kernel/justification.h" namespace lean { -void justification_cell::add_pos_info(format & r, expr const & e, pos_info_provider const * p) { +void justification_cell::add_pos_info(format & r, optional const & e, pos_info_provider const * p) { if (!p || !e) return; - format f = p->pp(e); + format f = p->pp(*e); if (!f) return; r += f; @@ -42,7 +42,7 @@ bool justification::has_children() const { assumption_justification::assumption_justification(unsigned idx):m_idx(idx) {} void assumption_justification::get_children(buffer &) const {} -expr const & assumption_justification::get_main_expr() const { return expr::null(); } +optional assumption_justification::get_main_expr() const { return optional(); } format assumption_justification::pp_header(formatter const &, options const &) const { return format{format("Assumption"), space(), format(m_idx)}; } diff --git a/src/kernel/justification.h b/src/kernel/justification.h index 3bd8b9764..447281099 100644 --- a/src/kernel/justification.h +++ b/src/kernel/justification.h @@ -26,14 +26,14 @@ class justification_cell { MK_LEAN_RC(); void dealloc() { delete this; } protected: - static void add_pos_info(format & r, expr const & e, pos_info_provider const * p); + static void add_pos_info(format & r, optional const & e, pos_info_provider const * p); public: justification_cell():m_rc(0) {} virtual ~justification_cell() {} virtual format pp_header(formatter const & fmt, options const & opts) const = 0; virtual format pp(formatter const & fmt, options const & opts, pos_info_provider const * p, bool display_children) const; virtual void get_children(buffer & r) const = 0; - virtual expr const & get_main_expr() const { return expr::null(); } + virtual optional get_main_expr() const { return optional(); } bool is_shared() const { return get_rc() > 1; } }; @@ -45,7 +45,7 @@ class assumption_justification : public justification_cell { public: assumption_justification(unsigned idx); virtual void get_children(buffer &) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; virtual format pp_header(formatter const &, options const &) const; }; @@ -73,7 +73,7 @@ public: lean_assert(m_ptr); return m_ptr->pp(fmt, opts, p, display_children); } - expr const & get_main_expr() const { return m_ptr ? m_ptr->get_main_expr() : expr::null(); } + optional get_main_expr() const { return m_ptr ? m_ptr->get_main_expr() : optional(); } void get_children(buffer & r) const { if (m_ptr) m_ptr->get_children(r); } bool has_children() const; }; diff --git a/src/kernel/kernel_exception.h b/src/kernel/kernel_exception.h index 7bb0b5349..3e89845ed 100644 --- a/src/kernel/kernel_exception.h +++ b/src/kernel/kernel_exception.h @@ -24,11 +24,10 @@ public: virtual ~kernel_exception() noexcept {} environment const & get_environment() const { return m_env; } /** - \brief Return a reference to the main expression associated with this exception. - Return the null expression if there is none. This information is used to provide - better error messages. + \brief Return a reference (if available) to the main expression associated with this exception. + This information is used to provide better error messages. */ - virtual expr const & get_main_expr() const { return expr::null(); } + virtual optional get_main_expr() const { return optional(); } virtual format pp(formatter const & fmt, options const & opts) const; virtual exception * clone() const { return new kernel_exception(m_env, m_msg.c_str()); } virtual void rethrow() const { throw *this; } @@ -107,7 +106,7 @@ class has_no_type_exception : public type_checker_exception { public: has_no_type_exception(environment const & env, expr const & c):type_checker_exception(env), m_const(c) {} virtual ~has_no_type_exception() {} - virtual expr const & get_main_expr() const { return m_const; } + virtual optional get_main_expr() const { return some(m_const); } virtual char const * what() const noexcept { return "object does not have a type associated with it"; } virtual format pp(formatter const & fmt, options const & opts) const; virtual exception * clone() const { return new has_no_type_exception(m_env, m_const); } @@ -134,7 +133,7 @@ public: virtual ~app_type_mismatch_exception() {} context const & get_context() const { return m_context; } expr const & get_application() const { return m_app; } - virtual expr const & get_main_expr() const { return get_application(); } + virtual optional get_main_expr() const { return some(get_application()); } std::vector const & get_arg_types() const { return m_arg_types; } virtual char const * what() const noexcept { return "application argument type mismatch"; } virtual format pp(formatter const & fmt, options const & opts) const; @@ -159,7 +158,7 @@ public: virtual ~function_expected_exception() {} context const & get_context() const { return m_context; } expr const & get_expr() const { return m_expr; } - virtual expr const & get_main_expr() const { return get_expr(); } + virtual optional get_main_expr() const { return some(get_expr()); } virtual char const * what() const noexcept { return "function expected"; } virtual format pp(formatter const & fmt, options const & opts) const; virtual exception * clone() const { return new function_expected_exception(m_env, m_context, m_expr); } @@ -183,7 +182,7 @@ public: virtual ~type_expected_exception() {} context const & get_context() const { return m_context; } expr const & get_expr() const { return m_expr; } - virtual expr const & get_main_expr() const { return get_expr(); } + virtual optional get_main_expr() const { return some(get_expr()); } virtual char const * what() const noexcept { return "type expected"; } virtual format pp(formatter const & fmt, options const & opts) const; virtual exception * clone() const { return new type_expected_exception(m_env, m_context, m_expr); } @@ -219,7 +218,7 @@ public: expr const & get_type() const { return m_type; } expr const & get_value() const { return m_value; } expr const & get_value_type() const { return m_value_type; } - virtual expr const & get_main_expr() const { return m_value; } + virtual optional get_main_expr() const { return some(m_value); } virtual char const * what() const noexcept { return "definition type mismatch"; } virtual format pp(formatter const & fmt, options const & opts) const; virtual exception * clone() const { return new def_type_mismatch_exception(m_env, m_context, m_name, m_type, m_value, m_value_type); } @@ -235,7 +234,7 @@ public: invalid_builtin_value_declaration(environment const & env, expr const & e):kernel_exception(env), m_expr(e) {} virtual ~invalid_builtin_value_declaration() {} virtual char const * what() const noexcept { return "invalid builtin value declaration, expression is not a builtin value"; } - virtual expr const & get_main_expr() const { return m_expr; } + virtual optional get_main_expr() const { return some(m_expr); } virtual exception * clone() const { return new invalid_builtin_value_declaration(m_env, m_expr); } virtual void rethrow() const { throw *this; } }; @@ -250,7 +249,7 @@ public: invalid_builtin_value_reference(environment const & env, expr const & e):kernel_exception(env), m_expr(e) {} virtual ~invalid_builtin_value_reference() {} virtual char const * what() const noexcept { return "invalid builtin value reference, this kind of builtin value was not declared in the environment"; } - virtual expr const & get_main_expr() const { return m_expr; } + virtual optional get_main_expr() const { return some(m_expr); } virtual exception * clone() const { return new invalid_builtin_value_reference(m_env, m_expr); } virtual void rethrow() const { throw *this; } }; @@ -264,7 +263,7 @@ public: unexpected_metavar_occurrence(environment const & env, expr const & e):kernel_exception(env), m_expr(e) {} virtual ~unexpected_metavar_occurrence() {} virtual char const * what() const noexcept { return "unexpected metavariable occurrence"; } - virtual expr const & get_main_expr() const { return m_expr; } + virtual optional get_main_expr() const { return some(m_expr); } virtual exception * clone() const { return new unexpected_metavar_occurrence(m_env, m_expr); } virtual void rethrow() const { throw *this; } }; diff --git a/src/kernel/metavar.cpp b/src/kernel/metavar.cpp index 01d5c37f9..20907e234 100644 --- a/src/kernel/metavar.cpp +++ b/src/kernel/metavar.cpp @@ -48,7 +48,7 @@ public: append(r, m_jsts); } - virtual expr const & get_main_expr() const { return m_expr; } + virtual optional get_main_expr() const { return some(m_expr); } }; void swap(metavar_env & a, metavar_env & b) { @@ -77,7 +77,7 @@ metavar_env::metavar_env(): metavar_env(g_default_name) { } -expr metavar_env::mk_metavar(context const & ctx, expr const & type) { +expr metavar_env::mk_metavar(context const & ctx, optional const & type) { inc_timestamp(); name m = m_name_generator.next(); expr r = ::lean::mk_metavar(m); @@ -115,7 +115,7 @@ expr metavar_env::get_type(name const & m) { auto it = const_cast(this)->m_metavar_data.splay_find(m); lean_assert(it); if (it->m_type) { - return it->m_type; + return *(it->m_type); } else { expr t = mk_metavar(get_context(m)); it->m_type = t; @@ -134,13 +134,17 @@ bool metavar_env::has_type(expr const & m) const { return has_type(metavar_name(m)); } -justification metavar_env::get_justification(expr const & m) const { +optional metavar_env::get_justification(expr const & m) const { lean_assert(is_metavar(m)); return get_justification(metavar_name(m)); } -justification metavar_env::get_justification(name const & m) const { - return get_subst_jst(m).second; +optional metavar_env::get_justification(name const & m) const { + auto r = get_subst_jst(m); + if (r) + return optional(r->second); + else + return optional(); } bool metavar_env::is_assigned(name const & m) const { @@ -187,44 +191,53 @@ expr apply_local_context(expr const & a, local_context const & lctx) { } } -std::pair metavar_env::get_subst_jst(expr const & m) const { +optional> metavar_env::get_subst_jst(expr const & m) const { lean_assert(is_metavar(m)); auto p = get_subst_jst(metavar_name(m)); - expr r = p.first; - if (p.first) { + if (p) { + expr r = p->first; local_context const & lctx = metavar_lctx(m); if (lctx) r = apply_local_context(r, lctx); - return mk_pair(r, p.second); + return some(mk_pair(r, p->second)); } else { return p; } } -std::pair metavar_env::get_subst_jst(name const & m) const { +optional> metavar_env::get_subst_jst(name const & m) const { auto it = const_cast(this)->m_metavar_data.splay_find(m); if (it->m_subst) { - if (has_assigned_metavar(it->m_subst, *this)) { + expr s = *(it->m_subst); + if (has_assigned_metavar(s, *this)) { buffer jsts; - expr new_subst = instantiate_metavars(it->m_subst, *this, jsts); + expr new_subst = instantiate_metavars(s, *this, jsts); if (!jsts.empty()) { - it->m_justification = justification(new normalize_assignment_justification(it->m_context, it->m_subst, it->m_justification, + it->m_justification = justification(new normalize_assignment_justification(it->m_context, s, it->m_justification, jsts.size(), jsts.data())); it->m_subst = new_subst; } } - return mk_pair(it->m_subst, it->m_justification); + return optional>(std::pair(*(it->m_subst), it->m_justification)); } else { - return mk_pair(expr(), justification()); + return optional>(); } } -expr metavar_env::get_subst(name const & m) const { - return get_subst_jst(m).first; +optional metavar_env::get_subst(name const & m) const { + auto r = get_subst_jst(m); + if (r) + return optional(r->first); + else + return optional(); } -expr metavar_env::get_subst(expr const & m) const { - return get_subst_jst(m).first; +optional metavar_env::get_subst(expr const & m) const { + auto r = get_subst_jst(m); + if (r) + return optional(r->first); + else + return optional(); } class instantiate_metavars_proc : public replace_visitor { @@ -240,8 +253,9 @@ protected: virtual expr visit_metavar(expr const & m, context const & ctx) { if (is_metavar(m) && m_menv.is_assigned(m)) { auto p = m_menv.get_subst_jst(m); - expr r = p.first; - push_back(p.second); + lean_assert(p); + expr r = p->first; + push_back(p->second); if (has_assigned_metavar(r, m_menv)) { return visit(r, ctx); } else { @@ -275,7 +289,7 @@ public: }; expr instantiate_metavars(expr const & e, metavar_env const & menv, buffer & jsts) { - if (!e || !has_metavar(e)) { + if (!has_metavar(e)) { return e; } else { return instantiate_metavars_proc(menv, jsts)(e); @@ -365,7 +379,7 @@ bool has_metavar(expr const & e, expr const & m, metavar_env const & menv) { return is_metavar(m2) && ((metavar_name(m) == metavar_name(m2)) || - (menv.is_assigned(m2) && has_metavar(menv.get_subst(m2), m, menv))); + (menv.is_assigned(m2) && has_metavar(*menv.get_subst(m2), m, menv))); })); } else { return false; diff --git a/src/kernel/metavar.h b/src/kernel/metavar.h index e64b62c04..fa01c349d 100644 --- a/src/kernel/metavar.h +++ b/src/kernel/metavar.h @@ -25,11 +25,11 @@ namespace lean { */ class metavar_env { struct data { - expr m_subst; // substitution - expr m_type; // type of the metavariable - context m_context; // context where the metavariable was defined - justification m_justification; // justification for assigned metavariables. - data(expr const & t = expr(), context const & ctx = context()):m_type(t), m_context(ctx) {} + optional m_subst; // substitution + optional m_type; // type of the metavariable + context m_context; // context where the metavariable was defined + justification m_justification; // justification for assigned metavariables. + data(optional const & t = optional(), context const & ctx = context()):m_type(t), m_context(ctx) {} }; typedef splay_map name2data; @@ -62,7 +62,7 @@ public: /** \brief Create a new metavariable in the given context and with the given type. */ - expr mk_metavar(context const & ctx = context(), expr const & type = expr()); + expr mk_metavar(context const & ctx = context(), optional const & type = optional()); /** \brief Return the context where the given metavariable was created. @@ -93,21 +93,16 @@ public: /** \brief Return the substitution and justification for the given metavariable. - - If the metavariable is not assigned in this substitution, then it returns the null - expression. */ - std::pair get_subst_jst(name const & m) const; - std::pair get_subst_jst(expr const & m) const; + optional> get_subst_jst(name const & m) const; + optional> get_subst_jst(expr const & m) const; /** \brief Return the justification for an assigned metavariable. \pre is_metavar(m) - \pre is_assigned(m) */ - justification get_justification(expr const & m) const; - justification get_justification(name const & m) const; - + optional get_justification(expr const & m) const; + optional get_justification(name const & m) const; /** \brief Return true iff the metavariable named \c m is assigned in this substitution. @@ -142,13 +137,10 @@ public: \brief Return the substitution associated with the given metavariable in this substitution. - If the metavariable is not assigned in this substitution, then it returns the null - expression. - \pre is_metavar(m) */ - expr get_subst(expr const & m) const; - expr get_subst(name const & m) const; + optional get_subst(expr const & m) const; + optional get_subst(name const & m) const; /** \brief Apply f to each substitution in the metavariable environment. @@ -157,7 +149,7 @@ public: void for_each_subst(F f) const { m_metavar_data.for_each([&](name const & k, data const & d) { if (d.m_subst) - f(k, d.m_subst); + f(k, *(d.m_subst)); }); } }; diff --git a/src/kernel/normalizer.cpp b/src/kernel/normalizer.cpp index fdddfa638..ffe096174 100644 --- a/src/kernel/normalizer.cpp +++ b/src/kernel/normalizer.cpp @@ -38,10 +38,10 @@ typedef list value_stack; //!< Normalization stack enum class svalue_kind { Expr, Closure, BoundedVar }; /** \brief Stack value: simple expressions, closures and bounded variables. */ class svalue { - svalue_kind m_kind; - unsigned m_bvar; - expr m_expr; - value_stack m_ctx; + svalue_kind m_kind; + unsigned m_bvar; + optional m_expr; + value_stack m_ctx; public: svalue() {} explicit svalue(expr const & e): m_kind(svalue_kind::Expr), m_expr(e) {} @@ -54,9 +54,9 @@ public: bool is_closure() const { return kind() == svalue_kind::Closure; } bool is_bounded_var() const { return kind() == svalue_kind::BoundedVar; } - expr const & get_expr() const { lean_assert(is_expr() || is_closure()); return m_expr; } - value_stack const & get_ctx() const { lean_assert(is_closure()); return m_ctx; } - unsigned get_var_idx() const { lean_assert(is_bounded_var()); return m_bvar; } + expr const & get_expr() const { lean_assert(is_expr() || is_closure()); return *m_expr; } + value_stack const & get_ctx() const { lean_assert(is_closure()); return m_ctx; } + unsigned get_var_idx() const { lean_assert(is_bounded_var()); return m_bvar; } }; svalue_kind kind(svalue const & v) { return v.kind(); } @@ -113,7 +113,7 @@ class normalizer::imp { save_context save(*this); // it restores the context and cache m_ctx = entry_c; unsigned k = m_ctx.size(); - return svalue(reify(normalize(entry.get_body(), value_stack(), k), k)); + return svalue(reify(normalize(*(entry.get_body()), value_stack(), k), k)); } else { return svalue(entry_c.size()); } @@ -123,12 +123,10 @@ class normalizer::imp { expr reify_closure(expr const & a, value_stack const & s, unsigned k) { lean_assert(is_lambda(a)); expr new_t = reify(normalize(abst_domain(a), s, k), k); - expr new_b; { cache::mk_scope sc(m_cache); - new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1); + return mk_lambda(abst_name(a), new_t, reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1)); } - return mk_lambda(abst_name(a), new_t, new_b); } /** \brief Convert the value \c v back into an expression in a context that contains \c k binders. */ @@ -236,9 +234,9 @@ class normalizer::imp { for (; i < n; i++) new_args.push_back(reify(normalize(arg(a, i), s, k), k)); if (is_value(new_f)) { - expr m; - if (to_value(new_f).normalize(new_args.size(), new_args.data(), m)) { - r = normalize(m, s, k); + optional m = to_value(new_f).normalize(new_args.size(), new_args.data()); + if (m) { + r = normalize(*m, s, k); break; } } @@ -263,12 +261,11 @@ class normalizer::imp { break; case expr_kind::Pi: { expr new_t = reify(normalize(abst_domain(a), s, k), k); - expr new_b; { cache::mk_scope sc(m_cache); - new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1); + expr new_b = reify(normalize(abst_body(a), extend(s, svalue(k)), k+1), k+1); + r = svalue(mk_pi(abst_name(a), new_t, new_b)); } - r = svalue(mk_pi(abst_name(a), new_t, new_b)); break; } case expr_kind::Let: { diff --git a/src/kernel/printer.cpp b/src/kernel/printer.cpp index 6f9dc446e..3aca989d6 100644 --- a/src/kernel/printer.cpp +++ b/src/kernel/printer.cpp @@ -138,7 +138,7 @@ struct print_expr_fn { out() << "let " << let_name(a); if (let_type(a)) { out() << " : "; - print(let_type(a), c); + print(*let_type(a), c); } out() << " := "; print(let_value(a), c); @@ -182,10 +182,10 @@ static void display_context_core(std::ostream & out, context const & ctx) { if (!empty(tail_ctx)) out << "; "; out << head.get_name(); - if (head.get_domain()) - out << " : " << mk_pair(head.get_domain(), tail_ctx); - if (head.get_body()) { - out << " := " << mk_pair(head.get_body(), tail_ctx); + if (optional const & d = head.get_domain()) + out << " : " << mk_pair(*d, tail_ctx); + if (optional const & b = head.get_body()) { + out << " := " << mk_pair(*b, tail_ctx); } } } diff --git a/src/kernel/replace_fn.h b/src/kernel/replace_fn.h index 97ec13c87..b90753a31 100644 --- a/src/kernel/replace_fn.h +++ b/src/kernel/replace_fn.h @@ -47,10 +47,15 @@ class replace_fn { F m_f; P m_post; + optional apply(optional const & e, unsigned offset) { + if (e) + return optional(apply(*e, offset)); + else + return optional(); + } + expr apply(expr const & e, unsigned offset) { check_system("expression replacer"); - if (!e) - return e; bool sh = false; if (is_shared(e)) { expr_cell_offset p(e.raw(), offset); @@ -81,9 +86,8 @@ class replace_fn { r = update_abst(e, [=](expr const & t, expr const & b) { return std::make_pair(apply(t, offset), apply(b, offset+1)); }); break; case expr_kind::Let: - r = update_let(e, [=](expr const & t, expr const & v, expr const & b) { - expr new_t = t ? apply(t, offset) : expr(); - return std::make_tuple(new_t, apply(v, offset), apply(b, offset+1)); + r = update_let(e, [=](optional const & t, expr const & v, expr const & b) { + return std::make_tuple(apply(t, offset), apply(v, offset), apply(b, offset+1)); }); break; } diff --git a/src/kernel/replace_visitor.cpp b/src/kernel/replace_visitor.cpp index c6c614ce9..1e962ce31 100644 --- a/src/kernel/replace_visitor.cpp +++ b/src/kernel/replace_visitor.cpp @@ -15,11 +15,7 @@ expr replace_visitor::visit_var(expr const & e, context const &) { lean_assert(i expr replace_visitor::visit_metavar(expr const & e, context const &) { lean_assert(is_metavar(e)); return e; } expr replace_visitor::visit_constant(expr const & e, context const & ctx) { lean_assert(is_constant(e)); - if (const_type(e)) { - return update_const(e, visit(const_type(e), ctx)); - } else { - return e; - } + return update_const(e, visit(const_type(e), ctx)); } expr replace_visitor::visit_app(expr const & e, context const & ctx) { lean_assert(is_app(e)); @@ -33,12 +29,11 @@ expr replace_visitor::visit_abst(expr const & e, context const & ctx) { lean_assert(is_abstraction(e)); return update_abst(e, [&](expr const & t, expr const & b) { expr new_t = visit(t, ctx); - expr new_b; { cache::mk_scope sc(m_cache); - new_b = visit(b, extend(ctx, abst_name(e), new_t)); + expr new_b = visit(b, extend(ctx, abst_name(e), new_t)); + return std::make_pair(new_t, new_b); } - return std::make_pair(new_t, new_b); }); } expr replace_visitor::visit_lambda(expr const & e, context const & ctx) { @@ -49,23 +44,26 @@ expr replace_visitor::visit_pi(expr const & e, context const & ctx) { lean_assert(is_pi(e)); return visit_abst(e, ctx); } + expr replace_visitor::visit_let(expr const & e, context const & ctx) { lean_assert(is_let(e)); - return update_let(e, [&](expr const & t, expr const & v, expr const & b) { - expr new_t = visit(t, ctx); + return update_let(e, [&](optional const & t, expr const & v, expr const & b) { + optional new_t = visit(t, ctx); expr new_v = visit(v, ctx); - expr new_b; { cache::mk_scope sc(m_cache); - new_b = visit(b, extend(ctx, let_name(e), new_t, new_v)); + expr new_b = visit(b, extend(ctx, let_name(e), new_t, new_v)); + return std::make_tuple(new_t, new_v, new_b); } - return std::make_tuple(new_t, new_v, new_b); }); } +expr replace_visitor::save_result(expr const & e, expr && r, bool shared) { + if (shared) + m_cache.insert(std::make_pair(e, r)); + return r; +} expr replace_visitor::visit(expr const & e, context const & ctx) { check_system("expression replacer"); - if (!e) - return e; bool shared = false; if (is_shared(e)) { shared = true; @@ -74,23 +72,25 @@ expr replace_visitor::visit(expr const & e, context const & ctx) { return it->second; } - expr r; switch (e.kind()) { - case expr_kind::Type: r = visit_type(e, ctx); break; - case expr_kind::Value: r = visit_value(e, ctx); break; - case expr_kind::Constant: r = visit_constant(e, ctx); break; - case expr_kind::Var: r = visit_var(e, ctx); break; - case expr_kind::MetaVar: r = visit_metavar(e, ctx); break; - case expr_kind::App: r = visit_app(e, ctx); break; - case expr_kind::Eq: r = visit_eq(e, ctx); break; - case expr_kind::Lambda: r = visit_lambda(e, ctx); break; - case expr_kind::Pi: r = visit_pi(e, ctx); break; - case expr_kind::Let: r = visit_let(e, ctx); break; + case expr_kind::Type: return save_result(e, visit_type(e, ctx), shared); + case expr_kind::Value: return save_result(e, visit_value(e, ctx), shared); + case expr_kind::Constant: return save_result(e, visit_constant(e, ctx), shared); + case expr_kind::Var: return save_result(e, visit_var(e, ctx), shared); + case expr_kind::MetaVar: return save_result(e, visit_metavar(e, ctx), shared); + case expr_kind::App: return save_result(e, visit_app(e, ctx), shared); + case expr_kind::Eq: return save_result(e, visit_eq(e, ctx), shared); + case expr_kind::Lambda: return save_result(e, visit_lambda(e, ctx), shared); + case expr_kind::Pi: return save_result(e, visit_pi(e, ctx), shared); + case expr_kind::Let: return save_result(e, visit_let(e, ctx), shared); } - if (shared) - m_cache.insert(std::make_pair(e, r)); - - return r; + lean_unreachable(); // LCOV_EXCL_LINE +} +optional replace_visitor::visit(optional const & e, context const & ctx) { + if (e) + return some(visit(*e, ctx)); + else + return optional(); } } diff --git a/src/kernel/replace_visitor.h b/src/kernel/replace_visitor.h index 9f402543b..1f8dfe71c 100644 --- a/src/kernel/replace_visitor.h +++ b/src/kernel/replace_visitor.h @@ -24,7 +24,7 @@ protected: typedef scoped_map cache; cache m_cache; context m_ctx; - + expr save_result(expr const & e, expr && r, bool shared); virtual expr visit_type(expr const &, context const &); virtual expr visit_value(expr const &, context const &); virtual expr visit_constant(expr const &, context const &); @@ -37,6 +37,7 @@ protected: virtual expr visit_pi(expr const &, context const &); virtual expr visit_let(expr const &, context const &); virtual expr visit(expr const &, context const &); + optional visit(optional const &, context const &); void set_ctx(context const & ctx) { if (!is_eqp(m_ctx, ctx)) { diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 532d1209f..dbb4c2b23 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -39,14 +39,6 @@ class type_checker::imp { return m_normalizer(e, ctx); } - expr lookup(context const & c, unsigned i) { - auto p = lookup_ext(c, i); - context_entry const & def = p.first; - context const & def_c = p.second; - lean_assert(c.size() > def_c.size()); - return lift_free_vars(def.get_domain(), c.size() - def_c.size()); - } - expr check_pi(expr const & e, expr const & s, context const & ctx) { if (is_pi(e)) return e; @@ -84,6 +76,12 @@ class type_checker::imp { throw type_expected_exception(env(), ctx, s); } + expr save_result(expr const & e, expr const & r, bool shared) { + if (shared) + m_cache.insert(e, r); + return r; + } + expr infer_type_core(expr const & e, context const & ctx) { check_system("type checker"); bool shared = false; @@ -94,36 +92,44 @@ class type_checker::imp { return it->second; } - expr r; switch (e.kind()) { case expr_kind::MetaVar: if (m_menv) { if (m_menv->is_assigned(e)) - return infer_type_core(m_menv->get_subst(e), ctx); + return infer_type_core(*(m_menv->get_subst(e)), ctx); else return m_menv->get_type(e); } else { throw unexpected_metavar_occurrence(env(), e); } - break; case expr_kind::Constant: { if (const_type(e)) { - r = const_type(e); + return save_result(e, *const_type(e), shared); } else { object const & obj = env().get_object(const_name(e)); if (obj.has_type()) - r = obj.get_type(); + return save_result(e, obj.get_type(), shared); else throw has_no_type_exception(env(), e); } - break; } - case expr_kind::Var: - r = lookup(ctx, var_idx(e)); - break; + case expr_kind::Var: { + unsigned i = var_idx(e); + auto p = lookup_ext(ctx, i); + context_entry const & def = p.first; + context const & def_ctx = p.second; + lean_assert(ctx.size() > def_ctx.size()); + if (optional const & d = def.get_domain()) { + return save_result(e, lift_free_vars(*d, ctx.size() - def_ctx.size()), shared); + } else { + optional const & b = def.get_body(); + lean_assert(b); + expr t = infer_type_core(*b, def_ctx); + return save_result(e, lift_free_vars(t, ctx.size() - def_ctx.size()), shared); + } + } case expr_kind::Type: - r = mk_type(ty_level(e) + 1); - break; + return save_result(e, mk_type(ty_level(e) + 1), shared); case expr_kind::App: { unsigned num = num_args(e); lean_assert(num >= 2); @@ -146,81 +152,71 @@ class type_checker::imp { else f_t = instantiate(abst_body(f_t), c); i++; - if (i == num) { - r = f_t; - break; - } + if (i == num) + return save_result(e, f_t, shared); f_t = check_pi(f_t, e, ctx); } - break; } case expr_kind::Eq: infer_type_core(eq_lhs(e), ctx); infer_type_core(eq_rhs(e), ctx); - r = mk_bool_type(); - break; + return save_result(e, mk_bool_type(), shared); case expr_kind::Lambda: { expr d = infer_type_core(abst_domain(e), ctx); check_type(d, abst_domain(e), ctx); - expr t; { cache::mk_scope sc(m_cache); - t = infer_type_core(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); + return save_result(e, + mk_pi(abst_name(e), abst_domain(e), infer_type_core(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)))), + shared); } - r = mk_pi(abst_name(e), abst_domain(e), t); - break; } case expr_kind::Pi: { expr t1 = check_type(infer_type_core(abst_domain(e), ctx), abst_domain(e), ctx); - expr t2; + optional t2; context new_ctx = extend(ctx, abst_name(e), abst_domain(e)); { cache::mk_scope sc(m_cache); t2 = check_type(infer_type_core(abst_body(e), new_ctx), abst_body(e), new_ctx); } - if (is_type(t1) && is_type(t2)) { - r = mk_type(max(ty_level(t1), ty_level(t2))); + if (is_type(t1) && is_type(*t2)) { + return save_result(e, mk_type(max(ty_level(t1), ty_level(*t2))), shared); } else { lean_assert(m_uc); justification jst = mk_max_type_justification(ctx, e); - r = m_menv->mk_metavar(ctx); - m_uc->push_back(mk_max_constraint(new_ctx, lift_free_vars(t1, 0, 1), t2, r, jst)); + expr r = m_menv->mk_metavar(ctx); + m_uc->push_back(mk_max_constraint(new_ctx, lift_free_vars(t1, 0, 1), *t2, r, jst)); + return save_result(e, r, shared); } - break; } case expr_kind::Let: { expr lt = infer_type_core(let_value(e), ctx); if (let_type(e)) { - expr ty = infer_type_core(let_type(e), ctx); - check_type(ty, let_type(e), ctx); // check if it is really a type - auto mk_justification = [&](){ return mk_def_type_match_justification(ctx, let_name(e), let_value(e)); }; // thunk for creating justification object if needed - if (!is_convertible(lt, let_type(e), ctx, mk_justification)) - throw def_type_mismatch_exception(env(), ctx, let_name(e), let_type(e), let_value(e), lt); + expr ty = infer_type_core(*let_type(e), ctx); + check_type(ty, *let_type(e), ctx); // check if it is really a type + // thunk for creating justification object if needed + auto mk_justification = [&](){ return mk_def_type_match_justification(ctx, let_name(e), let_value(e)); }; + if (!is_convertible(lt, *let_type(e), ctx, mk_justification)) + throw def_type_mismatch_exception(env(), ctx, let_name(e), *let_type(e), let_value(e), lt); } { cache::mk_scope sc(m_cache); expr t = infer_type_core(let_body(e), extend(ctx, let_name(e), lt, let_value(e))); - r = instantiate(t, let_value(e)); + return save_result(e, instantiate(t, let_value(e)), shared); } - break; } case expr_kind::Value: { // Check if the builtin value (or its set) is declared in the environment. name const & n = to_value(e).get_name(); object const & obj = env().get_object(n); if (obj && ((obj.is_builtin() && obj.get_value() == e) || (obj.is_builtin_set() && obj.in_builtin_set(e)))) { - r = to_value(e).get_type(); + return save_result(e, to_value(e).get_type(), shared); } else { throw invalid_builtin_value_reference(env(), e); } - break; } } - - if (shared) { - m_cache.insert(e, r); - } - return r; + lean_unreachable(); // LCOV_EXCL_LINE } bool is_convertible_core(expr const & given, expr const & expected) { diff --git a/src/kernel/type_checker_justification.cpp b/src/kernel/type_checker_justification.cpp index 1c52af185..da70ac680 100644 --- a/src/kernel/type_checker_justification.cpp +++ b/src/kernel/type_checker_justification.cpp @@ -23,8 +23,8 @@ format function_expected_justification_cell::pp_header(formatter const & fmt, op void function_expected_justification_cell::get_children(buffer &) const { } -expr const & function_expected_justification_cell::get_main_expr() const { - return m_app; +optional function_expected_justification_cell::get_main_expr() const { + return some(m_app); } app_type_match_justification_cell::~app_type_match_justification_cell() { @@ -51,8 +51,8 @@ format app_type_match_justification_cell::pp_header(formatter const & fmt, optio void app_type_match_justification_cell::get_children(buffer &) const { } -expr const & app_type_match_justification_cell::get_main_expr() const { - return m_app; +optional app_type_match_justification_cell::get_main_expr() const { + return some(m_app); } type_expected_justification_cell::~type_expected_justification_cell() { @@ -70,8 +70,8 @@ format type_expected_justification_cell::pp_header(formatter const & fmt, option void type_expected_justification_cell::get_children(buffer &) const { } -expr const & type_expected_justification_cell::get_main_expr() const { - return m_type; +optional type_expected_justification_cell::get_main_expr() const { + return some(m_type); } def_type_match_justification_cell::~def_type_match_justification_cell() { @@ -88,8 +88,8 @@ format def_type_match_justification_cell::pp_header(formatter const &, options c void def_type_match_justification_cell::get_children(buffer &) const { } -expr const & def_type_match_justification_cell::get_main_expr() const { - return m_value; +optional def_type_match_justification_cell::get_main_expr() const { + return some(m_value); } type_match_justification_cell::~type_match_justification_cell() { @@ -102,7 +102,7 @@ format type_match_justification_cell::pp_header(formatter const &, options const void type_match_justification_cell::get_children(buffer &) const { } -expr const & type_match_justification_cell::get_main_expr() const { - return m_value; +optional type_match_justification_cell::get_main_expr() const { + return some(m_value); } } diff --git a/src/kernel/type_checker_justification.h b/src/kernel/type_checker_justification.h index ebd947527..9f559c15e 100644 --- a/src/kernel/type_checker_justification.h +++ b/src/kernel/type_checker_justification.h @@ -29,7 +29,7 @@ public: virtual ~function_expected_justification_cell(); virtual format pp_header(formatter const & fmt, options const & opts) const; virtual void get_children(buffer &) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; context const & get_context() const { return m_ctx; } expr const & get_app() const { return m_app; } }; @@ -54,7 +54,7 @@ public: virtual ~app_type_match_justification_cell(); virtual format pp_header(formatter const & fmt, options const & opts) const; virtual void get_children(buffer &) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; context const & get_context() const { return m_ctx; } expr const & get_app() const { return m_app; } unsigned get_arg_pos() const { return m_i; } @@ -77,7 +77,7 @@ public: virtual ~type_expected_justification_cell(); virtual format pp_header(formatter const & fmt, options const & opts) const; virtual void get_children(buffer &) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; context const & get_context() const { return m_ctx; } expr const & get_type() const { return m_type; } }; @@ -114,7 +114,7 @@ public: virtual ~def_type_match_justification_cell(); virtual format pp_header(formatter const & fmt, options const & opts) const; virtual void get_children(buffer &) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; context const & get_context() const { return m_ctx; } name const & get_name() const { return m_name; } expr const & get_value() const { return m_value; } @@ -133,7 +133,7 @@ public: virtual ~type_match_justification_cell(); virtual format pp_header(formatter const & fmt, options const & opts) const; virtual void get_children(buffer &) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; context const & get_context() const { return m_ctx; } expr const & get_type() const { return m_type; } expr const & get_value() const { return m_value; } diff --git a/src/library/arith/int.cpp b/src/library/arith/int.cpp index 0f070800d..cd7dd3b00 100644 --- a/src/library/arith/int.cpp +++ b/src/library/arith/int.cpp @@ -72,12 +72,11 @@ template class int_bin_op : public const_value { public: int_bin_op():const_value(name("Int", Name), Int >> (Int >> Int)) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 3 && is_int_value(args[1]) && is_int_value(args[2])) { - r = mk_int_value(F()(int_value_numeral(args[1]), int_value_numeral(args[2]))); - return true; + return some(mk_int_value(F()(int_value_numeral(args[1]), int_value_numeral(args[2])))); } else { - return false; + return optional(); } } }; @@ -107,12 +106,11 @@ MK_BUILTIN(int_div_fn, int_div_value); class int_le_value : public const_value { public: int_le_value():const_value(name{"Int", "le"}, Int >> (Int >> Bool)) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 3 && is_int_value(args[1]) && is_int_value(args[2])) { - r = mk_bool_value(int_value_numeral(args[1]) <= int_value_numeral(args[2])); - return true; + return some(mk_bool_value(int_value_numeral(args[1]) <= int_value_numeral(args[2]))); } else { - return false; + return optional(); } } }; @@ -133,12 +131,11 @@ MK_CONSTANT(int_gt_fn, name({"Int", "gt"})); class nat_to_int_value : public const_value { public: nat_to_int_value():const_value("nat_to_int", Nat >> Int) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 2 && is_nat_value(args[1])) { - r = mk_int_value(nat_value_numeral(args[1])); - return true; + return some(mk_int_value(nat_value_numeral(args[1]))); } else { - return false; + return optional(); } } }; diff --git a/src/library/arith/nat.cpp b/src/library/arith/nat.cpp index 10e2a66e7..a96549b4d 100644 --- a/src/library/arith/nat.cpp +++ b/src/library/arith/nat.cpp @@ -66,12 +66,11 @@ template class nat_bin_op : public const_value { public: nat_bin_op():const_value(name("Nat", Name), Nat >> (Nat >> Nat)) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 3 && is_nat_value(args[1]) && is_nat_value(args[2])) { - r = mk_nat_value(F()(nat_value_numeral(args[1]), nat_value_numeral(args[2]))); - return true; + return some(mk_nat_value(F()(nat_value_numeral(args[1]), nat_value_numeral(args[2])))); } else { - return false; + return optional(); } } }; @@ -95,12 +94,11 @@ MK_BUILTIN(nat_mul_fn, nat_mul_value); class nat_le_value : public const_value { public: nat_le_value():const_value(name{"Nat", "le"}, Nat >> (Nat >> Bool)) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 3 && is_nat_value(args[1]) && is_nat_value(args[2])) { - r = mk_bool_value(nat_value_numeral(args[1]) <= nat_value_numeral(args[2])); - return true; + return some(mk_bool_value(nat_value_numeral(args[1]) <= nat_value_numeral(args[2]))); } else { - return false; + return optional(); } } }; diff --git a/src/library/arith/real.cpp b/src/library/arith/real.cpp index 93c26e03f..6ed5b8952 100644 --- a/src/library/arith/real.cpp +++ b/src/library/arith/real.cpp @@ -75,12 +75,11 @@ template class real_bin_op : public const_value { public: real_bin_op():const_value(name("Real", Name), Real >> (Real >> Real)) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 3 && is_real_value(args[1]) && is_real_value(args[2])) { - r = mk_real_value(F()(real_value_numeral(args[1]), real_value_numeral(args[2]))); - return true; + return some(mk_real_value(F()(real_value_numeral(args[1]), real_value_numeral(args[2])))); } else { - return false; + return optional(); } } }; @@ -117,12 +116,11 @@ MK_BUILTIN(real_div_fn, real_div_value); class real_le_value : public const_value { public: real_le_value():const_value(name{"Real", "le"}, Real >> (Real >> Bool)) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 3 && is_real_value(args[1]) && is_real_value(args[2])) { - r = mk_bool_value(real_value_numeral(args[1]) <= real_value_numeral(args[2])); - return true; + return some(mk_bool_value(real_value_numeral(args[1]) <= real_value_numeral(args[2]))); } else { - return false; + return optional(); } } }; @@ -167,12 +165,11 @@ void import_real(environment & env) { class int_to_real_value : public const_value { public: int_to_real_value():const_value("int_to_real", Int >> Real) {} - virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + virtual optional normalize(unsigned num_args, expr const * args) const { if (num_args == 2 && is_int_value(args[1])) { - r = mk_real_value(mpq(int_value_numeral(args[1]))); - return true; + return some(mk_real_value(mpq(int_value_numeral(args[1])))); } else { - return false; + return optional(); } } }; diff --git a/src/library/cast/cast.cpp b/src/library/cast/cast.cpp index 6ce705d06..6fadb8eeb 100644 --- a/src/library/cast/cast.cpp +++ b/src/library/cast/cast.cpp @@ -27,14 +27,13 @@ public: virtual ~cast_fn_value() {} virtual expr get_type() const { return m_type; } virtual name get_name() const { return g_cast_name; } - virtual bool normalize(unsigned num_as, expr const * as, expr & r) const { + virtual optional normalize(unsigned num_as, expr const * as) const { if (num_as > 4 && as[1] == as[2]) { // Cast T T H a == a if (num_as == 5) - r = as[4]; + return some(as[4]); else - r = mk_app(num_as - 4, as + 4); - return true; + return some(mk_app(num_as - 4, as + 4)); } else if (is_app(as[4]) && arg(as[4], 0) == mk_Cast_fn() && num_args(as[4]) == 5 && @@ -49,14 +48,13 @@ public: expr const & a = arg(nested, 4); expr c = Cast(T3, T2, Trans(TypeU, T3, T1, T2, H1, H2), a); if (num_as == 5) { - r = c; + return optional(c); } else { buffer new_as; new_as.push_back(c); new_as.append(num_as - 5, as + 5); - r = mk_app(new_as); + return optional(mk_app(new_as)); } - return true; } else if (num_as > 5 && is_pi(as[1]) && is_pi(as[2])) { // cast T1 T2 H f a_1 ... a_k // Propagate application over cast. @@ -86,16 +84,15 @@ public: expr B1_eq_B2_at_a_1p = RanInj(A1, A2, B1f, B2f, H, a_1p); expr fa_1_B2 = Cast(B1, B2, B1_eq_B2_at_a_1p, fa_1); if (num_as == 6) { - r = fa_1_B2; + return optional(fa_1_B2); } else { buffer new_as; new_as.push_back(fa_1_B2); new_as.append(num_as - 6, as + 6); - r = mk_app(new_as); + return optional(mk_app(new_as)); } - return true; } else { - return false; + return optional(); } } }; diff --git a/src/library/context_to_lambda.cpp b/src/library/context_to_lambda.cpp index d65fe21f1..04507e58f 100644 --- a/src/library/context_to_lambda.cpp +++ b/src/library/context_to_lambda.cpp @@ -13,12 +13,18 @@ expr context_to_lambda(context::iterator it, context::iterator const & end, expr return e; } else { context_entry const & entry = *it; - expr t; - if (entry.get_body()) - t = mk_app(g_fake, entry.get_domain(), entry.get_body()); + optional t; + optional const & d = entry.get_domain(); + optional const & b = entry.get_body(); + lean_assert(b || d); + if (b && d) + t = mk_app(g_fake, *d, *b); + else if (d) + t = mk_app(g_fake, *d, g_fake); else - t = mk_app(g_fake, entry.get_domain()); - return context_to_lambda(++it, end, mk_lambda(entry.get_name(), t, e)); + t = mk_app(g_fake, g_fake, *b); + lean_assert(t); + return context_to_lambda(++it, end, mk_lambda(entry.get_name(), *t, e)); } } expr context_to_lambda(context const & c, expr const & e) { @@ -31,16 +37,21 @@ name const & fake_context_name(expr const & e) { lean_assert(is_fake_context(e)); return abst_name(e); } -expr const & fake_context_domain(expr const & e) { +optional fake_context_domain(expr const & e) { lean_assert(is_fake_context(e)); - return arg(abst_domain(e), 1); -} -expr const & fake_context_value(expr const & e) { - lean_assert(is_fake_context(e)); - if (num_args(abst_domain(e)) > 2) - return arg(abst_domain(e), 2); + expr r = arg(abst_domain(e), 1); + if (!is_eqp(r, g_fake)) + return optional(r); else - return expr::null(); + return optional(); +} +optional fake_context_value(expr const & e) { + lean_assert(is_fake_context(e)); + expr r = arg(abst_domain(e), 2); + if (!is_eqp(r, g_fake)) + return optional(r); + else + return optional(); } expr const & fake_context_rest(expr const & e) { return abst_body(e); diff --git a/src/library/context_to_lambda.h b/src/library/context_to_lambda.h index fec661c30..bc3ae3344 100644 --- a/src/library/context_to_lambda.h +++ b/src/library/context_to_lambda.h @@ -42,15 +42,14 @@ name const & fake_context_name(expr const & e); \pre is_fake_context(e) */ -expr const & fake_context_domain(expr const & e); +optional fake_context_domain(expr const & e); /** - \brief Return the value V_1 of the head of the (fake) context + \brief Return (if available) the value V_1 of the head of the (fake) context (lambda (n_1 : (fake T_1 V_1)) ... (lambda (n_k : (fake T_k V_k)) e)) \pre is_fake_context(e) - \remark If the head does not have a value, then return a null expression */ -expr const & fake_context_value(expr const & e); +optional fake_context_value(expr const & e); /** \brief Return the rest (lambda (n_2 : (fake T_2 V_2)) ... (lambda (n_k : (fake T_k V_k)) e)) of the fake context (lambda (n_1 : (fake T_1 V_1)) ... (lambda (n_k : (fake T_k V_k)) e)) diff --git a/src/library/deep_copy.cpp b/src/library/deep_copy.cpp index f39905a38..4e9cc3ffd 100644 --- a/src/library/deep_copy.cpp +++ b/src/library/deep_copy.cpp @@ -13,8 +13,20 @@ namespace lean { class deep_copy_fn { expr_cell_map m_cache; + expr save_result(expr const & a, expr && r, bool shared) { + if (shared) + m_cache.insert(std::make_pair(a.raw(), r)); + return r; + } + + optional apply(optional const & a) { + if (a) + return some(apply(*a)); + else + return a; + } + expr apply(expr const & a) { - if (!a) return a; bool sh = false; if (is_shared(a)) { auto r = m_cache.find(a.raw()); @@ -22,33 +34,30 @@ class deep_copy_fn { return r->second; sh = true; } - expr r; switch (a.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Type: case expr_kind::Value: - r = copy(a); break; // shallow copy is equivalent to deep copy for these ones. + return save_result(a, copy(a), sh); case expr_kind::App: { buffer new_args; for (expr const & old_arg : args(a)) new_args.push_back(apply(old_arg)); - r = mk_app(new_args); - break; + return save_result(a, mk_app(new_args), sh); } - case expr_kind::Eq: r = mk_eq(apply(eq_lhs(a)), apply(eq_rhs(a))); break; - case expr_kind::Lambda: r = mk_lambda(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))); break; - case expr_kind::Pi: r = mk_pi(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))); break; - case expr_kind::Let: r = mk_let(let_name(a), apply(let_type(a)), apply(let_value(a)), apply(let_body(a))); break; + case expr_kind::Eq: return save_result(a, mk_eq(apply(eq_lhs(a)), apply(eq_rhs(a))), sh); + case expr_kind::Lambda: return save_result(a, mk_lambda(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))), sh); + case expr_kind::Pi: return save_result(a, mk_pi(abst_name(a), apply(abst_domain(a)), apply(abst_body(a))), sh); + case expr_kind::Let: return save_result(a, mk_let(let_name(a), apply(let_type(a)), apply(let_value(a)), apply(let_body(a))), sh); case expr_kind::MetaVar: - r = update_metavar(a, [&](local_entry const & e) -> local_entry { - if (e.is_inst()) - return mk_inst(e.s(), apply(e.v())); - else - return e; - }); - break; + return save_result(a, + update_metavar(a, [&](local_entry const & e) -> local_entry { + if (e.is_inst()) + return mk_inst(e.s(), apply(e.v())); + else + return e; + }), + sh); } - if (sh) - m_cache.insert(std::make_pair(a.raw(), r)); - return r; + lean_unreachable(); // LCOV_EXCL_LINE } public: /** diff --git a/src/library/elaborator/elaborator.cpp b/src/library/elaborator/elaborator.cpp index 8150b4b6c..7ee8a7e0f 100644 --- a/src/library/elaborator/elaborator.cpp +++ b/src/library/elaborator/elaborator.cpp @@ -187,7 +187,7 @@ class elaborator::imp { std::pair get_subst_jst(expr const & m) const { lean_assert(is_metavar(m)); lean_assert(is_assigned(m)); - return m_state.m_menv.get_subst_jst(m); + return *(m_state.m_menv.get_subst_jst(m)); } /** \brief Return the type of an metavariable */ @@ -523,7 +523,7 @@ class elaborator::imp { try { context_entry const & e = lookup(ctx, var_idx(a)); if (e.get_body()) - a = e.get_body(); + a = *(e.get_body()); } catch (exception&) { } } @@ -549,14 +549,14 @@ class elaborator::imp { if (curr != new_curr) { modified = true; new_args[i] = new_curr; - if (to_value(f).normalize(new_args.size(), new_args.data(), r)) { - a = r; + if (optional r = to_value(f).normalize(new_args.size(), new_args.data())) { + a = *r; return; } } } - if (to_value(f).normalize(new_args.size(), new_args.data(), r)) { - a = r; + if (optional r = to_value(f).normalize(new_args.size(), new_args.data())) { + a = *r; return; } if (modified) { @@ -720,8 +720,13 @@ class elaborator::imp { if (is_simple_ho_match(ctx, a, b, is_lhs, c)) { expr m = arg(a, 0); buffer types; - for (unsigned i = 1; i < num_args(a); i++) - types.push_back(lookup(ctx, var_idx(arg(a, i))).get_domain()); + for (unsigned i = 1; i < num_args(a); i++) { + optional d = lookup(ctx, var_idx(arg(a, i))).get_domain(); + if (d) + types.push_back(*d); + else + return false; + } justification new_jst(new destruct_justification(c)); expr s = mk_lambda(types, b); if (!is_lhs) @@ -811,7 +816,6 @@ class elaborator::imp { // Assign f_a <- fun (x_1 : T_0) ... (x_{num_a-1} : T_{num_a-1}), b imitation = mk_lambda(arg_types, lift_free_vars(b, 0, num_a - 1)); } - lean_assert(imitation); push_new_eq_constraint(new_state.m_queue, ctx, f_a, imitation, new_assumption); new_cs->push_back(new_state, new_assumption); } diff --git a/src/library/elaborator/elaborator_justification.cpp b/src/library/elaborator/elaborator_justification.cpp index 0d04e0c93..51fb95445 100644 --- a/src/library/elaborator/elaborator_justification.cpp +++ b/src/library/elaborator/elaborator_justification.cpp @@ -18,8 +18,8 @@ propagation_justification::~propagation_justification() { void propagation_justification::get_children(buffer & r) const { push_back(r, m_constraint.get_justification()); } -expr const & propagation_justification::get_main_expr() const { - return expr::null(); +optional propagation_justification::get_main_expr() const { + return optional(); } format propagation_justification::pp_header(formatter const & fmt, options const & opts) const { format r; @@ -120,8 +120,8 @@ format synthesis_justification::pp_header(formatter const & fmt, options const & void synthesis_justification::get_children(buffer & r) const { append(r, m_substitution_justifications); } -expr const & synthesis_justification::get_main_expr() const { - return m_mvar; +optional synthesis_justification::get_main_expr() const { + return some(m_mvar); } char const * synthesis_failure_justification::get_label() const { @@ -156,7 +156,7 @@ format next_solution_justification::pp_header(formatter const &, options const & void next_solution_justification::get_children(buffer & r) const { append(r, m_assumptions); } -expr const & next_solution_justification::get_main_expr() const { - return expr::null(); +optional next_solution_justification::get_main_expr() const { + return optional(); } } diff --git a/src/library/elaborator/elaborator_justification.h b/src/library/elaborator/elaborator_justification.h index fd33a019c..298baa4d6 100644 --- a/src/library/elaborator/elaborator_justification.h +++ b/src/library/elaborator/elaborator_justification.h @@ -21,7 +21,7 @@ public: propagation_justification(unification_constraint const & c); virtual ~propagation_justification(); virtual void get_children(buffer & r) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; virtual format pp_header(formatter const &, options const &) const; unification_constraint const & get_constraint() const { return m_constraint; } }; @@ -168,7 +168,7 @@ public: virtual ~synthesis_justification(); virtual format pp_header(formatter const &, options const &) const; virtual void get_children(buffer & r) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; }; /** @@ -212,6 +212,6 @@ public: virtual ~next_solution_justification(); virtual format pp_header(formatter const &, options const &) const; virtual void get_children(buffer & r) const; - virtual expr const & get_main_expr() const; + virtual optional get_main_expr() const; }; }; diff --git a/src/library/expr_lt.cpp b/src/library/expr_lt.cpp index f6fab22d2..8da4cd39a 100644 --- a/src/library/expr_lt.cpp +++ b/src/library/expr_lt.cpp @@ -7,10 +7,16 @@ Author: Leonardo de Moura #include "kernel/expr.h" namespace lean { +bool is_lt(expr const & a, expr const & b, bool use_hash); +static bool is_lt(optional const & a, optional const & b, bool use_hash) { + if (is_eqp(a, b)) return false; + else if (!a && b) return true; + else if (a && !b) return false; + else return is_lt(*a, *b, use_hash); +} + bool is_lt(expr const & a, expr const & b, bool use_hash) { if (is_eqp(a, b)) return false; - if (!a && b) return true; // the null expression is the smallest one - if (a && !b) return false; if (a.kind() != b.kind()) return a.kind() < b.kind(); if (use_hash) { if (a.hash() < b.hash()) return true; diff --git a/src/library/fo_unify.cpp b/src/library/fo_unify.cpp index 1130ef782..6c8a90954 100644 --- a/src/library/fo_unify.cpp +++ b/src/library/fo_unify.cpp @@ -20,16 +20,21 @@ static bool is_metavar_wo_local_context(expr const & e) { return is_metavar(e) && !metavar_lctx(e); } -static bool is_eq_heq(expr const & e, expr & lhs, expr & rhs) { - return is_eq(e, lhs, rhs) || is_homo_eq(e, lhs, rhs); +static bool is_eq_heq(expr const & e) { + return is_eq(e) || is_homo_eq(e); +} + +static expr_pair eq_heq_args(expr const & e) { + lean_assert(is_eq(e) || is_homo_eq(e)); + if (is_eq(e)) + return expr_pair(eq_lhs(e), eq_rhs(e)); + else + return expr_pair(arg(e, 1), arg(e, 2)); } optional fo_unify(expr e1, expr e2) { - lean_assert(e1); - lean_assert(e2); substitution s; unsigned i1, i2; - expr lhs1, rhs1, lhs2, rhs2; buffer todo; todo.emplace_back(e1, e2); while (!todo.empty()) { @@ -42,9 +47,11 @@ optional fo_unify(expr e1, expr e2) { assign(s, e1, e2); } else if (is_metavar_wo_local_context(e2)) { assign(s, e2, e1); - } else if (is_eq_heq(e1, lhs1, rhs1) && is_eq_heq(e2, lhs2, rhs2)) { - todo.emplace_back(lhs1, lhs2); - todo.emplace_back(rhs1, rhs2); + } else if (is_eq_heq(e1) && is_eq_heq(e2)) { + expr_pair p1 = eq_heq_args(e1); + expr_pair p2 = eq_heq_args(e2); + todo.emplace_back(p1.second, p2.second); + todo.emplace_back(p1.first, p2.first); } else { if (e1.kind() != e2.kind()) return optional(); @@ -79,7 +86,7 @@ optional fo_unify(expr e1, expr e2) { return optional(); if (let_type(e1)) { lean_assert(let_type(e2)); - todo.emplace_back(let_type(e1), let_type(e2)); + todo.emplace_back(*let_type(e1), *let_type(e2)); } break; } @@ -91,7 +98,7 @@ optional fo_unify(expr e1, expr e2) { static int fo_unify(lua_State * L) { - optional r = fo_unify(to_nonnull_expr(L, 1), to_nonnull_expr(L, 2)); + optional r = fo_unify(to_expr(L, 1), to_expr(L, 2)); if (!r) { lua_pushnil(L); return 1; diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 2051b3a8e..18c6dab64 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -162,7 +162,7 @@ static int local_entry_mk_lift(lua_State * L) { } static int local_entry_mk_inst(lua_State * L) { - return push_local_entry(L, mk_inst(luaL_checkinteger(L, 1), to_nonnull_expr(L, 2))); + return push_local_entry(L, mk_inst(luaL_checkinteger(L, 1), to_expr(L, 2))); } static int local_entry_is_lift(lua_State * L) { @@ -268,15 +268,16 @@ static void open_local_context(lua_State * L) { DECL_UDATA(expr) -expr & to_nonnull_expr(lua_State * L, int idx) { - expr & r = to_expr(L, idx); - if (!r) - throw exception("non-null Lean expression expected"); - return r; +int push_optional_expr(lua_State * L, optional const & e) { + if (e) + push_expr(L, *e); + else + lua_pushnil(L); + return 1; } expr & to_app(lua_State * L, int idx) { - expr & r = to_nonnull_expr(L, idx); + expr & r = to_expr(L, idx); if (!is_app(r)) throw exception("Lean application expression expected"); return r; @@ -284,14 +285,9 @@ expr & to_app(lua_State * L, int idx) { static int expr_tostring(lua_State * L) { std::ostringstream out; - expr & e = to_expr(L, 1); - if (e) { - formatter fmt = get_global_formatter(L); - options opts = get_global_options(L); - out << mk_pair(fmt(to_expr(L, 1), opts), opts); - } else { - out << ""; - } + formatter fmt = get_global_formatter(L); + options opts = get_global_options(L); + out << mk_pair(fmt(to_expr(L, 1), opts), opts); lua_pushstring(L, out.str().c_str()); return 1; } @@ -320,39 +316,39 @@ static int expr_mk_app(lua_State * L) { throw exception("application must have at least two arguments"); buffer args; for (int i = 1; i <= nargs; i++) - args.push_back(to_nonnull_expr(L, i)); + args.push_back(to_expr(L, i)); return push_expr(L, mk_app(args)); } static int expr_mk_eq(lua_State * L) { - return push_expr(L, mk_eq(to_nonnull_expr(L, 1), to_nonnull_expr(L, 2))); + return push_expr(L, mk_eq(to_expr(L, 1), to_expr(L, 2))); } static int expr_mk_lambda(lua_State * L) { - return push_expr(L, mk_lambda(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); + return push_expr(L, mk_lambda(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); } static int expr_mk_pi(lua_State * L) { - return push_expr(L, mk_pi(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); + return push_expr(L, mk_pi(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); } static int expr_mk_arrow(lua_State * L) { - return push_expr(L, mk_arrow(to_nonnull_expr(L, 1), to_nonnull_expr(L, 2))); + return push_expr(L, mk_arrow(to_expr(L, 1), to_expr(L, 2))); } static int expr_mk_let(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 3) - return push_expr(L, mk_let(to_name_ext(L, 1), expr(), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); + return push_expr(L, mk_let(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); else - return push_expr(L, mk_let(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3), to_nonnull_expr(L, 4))); + return push_expr(L, mk_let(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3), to_expr(L, 4))); } static expr get_expr_from_table(lua_State * L, int t, int i) { lua_pushvalue(L, t); // push table to the top lua_pushinteger(L, i); lua_gettable(L, -2); - expr r = to_nonnull_expr(L, -1); + expr r = to_expr(L, -1); lua_pop(L, 2); // remove table and value return r; } @@ -385,7 +381,7 @@ int expr_abst(lua_State * L) { int len = objlen(L, 1); if (len == 0) throw exception("function expects arg #1 to be a non-empty table"); - expr r = to_nonnull_expr(L, 2); + expr r = to_expr(L, 2); for (int i = len; i >= 1; i--) { auto p = get_expr_pair_from_table(L, 1, i); r = F1(p.first, p.second, r); @@ -394,12 +390,12 @@ int expr_abst(lua_State * L) { } else { if (nargs % 2 == 0) throw exception("function must have an odd number of arguments"); - expr r = to_nonnull_expr(L, nargs); + expr r = to_expr(L, nargs); for (int i = nargs - 1; i >= 1; i-=2) { if (is_expr(L, i - 1)) - r = F1(to_nonnull_expr(L, i - 1), to_nonnull_expr(L, i), r); + r = F1(to_expr(L, i - 1), to_expr(L, i), r); else - r = F2(to_name_ext(L, i - 1), to_nonnull_expr(L, i), r); + r = F2(to_name_ext(L, i - 1), to_expr(L, i), r); } return push_expr(L, r); } @@ -425,19 +421,14 @@ static int expr_mk_metavar(lua_State * L) { return push_expr(L, mk_metavar(to_name_ext(L, 1), to_local_context(L, 2))); } -static int expr_is_null(lua_State * L) { - lua_pushboolean(L, !to_expr(L, 1)); - return 1; -} - static int expr_get_kind(lua_State * L) { - lua_pushinteger(L, static_cast(to_nonnull_expr(L, 1).kind())); + lua_pushinteger(L, static_cast(to_expr(L, 1).kind())); return 1; } #define EXPR_PRED(P) \ static int expr_ ## P(lua_State * L) { \ - lua_pushboolean(L, P(to_nonnull_expr(L, 1))); \ + lua_pushboolean(L, P(to_expr(L, 1))); \ return 1; \ } @@ -502,7 +493,7 @@ static int expr_arg(lua_State * L) { } static int expr_fields(lua_State * L) { - expr & e = to_nonnull_expr(L, 1); + expr & e = to_expr(L, 1); switch (e.kind()) { case expr_kind::Var: lua_pushinteger(L, var_idx(e)); return 1; case expr_kind::Constant: return push_name(L, const_name(e)); @@ -511,8 +502,10 @@ static int expr_fields(lua_State * L) { case expr_kind::App: lua_pushinteger(L, num_args(e)); expr_args(L); return 2; case expr_kind::Eq: push_expr(L, eq_lhs(e)); push_expr(L, eq_rhs(e)); return 2; case expr_kind::Lambda: - case expr_kind::Pi: push_name(L, abst_name(e)); push_expr(L, abst_domain(e)); push_expr(L, abst_body(e)); return 3; - case expr_kind::Let: push_name(L, let_name(e)); push_expr(L, let_type(e)); push_expr(L, let_value(e)); push_expr(L, let_body(e)); return 4; + case expr_kind::Pi: + push_name(L, abst_name(e)); push_expr(L, abst_domain(e)); push_expr(L, abst_body(e)); return 3; + case expr_kind::Let: + push_name(L, let_name(e)); push_optional_expr(L, let_type(e)); push_expr(L, let_value(e)); push_expr(L, let_body(e)); return 4; case expr_kind::MetaVar: push_name(L, metavar_name(e)); push_local_context(L, metavar_lctx(e)); return 2; } lean_unreachable(); // LCOV_EXCL_LINE @@ -520,7 +513,7 @@ static int expr_fields(lua_State * L) { } static int expr_for_each(lua_State * L) { - expr & e = to_nonnull_expr(L, 1); // expr + expr & e = to_expr(L, 1); // expr luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun auto f = [&](expr const & a, unsigned offset) { lua_pushvalue(L, 2); // push user-fun @@ -619,7 +612,6 @@ static const struct luaL_Reg expr_m[] = { {"__lt", safe_function}, {"__call", safe_function}, {"kind", safe_function}, - {"is_null", safe_function}, {"is_var", safe_function}, {"is_constant", safe_function}, {"is_app", safe_function}, @@ -709,14 +701,14 @@ DECL_UDATA(context_entry) static int mk_context_entry(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 2) - return push_context_entry(L, context_entry(to_name_ext(L, 1), to_nonnull_expr(L, 2))); + return push_context_entry(L, context_entry(to_name_ext(L, 1), to_expr(L, 2))); else - return push_context_entry(L, context_entry(to_name_ext(L, 1), to_nonnull_expr(L, 2), to_nonnull_expr(L, 3))); + return push_context_entry(L, context_entry(to_name_ext(L, 1), to_expr(L, 2), to_expr(L, 3))); } static int context_entry_get_name(lua_State * L) { return push_name(L, to_context_entry(L, 1).get_name()); } -static int context_entry_get_domain(lua_State * L) { return push_expr(L, to_context_entry(L, 1).get_domain()); } -static int context_entry_get_body(lua_State * L) { return push_expr(L, to_context_entry(L, 1).get_body()); } +static int context_entry_get_domain(lua_State * L) { return push_optional_expr(L, to_context_entry(L, 1).get_domain()); } +static int context_entry_get_body(lua_State * L) { return push_optional_expr(L, to_context_entry(L, 1).get_body()); } static const struct luaL_Reg context_entry_m[] = { {"__gc", context_entry_gc}, // never throws @@ -743,11 +735,14 @@ static int mk_context(lua_State * L) { return push_context(L, context()); } else if (nargs == 2) { context_entry & e = to_context_entry(L, 2); - return push_context(L, context(to_context(L, 1), e.get_name(), e.get_domain(), e.get_body())); + return push_context(L, context(to_context(L, 1), e)); } else if (nargs == 3) { - return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_nonnull_expr(L, 3))); + return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_expr(L, 3))); } else { - return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_expr(L, 3), to_nonnull_expr(L, 4))); + if (lua_isnil(L, 3)) + return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), optional(), to_expr(L, 4))); + else + return push_context(L, context(to_context(L, 1), to_name_ext(L, 2), to_expr(L, 3), to_expr(L, 4))); } } @@ -979,33 +974,33 @@ static int environment_add_definition(lua_State * L) { rw_environment env(L, 1); int nargs = lua_gettop(L); if (nargs == 3) { - env->add_definition(to_name_ext(L, 2), to_nonnull_expr(L, 3)); + env->add_definition(to_name_ext(L, 2), to_expr(L, 3)); } else if (nargs == 4) { if (is_expr(L, 4)) - env->add_definition(to_name_ext(L, 2), to_nonnull_expr(L, 3), to_nonnull_expr(L, 4)); + env->add_definition(to_name_ext(L, 2), to_expr(L, 3), to_expr(L, 4)); else - env->add_definition(to_name_ext(L, 2), to_nonnull_expr(L, 3), lua_toboolean(L, 4)); + env->add_definition(to_name_ext(L, 2), to_expr(L, 3), lua_toboolean(L, 4)); } else { - env->add_definition(to_name_ext(L, 2), to_nonnull_expr(L, 3), to_nonnull_expr(L, 4), lua_toboolean(L, 5)); + env->add_definition(to_name_ext(L, 2), to_expr(L, 3), to_expr(L, 4), lua_toboolean(L, 5)); } return 0; } static int environment_add_theorem(lua_State * L) { rw_environment env(L, 1); - env->add_theorem(to_name_ext(L, 2), to_nonnull_expr(L, 3), to_nonnull_expr(L, 4)); + env->add_theorem(to_name_ext(L, 2), to_expr(L, 3), to_expr(L, 4)); return 0; } static int environment_add_var(lua_State * L) { rw_environment env(L, 1); - env->add_var(to_name_ext(L, 2), to_nonnull_expr(L, 3)); + env->add_var(to_name_ext(L, 2), to_expr(L, 3)); return 0; } static int environment_add_axiom(lua_State * L) { rw_environment env(L, 1); - env->add_axiom(to_name_ext(L, 2), to_nonnull_expr(L, 3)); + env->add_axiom(to_name_ext(L, 2), to_expr(L, 3)); return 0; } @@ -1024,18 +1019,18 @@ static int environment_check_type(lua_State * L) { ro_environment env(L, 1); int nargs = lua_gettop(L); if (nargs == 2) - return push_expr(L, env->infer_type(to_nonnull_expr(L, 2))); + return push_expr(L, env->infer_type(to_expr(L, 2))); else - return push_expr(L, env->infer_type(to_nonnull_expr(L, 2), to_context(L, 3))); + return push_expr(L, env->infer_type(to_expr(L, 2), to_context(L, 3))); } static int environment_normalize(lua_State * L) { ro_environment env(L, 1); int nargs = lua_gettop(L); if (nargs == 2) - return push_expr(L, env->normalize(to_nonnull_expr(L, 2))); + return push_expr(L, env->normalize(to_expr(L, 2))); else - return push_expr(L, env->normalize(to_nonnull_expr(L, 2), to_context(L, 3))); + return push_expr(L, env->normalize(to_expr(L, 2), to_context(L, 3))); } /** @@ -1081,9 +1076,9 @@ static int environment_infer_type(lua_State * L) { int nargs = lua_gettop(L); type_inferer inferer(to_environment(L, 1)); if (nargs == 2) - return push_expr(L, inferer(to_nonnull_expr(L, 2))); + return push_expr(L, inferer(to_expr(L, 2))); else - return push_expr(L, inferer(to_nonnull_expr(L, 2), to_context(L, 3))); + return push_expr(L, inferer(to_expr(L, 2), to_context(L, 3))); } static int environment_tostring(lua_State * L) { @@ -1303,6 +1298,14 @@ static void open_object(lua_State * L) { DECL_UDATA(justification) +int push_optional_justification(lua_State * L, optional const & j) { + if (j) + push_justification(L, *j); + else + lua_pushnil(L); + return 1; +} + static int justification_tostring(lua_State * L) { std::ostringstream out; justification & jst = to_justification(L, 1); @@ -1359,7 +1362,12 @@ static int justification_children(lua_State * L) { } static int justification_get_main_expr(lua_State * L) { - return push_expr(L, to_justification(L, 1).get_main_expr()); + optional r = to_justification(L, 1).get_main_expr(); + if (r) + push_expr(L, *r); + else + lua_pushnil(L); + return 1; } static int justification_pp(lua_State * L) { @@ -1437,7 +1445,7 @@ static int menv_mk_metavar(lua_State * L) { } else if (nargs == 2) { return push_expr(L, to_metavar_env(L, 1).mk_metavar(to_context(L, 2))); } else { - return push_expr(L, to_metavar_env(L, 1).mk_metavar(to_context(L, 2), to_expr(L, 3))); + return push_expr(L, to_metavar_env(L, 1).mk_metavar(to_context(L, 2), optional(to_expr(L, 3)))); } } @@ -1500,29 +1508,32 @@ static int menv_assign(lua_State * L) { static int menv_get_subst(lua_State * L) { if (is_expr(L, 2)) - return push_expr(L, to_metavar_env(L, 1).get_subst(to_metavar(L, 2))); + return push_optional_expr(L, to_metavar_env(L, 1).get_subst(to_metavar(L, 2))); else - return push_expr(L, to_metavar_env(L, 1).get_subst(to_name_ext(L, 2))); + return push_optional_expr(L, to_metavar_env(L, 1).get_subst(to_name_ext(L, 2))); } static int menv_get_justification(lua_State * L) { if (is_expr(L, 2)) - return push_justification(L, to_metavar_env(L, 1).get_justification(to_metavar(L, 2))); + return push_optional_justification(L, to_metavar_env(L, 1).get_justification(to_metavar(L, 2))); else - return push_justification(L, to_metavar_env(L, 1).get_justification(to_name_ext(L, 2))); + return push_optional_justification(L, to_metavar_env(L, 1).get_justification(to_name_ext(L, 2))); } static int menv_get_subst_jst(lua_State * L) { + optional> r; if (is_expr(L, 2)) { - auto p = to_metavar_env(L, 1).get_subst_jst(to_metavar(L, 2)); - push_expr(L, p.first); - push_justification(L, p.second); + r = to_metavar_env(L, 1).get_subst_jst(to_metavar(L, 2)); } else { - auto p = to_metavar_env(L, 1).get_subst_jst(to_name_ext(L, 2)); - push_expr(L, p.first); - push_justification(L, p.second); + r = to_metavar_env(L, 1).get_subst_jst(to_name_ext(L, 2)); + } + if (r) { + push_expr(L, r->first); + push_justification(L, r->second); + return 2; + } else { + return 0; } - return 2; } static int menv_for_each_subst(lua_State * L) { diff --git a/src/library/kernel_bindings.h b/src/library/kernel_bindings.h index e3e018387..1b0f61159 100644 --- a/src/library/kernel_bindings.h +++ b/src/library/kernel_bindings.h @@ -21,7 +21,8 @@ UDATA_DEFS(object) UDATA_DEFS(environment) UDATA_DEFS(justification) UDATA_DEFS(metavar_env) -expr & to_nonnull_expr(lua_State * L, int idx); +int push_optional_expr(lua_State * L, optional const & e); +int push_optional_justification(lua_State * L, optional const & j); /** \brief Return the formatter object associated with the given Lua State. This procedure checks for options at: diff --git a/src/library/max_sharing.cpp b/src/library/max_sharing.cpp index e932ad692..50d0d3858 100644 --- a/src/library/max_sharing.cpp +++ b/src/library/max_sharing.cpp @@ -25,6 +25,13 @@ struct max_sharing_fn::imp { m_cache.insert(a); } + optional apply(optional const & a) { + if (a) + return some(apply(*a)); + else + return a; + } + expr apply(expr const & a) { auto r = m_cache.find(a); if (r != m_cache.end()) { @@ -36,25 +43,21 @@ struct max_sharing_fn::imp { return a; } switch (a.kind()) { - case expr_kind::Constant: - if (const_type(a)) { - expr r = update_const(a, apply(const_type(a))); - cache(r); - return r; - } else { - cache(a); - return a; - } + case expr_kind::Constant: { + expr r = update_const(a, apply(const_type(a))); + cache(r); + return r; + } case expr_kind::Var: case expr_kind::Type: case expr_kind::Value: cache(a); return a; case expr_kind::App: { - expr r = update_app(a, [=](expr const & c){ return apply(c); }); + expr r = update_app(a, [=](expr const & c) { return apply(c); }); cache(r); return r; } case expr_kind::Eq : { - expr r = update_eq(a, [=](expr const & l, expr const & r){ return std::make_pair(apply(l), apply(r)); }); + expr r = update_eq(a, [=](expr const & l, expr const & r) { return std::make_pair(apply(l), apply(r)); }); cache(r); return r; } @@ -65,9 +68,8 @@ struct max_sharing_fn::imp { return r; } case expr_kind::Let: { - expr r = update_let(a, [=](expr const & t, expr const & v, expr const & b) { - expr new_t = t ? apply(t) : expr(); - return std::make_tuple(new_t, apply(v), apply(b)); + expr r = update_let(a, [=](optional const & t, expr const & v, expr const & b) { + return std::make_tuple(apply(t), apply(v), apply(b)); }); cache(r); return r; diff --git a/src/library/placeholder.cpp b/src/library/placeholder.cpp index 54d2b0fe1..3cde485e7 100644 --- a/src/library/placeholder.cpp +++ b/src/library/placeholder.cpp @@ -8,11 +8,12 @@ Author: Leonardo de Moura #include "kernel/metavar.h" #include "kernel/expr_maps.h" #include "kernel/replace_visitor.h" +#include "library/expr_pair.h" #include "library/placeholder.h" namespace lean { static name g_placeholder_name("_"); -expr mk_placeholder(expr const & t) { +expr mk_placeholder(optional const & t) { return mk_constant(g_placeholder_name, t); } @@ -39,7 +40,7 @@ protected: expr visit(expr const & e, context const & c) { expr r = replace_visitor::visit(e, c); if (!is_eqp(r, e) && m_new2old) - (*m_new2old)[r] = e; + m_new2old->insert(expr_pair(r, e)); return r; } public: diff --git a/src/library/placeholder.h b/src/library/placeholder.h index 15254a42e..443c57d18 100644 --- a/src/library/placeholder.h +++ b/src/library/placeholder.h @@ -15,7 +15,7 @@ class metavar_env; type). To be able to track location, a new constant for each placeholder. */ -expr mk_placeholder(expr const & t = expr()); +expr mk_placeholder(optional const & t = optional()); /** \brief Return true iff the given expression is a placeholder. diff --git a/src/library/rewriter/fo_match.cpp b/src/library/rewriter/fo_match.cpp index 8c6d067af..6164b2537 100644 --- a/src/library/rewriter/fo_match.cpp +++ b/src/library/rewriter/fo_match.cpp @@ -153,6 +153,13 @@ bool fo_match::match_metavar(expr const & p, expr const & t, unsigned, subst_map return p == t; } +bool fo_match::match_main(optional const & p, optional const & t, unsigned o, subst_map & s) { + if (p && t) + return match_main(*p, *t, o, s); + else + return !p && !t; +} + bool fo_match::match_main(expr const & p, expr const & t, unsigned o, subst_map & s) { lean_trace("fo_match", tout << "match : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;); // LCOV_EXCL_LINE switch (p.kind()) { diff --git a/src/library/rewriter/fo_match.h b/src/library/rewriter/fo_match.h index 9ae86cc73..223d4d25d 100644 --- a/src/library/rewriter/fo_match.h +++ b/src/library/rewriter/fo_match.h @@ -26,6 +26,7 @@ private: bool match_let(expr const & p, expr const & t, unsigned o, subst_map & s); bool match_metavar(expr const & p, expr const & t, unsigned o, subst_map & s); bool match_main(expr const & p, expr const & t, unsigned o, subst_map & s); + bool match_main(optional const & p, optional const & t, unsigned o, subst_map & s); public: bool match(expr const & p, expr const & t, unsigned o, subst_map & s); diff --git a/src/library/rewriter/rewriter.cpp b/src/library/rewriter/rewriter.cpp index a2529c5b2..16147129e 100644 --- a/src/library/rewriter/rewriter.cpp +++ b/src/library/rewriter/rewriter.cpp @@ -370,21 +370,25 @@ pair rewrite_eq(environment const & env, context & ctx, expr const & pair rewrite_let_type(environment const & env, context & ctx, expr const & v, pair const & result_ty) { lean_assert(is_let(v)); type_inferer ti(env); - name const & n = let_name(v); - expr const & ty = let_type(v); - expr const & val = let_value(v); - expr const & body = let_body(v); - expr const & new_ty = result_ty.first; - expr const & pf = result_ty.second; - expr const & new_v = mk_let(n, new_ty, val, body); - expr const & ty_ty = ti(ty, ctx); - expr const & ty_v = ti(v, ctx); - expr const & proof = Subst(ty_ty, ty, new_ty, - Fun({Const("x"), ty_ty}, - mk_eq(v, mk_let(n, Const("x"), val, body))), - Refl(ty_v, v), - pf); - return make_pair(new_v, proof); + if (!let_type(v)) { + name const & n = let_name(v); + expr const & ty = *let_type(v); + expr const & val = let_value(v); + expr const & body = let_body(v); + expr const & new_ty = result_ty.first; + expr const & pf = result_ty.second; + expr const & new_v = mk_let(n, new_ty, val, body); + expr const & ty_ty = ti(ty, ctx); + expr const & ty_v = ti(v, ctx); + expr const & proof = Subst(ty_ty, ty, new_ty, + Fun({Const("x"), ty_ty}, + mk_eq(v, mk_let(n, Const("x"), val, body))), + Refl(ty_v, v), + pf); + return make_pair(new_v, proof); + } else { + throw rewriter_exception(); + } } /** @@ -403,7 +407,7 @@ pair rewrite_let_value(environment const & env, context & ctx, expr lean_assert(is_let(v)); type_inferer ti(env); name const & n = let_name(v); - expr const & ty = let_type(v); + optional const & ty = let_type(v); expr const & val = let_value(v); expr const & body = let_body(v); expr const & new_val = result_value.first; @@ -436,7 +440,7 @@ pair rewrite_let_body(environment const & env, context & ctx, expr c lean_assert(is_let(v)); type_inferer ti(env); name const & n = let_name(v); - expr const & ty = let_type(v); + optional const & ty = let_type(v); expr const & val = let_value(v); expr const & body = let_body(v); expr const & new_body = result_body.first; @@ -828,10 +832,11 @@ let_type_rewriter_cell::let_type_rewriter_cell(rewriter const & rw) :rewriter_cell(rewriter_kind::LetType), m_rw(rw) { } let_type_rewriter_cell::~let_type_rewriter_cell() { } pair let_type_rewriter_cell::operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception) { - if (!is_let(v)) + if (!is_let(v) || !let_type(v)) throw rewriter_exception(); - expr const & ty = let_type(v); + expr const & ty = *let_type(v); + pair result_ty = m_rw(env, ctx, ty); if (ty != result_ty.first) { // ty changed @@ -878,9 +883,9 @@ pair let_body_rewriter_cell::operator()(environment const & env, con throw rewriter_exception(); name const & n = let_name(v); - expr const & ty = let_type(v); + optional const & ty = let_type(v); expr const & body = let_body(v); - context new_ctx = extend(ctx, n, ty); + context new_ctx = extend(ctx, n, ty, let_value(v)); pair result_body = m_rw(env, new_ctx, body); if (body != result_body.first) { return rewrite_let_body(env, ctx, v, result_body); diff --git a/src/library/rewriter/rewriter.h b/src/library/rewriter/rewriter.h index 478200335..543f786e5 100644 --- a/src/library/rewriter/rewriter.h +++ b/src/library/rewriter/rewriter.h @@ -483,7 +483,7 @@ class apply_rewriter_fn { break; case expr_kind::Let: { name const & n = let_name(v); - expr const & ty = let_type(v); + optional const & ty = let_type(v); expr const & val = let_value(v); expr const & body = let_body(v); @@ -492,13 +492,15 @@ class apply_rewriter_fn { expr pf = Refl(ty_v, v); bool changed = false; - std::pair result_ty = apply(env, ctx, ty); - if (ty != result_ty.first) { - // ty changed - result = rewrite_let_type(env, ctx, new_v, result_ty); - new_v = result.first; - pf = result.second; - changed = true; + if (ty) { + std::pair result_ty = apply(env, ctx, *ty); + if (*ty != result_ty.first) { + // ty changed + result = rewrite_let_type(env, ctx, new_v, result_ty); + new_v = result.first; + pf = result.second; + changed = true; + } } std::pair result_val = apply(env, ctx, val); @@ -513,7 +515,7 @@ class apply_rewriter_fn { changed = true; } - context new_ctx = extend(ctx, n, ty); + context new_ctx = extend(ctx, n, ty, val); std::pair result_body = apply(env, new_ctx, body); if (body != result_body.first) { result = rewrite_let_body(env, ctx, new_v, result_body); diff --git a/src/library/substitution.cpp b/src/library/substitution.cpp index e913fb073..21402682d 100644 --- a/src/library/substitution.cpp +++ b/src/library/substitution.cpp @@ -59,7 +59,7 @@ static int substitution_find(lua_State * L) { substitution & s = to_substitution(L, 1); expr * it; if (is_expr(L, 2)) { - expr const & e = to_nonnull_expr(L, 2); + expr const & e = to_expr(L, 2); if (is_metavar(e)) it = s.splay_find(metavar_name(e)); else @@ -75,7 +75,7 @@ static int substitution_find(lua_State * L) { } static int substitution_apply(lua_State * L) { - return push_expr(L, apply(to_substitution(L, 1), to_nonnull_expr(L, 2))); + return push_expr(L, apply(to_substitution(L, 1), to_expr(L, 2))); } static const struct luaL_Reg substitution_m[] = { diff --git a/src/library/tactic/apply_tactic.cpp b/src/library/tactic/apply_tactic.cpp index 23355c129..c04f78b95 100644 --- a/src/library/tactic/apply_tactic.cpp +++ b/src/library/tactic/apply_tactic.cpp @@ -44,7 +44,7 @@ static optional apply_tactic(environment const & env, proof_state c // 1) regular arguments computed using unification. // 2) propostions that generate new subgoals. // We use a pair to simulate this "union" type. - typedef list> arg_list; + typedef list, name>> arg_list; // We may solve more than one goal. // We store the solved goals using a list of pairs // name, args. Where the 'name' is the name of the solved goal. @@ -66,21 +66,21 @@ static optional apply_tactic(environment const & env, proof_state c for (auto const & mvar : mvars) { expr mvar_sol = apply(*subst, mvar); if (mvar_sol != mvar) { - l.emplace_front(mvar_sol, name()); + l = cons(mk_pair(optional(mvar_sol), name()), l); th_type_c = instantiate(abst_body(th_type_c), mvar_sol); } else { if (inferer.is_proposition(abst_domain(th_type_c), context(), &new_menv)) { name new_gname(gname, new_goal_idx); new_goal_idx++; - l.emplace_front(expr(), new_gname); + l = cons(mk_pair(optional(), new_gname), l); new_goals_buf.emplace_back(new_gname, update(g, abst_domain(th_type_c))); th_type_c = instantiate(abst_body(th_type_c), mk_constant(new_gname, abst_domain(th_type_c))); } else { // we have to create a new metavar in menv // since we do not have a substitution for mvar, and // it is not a proposition - expr new_m = new_menv.mk_metavar(context(), abst_domain(th_type_c)); - l.emplace_front(new_m, name()); + expr new_m = new_menv.mk_metavar(context(), optional(abst_domain(th_type_c))); + l = cons(mk_pair(optional(new_m), name()), l); // we use instantiate_with_closed_relaxed because we do not want // to introduce a lift operator in the new_m th_type_c = instantiate_with_closed_relaxed(abst_body(th_type_c), 1, &new_m); @@ -105,10 +105,10 @@ static optional apply_tactic(environment const & env, proof_state c buffer args; args.push_back(th); for (auto const & p2 : l) { - expr const & arg = p2.first; + optional const & arg = p2.first; if (arg) { // TODO(Leo): decide if we instantiate the metavars in the end or not. - args.push_back(arg); + args.push_back(*arg); } else { name const & subgoal_name = p2.second; args.push_back(find(m, subgoal_name)); diff --git a/src/library/tactic/assignment.h b/src/library/tactic/assignment.h index 6600b89bd..55f475690 100644 --- a/src/library/tactic/assignment.h +++ b/src/library/tactic/assignment.h @@ -15,6 +15,6 @@ class assignment { metavar_env m_menv; public: assignment(metavar_env const & menv):m_menv(menv) {} - expr operator()(name const & mvar) const { return m_menv.get_subst(mvar); } + optional operator()(name const & mvar) const { return m_menv.get_subst(mvar); } }; } diff --git a/src/library/tactic/boolean_tactics.cpp b/src/library/tactic/boolean_tactics.cpp index d3571b9d8..b8b98a6a7 100644 --- a/src/library/tactic/boolean_tactics.cpp +++ b/src/library/tactic/boolean_tactics.cpp @@ -62,16 +62,16 @@ tactic imp_tactic(name const & H_name, bool all) { expr impfn = mk_implies_fn(); bool found = false; list> proof_info; - goals new_goals = map_goals(s, [&](name const & g_name, goal const & g) -> goal { + goals new_goals = map_goals(s, [&](name const & g_name, goal const & g) -> optional { expr const & c = g.get_conclusion(); expr new_h, new_c; if ((all || !found) && is_implies(c, new_h, new_c)) { found = true; name new_h_name = g.mk_unique_hypothesis_name(H_name); proof_info.emplace_front(g_name, new_h_name, c); - return goal(add_hypothesis(new_h_name, new_h, g.get_hypotheses()), new_c); + return optional(goal(add_hypothesis(new_h_name, new_h, g.get_hypotheses()), new_c)); } else { - return g; + return optional(g); } }); if (found) { @@ -100,7 +100,7 @@ tactic conj_hyp_tactic(bool all) { return mk_tactic01([=](environment const &, io_state const &, proof_state const & s) -> optional { bool found = false; list> proof_info; // goal name -> expanded hypotheses - goals new_goals = map_goals(s, [&](name const & ng, goal const & g) -> goal { + goals new_goals = map_goals(s, [&](name const & ng, goal const & g) -> optional { if (all || !found) { buffer new_hyp_buf; hypotheses proof_info_data; @@ -119,12 +119,12 @@ tactic conj_hyp_tactic(bool all) { } if (proof_info_data) { proof_info.emplace_front(ng, proof_info_data); - return update(g, new_hyp_buf); + return some(update(g, new_hyp_buf)); } else { - return g; + return some(g); } } else { - return g; + return some(g); } }); if (found) { @@ -158,7 +158,7 @@ tactic conj_hyp_tactic(bool all) { optional disj_hyp_tactic_core(name const & goal_name, name const & hyp_name, proof_state const & s) { buffer> new_goals_buf; - expr H; + optional H; expr conclusion; for (auto const & p1 : s.get_goals()) { check_interrupted(); @@ -170,10 +170,10 @@ optional disj_hyp_tactic_core(name const & goal_name, name const & for (auto const & p2 : g.get_hypotheses()) { if (p2.first == hyp_name) { H = p2.second; - if (!is_or(H)) + if (!is_or(*H)) return none_proof_state(); // tactic failed - new_hyp_buf1.emplace_back(p2.first, arg(H, 1)); - new_hyp_buf2.emplace_back(p2.first, arg(H, 2)); + new_hyp_buf1.emplace_back(p2.first, arg(*H, 1)); + new_hyp_buf2.emplace_back(p2.first, arg(*H, 2)); } else { new_hyp_buf1.push_back(p2); new_hyp_buf2.push_back(p2); @@ -191,13 +191,14 @@ optional disj_hyp_tactic_core(name const & goal_name, name const & return none_proof_state(); // tactic failed goals new_gs = to_list(new_goals_buf.begin(), new_goals_buf.end()); proof_builder pb = s.get_proof_builder(); + expr Href = *H; proof_builder new_pb = mk_proof_builder([=](proof_map const & m, assignment const & a) -> expr { proof_map new_m(m); expr pr1 = find(m, name(goal_name, 1)); expr pr2 = find(m, name(goal_name, 2)); - pr1 = Fun(hyp_name, arg(H, 1), pr1); - pr2 = Fun(hyp_name, arg(H, 2), pr2); - new_m.insert(goal_name, DisjCases(arg(H, 1), arg(H, 2), conclusion, mk_constant(hyp_name), pr1, pr2)); + pr1 = Fun(hyp_name, arg(Href, 1), pr1); + pr2 = Fun(hyp_name, arg(Href, 2), pr2); + new_m.insert(goal_name, DisjCases(arg(Href, 1), arg(Href, 2), conclusion, mk_constant(hyp_name), pr1, pr2)); new_m.erase(name(goal_name, 1)); new_m.erase(name(goal_name, 2)); return pb(new_m, a); @@ -248,7 +249,7 @@ optional disj_tactic(proof_state const & s, name gname) { optional(); } buffer> new_goals_buf1, new_goals_buf2; - expr conclusion; + optional conclusion; for (auto const & p : s.get_goals()) { check_interrupted(); goal const & g = p.second; @@ -276,12 +277,12 @@ optional disj_tactic(proof_state const & s, name gname) { proof_builder pb = s.get_proof_builder(); proof_builder new_pb1 = mk_proof_builder([=](proof_map const & m, assignment const & a) -> expr { proof_map new_m(m); - new_m.insert(gname, Disj1(arg(conclusion, 1), arg(conclusion, 2), find(m, gname))); + new_m.insert(gname, Disj1(arg(*conclusion, 1), arg(*conclusion, 2), find(m, gname))); return pb(new_m, a); }); proof_builder new_pb2 = mk_proof_builder([=](proof_map const & m, assignment const & a) -> expr { proof_map new_m(m); - new_m.insert(gname, Disj2(arg(conclusion, 2), arg(conclusion, 1), find(m, gname))); + new_m.insert(gname, Disj2(arg(*conclusion, 2), arg(*conclusion, 1), find(m, gname))); return pb(new_m, a); }); proof_state s1(precision::Over, new_gs1, s.get_menv(), new_pb1, s.get_cex_builder()); @@ -323,7 +324,7 @@ tactic disj_tactic(unsigned i) { tactic absurd_tactic() { return mk_tactic01([](environment const &, io_state const &, proof_state const & s) -> optional { list> proofs; - goals new_gs = map_goals(s, [&](name const & gname, goal const & g) -> goal { + goals new_gs = map_goals(s, [&](name const & gname, goal const & g) -> optional { expr const & c = g.get_conclusion(); for (auto const & p1 : g.get_hypotheses()) { check_interrupted(); @@ -333,12 +334,12 @@ tactic absurd_tactic() { if (p2.second == a) { expr pr = AbsurdImpAny(a, c, mk_constant(p2.first), mk_constant(p1.first)); proofs.emplace_front(gname, pr); - return goal(); // remove goal + return optional(); // remove goal } } } } - return g; // keep goal + return some(g); // keep goal }); if (empty(proofs)) return none_proof_state(); // tactic failed diff --git a/src/library/tactic/goal.cpp b/src/library/tactic/goal.cpp index dafa5a665..fdf9c1ebf 100644 --- a/src/library/tactic/goal.cpp +++ b/src/library/tactic/goal.cpp @@ -76,10 +76,10 @@ static name_set collect_used_names(context const & ctx, expr const & t) { auto f = [&r](expr const & e, unsigned) { if (is_constant(e)) r.insert(const_name(e)); return true; }; for_each_fn visitor(f); for (auto const & e : ctx) { - if (expr const & d = e.get_domain()) - visitor(d); - if (expr const & b = e.get_body()) - visitor(b); + if (optional const & d = e.get_domain()) + visitor(*d); + if (optional const & b = e.get_body()) + visitor(*b); } visitor(t); return r; @@ -108,9 +108,9 @@ std::pair to_goal(environment const & env, context const & for (auto const & e : ctx) entries.push_back(e); std::reverse(entries.begin(), entries.end()); - buffer hypotheses; // normalized names and types of the entries processed so far - buffer bodies; // normalized bodies of the entries processed so far - std::vector consts; // cached consts[i] == mk_constant(names[i], hypotheses[i]) + buffer hypotheses; // normalized names and types of the entries processed so far + buffer> bodies; // normalized bodies of the entries processed so far + std::vector consts; // cached consts[i] == mk_constant(names[i], hypotheses[i]) auto replace_vars = [&](expr const & e, unsigned offset) -> expr { if (is_var(e)) { unsigned vidx = var_idx(e); @@ -121,7 +121,7 @@ std::pair to_goal(environment const & env, context const & throw exception("to_goal failed, unknown free variable"); unsigned lvl = nfv - nvidx - 1; if (bodies[lvl]) - return bodies[lvl]; + return *(bodies[lvl]); else return consts[lvl]; } @@ -134,20 +134,22 @@ std::pair to_goal(environment const & env, context const & for (; it != end; ++it) { auto const & e = *it; name n = mk_unique_name(used_names, e.get_name()); - expr d = replacer(e.get_domain()); - expr b = replacer(e.get_body()); + optional d = e.get_domain(); + optional b = e.get_body(); + if (d) d = some(replacer(*d)); + if (b) b = some(replacer(*b)); if (b && !d) { - d = inferer(b); + d = some(inferer(*b)); } replacer.clear(); - if (b && !inferer.is_proposition(d)) { + if (b && !inferer.is_proposition(*d)) { bodies.push_back(b); consts.push_back(expr()); } else { lean_assert(d); - hypotheses.emplace_back(n, d); - bodies.push_back(expr()); - consts.push_back(mk_constant(n, d)); + hypotheses.emplace_back(n, *d); + bodies.push_back(optional()); + consts.push_back(mk_constant(n, *d)); } } expr conclusion = replacer(t); @@ -162,9 +164,9 @@ static int mk_hypotheses(lua_State * L) { if (nargs == 0) { return push_hypotheses(L, hypotheses()); } else if (nargs == 2) { - return push_hypotheses(L, hypotheses(mk_pair(to_name_ext(L, 1), to_nonnull_expr(L, 2)), hypotheses())); + return push_hypotheses(L, hypotheses(mk_pair(to_name_ext(L, 1), to_expr(L, 2)), hypotheses())); } else if (nargs == 3) { - return push_hypotheses(L, hypotheses(mk_pair(to_name_ext(L, 1), to_nonnull_expr(L, 2)), to_hypotheses(L, 3))); + return push_hypotheses(L, hypotheses(mk_pair(to_name_ext(L, 1), to_expr(L, 2)), to_hypotheses(L, 3))); } else { throw exception("hypotheses functions expects 0 (empty list), 2 (name & expr for singleton hypotheses list), or 3 (name & expr & hypotheses list) arguments"); } @@ -233,12 +235,7 @@ static const struct luaL_Reg hypotheses_m[] = { DECL_UDATA(goal) static int mk_goal(lua_State * L) { - return push_goal(L, goal(to_hypotheses(L, 1), to_nonnull_expr(L, 2))); -} - -static int goal_is_null(lua_State * L) { - lua_pushboolean(L, !to_goal(L, 1)); - return 1; + return push_goal(L, goal(to_hypotheses(L, 1), to_expr(L, 2))); } static int goal_hypotheses(lua_State * L) { @@ -256,13 +253,9 @@ static int goal_unique_name(lua_State * L) { static int goal_tostring(lua_State * L) { std::ostringstream out; goal & g = to_goal(L, 1); - if (g) { - formatter fmt = get_global_formatter(L); - options opts = get_global_options(L); - out << mk_pair(g.pp(fmt, opts), opts); - } else { - out << ""; - } + formatter fmt = get_global_formatter(L); + options opts = get_global_options(L); + out << mk_pair(g.pp(fmt, opts), opts); lua_pushstring(L, out.str().c_str()); return 1; } @@ -270,9 +263,7 @@ static int goal_tostring(lua_State * L) { static int goal_pp(lua_State * L) { int nargs = lua_gettop(L); goal & g = to_goal(L, 1); - if (!g) { - return push_format(L, format()); - } else if (nargs == 1) { + if (nargs == 1) { return push_format(L, g.pp(get_global_formatter(L), get_global_options(L))); } else if (nargs == 2) { if (is_formatter(L, 2)) @@ -287,7 +278,6 @@ static int goal_pp(lua_State * L) { static const struct luaL_Reg goal_m[] = { {"__gc", goal_gc}, // never throws {"__tostring", safe_function}, - {"is_null", safe_function}, {"hypotheses", safe_function}, {"hyps", safe_function}, {"conclusion", safe_function}, diff --git a/src/library/tactic/goal.h b/src/library/tactic/goal.h index 48b47c3c7..c1bc0f2af 100644 --- a/src/library/tactic/goal.h +++ b/src/library/tactic/goal.h @@ -26,7 +26,6 @@ public: goal(hypotheses const & hs, expr const & c); hypotheses const & get_hypotheses() const { return m_hypotheses; } expr const & get_conclusion() const { return m_conclusion; } - explicit operator bool() const { return static_cast(m_conclusion); } format pp(formatter const & fmt, options const & opts) const; name mk_unique_hypothesis_name(name const & suggestion) const; }; diff --git a/src/library/tactic/proof_builder.cpp b/src/library/tactic/proof_builder.cpp index 1f7fbae3e..c42a62cb2 100644 --- a/src/library/tactic/proof_builder.cpp +++ b/src/library/tactic/proof_builder.cpp @@ -50,7 +50,7 @@ static int proof_map_find(lua_State * L) { } static int proof_map_insert(lua_State * L) { - to_proof_map(L, 1).insert(to_name_ext(L, 2), to_nonnull_expr(L, 3)); + to_proof_map(L, 1).insert(to_name_ext(L, 2), to_expr(L, 3)); return 0; } @@ -80,7 +80,7 @@ static int mk_assignment(lua_State * L) { } static int assignment_call(lua_State * L) { - return push_expr(L, to_assignment(L, 1)(to_name_ext(L, 2))); + return push_optional_expr(L, to_assignment(L, 1)(to_name_ext(L, 2))); } static const struct luaL_Reg assignment_m[] = { @@ -104,7 +104,7 @@ static int mk_proof_builder(lua_State * L) { push_proof_map(L, m); push_assignment(L, a); pcall(L, 2, 1, 0); - r = to_nonnull_expr(L, -1); + r = to_expr(L, -1); lua_pop(L, 1); }); return r; diff --git a/src/library/tactic/proof_state.cpp b/src/library/tactic/proof_state.cpp index e5458720c..3958c805b 100644 --- a/src/library/tactic/proof_state.cpp +++ b/src/library/tactic/proof_state.cpp @@ -223,7 +223,7 @@ static int mk_proof_state(lua_State * L) { } static int to_proof_state(lua_State * L) { - return push_proof_state(L, to_proof_state(to_environment(L, 1), to_context(L, 2), to_nonnull_expr(L, 3))); + return push_proof_state(L, to_proof_state(to_environment(L, 1), to_context(L, 2), to_expr(L, 3))); } static int proof_state_tostring(lua_State * L) { diff --git a/src/library/tactic/proof_state.h b/src/library/tactic/proof_state.h index b3e1ee28f..e67ae1d64 100644 --- a/src/library/tactic/proof_state.h +++ b/src/library/tactic/proof_state.h @@ -107,10 +107,10 @@ template goals map_goals(proof_state const & s, F && f) { return map_filter(s.get_goals(), [=](std::pair const & in, std::pair & out) -> bool { check_interrupted(); - goal new_goal = f(in.first, in.second); + optional new_goal = f(in.first, in.second); if (new_goal) { out.first = in.first; - out.second = new_goal; + out.second = *new_goal; return true; } else { return false; diff --git a/src/library/tactic/tactic.cpp b/src/library/tactic/tactic.cpp index 6b6698288..680135230 100644 --- a/src/library/tactic/tactic.cpp +++ b/src/library/tactic/tactic.cpp @@ -184,9 +184,9 @@ tactic suppress_trace(tactic const & t) { tactic assumption_tactic() { return mk_tactic01([](environment const &, io_state const &, proof_state const & s) -> optional { list> proofs; - goals new_gs = map_goals(s, [&](name const & gname, goal const & g) -> goal { + goals new_gs = map_goals(s, [&](name const & gname, goal const & g) -> optional { expr const & c = g.get_conclusion(); - expr pr; + optional pr; for (auto const & p : g.get_hypotheses()) { check_interrupted(); if (p.second == c) { @@ -195,10 +195,10 @@ tactic assumption_tactic() { } } if (pr) { - proofs.emplace_front(gname, pr); - return goal(); + proofs.emplace_front(gname, *pr); + return optional(); } else { - return g; + return some(g); } }); if (empty(proofs)) @@ -410,13 +410,13 @@ public: }; optional unfold_tactic_core(unfold_core_fn & fn, proof_state const & s) { - goals new_gs = map_goals(s, [&](name const &, goal const & g) -> goal { + goals new_gs = map_goals(s, [&](name const &, goal const & g) -> optional { hypotheses new_hs = map(g.get_hypotheses(), [&](hypothesis const & h) { return hypothesis(h.first, fn(h.second)); }); expr new_c = fn(g.get_conclusion()); - return goal(new_hs, new_c); + return some(goal(new_hs, new_c)); }); if (fn.unfolded()) { - return proof_state(s, new_gs); + return some(proof_state(s, new_gs)); } else { return none_proof_state(); } @@ -464,13 +464,13 @@ public: tactic beta_tactic() { return mk_tactic01([=](environment const &, io_state const &, proof_state const & s) -> optional { beta_fn fn; - goals new_gs = map_goals(s, [&](name const &, goal const & g) -> goal { + goals new_gs = map_goals(s, [&](name const &, goal const & g) -> optional { hypotheses new_hs = map(g.get_hypotheses(), [&](hypothesis const & h) { return hypothesis(h.first, fn(h.second)); }); expr new_c = fn(g.get_conclusion()); - return goal(new_hs, new_c); + return some(goal(new_hs, new_c)); }); if (fn.reduced()) { - return proof_state(s, new_gs); + return some(proof_state(s, new_gs)); } else { return none_proof_state(); } @@ -660,10 +660,10 @@ static int tactic_solve(lua_State * L) { } else { io_state * ios = get_io_state(L); check_ios(ios); - return tactic_solve_core(L, t, env, *ios, to_context(L, 3), to_nonnull_expr(L, 4)); + return tactic_solve_core(L, t, env, *ios, to_context(L, 3), to_expr(L, 4)); } } else { - return tactic_solve_core(L, t, env, to_io_state(L, 3), to_context(L, 4), to_nonnull_expr(L, 5)); + return tactic_solve_core(L, t, env, to_io_state(L, 3), to_context(L, 4), to_expr(L, 5)); } } diff --git a/src/library/type_inferer.cpp b/src/library/type_inferer.cpp index 64b21d178..b0677d139 100644 --- a/src/library/type_inferer.cpp +++ b/src/library/type_inferer.cpp @@ -104,7 +104,7 @@ class type_inferer::imp { case expr_kind::MetaVar: if (m_menv) { if (m_menv->is_assigned(e)) - return infer_type(m_menv->get_subst(e), ctx); + return infer_type(*(m_menv->get_subst(e)), ctx); else return m_menv->get_type(e); } else { @@ -112,7 +112,7 @@ class type_inferer::imp { } case expr_kind::Constant: { if (const_type(e)) { - return const_type(e); + return *const_type(e); } else { object const & obj = m_env.get_object(const_name(e)); if (obj.has_type()) @@ -127,7 +127,7 @@ class type_inferer::imp { context_entry const & ce = p.first; if (ce.get_domain()) { context const & ce_ctx = p.second; - return lift_free_vars(ce.get_domain(), ctx.size() - ce_ctx.size()); + return lift_free_vars(*(ce.get_domain()), ctx.size() - ce_ctx.size()); } // Remark: the case where ce.get_domain() is not // available is not considered cheap. @@ -164,7 +164,7 @@ class type_inferer::imp { context_entry const & ce = p.first; context const & ce_ctx = p.second; lean_assert(!ce.get_domain()); - r = lift_free_vars(infer_type(ce.get_body(), ce_ctx), ctx.size() - ce_ctx.size()); + r = lift_free_vars(infer_type(*(ce.get_body()), ce_ctx), ctx.size() - ce_ctx.size()); break; } case expr_kind::App: { @@ -269,9 +269,9 @@ static int type_inferer_call(lua_State * L) { int nargs = lua_gettop(L); type_inferer & inferer = to_type_inferer(L, 1); if (nargs == 2) - return push_expr(L, inferer(to_nonnull_expr(L, 2))); + return push_expr(L, inferer(to_expr(L, 2))); else - return push_expr(L, inferer(to_nonnull_expr(L, 2), to_context(L, 3))); + return push_expr(L, inferer(to_expr(L, 2), to_context(L, 3))); } static int type_inferer_clear(lua_State * L) { diff --git a/src/library/update_expr.cpp b/src/library/update_expr.cpp index 0284567e3..6d237bd29 100644 --- a/src/library/update_expr.cpp +++ b/src/library/update_expr.cpp @@ -60,7 +60,7 @@ expr update_abstraction(expr const & abst, expr const & d, expr const & b) { return update_pi(abst, d, b); } -expr update_let(expr const & let, expr const & t, expr const & v, expr const & b) { +expr update_let(expr const & let, optional const & t, expr const & v, expr const & b) { if (is_eqp(let_type(let), t) && is_eqp(let_value(let), v) && is_eqp(let_body(let), b)) return let; else diff --git a/src/library/update_expr.h b/src/library/update_expr.h index 92dfc531d..5836c6f08 100644 --- a/src/library/update_expr.h +++ b/src/library/update_expr.h @@ -37,7 +37,7 @@ expr update_abstraction(expr const & abst, expr const & d, expr const & b); \remark Return \c let if the given value and body are (pointer) equal to the ones in \c let. */ -expr update_let(expr const & let, expr const & t, expr const & v, expr const & b); +expr update_let(expr const & let, optional const & t, expr const & v, expr const & b); /** \brief Return a new equality with lhs \c l and rhs \c r. diff --git a/src/tests/kernel/expr.cpp b/src/tests/kernel/expr.cpp index e9a7815f0..f60223d86 100644 --- a/src/tests/kernel/expr.cpp +++ b/src/tests/kernel/expr.cpp @@ -278,7 +278,7 @@ static void tst13() { static void tst14() { expr t = Eq(Const("a"), Const("b")); std::cout << t << "\n"; - expr l = mk_let("a", expr(), Const("b"), Var(0)); + expr l = mk_let("a", optional(), Const("b"), Var(0)); std::cout << l << "\n"; lean_assert(closed(l)); } @@ -327,7 +327,7 @@ static void tst16() { check_copy(mk_metavar("M")); check_copy(mk_lambda("x", a, Var(0))); check_copy(mk_pi("x", a, Var(0))); - check_copy(mk_let("x", expr(), a, Var(0))); + check_copy(mk_let("x", optional(), a, Var(0))); } static void tst17() { diff --git a/src/tests/kernel/metavar.cpp b/src/tests/kernel/metavar.cpp index ec18b690a..11d458735 100644 --- a/src/tests/kernel/metavar.cpp +++ b/src/tests/kernel/metavar.cpp @@ -61,7 +61,7 @@ static void tst1() { menv.assign(m1, f(a)); lean_assert(menv.is_assigned(m1)); lean_assert(!menv.is_assigned(m2)); - lean_assert(menv.get_subst(m1) == f(a)); + lean_assert(*(menv.get_subst(m1)) == f(a)); } static void tst2() { diff --git a/src/tests/kernel/normalizer.cpp b/src/tests/kernel/normalizer.cpp index ce999ffe9..af7d07ecd 100644 --- a/src/tests/kernel/normalizer.cpp +++ b/src/tests/kernel/normalizer.cpp @@ -200,7 +200,7 @@ static void tst3() { static void tst4() { environment env; env.add_var("b", Type()); - expr t1 = mk_let("a", expr(), Const("b"), mk_lambda("c", Type(), Var(1)(Var(0)))); + expr t1 = mk_let("a", optional(), Const("b"), mk_lambda("c", Type(), Var(1)(Var(0)))); std::cout << t1 << " --> " << normalize(t1, env) << "\n"; lean_assert(normalize(t1, env) == mk_lambda("c", Type(), Const("b")(Var(0)))); } diff --git a/src/tests/kernel/type_checker.cpp b/src/tests/kernel/type_checker.cpp index e94f5d279..7b28acbe3 100644 --- a/src/tests/kernel/type_checker.cpp +++ b/src/tests/kernel/type_checker.cpp @@ -82,7 +82,7 @@ static void tst3() { expr f = Fun("a", Bool, Eq(Const("a"), True)); std::cout << infer_type(f, env) << "\n"; lean_assert(infer_type(f, env) == mk_arrow(Bool, Bool)); - expr t = mk_let("a", expr(), True, Var(0)); + expr t = mk_let("a", optional(), True, Var(0)); std::cout << infer_type(t, env) << "\n"; } @@ -205,7 +205,7 @@ static void tst11() { expr t3 = f(b, b); for (unsigned i = 0; i < n; i++) { t1 = f(t1, t1); - t2 = mk_let("x", expr(), t2, f(Var(0), Var(0))); + t2 = mk_let("x", optional(), t2, f(Var(0), Var(0))); t3 = f(t3, t3); } lean_assert(t1 != t2); diff --git a/src/tests/library/elaborator/elaborator.cpp b/src/tests/library/elaborator/elaborator.cpp index deb94dc2e..ba484a4cc 100644 --- a/src/tests/library/elaborator/elaborator.cpp +++ b/src/tests/library/elaborator/elaborator.cpp @@ -652,7 +652,7 @@ void tst20() { while (true) { try { auto sol = elb.next(); - std::cout << m1 << " -> " << sol.get_subst(m1) << "\n"; + std::cout << m1 << " -> " << *(sol.get_subst(m1)) << "\n"; std::cout << instantiate_metavars(l, sol) << "\n"; lean_assert(instantiate_metavars(l, sol) == r); std::cout << "--------------\n"; @@ -684,7 +684,7 @@ void tst21() { while (true) { try { auto sol = elb.next(); - std::cout << m1 << " -> " << sol.get_subst(m1) << "\n"; + std::cout << m1 << " -> " << *(sol.get_subst(m1)) << "\n"; std::cout << instantiate_metavars(l, sol) << "\n"; lean_assert(instantiate_metavars(l, sol) == r); std::cout << "--------------\n"; @@ -718,8 +718,8 @@ void tst22() { while (true) { try { auto sol = elb.next(); - std::cout << m3 << " -> " << sol.get_subst(m3) << "\n"; - lean_assert(sol.get_subst(m3) == iVal(1)); + std::cout << m3 << " -> " << *(sol.get_subst(m3)) << "\n"; + lean_assert(*(sol.get_subst(m3)) == iVal(1)); std::cout << instantiate_metavars(l, sol) << "\n"; std::cout << instantiate_metavars(r, sol) << "\n"; std::cout << "--------------\n"; @@ -748,7 +748,7 @@ void tst23() { while (true) { try { auto sol = elb.next(); - std::cout << m1 << " -> " << sol.get_subst(m1) << "\n"; + std::cout << m1 << " -> " << *(sol.get_subst(m1)) << "\n"; std::cout << instantiate_metavars(l, sol) << "\n"; lean_assert_eq(instantiate_metavars(l, sol), instantiate_metavars(r, sol)); @@ -797,7 +797,7 @@ void tst25() { while (true) { try { auto sol = elb.next(); - std::cout << m1 << " -> " << sol.get_subst(m1) << "\n"; + std::cout << m1 << " -> " << *(sol.get_subst(m1)) << "\n"; std::cout << instantiate_metavars(l, sol) << "\n"; lean_assert_eq(beta_reduce(instantiate_metavars(l, sol)), beta_reduce(instantiate_metavars(r, sol))); diff --git a/src/tests/library/expr_lt.cpp b/src/tests/library/expr_lt.cpp index 53aa54338..53a0ed92d 100644 --- a/src/tests/library/expr_lt.cpp +++ b/src/tests/library/expr_lt.cpp @@ -34,7 +34,6 @@ static void tst1() { lt(Const("a"), Const("b"), true); lt(Const("a"), Const("a"), false); lt(Var(1), Const("a"), true); - lt(expr(), Var(0), true); lt(Eq(Var(0), Var(1)), Eq(Var(1), Var(1)), true); lt(Eq(Var(1), Var(0)), Eq(Var(1), Var(1)), true); lt(Eq(Var(1), Var(1)), Eq(Var(1), Var(1)), false); diff --git a/src/tests/library/update_expr.cpp b/src/tests/library/update_expr.cpp index f47bff125..1b16f854d 100644 --- a/src/tests/library/update_expr.cpp +++ b/src/tests/library/update_expr.cpp @@ -40,9 +40,9 @@ void tst3() { expr b = Const("b"); expr f = Const("f"); expr t1 = Let(a, b, f(a)); - expr t2 = update_let(t1, expr(), b, let_body(t1)); + expr t2 = update_let(t1, optional(), b, let_body(t1)); lean_assert(is_eqp(t1, t2)); - t2 = update_let(t1, expr(), a, let_body(t1)); + t2 = update_let(t1, optional(), a, let_body(t1)); lean_assert(!is_eqp(t1, t2)); lean_assert(t2 == Let(a, a, f(a))); } diff --git a/src/util/optional.h b/src/util/optional.h index 8761fc51d..8e3513ab4 100644 --- a/src/util/optional.h +++ b/src/util/optional.h @@ -20,8 +20,9 @@ class optional { }; public: optional():m_some(false) {} - optional(T const & v):m_some(true) { - new (&m_value) T(v); + optional(optional & other):m_some(other.m_some) { + if (m_some) + new (&m_value) T(other.m_value); } optional(optional const & other):m_some(other.m_some) { if (m_some) @@ -31,9 +32,15 @@ public: if (m_some) new (&m_value) T(std::forward(other.m_value)); } + explicit optional(T const & v):m_some(true) { + new (&m_value) T(v); + } + explicit optional(T && v):m_some(true) { + new (&m_value) T(std::forward(v)); + } template - optional(Args&&... args):m_some(true) { - new (&m_value) T(args...); + explicit optional(Args&&... args):m_some(true) { + new (&m_value) T(std::forward(args)...); } ~optional() { if (m_some) @@ -95,6 +102,10 @@ public: else return !o1.m_some || o1.m_value == o2.m_value; } + + friend bool operator!=(optional const & o1, optional const & o2) { + return !operator==(o1, o2); + } }; template optional some(T const & t) { return optional(t); } diff --git a/src/util/scoped_map.h b/src/util/scoped_map.h index 618b81683..478856258 100644 --- a/src/util/scoped_map.h +++ b/src/util/scoped_map.h @@ -108,7 +108,7 @@ public: auto it = m_map.find(k); if (it == m_map.end()) { if (!at_base_lvl()) - m_actions.emplace_back(action_kind::Insert, value_type(k, T())); + m_actions.emplace_back(action_kind::Insert, value_type(k, v /* dummy */)); m_map.insert(value_type(k, v)); } else { if (!at_base_lvl()) diff --git a/tests/lua/context1.lua b/tests/lua/context1.lua index bbdcf9b9b..d05083d17 100644 --- a/tests/lua/context1.lua +++ b/tests/lua/context1.lua @@ -4,7 +4,7 @@ print(c) e = context_entry("x", Const("N")) assert(e:get_name() == name("x")) assert(e:get_domain() == Const("N")) -assert(e:get_body():is_null()) +assert(not e:get_body()) print(e:get_body()) c = context(c, e) print(c) diff --git a/tests/lua/env2.lua b/tests/lua/env2.lua index 53aef6e29..f2a6d9749 100644 --- a/tests/lua/env2.lua +++ b/tests/lua/env2.lua @@ -6,7 +6,7 @@ for v in env:objects() do print(v:get_name()) end end -assert(not env:find_object("N"):is_null()) +assert(env:find_object("N")) assert(env:find_object("Z"):is_null()) assert(env:find_object("N"):is_var_decl()) assert(env:find_object("N"):has_type()) diff --git a/tests/lua/env4.lua b/tests/lua/env4.lua index d17f89d32..1da59174e 100644 --- a/tests/lua/env4.lua +++ b/tests/lua/env4.lua @@ -11,7 +11,7 @@ for o in child:local_objects() do print(o) end local eenv = empty_environment() -assert(eenv:find_object("Int"):is_null()) +print(eenv:find_object("Int"):is_null()) assert(not pcall(function() env:parent() end)) local p = child:parent() assert(p:has_children()) diff --git a/tests/lua/expr6.lua b/tests/lua/expr6.lua index a6494e150..0179190c5 100644 --- a/tests/lua/expr6.lua +++ b/tests/lua/expr6.lua @@ -1,6 +1,5 @@ - function print_leaves(e, ctx) - if e:is_null() then + if (not e) then return end local k = e:kind() @@ -46,6 +45,3 @@ local F = fun(h, mk_arrow(N, N), Let(x, h(a), Eq(f(x), h(x)))) print(F) print_leaves(F, context()) - - - diff --git a/tests/lua/goal1.lua b/tests/lua/goal1.lua index 4160baff7..72dfd364e 100644 --- a/tests/lua/goal1.lua +++ b/tests/lua/goal1.lua @@ -21,7 +21,6 @@ assert(is_goal(g)) print(g) assert(#(g:hypotheses()) == 2) assert(g:conclusion() == Const("p1")) -assert(not g:is_null()) assert(g:unique_name("H") == name("H")) assert(g:unique_name("H1") == name("H1", 1)) print(g:pp()) diff --git a/tests/lua/jst1.lua b/tests/lua/jst1.lua index b28d5ec94..073a867e8 100644 --- a/tests/lua/jst1.lua +++ b/tests/lua/jst1.lua @@ -8,4 +8,4 @@ for c in j1:children() do end assert(not j2:has_children()) print(j1) -assert(j1:get_main_expr():is_null()) +assert(not j1:get_main_expr())