diff --git a/src/kernel/expr.h b/src/kernel/expr.h index f0cead70d..b518fee89 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -336,7 +336,7 @@ public: macro_definition_cell():m_rc(0) {} virtual ~macro_definition_cell() {} virtual name get_name() const = 0; - virtual expr get_type(expr const & m, expr const * arg_types, extension_context & ctx) const = 0; + virtual pair get_type(expr const & m, extension_context & ctx) const = 0; virtual optional expand(expr const & m, extension_context & ctx) const = 0; virtual optional expand1(expr const & m, extension_context & ctx) const { return expand(m, ctx); } virtual unsigned trust_level() const; @@ -363,8 +363,8 @@ public: macro_definition & operator=(macro_definition && s); name get_name() const { return m_ptr->get_name(); } - expr get_type(expr const & m, expr const * arg_types, extension_context & ctx) const { - return m_ptr->get_type(m, arg_types, ctx); + pair get_type(expr const & m, extension_context & ctx) const { + return m_ptr->get_type(m, ctx); } optional expand(expr const & m, extension_context & ctx) const { return m_ptr->expand(m, ctx); } optional expand1(expr const & m, extension_context & ctx) const { return m_ptr->expand1(m, ctx); } diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index dd0a161c0..882f21f7e 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -184,13 +184,10 @@ expr type_checker::infer_constant(expr const & e, bool infer_only) { } pair type_checker::infer_macro(expr const & e, bool infer_only) { - buffer arg_types; - constraint_seq cs; - for (unsigned i = 0; i < macro_num_args(e); i++) { - arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only, cs)); - } auto def = macro_def(e); - expr t = def.get_type(e, arg_types.data(), m_tc_ctx); + pair tcs = def.get_type(e, m_tc_ctx); + expr t = tcs.first; + constraint_seq cs = tcs.second; if (!infer_only && def.trust_level() >= m_env.trust_lvl()) { optional m = expand_macro(e); if (!m) diff --git a/src/library/annotation.cpp b/src/library/annotation.cpp index 06817403e..625a36e04 100644 --- a/src/library/annotation.cpp +++ b/src/library/annotation.cpp @@ -33,9 +33,9 @@ public: virtual name get_name() const { return get_annotation_name(); } virtual format pp(formatter const &) const { return format(m_name); } virtual void display(std::ostream & out) const { out << m_name; } - virtual expr get_type(expr const & m, expr const * arg_types, extension_context &) const { + virtual pair get_type(expr const & m, extension_context & ctx) const { check_macro(m); - return arg_types[0]; + return ctx.infer_type(macro_arg(m, 0)); } virtual optional expand(expr const & m, extension_context &) const { check_macro(m); diff --git a/src/library/choice.cpp b/src/library/choice.cpp index 564ed3a59..8711b69f0 100644 --- a/src/library/choice.cpp +++ b/src/library/choice.cpp @@ -25,7 +25,7 @@ public: virtual name get_name() const { return *g_choice_name; } // Choice expressions must be replaced with metavariables before invoking the type checker. // Choice expressions cannot be exported. They are transient/auxiliary objects. - virtual expr get_type(expr const &, expr const *, extension_context &) const { throw_ex(); } + virtual pair get_type(expr const &, extension_context &) const { throw_ex(); } virtual optional expand(expr const &, extension_context &) const { throw_ex(); } virtual void write(serializer & s) const { // we should be able to write choice expressions because of notation declarations diff --git a/src/library/let.cpp b/src/library/let.cpp index 59d784c4d..44ad2d5d9 100644 --- a/src/library/let.cpp +++ b/src/library/let.cpp @@ -32,9 +32,9 @@ class let_value_definition_cell : public macro_definition_cell { public: let_value_definition_cell():m_id(next_let_value_id()) {} virtual name get_name() const { return *g_let_value; } - virtual expr get_type(expr const & m, expr const * arg_types, extension_context &) const { + virtual pair get_type(expr const & m, extension_context & ctx) const { check_macro(m); - return arg_types[0]; + return ctx.infer_type(macro_arg(m, 0)); } virtual optional expand(expr const & m, extension_context &) const { check_macro(m); @@ -75,9 +75,9 @@ public: let_macro_definition_cell(name const & n):m_var_name(n) {} name const & get_var_name() const { return m_var_name; } virtual name get_name() const { return *g_let; } - virtual expr get_type(expr const & m, expr const * arg_types, extension_context &) const { + virtual pair get_type(expr const & m, extension_context & ctx) const { check_macro(m); - return arg_types[1]; + return ctx.infer_type(macro_arg(m, 1)); } virtual optional expand(expr const & m, extension_context &) const { check_macro(m); diff --git a/src/library/resolve_macro.cpp b/src/library/resolve_macro.cpp index 525dff82a..760f4f0e6 100644 --- a/src/library/resolve_macro.cpp +++ b/src/library/resolve_macro.cpp @@ -139,13 +139,8 @@ public: } } - virtual expr get_type(expr const & m, expr const * arg_types, extension_context & ctx) const { - environment const & env = ctx.env(); - check_num_args(env, m); - expr l = whnf(macro_arg(m, 0), ctx); - expr not_l = whnf(mk_app(*g_not, l), ctx); - expr C1 = arg_types[1]; - expr C2 = arg_types[2]; + expr mk_resolvent(environment const & env, extension_context & ctx, expr const & m, + expr const & l, expr const & not_l, expr const C1, expr const & C2) const { buffer R; // resolvent if (!collect(C1, l, R, ctx)) throw_kernel_exception(env, "invalid resolve macro, positive literal was not found", m); @@ -154,6 +149,16 @@ public: return mk_bin_rop(*g_or, *g_false, R.size(), R.data()); } + virtual pair get_type(expr const & m, extension_context & ctx) const { + environment const & env = ctx.env(); + check_num_args(env, m); + expr l = whnf(macro_arg(m, 0), ctx); + expr not_l = whnf(mk_app(*g_not, l), ctx); + expr C1 = infer_type(macro_arg(m, 1), ctx); + expr C2 = infer_type(macro_arg(m, 2), ctx); + return mk_pair(mk_resolvent(env, ctx, m, l, not_l, C1, C2), constraint_seq()); + } + // End of resolve_macro get_type implementation // ---------------------------------------------- @@ -171,8 +176,7 @@ public: expr H2 = macro_arg(m, 2); expr C1 = infer_type(H1, ctx); expr C2 = infer_type(H2, ctx); - expr arg_types[3] = { expr() /* get_type() does not use first argument */, C1, C2 }; - expr R = get_type(m, arg_types, ctx); + expr R = mk_resolvent(env, ctx, m, l, not_l, C1, C2); return some_expr(mk_or_elim_tree1(l, not_l, C1, H1, C2, H2, R, ctx)); } diff --git a/src/library/string.cpp b/src/library/string.cpp index f8cea0018..45ede9e9c 100644 --- a/src/library/string.cpp +++ b/src/library/string.cpp @@ -33,8 +33,8 @@ public: return m_value < static_cast(d).m_value; } virtual name get_name() const { return *g_string_macro; } - virtual expr get_type(expr const &, expr const *, extension_context &) const { - return *g_string; + virtual pair get_type(expr const &, extension_context &) const { + return mk_pair(*g_string, constraint_seq()); } virtual optional expand(expr const &, extension_context &) const { return some_expr(from_string_core(0, m_value)); diff --git a/src/library/typed_expr.cpp b/src/library/typed_expr.cpp index 82d65b0e5..084fc9768 100644 --- a/src/library/typed_expr.cpp +++ b/src/library/typed_expr.cpp @@ -34,9 +34,9 @@ class typed_expr_macro_definition_cell : public macro_definition_cell { } public: virtual name get_name() const { return get_typed_expr_name(); } - virtual expr get_type(expr const & m, expr const * arg_types, extension_context &) const { + virtual pair get_type(expr const & m, extension_context & ctx) const { check_macro(m); - return arg_types[0]; + return ctx.infer_type(macro_arg(m, 0)); } virtual optional expand(expr const & m, extension_context &) const { check_macro(m);