From abc939382b07469c964af2fca2308196d143d1ad Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 2 Sep 2013 13:20:00 -0700 Subject: [PATCH] Add Real arithmetic. Fix elaborator for coercions. Now, two overloads are considered ambiguous if they need the same number of coercions. Improve pretty printer for nest infix operators with same precedence and associativity. Signed-off-by: Leonardo de Moura --- src/frontends/lean/lean_elaborator.cpp | 82 ++++++------- src/frontends/lean/lean_notation.cpp | 15 ++- src/frontends/lean/lean_parser.cpp | 3 +- src/frontends/lean/lean_pp.cpp | 42 +++++-- src/kernel/arith/arith.cpp | 160 ++++++++++++++++++++++++- src/kernel/arith/arith.h | 53 +++++++- src/library/toplevel.cpp | 2 +- tests/lean/arith2.lean | 14 +++ tests/lean/arith2.lean.expected.out | 14 +++ 9 files changed, 332 insertions(+), 53 deletions(-) create mode 100644 tests/lean/arith2.lean create mode 100644 tests/lean/arith2.lean.expected.out diff --git a/src/frontends/lean/lean_elaborator.cpp b/src/frontends/lean/lean_elaborator.cpp index cba746a30..fda09f3c3 100644 --- a/src/frontends/lean/lean_elaborator.cpp +++ b/src/frontends/lean/lean_elaborator.cpp @@ -199,52 +199,54 @@ class elaborator::imp { context const & ctx, expr const & src) { lean_assert(f_choices.size() == f_choice_types.size()); buffer good_choices; + unsigned best_num_coercions = std::numeric_limits::max(); unsigned num_choices = f_choices.size(); unsigned num_args = args.size(); - for (unsigned round = 0; round < 2; round++) { - // In the first round we only select perfect matches without considering - // overloads. This is the same approach used in C++. - // If a perfect match does not exist, then we try again using coercions. - for (unsigned j = 0; j < num_choices; j++) { - expr f_t = f_choice_types[j]; - try { - unsigned i = 1; - for (; i < num_args; i++) { - f_t = check_pi(f_t, ctx, src, ctx); - expr expected = abst_domain(f_t); - expr given = types[i]; - if (!has_metavar(expected) && !has_metavar(given)) { - if (!is_convertible(expected, given, ctx) && - // remark, we only consider coercions in the second round - (round == 0 || !m_frontend.get_coercion(given, expected))) - break; // failed to use this overload + // We consider two overloads ambiguous if they need the same number of coercions. + for (unsigned j = 0; j < num_choices; j++) { + expr f_t = f_choice_types[j]; + unsigned num_coercions = 0; // number of coercions needed by current choice + try { + unsigned i = 1; + for (; i < num_args; i++) { + f_t = check_pi(f_t, ctx, src, ctx); + expr expected = abst_domain(f_t); + expr given = types[i]; + if (!has_metavar(expected) && !has_metavar(given)) { + if (is_convertible(expected, given, ctx)) { + // compatible + } else if (m_frontend.get_coercion(given, expected)) { + // compatible if using coercion + num_coercions++; + } else { + break; // failed } - f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); } - if (i == num_args) { - if (good_choices.empty()) { - // first good choice - args[0] = f_choices[j]; - types[0] = f_choice_types[j]; - } - good_choices.push_back(j); - } - } catch (exception & ex) { - // candidate failed - // do nothing + f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); } + if (i == num_args) { + if (num_coercions < best_num_coercions) { + // found best choice + args[0] = f_choices[j]; + types[0] = f_choice_types[j]; + good_choices.clear(); + } + good_choices.push_back(j); + } + } catch (exception & ex) { + // candidate failed + // do nothing } - if (good_choices.size() == 0) { - // TODO add information to the exception - if (round == 1) - throw exception("none of the overloads are good"); - } else if (good_choices.size() == 1) { - // found overload - return; - } else { - // TODO add information to the exception - throw exception("ambiguous overload"); - } + } + if (good_choices.size() == 0) { + // TODO add information to the exception + throw exception("none of the overloads are good"); + } else if (good_choices.size() == 1) { + // found overload + return; + } else { + // TODO add information to the exception + throw exception("ambiguous overload"); } } diff --git a/src/frontends/lean/lean_notation.cpp b/src/frontends/lean/lean_notation.cpp index 98921fd2d..4e11f9ea5 100644 --- a/src/frontends/lean/lean_notation.cpp +++ b/src/frontends/lean/lean_notation.cpp @@ -42,7 +42,7 @@ void init_builtin_notation(frontend & f) { f.add_infixl("+", 65, mk_int_add_fn()); f.add_infixl("-", 65, mk_int_sub_fn()); f.add_infixl("*", 70, mk_int_mul_fn()); - f.add_infixl("/", 70, mk_int_div_fn()); + f.add_infixl("div", 70, mk_int_div_fn()); f.add_infix("<=", 50, mk_int_le_fn()); f.add_infix("\u2264", 50, mk_int_le_fn()); // ≤ f.add_infix(">=", 50, mk_int_ge_fn()); @@ -50,7 +50,20 @@ void init_builtin_notation(frontend & f) { f.add_infix("<", 50, mk_int_lt_fn()); f.add_infix(">", 50, mk_int_gt_fn()); + f.add_infixl("+", 65, mk_real_add_fn()); + f.add_infixl("-", 65, mk_real_sub_fn()); + f.add_infixl("*", 70, mk_real_mul_fn()); + f.add_infixl("/", 70, mk_real_div_fn()); + f.add_infix("<=", 50, mk_real_le_fn()); + f.add_infix("\u2264", 50, mk_real_le_fn()); // ≤ + f.add_infix(">=", 50, mk_real_ge_fn()); + f.add_infix("\u2265", 50, mk_real_ge_fn()); // ≥ + f.add_infix("<", 50, mk_real_lt_fn()); + f.add_infix(">", 50, mk_real_gt_fn()); + f.add_coercion(mk_nat_to_int_fn()); + f.add_coercion(mk_int_to_real_fn()); + f.add_coercion(mk_nat_to_real_fn()); // implicit arguments for builtin axioms f.mark_implicit_arguments(mk_mp_fn(), {true, true, false, false}); diff --git a/src/frontends/lean/lean_parser.cpp b/src/frontends/lean/lean_parser.cpp index d13f4b904..34a2408b2 100644 --- a/src/frontends/lean/lean_parser.cpp +++ b/src/frontends/lean/lean_parser.cpp @@ -253,8 +253,9 @@ class parser::imp { m_builtins["false"] = False; m_builtins["\u22A4"] = True; m_builtins["\u22A5"] = False; - m_builtins["Int"] = Int; m_builtins["Nat"] = Nat; + m_builtins["Int"] = Int; + m_builtins["Real"] = Real; } unsigned parse_unsigned(char const * msg) { diff --git a/src/frontends/lean/lean_pp.cpp b/src/frontends/lean/lean_pp.cpp index d9077ed71..dfc8c8cc3 100644 --- a/src/frontends/lean/lean_pp.cpp +++ b/src/frontends/lean/lean_pp.cpp @@ -169,6 +169,10 @@ class pp_fn { typedef std::pair result; + bool is_coercion(expr const & e) { + return is_app(e) && num_args(e) == 2 && m_frontend.is_coercion(arg(e,0)); + } + /** \brief Return true iff \c e is an atomic operation. */ @@ -176,7 +180,12 @@ class pp_fn { switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Value: case expr_kind::Type: return true; - case expr_kind::App: case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Eq: case expr_kind::Let: + case expr_kind::App: + if (!m_coercion && is_coercion(e)) + return is_atomic(arg(e,1)); + else + return false; + case expr_kind::Lambda: case expr_kind::Pi: case expr_kind::Eq: case expr_kind::Let: return false; } return false; @@ -399,6 +408,22 @@ class pp_fn { } } + /** + \brief Return true iff the given expression has the given fixity. + */ + bool has_fixity(expr const & e, fixity fx) { + operator_info op = get_operator(e); + if (op) { + return op.get_fixity() == fx; + } else if (is_eq(e)) { + return fixity::Infix == fx; + } else if (is_arrow(e)) { + return fixity::Infixr == fx; + } else { + return false; + } + } + /** \brief Pretty print the child of an infix, prefix, postfix or mixfix operator. It will add parethesis when needed. @@ -418,11 +443,14 @@ class pp_fn { \brief Pretty print the child of an associative infix operator. It will add parethesis when needed. */ - result pp_infix_child(operator_info const & op, expr const & e, unsigned depth) { + result pp_infix_child(operator_info const & op, expr const & e, unsigned depth, fixity fx) { if (is_atomic(e)) { return pp(e, depth + 1); } else { - if (op.get_precedence() < get_operator_precedence(e) || op == get_operator(e)) + unsigned e_prec = get_operator_precedence(e); + if (op.get_precedence() < e_prec) + return pp(e, depth + 1); + else if (op.get_precedence() == e_prec && has_fixity(e, fx)) return pp(e, depth + 1); else return pp_child_with_paren(e, depth); @@ -543,7 +571,7 @@ class pp_fn { \brief Pretty print an application. */ result pp_app(expr const & e, unsigned depth) { - if (!m_coercion && num_args(e) == 2 && m_frontend.is_coercion(arg(e,0))) + if (!m_coercion && is_coercion(e)) return pp(arg(e,1), depth); application app(e, *this, m_implict); operator_info op; @@ -556,11 +584,11 @@ class pp_fn { case fixity::Infix: return mk_infix(op, pp_mixfix_child(op, app.get_arg(0), depth), pp_mixfix_child(op, app.get_arg(1), depth)); case fixity::Infixr: - return mk_infix(op, pp_mixfix_child(op, app.get_arg(0), depth), pp_infix_child(op, app.get_arg(1), depth)); + return mk_infix(op, pp_mixfix_child(op, app.get_arg(0), depth), pp_infix_child(op, app.get_arg(1), depth, fixity::Infixr)); case fixity::Infixl: - return mk_infix(op, pp_infix_child(op, app.get_arg(0), depth), pp_mixfix_child(op, app.get_arg(1), depth)); + return mk_infix(op, pp_infix_child(op, app.get_arg(0), depth, fixity::Infixl), pp_mixfix_child(op, app.get_arg(1), depth)); case fixity::Prefix: - p_arg = pp_infix_child(op, app.get_arg(0), depth); + p_arg = pp_infix_child(op, app.get_arg(0), depth, fixity::Prefix); sz = op.get_op_name().size(); return mk_result(group(format{format(op.get_op_name()), nest(sz+1, format{line(), p_arg.first})}), p_arg.second + 1); diff --git a/src/kernel/arith/arith.cpp b/src/kernel/arith/arith.cpp index 35a7f096a..7d38c619a 100644 --- a/src/kernel/arith/arith.cpp +++ b/src/kernel/arith/arith.cpp @@ -247,7 +247,131 @@ MK_BUILTIN(int_le_fn, int_le_value); MK_CONSTANT(int_ge_fn, name(name("Int"), "ge")); MK_CONSTANT(int_lt_fn, name(name("Int"), "lt")); MK_CONSTANT(int_gt_fn, name(name("Int"), "gt")); +// ======================================= +// ======================================= +// Reals +class real_type_value : public num_type_value { +public: + real_type_value():num_type_value("Real") {} + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } +}; +expr const Real = mk_value(*(new real_type_value())); +expr mk_real_type() { return Real; } + +class real_value_value : public value { + mpq m_val; +public: + real_value_value(mpq const & v):m_val(v) {} + virtual ~real_value_value() {} + virtual expr get_type() const { return Real; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } + virtual bool operator==(value const & other) const { + real_value_value const * _other = dynamic_cast(&other); + return _other && _other->m_val == m_val; + } + virtual void display(std::ostream & out) const { out << m_val; } + virtual format pp() const { return format(m_val); } + virtual unsigned hash() const { return m_val.hash(); } + mpq const & get_num() const { return m_val; } +}; + +expr mk_real_value(mpq const & v) { + return mk_value(*(new real_value_value(v))); +} + +bool is_real_value(expr const & e) { + return is_value(e) && dynamic_cast(&to_value(e)) != nullptr; +} + +mpq const & real_value_numeral(expr const & e) { + lean_assert(is_real_value(e)); + return static_cast(to_value(e)).get_num(); +} + +template +class real_bin_op : public value { + expr m_type; + name m_name; +public: + real_bin_op() { + m_type = Real >> (Real >> Real); + m_name = name("Real", Name); + } + virtual ~real_bin_op() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 3 && is_real_value(args[1]) && is_real_value(args[2])) { + r = mk_real_value(F()(real_value_numeral(args[1]), real_value_numeral(args[2]))); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; + +constexpr char real_add_name[] = "add"; +struct real_add_eval { mpq operator()(mpq const & v1, mpq const & v2) { return v1 + v2; }; }; +typedef real_bin_op real_add_value; +MK_BUILTIN(real_add_fn, real_add_value); + +constexpr char real_sub_name[] = "sub"; +struct real_sub_eval { mpq operator()(mpq const & v1, mpq const & v2) { return v1 - v2; }; }; +typedef real_bin_op real_sub_value; +MK_BUILTIN(real_sub_fn, real_sub_value); + +constexpr char real_mul_name[] = "mul"; +struct real_mul_eval { mpq operator()(mpq const & v1, mpq const & v2) { return v1 * v2; }; }; +typedef real_bin_op real_mul_value; +MK_BUILTIN(real_mul_fn, real_mul_value); + +constexpr char real_div_name[] = "div"; +struct real_div_eval { + mpq operator()(mpq const & v1, mpq const & v2) { + if (v2.is_zero()) + return v2; + else + return v1 / v2; + }; +}; +typedef real_bin_op real_div_value; +MK_BUILTIN(real_div_fn, real_div_value); + +class real_le_value : public value { + expr m_type; + name m_name; +public: + real_le_value() { + m_type = Real >> (Real >> Bool); + m_name = name{"Real", "le"}; + } + virtual ~real_le_value() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 3 && is_real_value(args[1]) && is_real_value(args[2])) { + r = mk_bool_value(real_value_numeral(args[1]) <= real_value_numeral(args[2])); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; +MK_BUILTIN(real_le_fn, real_le_value); +MK_CONSTANT(real_ge_fn, name(name("Real"), "ge")); +MK_CONSTANT(real_lt_fn, name(name("Real"), "lt")); +MK_CONSTANT(real_gt_fn, name(name("Real"), "gt")); +// ======================================= + +// ======================================= +// Coercions class nat_to_int_value : public value { expr m_type; name m_name; @@ -273,11 +397,37 @@ public: }; MK_BUILTIN(nat_to_int_fn, nat_to_int_value); +class int_to_real_value : public value { + expr m_type; + name m_name; +public: + int_to_real_value() { + m_type = Int >> Real; + m_name = "int_to_real"; + } + virtual ~int_to_real_value() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 2 && is_int_value(args[1])) { + r = mk_real_value(mpq(int_value_numeral(args[1]))); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; +MK_BUILTIN(int_to_real_fn, int_to_real_value); +MK_CONSTANT(nat_to_real_fn, name("nat_to_real")); // ======================================= -void add_int_theory(environment & env) { - expr p_ii = Int >> (Int >> Bool); +void add_arith_theory(environment & env) { expr p_nn = Nat >> (Nat >> Bool); + expr p_ii = Int >> (Int >> Bool); + expr p_rr = Real >> (Real >> Bool); expr x = Const("x"); expr y = Const("y"); @@ -288,5 +438,11 @@ void add_int_theory(environment & env) { env.add_definition(int_ge_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, iLe(y, x))); env.add_definition(int_lt_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, Not(iLe(y, x)))); env.add_definition(int_gt_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, Not(iLe(x, y)))); + + env.add_definition(real_ge_fn_name, p_rr, Fun({{x, Real}, {y, Real}}, rLe(y, x))); + env.add_definition(real_lt_fn_name, p_rr, Fun({{x, Real}, {y, Real}}, Not(rLe(y, x)))); + env.add_definition(real_gt_fn_name, p_rr, Fun({{x, Real}, {y, Real}}, Not(rLe(x, y)))); + + env.add_definition(nat_to_real_fn_name, Nat >> Real, Fun({x, Nat}, i2r(n2i(x)))); } } diff --git a/src/kernel/arith/arith.h b/src/kernel/arith/arith.h index ec32ff0cf..43dc03044 100644 --- a/src/kernel/arith/arith.h +++ b/src/kernel/arith/arith.h @@ -11,6 +11,8 @@ Author: Leonardo de Moura #include "mpq.h" namespace lean { +// ======================================= +// Natural numbers expr mk_nat_type(); extern expr const Nat; @@ -39,7 +41,10 @@ expr mk_nat_gt_fn(); inline expr nGt(expr const & e1, expr const & e2) { return mk_app(mk_nat_gt_fn(), e1, e2); } inline expr nIf(expr const & c, expr const & t, expr const & e) { return mk_if(Nat, c, t, e); } +// ======================================= +// ======================================= +// Integers expr mk_int_type(); extern expr const Int; @@ -74,10 +79,56 @@ expr mk_int_gt_fn(); inline expr iGt(expr const & e1, expr const & e2) { return mk_app(mk_int_gt_fn(), e1, e2); } inline expr iIf(expr const & c, expr const & t, expr const & e) { return mk_if(Int, c, t, e); } +// ======================================= +// ======================================= +// Reals +expr mk_real_type(); +extern expr const Real; + +expr mk_real_value(mpq const & v); +inline expr mk_real_value(int v) { return mk_real_value(mpq(v)); } +inline expr rVal(int v) { return mk_real_value(v); } +bool is_real_value(expr const & e); +mpq const & real_value_numeral(expr const & e); + +expr mk_real_add_fn(); +inline expr rAdd(expr const & e1, expr const & e2) { return mk_app(mk_real_add_fn(), e1, e2); } + +expr mk_real_sub_fn(); +inline expr rSub(expr const & e1, expr const & e2) { return mk_app(mk_real_sub_fn(), e1, e2); } + +expr mk_real_mul_fn(); +inline expr rMul(expr const & e1, expr const & e2) { return mk_app(mk_real_mul_fn(), e1, e2); } + +expr mk_real_div_fn(); +inline expr rDiv(expr const & e1, expr const & e2) { return mk_app(mk_real_div_fn(), e1, e2); } + +expr mk_real_le_fn(); +inline expr rLe(expr const & e1, expr const & e2) { return mk_app(mk_real_le_fn(), e1, e2); } + +expr mk_real_ge_fn(); +inline expr rGe(expr const & e1, expr const & e2) { return mk_app(mk_real_ge_fn(), e1, e2); } + +expr mk_real_lt_fn(); +inline expr rLt(expr const & e1, expr const & e2) { return mk_app(mk_real_lt_fn(), e1, e2); } + +expr mk_real_gt_fn(); +inline expr rGt(expr const & e1, expr const & e2) { return mk_app(mk_real_gt_fn(), e1, e2); } + +inline expr rIf(expr const & c, expr const & t, expr const & e) { return mk_if(Real, c, t, e); } +// ======================================= + +// ======================================= +// Coercions expr mk_nat_to_int_fn(); inline expr n2i(expr const & e) { return mk_app(mk_nat_to_int_fn(), e); } +expr mk_int_to_real_fn(); +inline expr i2r(expr const & e) { return mk_app(mk_int_to_real_fn(), e); } +expr mk_nat_to_real_fn(); +inline expr n2r(expr const & e) { return mk_app(mk_nat_to_real_fn(), e); } +// ======================================= class environment; -void add_int_theory(environment & env); +void add_arith_theory(environment & env); } diff --git a/src/library/toplevel.cpp b/src/library/toplevel.cpp index c922a648a..068ba79f2 100644 --- a/src/library/toplevel.cpp +++ b/src/library/toplevel.cpp @@ -13,7 +13,7 @@ namespace lean { void init_toplevel(environment & env) { add_basic_theory(env); add_basic_thms(env); - add_int_theory(env); + add_arith_theory(env); } environment mk_toplevel() { environment r; diff --git a/tests/lean/arith2.lean b/tests/lean/arith2.lean new file mode 100644 index 000000000..dc7bc11b8 --- /dev/null +++ b/tests/lean/arith2.lean @@ -0,0 +1,14 @@ +Show 1/2 +Eval 4/6 +Show 3 div 2 +Variable x : Real +Variable i : Int +Variable n : Nat +Show x + i + 1 + n +Set lean::pp::coercion true +Show x + i + 1 + n +Show x * i + x +Show x - i + x - x >= 0 +Show x < x +Show x <= x +Show x > x \ No newline at end of file diff --git a/tests/lean/arith2.lean.expected.out b/tests/lean/arith2.lean.expected.out new file mode 100644 index 000000000..9b53ea364 --- /dev/null +++ b/tests/lean/arith2.lean.expected.out @@ -0,0 +1,14 @@ +1 / 2 +2/3 +3 div 2 + Assumed: x + Assumed: i + Assumed: n +x + i + 1 + n + Set: lean::pp::coercion +x + (int_to_real i) + (nat_to_real 1) + (nat_to_real n) +x * (int_to_real i) + x +x - (int_to_real i) + x - x ≥ (nat_to_real 0) +x < x +x ≤ x +x > x