diff --git a/src/frontends/lean/lean_elaborator.cpp b/src/frontends/lean/lean_elaborator.cpp index cba746a30..fda09f3c3 100644 --- a/src/frontends/lean/lean_elaborator.cpp +++ b/src/frontends/lean/lean_elaborator.cpp @@ -199,52 +199,54 @@ class elaborator::imp { context const & ctx, expr const & src) { lean_assert(f_choices.size() == f_choice_types.size()); buffer good_choices; + unsigned best_num_coercions = std::numeric_limits::max(); unsigned num_choices = f_choices.size(); unsigned num_args = args.size(); - for (unsigned round = 0; round < 2; round++) { - // In the first round we only select perfect matches without considering - // overloads. This is the same approach used in C++. - // If a perfect match does not exist, then we try again using coercions. - for (unsigned j = 0; j < num_choices; j++) { - expr f_t = f_choice_types[j]; - try { - unsigned i = 1; - for (; i < num_args; i++) { - f_t = check_pi(f_t, ctx, src, ctx); - expr expected = abst_domain(f_t); - expr given = types[i]; - if (!has_metavar(expected) && !has_metavar(given)) { - if (!is_convertible(expected, given, ctx) && - // remark, we only consider coercions in the second round - (round == 0 || !m_frontend.get_coercion(given, expected))) - break; // failed to use this overload + // We consider two overloads ambiguous if they need the same number of coercions. + for (unsigned j = 0; j < num_choices; j++) { + expr f_t = f_choice_types[j]; + unsigned num_coercions = 0; // number of coercions needed by current choice + try { + unsigned i = 1; + for (; i < num_args; i++) { + f_t = check_pi(f_t, ctx, src, ctx); + expr expected = abst_domain(f_t); + expr given = types[i]; + if (!has_metavar(expected) && !has_metavar(given)) { + if (is_convertible(expected, given, ctx)) { + // compatible + } else if (m_frontend.get_coercion(given, expected)) { + // compatible if using coercion + num_coercions++; + } else { + break; // failed } - f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); } - if (i == num_args) { - if (good_choices.empty()) { - // first good choice - args[0] = f_choices[j]; - types[0] = f_choice_types[j]; - } - good_choices.push_back(j); - } - } catch (exception & ex) { - // candidate failed - // do nothing + f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); } + if (i == num_args) { + if (num_coercions < best_num_coercions) { + // found best choice + args[0] = f_choices[j]; + types[0] = f_choice_types[j]; + good_choices.clear(); + } + good_choices.push_back(j); + } + } catch (exception & ex) { + // candidate failed + // do nothing } - if (good_choices.size() == 0) { - // TODO add information to the exception - if (round == 1) - throw exception("none of the overloads are good"); - } else if (good_choices.size() == 1) { - // found overload - return; - } else { - // TODO add information to the exception - throw exception("ambiguous overload"); - } + } + if (good_choices.size() == 0) { + // TODO add information to the exception + throw exception("none of the overloads are good"); + } else if (good_choices.size() == 1) { + // found overload + return; + } else { + // TODO add information to the exception + throw exception("ambiguous overload"); } } diff --git a/src/frontends/lean/lean_notation.cpp b/src/frontends/lean/lean_notation.cpp index 98921fd2d..4e11f9ea5 100644 --- a/src/frontends/lean/lean_notation.cpp +++ b/src/frontends/lean/lean_notation.cpp @@ -42,7 +42,7 @@ void init_builtin_notation(frontend & f) { f.add_infixl("+", 65, mk_int_add_fn()); f.add_infixl("-", 65, mk_int_sub_fn()); f.add_infixl("*", 70, mk_int_mul_fn()); - f.add_infixl("/", 70, mk_int_div_fn()); + f.add_infixl("div", 70, mk_int_div_fn()); f.add_infix("<=", 50, mk_int_le_fn()); f.add_infix("\u2264", 50, mk_int_le_fn()); // ≤ f.add_infix(">=", 50, mk_int_ge_fn()); @@ -50,7 +50,20 @@ void init_builtin_notation(frontend & f) { f.add_infix("<", 50, mk_int_lt_fn()); f.add_infix(">", 50, mk_int_gt_fn()); + f.add_infixl("+", 65, mk_real_add_fn()); + f.add_infixl("-", 65, mk_real_sub_fn()); + f.add_infixl("*", 70, mk_real_mul_fn()); + f.add_infixl("/", 70, mk_real_div_fn()); + f.add_infix("<=", 50, mk_real_le_fn()); + f.add_infix("\u2264", 50, mk_real_le_fn()); // ≤ + f.add_infix(">=", 50, mk_real_ge_fn()); + f.add_infix("\u2265", 50, mk_real_ge_fn()); // ≥ + f.add_infix("<", 50, mk_real_lt_fn()); + f.add_infix(">", 50, mk_real_gt_fn()); + f.add_coercion(mk_nat_to_int_fn()); + f.add_coercion(mk_int_to_real_fn()); + f.add_coercion(mk_nat_to_real_fn()); // implicit arguments for builtin axioms f.mark_implicit_arguments(mk_mp_fn(), {true, true, false, false}); diff --git a/src/frontends/lean/lean_parser.cpp b/src/frontends/lean/lean_parser.cpp index d13f4b904..34a2408b2 100644 --- a/src/frontends/lean/lean_parser.cpp +++ b/src/frontends/lean/lean_parser.cpp @@ -253,8 +253,9 @@ class parser::imp { m_builtins["false"] = False; m_builtins["\u22A4"] = True; m_builtins["\u22A5"] = False; - m_builtins["Int"] = Int; m_builtins["Nat"] = Nat; + m_builtins["Int"] = Int; + m_builtins["Real"] = Real; } unsigned parse_unsigned(char const * msg) { diff --git a/src/frontends/lean/lean_pp.cpp b/src/frontends/lean/lean_pp.cpp index d9077ed71..dfc8c8cc3 100644 --- a/src/frontends/lean/lean_pp.cpp +++ b/src/frontends/lean/lean_pp.cpp @@ -169,6 +169,10 @@ class pp_fn { typedef std::pair result; + bool is_coercion(expr const & e) { + return is_app(e) && num_args(e) == 2 && m_frontend.is_coercion(arg(e,0)); + } + /** \brief Return true iff \c e is an atomic operation. */ @@ -176,7 +180,12 @@ class pp_fn { switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Value: case expr_kind::Type: return true; - case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Eq: case expr_kind::Let: + case expr_kind::App: + if (!m_coercion && is_coercion(e)) + return is_atomic(arg(e,1)); + else + return false; + case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Eq: case expr_kind::Let: return false; } return false; @@ -399,6 +408,22 @@ class pp_fn { } } + /** + \brief Return true iff the given expression has the given fixity. + */ + bool has_fixity(expr const & e, fixity fx) { + operator_info op = get_operator(e); + if (op) { + return op.get_fixity() == fx; + } else if (is_eq(e)) { + return fixity::Infix == fx; + } else if (is_arrow(e)) { + return fixity::Infixr == fx; + } else { + return false; + } + } + /** \brief Pretty print the child of an infix, prefix, postfix or mixfix operator. It will add parethesis when needed. @@ -418,11 +443,14 @@ class pp_fn { \brief Pretty print the child of an associative infix operator. It will add parethesis when needed. */ - result pp_infix_child(operator_info const & op, expr const & e, unsigned depth) { + result pp_infix_child(operator_info const & op, expr const & e, unsigned depth, fixity fx) { if (is_atomic(e)) { return pp(e, depth + 1); } else { - if (op.get_precedence() < get_operator_precedence(e) || op == get_operator(e)) + unsigned e_prec = get_operator_precedence(e); + if (op.get_precedence() < e_prec) + return pp(e, depth + 1); + else if (op.get_precedence() == e_prec && has_fixity(e, fx)) return pp(e, depth + 1); else return pp_child_with_paren(e, depth); @@ -543,7 +571,7 @@ class pp_fn { \brief Pretty print an application. */ result pp_app(expr const & e, unsigned depth) { - if (!m_coercion && num_args(e) == 2 && m_frontend.is_coercion(arg(e,0))) + if (!m_coercion && is_coercion(e)) return pp(arg(e,1), depth); application app(e, *this, m_implict); operator_info op; @@ -556,11 +584,11 @@ class pp_fn { case fixity::Infix: return mk_infix(op, pp_mixfix_child(op, app.get_arg(0), depth), pp_mixfix_child(op, app.get_arg(1), depth)); case fixity::Infixr: - return mk_infix(op, pp_mixfix_child(op, app.get_arg(0), depth), pp_infix_child(op, app.get_arg(1), depth)); + return mk_infix(op, pp_mixfix_child(op, app.get_arg(0), depth), pp_infix_child(op, app.get_arg(1), depth, fixity::Infixr)); case fixity::Infixl: - return mk_infix(op, pp_infix_child(op, app.get_arg(0), depth), pp_mixfix_child(op, app.get_arg(1), depth)); + return mk_infix(op, pp_infix_child(op, app.get_arg(0), depth, fixity::Infixl), pp_mixfix_child(op, app.get_arg(1), depth)); case fixity::Prefix: - p_arg = pp_infix_child(op, app.get_arg(0), depth); + p_arg = pp_infix_child(op, app.get_arg(0), depth, fixity::Prefix); sz = op.get_op_name().size(); return mk_result(group(format{format(op.get_op_name()), nest(sz+1, format{line(), p_arg.first})}), p_arg.second + 1); diff --git a/src/kernel/arith/arith.cpp b/src/kernel/arith/arith.cpp index 35a7f096a..7d38c619a 100644 --- a/src/kernel/arith/arith.cpp +++ b/src/kernel/arith/arith.cpp @@ -247,7 +247,131 @@ MK_BUILTIN(int_le_fn, int_le_value); MK_CONSTANT(int_ge_fn, name(name("Int"), "ge")); MK_CONSTANT(int_lt_fn, name(name("Int"), "lt")); MK_CONSTANT(int_gt_fn, name(name("Int"), "gt")); +// ======================================= +// ======================================= +// Reals +class real_type_value : public num_type_value { +public: + real_type_value():num_type_value("Real") {} + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } +}; +expr const Real = mk_value(*(new real_type_value())); +expr mk_real_type() { return Real; } + +class real_value_value : public value { + mpq m_val; +public: + real_value_value(mpq const & v):m_val(v) {} + virtual ~real_value_value() {} + virtual expr get_type() const { return Real; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } + virtual bool operator==(value const & other) const { + real_value_value const * _other = dynamic_cast(&other); + return _other && _other->m_val == m_val; + } + virtual void display(std::ostream & out) const { out << m_val; } + virtual format pp() const { return format(m_val); } + virtual unsigned hash() const { return m_val.hash(); } + mpq const & get_num() const { return m_val; } +}; + +expr mk_real_value(mpq const & v) { + return mk_value(*(new real_value_value(v))); +} + +bool is_real_value(expr const & e) { + return is_value(e) && dynamic_cast(&to_value(e)) != nullptr; +} + +mpq const & real_value_numeral(expr const & e) { + lean_assert(is_real_value(e)); + return static_cast(to_value(e)).get_num(); +} + +template +class real_bin_op : public value { + expr m_type; + name m_name; +public: + real_bin_op() { + m_type = Real >> (Real >> Real); + m_name = name("Real", Name); + } + virtual ~real_bin_op() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 3 && is_real_value(args[1]) && is_real_value(args[2])) { + r = mk_real_value(F()(real_value_numeral(args[1]), real_value_numeral(args[2]))); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; + +constexpr char real_add_name[] = "add"; +struct real_add_eval { mpq operator()(mpq const & v1, mpq const & v2) { return v1 + v2; }; }; +typedef real_bin_op real_add_value; +MK_BUILTIN(real_add_fn, real_add_value); + +constexpr char real_sub_name[] = "sub"; +struct real_sub_eval { mpq operator()(mpq const & v1, mpq const & v2) { return v1 - v2; }; }; +typedef real_bin_op real_sub_value; +MK_BUILTIN(real_sub_fn, real_sub_value); + +constexpr char real_mul_name[] = "mul"; +struct real_mul_eval { mpq operator()(mpq const & v1, mpq const & v2) { return v1 * v2; }; }; +typedef real_bin_op real_mul_value; +MK_BUILTIN(real_mul_fn, real_mul_value); + +constexpr char real_div_name[] = "div"; +struct real_div_eval { + mpq operator()(mpq const & v1, mpq const & v2) { + if (v2.is_zero()) + return v2; + else + return v1 / v2; + }; +}; +typedef real_bin_op real_div_value; +MK_BUILTIN(real_div_fn, real_div_value); + +class real_le_value : public value { + expr m_type; + name m_name; +public: + real_le_value() { + m_type = Real >> (Real >> Bool); + m_name = name{"Real", "le"}; + } + virtual ~real_le_value() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 3 && is_real_value(args[1]) && is_real_value(args[2])) { + r = mk_bool_value(real_value_numeral(args[1]) <= real_value_numeral(args[2])); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; +MK_BUILTIN(real_le_fn, real_le_value); +MK_CONSTANT(real_ge_fn, name(name("Real"), "ge")); +MK_CONSTANT(real_lt_fn, name(name("Real"), "lt")); +MK_CONSTANT(real_gt_fn, name(name("Real"), "gt")); +// ======================================= + +// ======================================= +// Coercions class nat_to_int_value : public value { expr m_type; name m_name; @@ -273,11 +397,37 @@ public: }; MK_BUILTIN(nat_to_int_fn, nat_to_int_value); +class int_to_real_value : public value { + expr m_type; + name m_name; +public: + int_to_real_value() { + m_type = Int >> Real; + m_name = "int_to_real"; + } + virtual ~int_to_real_value() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 2 && is_int_value(args[1])) { + r = mk_real_value(mpq(int_value_numeral(args[1]))); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; +MK_BUILTIN(int_to_real_fn, int_to_real_value); +MK_CONSTANT(nat_to_real_fn, name("nat_to_real")); // ======================================= -void add_int_theory(environment & env) { - expr p_ii = Int >> (Int >> Bool); +void add_arith_theory(environment & env) { expr p_nn = Nat >> (Nat >> Bool); + expr p_ii = Int >> (Int >> Bool); + expr p_rr = Real >> (Real >> Bool); expr x = Const("x"); expr y = Const("y"); @@ -288,5 +438,11 @@ void add_int_theory(environment & env) { env.add_definition(int_ge_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, iLe(y, x))); env.add_definition(int_lt_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, Not(iLe(y, x)))); env.add_definition(int_gt_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, Not(iLe(x, y)))); + + env.add_definition(real_ge_fn_name, p_rr, Fun({{x, Real}, {y, Real}}, rLe(y, x))); + env.add_definition(real_lt_fn_name, p_rr, Fun({{x, Real}, {y, Real}}, Not(rLe(y, x)))); + env.add_definition(real_gt_fn_name, p_rr, Fun({{x, Real}, {y, Real}}, Not(rLe(x, y)))); + + env.add_definition(nat_to_real_fn_name, Nat >> Real, Fun({x, Nat}, i2r(n2i(x)))); } } diff --git a/src/kernel/arith/arith.h b/src/kernel/arith/arith.h index ec32ff0cf..43dc03044 100644 --- a/src/kernel/arith/arith.h +++ b/src/kernel/arith/arith.h @@ -11,6 +11,8 @@ Author: Leonardo de Moura #include "mpq.h" namespace lean { +// ======================================= +// Natural numbers expr mk_nat_type(); extern expr const Nat; @@ -39,7 +41,10 @@ expr mk_nat_gt_fn(); inline expr nGt(expr const & e1, expr const & e2) { return mk_app(mk_nat_gt_fn(), e1, e2); } inline expr nIf(expr const & c, expr const & t, expr const & e) { return mk_if(Nat, c, t, e); } +// ======================================= +// ======================================= +// Integers expr mk_int_type(); extern expr const Int; @@ -74,10 +79,56 @@ expr mk_int_gt_fn(); inline expr iGt(expr const & e1, expr const & e2) { return mk_app(mk_int_gt_fn(), e1, e2); } inline expr iIf(expr const & c, expr const & t, expr const & e) { return mk_if(Int, c, t, e); } +// ======================================= +// ======================================= +// Reals +expr mk_real_type(); +extern expr const Real; + +expr mk_real_value(mpq const & v); +inline expr mk_real_value(int v) { return mk_real_value(mpq(v)); } +inline expr rVal(int v) { return mk_real_value(v); } +bool is_real_value(expr const & e); +mpq const & real_value_numeral(expr const & e); + +expr mk_real_add_fn(); +inline expr rAdd(expr const & e1, expr const & e2) { return mk_app(mk_real_add_fn(), e1, e2); } + +expr mk_real_sub_fn(); +inline expr rSub(expr const & e1, expr const & e2) { return mk_app(mk_real_sub_fn(), e1, e2); } + +expr mk_real_mul_fn(); +inline expr rMul(expr const & e1, expr const & e2) { return mk_app(mk_real_mul_fn(), e1, e2); } + +expr mk_real_div_fn(); +inline expr rDiv(expr const & e1, expr const & e2) { return mk_app(mk_real_div_fn(), e1, e2); } + +expr mk_real_le_fn(); +inline expr rLe(expr const & e1, expr const & e2) { return mk_app(mk_real_le_fn(), e1, e2); } + +expr mk_real_ge_fn(); +inline expr rGe(expr const & e1, expr const & e2) { return mk_app(mk_real_ge_fn(), e1, e2); } + +expr mk_real_lt_fn(); +inline expr rLt(expr const & e1, expr const & e2) { return mk_app(mk_real_lt_fn(), e1, e2); } + +expr mk_real_gt_fn(); +inline expr rGt(expr const & e1, expr const & e2) { return mk_app(mk_real_gt_fn(), e1, e2); } + +inline expr rIf(expr const & c, expr const & t, expr const & e) { return mk_if(Real, c, t, e); } +// ======================================= + +// ======================================= +// Coercions expr mk_nat_to_int_fn(); inline expr n2i(expr const & e) { return mk_app(mk_nat_to_int_fn(), e); } +expr mk_int_to_real_fn(); +inline expr i2r(expr const & e) { return mk_app(mk_int_to_real_fn(), e); } +expr mk_nat_to_real_fn(); +inline expr n2r(expr const & e) { return mk_app(mk_nat_to_real_fn(), e); } +// ======================================= class environment; -void add_int_theory(environment & env); +void add_arith_theory(environment & env); } diff --git a/src/library/toplevel.cpp b/src/library/toplevel.cpp index c922a648a..068ba79f2 100644 --- a/src/library/toplevel.cpp +++ b/src/library/toplevel.cpp @@ -13,7 +13,7 @@ namespace lean { void init_toplevel(environment & env) { add_basic_theory(env); add_basic_thms(env); - add_int_theory(env); + add_arith_theory(env); } environment mk_toplevel() { environment r; diff --git a/tests/lean/arith2.lean b/tests/lean/arith2.lean new file mode 100644 index 000000000..dc7bc11b8 --- /dev/null +++ b/tests/lean/arith2.lean @@ -0,0 +1,14 @@ +Show 1/2 +Eval 4/6 +Show 3 div 2 +Variable x : Real +Variable i : Int +Variable n : Nat +Show x + i + 1 + n +Set lean::pp::coercion true +Show x + i + 1 + n +Show x * i + x +Show x - i + x - x >= 0 +Show x < x +Show x <= x +Show x > x \ No newline at end of file diff --git a/tests/lean/arith2.lean.expected.out b/tests/lean/arith2.lean.expected.out new file mode 100644 index 000000000..9b53ea364 --- /dev/null +++ b/tests/lean/arith2.lean.expected.out @@ -0,0 +1,14 @@ +1 / 2 +2/3 +3 div 2 + Assumed: x + Assumed: i + Assumed: n +x + i + 1 + n + Set: lean::pp::coercion +x + (int_to_real i) + (nat_to_real 1) + (nat_to_real n) +x * (int_to_real i) + x +x - (int_to_real i) + x - x ≥ (nat_to_real 0) +x < x +x ≤ x +x > x