diff --git a/library/algebra/numeral.lean b/library/algebra/numeral.lean index e5d7e6e64..1a3db70a0 100644 --- a/library/algebra/numeral.lean +++ b/library/algebra/numeral.lean @@ -136,3 +136,43 @@ theorem mk_cong (op : A → A) (a b : A) (H : a = b) : op a = op b := by congruence; exact H theorem mk_eq (a : A) : a = a := rfl + +theorem neg_add_neg_eq_of_add_add_eq_zero [s : add_comm_group A] (a b c : A) (H : c + a + b = 0) : -a + -b = c := + begin apply add_neg_eq_of_eq_add, apply neg_eq_of_add_eq_zero, rewrite [add.comm, add.assoc, add.comm b, -add.assoc, H] end + +/-theorem neg_add_neg_helper [s : add_comm_group A] (t₁ t₂ e w₁ w₂ : A) (H₁ : t₁ = -w₁) + (H₂ : t₂ = -w₂) (H : e + w₁ + w₂ = 0) : t₁ + t₂ = e := + by rewrite [H₁, H₂, neg_add_neg_eq_of_add_add_eq_zero _ _ _ H]-/ + +theorem neg_add_neg_helper [s : add_comm_group A] (a b c : A) (H : a + b = c) : -a + -b = -c := + begin apply iff.mp !neg_eq_neg_iff_eq, rewrite [neg_add, *neg_neg, H] end + +theorem neg_add_pos_eq_of_eq_add [s : add_comm_group A] (a b c : A) (H : b = c + a) : -a + b = c := + begin apply neg_add_eq_of_eq_add, rewrite add.comm, exact H end + +/-theorem neg_add_pos_helper [s : add_comm_group A] (t₁ t₂ e v w₁ w₂ : A) (H₁ : t₁ = -w₁) + (H₂ : t₂ = w₂) (Hv : w₂ = v) (H : e + w₁ = v) : t₁ + t₂ = e := + begin rewrite [H₁, H₂, Hv, -H, add.comm, add_neg_cancel_right] end-/ + +theorem neg_add_pos_helper1 [s : add_comm_group A] (a b c : A) (H : b + c = a) : -a + b = -c := + begin apply neg_add_eq_of_eq_add, apply eq_add_neg_of_add_eq H end + +theorem neg_add_pos_helper2 [s : add_comm_group A] (a b c : A) (H : a + c = b) : -a + b = c := + begin apply neg_add_eq_of_eq_add, rewrite H end + +theorem pos_add_neg_helper [s : add_comm_group A] (a b c : A) (H : b + a = c) : a + b = c := + by rewrite [add.comm, H] + +theorem sub_eq_add_neg_helper [s : add_comm_group A] (t₁ t₂ e w₁ w₂: A) (H₁ : t₁ = w₁) (H₂ : t₂ = w₂) + (H : w₁ + -w₂ = e) : t₁ - t₂ = e := + by rewrite [sub_eq_add_neg, H₁, H₂, H] + +theorem pos_add_pos_helper [s : add_comm_group A] (a b c h₁ h₂ : A) (H₁ : a = h₁) (H₂ : b = h₂) + (H : h₁ + h₂ = c) : a + b = c := + by rewrite [H₁, H₂, H] + +theorem subst_into_subtr [s : add_group A] (l r t : A) (prt : l + -r = t) : l - r = t := + by rewrite [sub_eq_add_neg, prt] + +theorem neg_neg_helper [s : add_group A] (a b : A) (H : a = -b) : -a = b := + by rewrite [H, neg_neg] diff --git a/src/library/norm_num.cpp b/src/library/norm_num.cpp index 0797b7597..b5fe34a7c 100644 --- a/src/library/norm_num.cpp +++ b/src/library/norm_num.cpp @@ -6,11 +6,13 @@ Author: Robert Y. Lewis #include "library/norm_num.h" #include "library/constants.h" + namespace lean { static name * g_add = nullptr, * g_add1 = nullptr, * g_mul = nullptr, * g_sub = nullptr, + * g_neg = nullptr, * g_bit0_add_bit0 = nullptr, * g_bit1_add_bit0 = nullptr, * g_bit0_add_bit1 = nullptr, @@ -28,6 +30,7 @@ static name * g_add = nullptr, * g_add1_zero = nullptr, * g_add1_one = nullptr, * g_subst_sum = nullptr, + * g_subst_subtr = nullptr, * g_subst_prod = nullptr, * g_mk_cong = nullptr, * g_mk_eq = nullptr, @@ -40,9 +43,21 @@ static name * g_add = nullptr, * g_add_monoid = nullptr, * g_monoid = nullptr, * g_add_comm = nullptr, + * g_add_group = nullptr, * g_mul_zero_class = nullptr, * g_distrib = nullptr, - * g_semiring = nullptr; + * g_has_neg = nullptr, + * g_has_sub = nullptr, + * g_semiring = nullptr, + * g_eq_neg_of_add_eq_zero = nullptr, + * g_neg_add_neg_eq = nullptr, + * g_neg_add_pos1 = nullptr, + * g_neg_add_pos2 = nullptr, + * g_pos_add_neg = nullptr, + * g_pos_add_pos = nullptr, + * g_sub_eq_add_neg = nullptr, + * g_neg_neg = nullptr, + * g_add_comm_group = nullptr; static bool is_numeral_aux(expr const & e, bool is_first) { buffer args; @@ -64,13 +79,15 @@ bool norm_num_context::is_numeral(expr const & e) const { return is_numeral_aux(e, true); } +bool is_neg(expr const & e) { + return is_const_app(e, *g_neg, 3); +} + /* -Takes e : instance A, and tries to synthesize has_add A. +Takes A : Type, and tries to synthesize has_add A. */ expr norm_num_context::mk_has_add(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(get_has_add_name(), const_levels(f)), args[0]); + expr t = mk_app(mk_constant(get_has_add_name(), m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -80,9 +97,7 @@ expr norm_num_context::mk_has_add(expr const & e) { } expr norm_num_context::mk_has_mul(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_has_mul, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_has_mul, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -92,9 +107,17 @@ expr norm_num_context::mk_has_mul(expr const & e) { } expr norm_num_context::mk_has_one(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(get_has_one_name(), const_levels(f)), args[0]); + expr t = mk_app(mk_constant(get_has_one_name(), m_lvls), e); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_one instance"); + } +} + +expr norm_num_context::mk_has_zero(expr const & e) { + expr t = mk_app(mk_constant(get_has_zero_name(), m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -104,9 +127,7 @@ expr norm_num_context::mk_has_one(expr const & e) { } expr norm_num_context::mk_add_monoid(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_add_monoid, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_add_monoid, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -116,9 +137,7 @@ expr norm_num_context::mk_add_monoid(expr const & e) { } expr norm_num_context::mk_monoid(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_monoid, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_monoid, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -128,9 +147,17 @@ expr norm_num_context::mk_monoid(expr const & e) { } expr norm_num_context::mk_add_comm(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_add_comm, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_add_comm, m_lvls), e); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize add_comm_semigroup instance"); + } +} + +expr norm_num_context::mk_add_group(expr const & e) { + expr t = mk_app(mk_constant(*g_add_group, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -140,9 +167,7 @@ expr norm_num_context::mk_add_comm(expr const & e) { } expr norm_num_context::mk_has_distrib(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_distrib, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_distrib, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -152,9 +177,7 @@ expr norm_num_context::mk_has_distrib(expr const & e) { } expr norm_num_context::mk_mul_zero_class(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_mul_zero_class, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_mul_zero_class, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -164,9 +187,7 @@ expr norm_num_context::mk_mul_zero_class(expr const & e) { } expr norm_num_context::mk_semiring(expr const & e) { - buffer args; - expr f = get_app_args(e, args); - expr t = mk_app(mk_constant(*g_semiring, const_levels(f)), args[0]); + expr t = mk_app(mk_constant(*g_semiring, m_lvls), e); optional inst = mk_class_instance(m_env, m_ctx, t); if (inst) { return *inst; @@ -175,6 +196,36 @@ expr norm_num_context::mk_semiring(expr const & e) { } } +expr norm_num_context::mk_has_neg(expr const & e) { + expr t = mk_app(mk_constant(*g_has_neg, m_lvls), e); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_neg instance"); + } +} + +expr norm_num_context::mk_has_sub(expr const & e) { + expr t = mk_app(mk_constant(*g_has_sub, m_lvls), e); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_sub instance"); + } +} + +expr norm_num_context::mk_add_comm_group(expr const & e) { + expr t = mk_app(mk_constant(*g_add_comm_group, m_lvls), e); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize add_comm_group instance"); + } +} + expr norm_num_context::mk_const(name const & n) { return mk_constant(n, m_lvls); } @@ -183,7 +234,7 @@ expr norm_num_context::mk_cong(expr const & op, expr const & type, expr const & return mk_app({mk_const(*g_mk_cong), type, op, a, b, eq}); } -pair norm_num_context::mk_norm(expr const & e) { +/*pair norm_num_context::mk_norm(expr const & e) { buffer args; expr f = get_app_args(e, args); if (!is_constant(f)) { @@ -194,14 +245,14 @@ pair norm_num_context::mk_norm(expr const & e) { auto lhs_p = mk_norm(args[2]); auto rhs_p = mk_norm(args[3]); auto add_p = mk_norm_add(lhs_p.first, rhs_p.first); - expr prf = mk_app({mk_const(*g_subst_sum), args[0], mk_has_add(args[1]), args[2], args[3], + expr prf = mk_app({mk_const(*g_subst_sum), args[0], mk_has_add(args[0]), args[2], args[3], lhs_p.first, rhs_p.first, add_p.first, lhs_p.second, rhs_p.second, add_p.second}); return pair(add_p.first, prf); } else if (const_name(f) == *g_mul && args.size() == 4) { auto lhs_p = mk_norm(args[2]); auto rhs_p = mk_norm(args[3]); auto mul_p = mk_norm_mul(lhs_p.first, rhs_p.first); - expr prf = mk_app({mk_const(*g_subst_prod), args[0], mk_has_mul(args[1]), args[2], args[3], + expr prf = mk_app({mk_const(*g_subst_prod), args[0], mk_has_mul(args[0]), args[2], args[3], lhs_p.first, rhs_p.first, mul_p.first, lhs_p.second, rhs_p.second, mul_p.second}); return pair(mul_p.first, prf); } else if (const_name(f) == get_bit0_name() && args.size() == 3) { @@ -221,7 +272,7 @@ pair norm_num_context::mk_norm(expr const & e) { throw exception("mk_norm found unrecognized combo "); } // TODO(Rob): cases for sub, div -} + }*/ // returns such that p is a proof that lhs + rhs = t. pair norm_num_context::mk_norm_add(expr const & lhs, expr const & rhs) { @@ -234,29 +285,30 @@ pair norm_num_context::mk_norm_add(expr const & lhs, expr const & rh } auto type = args_lhs[0]; auto typec = args_lhs[1]; + // std::cout << "typec in mk_norm_add is: " << typec << ". lhs: " << lhs << ", rhs: " << rhs << "\n"; expr rv; expr prf; if (is_bit0(lhs) && is_bit0(rhs)) { // typec is has_add auto p = mk_norm_add(args_lhs[2], args_rhs[2]); rv = mk_app(lhs_head, type, typec, p.first); - prf = mk_app({mk_const(*g_bit0_add_bit0), type, mk_add_comm(typec), args_lhs[2], args_rhs[2], p.first, p.second}); + prf = mk_app({mk_const(*g_bit0_add_bit0), type, mk_add_comm(type), args_lhs[2], args_rhs[2], p.first, p.second}); } else if (is_bit0(lhs) && is_bit1(rhs)) { auto p = mk_norm_add(args_lhs[2], args_rhs[3]); rv = mk_app({rhs_head, type, args_rhs[1], args_rhs[2], p.first}); - prf = mk_app({mk_const(*g_bit0_add_bit1), type, mk_add_comm(typec), args_rhs[1], args_lhs[2], args_rhs[3], p.first, p.second}); + prf = mk_app({mk_const(*g_bit0_add_bit1), type, mk_add_comm(type), args_rhs[1], args_lhs[2], args_rhs[3], p.first, p.second}); } else if (is_bit0(lhs) && is_one(rhs)) { rv = mk_app({mk_const(get_bit1_name()), type, args_rhs[1], args_lhs[1], args_lhs[2]}); prf = mk_app({mk_const(*g_bit0_add_1), type, typec, args_rhs[1], args_lhs[2]}); } else if (is_bit1(lhs) && is_bit0(rhs)) { // typec is has_one auto p = mk_norm_add(args_lhs[3], args_rhs[2]); rv = mk_app(lhs_head, type, typec, args_lhs[2], p.first); - prf = mk_app({mk_const(*g_bit1_add_bit0), type, mk_add_comm(typec), typec, args_lhs[3], args_rhs[2], p.first, p.second}); + prf = mk_app({mk_const(*g_bit1_add_bit0), type, mk_add_comm(type), typec, args_lhs[3], args_rhs[2], p.first, p.second}); } else if (is_bit1(lhs) && is_bit1(rhs)) { // typec is has_one auto add_ts = mk_norm_add(args_lhs[3], args_rhs[3]); expr add1 = mk_app({mk_const(*g_add1), type, args_lhs[2], typec, add_ts.first}); auto p = mk_norm_add1(add1); rv = mk_app({mk_const(get_bit0_name()), type, args_lhs[2], p.first}); - prf = mk_app({mk_const(*g_bit1_add_bit1), type, mk_add_comm(typec), typec, args_lhs[3], args_rhs[3], add_ts.first, p.first, add_ts.second, p.second}); + prf = mk_app({mk_const(*g_bit1_add_bit1), type, mk_add_comm(type), typec, args_lhs[3], args_rhs[3], add_ts.first, p.first, add_ts.second, p.second}); } else if (is_bit1(lhs) && is_one(rhs)) { // typec is has_one expr add1 = mk_app({mk_const(*g_add1), type, args_lhs[2], typec, lhs}); auto p = mk_norm_add1(add1); @@ -264,21 +316,21 @@ pair norm_num_context::mk_norm_add(expr const & lhs, expr const & rh prf = mk_app({mk_const(*g_bit1_add_1), type, args_lhs[2], typec, args_lhs[3], p.first, p.second}); } else if (is_one(lhs) && is_bit0(rhs)) { // typec is has_one rv = mk_app({mk_const(get_bit1_name()), type, typec, args_rhs[1], args_rhs[2]}); - prf = mk_app({mk_const(*g_1_add_bit0), type, mk_add_comm(typec), typec, args_rhs[2]}); + prf = mk_app({mk_const(*g_1_add_bit0), type, mk_add_comm(type), typec, args_rhs[2]}); } else if (is_one(lhs) && is_bit1(rhs)) { // typec is has_one expr add1 = mk_app({mk_const(*g_add1), type, args_rhs[2], args_rhs[1], rhs}); auto p = mk_norm_add1(add1); rv = p.first; - prf = mk_app({mk_const(*g_1_add_bit1), type, mk_add_comm(typec), typec, args_rhs[3], p.first, p.second}); + prf = mk_app({mk_const(*g_1_add_bit1), type, mk_add_comm(type), typec, args_rhs[3], p.first, p.second}); } else if (is_one(lhs) && is_one(rhs)) { - rv = mk_app({mk_const(get_bit0_name()), type, mk_has_add(typec), lhs}); - prf = mk_app({mk_const(*g_one_add_one), type, mk_has_add(typec), typec}); + rv = mk_app({mk_const(get_bit0_name()), type, mk_has_add(type), lhs}); + prf = mk_app({mk_const(*g_one_add_one), type, mk_has_add(type), typec}); } else if (is_zero(lhs)) { rv = rhs; - prf = mk_app({mk_const(*g_bin_0_add), type, mk_add_monoid(typec), rhs}); + prf = mk_app({mk_const(*g_bin_0_add), type, mk_add_monoid(type), rhs}); } else if (is_zero(rhs)) { rv = lhs; - prf = mk_app({mk_const(*g_bin_add_0), type, mk_add_monoid(typec), lhs}); + prf = mk_app({mk_const(*g_bin_add_0), type, mk_add_monoid(type), lhs}); } else { std::cout << "\n\n bad args: " << lhs_head << ", " << rhs_head << "\n"; throw exception("mk_norm_add got malformed args"); @@ -302,10 +354,10 @@ pair norm_num_context::mk_norm_add1(expr const & e) { } else if (is_bit1(p)) { // ne_args : has_one, has_add auto np = mk_norm_add1(mk_app({mk_const(*g_add1), args[0], args[1], args[2], ne_args[3]})); rv = mk_app({mk_const(get_bit0_name()), args[0], args[1], np.first}); - prf = mk_app({mk_const(*g_add1_bit1), args[0], mk_add_comm(args[1]), args[2], ne_args[3], np.first, np.second}); + prf = mk_app({mk_const(*g_add1_bit1), args[0], mk_add_comm(args[0]), args[2], ne_args[3], np.first, np.second}); } else if (is_zero(p)) { rv = mk_app({mk_const(get_one_name()), args[0], args[2]}); - prf = mk_app({mk_const(*g_add1_zero), args[0], mk_add_monoid(args[1]), args[2]}); + prf = mk_app({mk_const(*g_add1_zero), args[0], mk_add_monoid(args[0]), args[2]}); } else if (is_one(p)) { rv = mk_app({mk_const(get_bit0_name()), args[0], args[1], mk_app({mk_const(get_one_name()), args[0], args[2]})}); prf = mk_app({mk_const(*g_add1_one), args[0], args[1], args[2]}); @@ -330,22 +382,22 @@ pair norm_num_context::mk_norm_mul(expr const & lhs, expr const & rh expr prf; if (is_zero(rhs)) { rv = rhs; - prf = mk_app({mk_const(*g_mul_zero), type, mk_mul_zero_class(typec), lhs}); + prf = mk_app({mk_const(*g_mul_zero), type, mk_mul_zero_class(type), lhs}); } else if (is_zero(lhs)) { rv = lhs; - prf = mk_app({mk_const(*g_zero_mul), type, mk_mul_zero_class(typec), rhs}); + prf = mk_app({mk_const(*g_zero_mul), type, mk_mul_zero_class(type), rhs}); } else if (is_one(rhs)) { rv = lhs; - prf = mk_app({mk_const(*g_mul_one), type, mk_monoid(typec), lhs}); + prf = mk_app({mk_const(*g_mul_one), type, mk_monoid(type), lhs}); } else if (is_bit0(rhs)) { auto mtp = mk_norm_mul(lhs, args_rhs[2]); rv = mk_app({rhs_head, type, typec, mtp.first}); - prf = mk_app({mk_const(*g_mul_bit0), type, mk_has_distrib(typec), lhs, args_rhs[2], mtp.first, mtp.second}); + prf = mk_app({mk_const(*g_mul_bit0), type, mk_has_distrib(type), lhs, args_rhs[2], mtp.first, mtp.second}); } else if (is_bit1(rhs)) { auto mtp = mk_norm_mul(lhs, args_rhs[3]); auto atp = mk_norm_add(mk_app({mk_const(get_bit0_name()), type, args_rhs[2], mtp.first}), lhs); rv = atp.first; - prf = mk_app({mk_const(*g_mul_bit1), type, mk_semiring(typec), lhs, args_rhs[3], + prf = mk_app({mk_const(*g_mul_bit1), type, mk_semiring(type), lhs, args_rhs[3], mtp.first, atp.first, mtp.second, atp.second}); } else { std::cout << "bad args to mk_norm_mul: " << rhs << "\n"; @@ -364,11 +416,262 @@ pair norm_num_context::mk_norm_sub(expr const &, expr const &) { throw exception("not implemented yet -- mk_norm_sub"); } +mpz norm_num_context::num_of_expr(expr const & e) { // note : mpz only supports nonneg nums + buffer args; + expr f = get_app_args(e, args); + if (!is_constant(f)) { + throw exception("cannot find num of nonconstant"); + } + auto v = to_num(e); + if (v) { + return *v; + } + if (const_name(f) == *g_add && args.size() == 4) { + return num_of_expr(args[2]) + num_of_expr(args[3]); + } else if (const_name(f) == *g_mul && args.size() == 4) { + return num_of_expr(args[2]) * num_of_expr(args[3]); + } else if (const_name(f) == *g_sub && args.size() == 4) { + return num_of_expr(args[2]) - num_of_expr(args[3]); + } else if (const_name(f) == *g_neg && args.size() == 3) { + return neg(num_of_expr(args[2])); + } else { + std::cout << "error : " << f << "\n"; + throw exception("expression in num_of_expr is malfomed"); + } +} + +pair get_type_and_arg_of_neg(expr & e) { + lean_assert(is_neg(e)); + buffer args; + expr f = get_app_args(e, args); + return pair(args[0], args[2]); +} + +// returns a proof that s_lhs + s_rhs = rhs, where all are negated numerals +expr norm_num_context::mk_norm_eq_neg_add_neg(expr & s_lhs, expr & s_rhs, expr & rhs) { + lean_assert(is_neg(s_lhs)); + lean_assert(is_neg(s_rhs)); + lean_assert(is_neg(rhs)); + auto s_lhs_v = get_type_and_arg_of_neg(s_lhs).second; + auto s_rhs_v = get_type_and_arg_of_neg(s_rhs).second; + auto rhs_v = get_type_and_arg_of_neg(rhs); + expr type = rhs_v.first; + auto sum_pr = mk_norm_eq_pos_add_pos(s_lhs_v, s_rhs_v, rhs_v.second); + return mk_app({mk_const(*g_neg_add_neg_eq), type, mk_add_comm_group(type), s_lhs_v, s_rhs_v, rhs_v.second, sum_pr}); +} + +expr norm_num_context::mk_norm_eq_neg_add_pos(expr & s_lhs, expr & s_rhs, expr & rhs) { + lean_assert(is_neg(s_lhs)); + lean_assert(!is_neg(s_rhs)); + auto s_lhs_v = get_type_and_arg_of_neg(s_lhs); + expr type = s_lhs_v.first; + if (is_neg(rhs)) { + auto rhs_v = get_type_and_arg_of_neg(rhs).second; + auto sum_pr = mk_norm_eq_pos_add_pos(s_rhs, rhs_v, s_lhs_v.second); + return mk_app({mk_const(*g_neg_add_pos1), type, mk_add_comm_group(type), s_lhs_v.second, s_rhs, rhs_v, sum_pr}); + } else { + auto sum_pr = mk_norm_eq_pos_add_pos(s_lhs_v.second, rhs, s_rhs); + return mk_app({mk_const(*g_neg_add_pos2), type, mk_add_comm_group(type), s_lhs_v.second, s_rhs, rhs, sum_pr}); + } +} + + +expr norm_num_context::mk_norm_eq_pos_add_neg(expr & s_lhs, expr & s_rhs, expr & rhs) { + expr prf = mk_norm_eq_neg_add_pos(s_rhs, s_lhs, rhs); + expr type = get_type_and_arg_of_neg(s_rhs).first; + return mk_app({mk_const(*g_pos_add_neg), type, mk_add_comm_group(type), s_lhs, s_rhs, rhs, prf}); +} + +// returns a proof that s_lhs + s_rhs = rhs, where all are nonneg normalized numerals +expr norm_num_context::mk_norm_eq_pos_add_pos(expr & s_lhs, expr & s_rhs, expr & rhs) { + lean_assert(!is_neg(s_lhs)); + lean_assert(!is_neg(s_rhs)); + lean_assert(!is_neg(rhs)); + auto p = mk_norm_add(s_lhs, s_rhs); + lean_assert(to_num(rhs) == to_num(p.first)); + return p.second; +} + + +/*expr norm_num_context::mk_norm_eq(expr const & lhs, expr const & rhs) { // rhs is a nonneg numeral + buffer args; + expr f = get_app_args(lhs, args); + if (!is_constant(f)) { + throw exception("cannot take norm of nonconstant"); + } +// m_lvls = const_levels(f); +// expr rv; +// expr prf; + if (const_name(f) == *g_add && args.size() == 4) { + auto lhs_p = num_of_expr(args[2]); // mk_norm_expr(args[2]); + auto rhs_p = num_of_expr(args[3]); //mk_norm_expr(args[3]); + buffer args_lhs, args_rhs; +// expr flhs = get_app_args(lhs_p.first, args_lhs); +// expr frhs = get_app_args(rhs_p.first, args_rhs); +// std::cout << "in mk_norm_eq add. is_neg first, second:" << is_neg(lhs_p.first) << is_neg(rhs_p.first) << "\n"; + if (lhs_p.is_neg()) { + if (rhs_p.is_neg()) { + return mk_norm_eq_neg_add_neg(f, rhs, args); + //return mk_norm_eq_neg_add_neg(f, rhs, args, args_lhs, args_rhs, lhs_p.second, rhs_p.second); + } else { + return mk_norm_eq_neg_add_pos(f, rhs, args); + } + } else { + if (rhs_p.is_neg()) { + buffer rvargs = buffer(args); + rvargs[3] = args[2]; + rvargs[2] = args[3]; + expr commprf = mk_norm_eq_neg_add_pos(f, rhs, rvargs); + // commprf : args[3] + args[2] = rhs + return mk_app({mk_const(*g_pos_add_neg), args[0], mk_add_comm_group(args[0]), args[2], args[3], rhs, commprf}); + } else { // nonneg add nonneg + return mk_norm_eq_pos_add_pos(f, rhs, args); + } + } + } else if (const_name(f) == *g_mul && args.size() == 4) { + auto lhs_p = mk_norm_expr(args[2]); + auto rhs_p = mk_norm_expr(args[3]);// TODO(Rob): handle case where either is neg + auto mul_p = mk_norm_mul(lhs_p.first, rhs_p.first); + rv = mul_p.first; + prf = mk_app({mk_const(*g_subst_prod), args[0], mk_has_mul(args[0]), args[2], args[3], + lhs_p.first, rhs_p.first, mul_p.first, lhs_p.second, rhs_p.second, mul_p.second}); + } else if (const_name(f) == *g_sub && args.size() == 4) { + // auto lhs_p = mk_norm_expr(args[2]); + // auto rhs_p = mk_norm_expr(args[3]); + expr addneg = mk_app({mk_const(*g_add), args[0], mk_has_add(args[0]), args[2], mk_neg(args[0], args[3])}); + expr prf = mk_norm_eq(addneg, rhs); // a + -b = c + return mk_app({mk_const(*g_sub_eq_add_neg), args[0], mk_add_comm_group(args[0]), args[2], args[3], rhs, prf}); + } else if ((const_name(f) == get_bit0_name() || const_name(f) == *g_neg) && args.size() == 3) { + auto arg = mk_norm(args[2]); + // rv = mk_app({f, args[0], args[1], arg.first}); + return mk_cong(mk_app({f, args[0], args[1]}), args[0], args[2], arg.first, arg.second); + } else if (const_name(f) == get_bit1_name() && args.size() == 4) { + auto arg = mk_norm(args[3]); + // rv = mk_app({f, args[0], args[1], args[2], arg.first}); + return mk_cong(mk_app({f, args[0], args[1], args[2]}), args[0], args[3], arg.first, arg.second); + } else if ((const_name(f) == get_zero_name() || const_name(f) == get_one_name()) && args.size() == 2) { + //rv = lhs; + return mk_app({mk_const(*g_mk_eq), args[0], lhs}); + } else { + std::cout << "error with name " << const_name(f) << " and size " << args.size() << ".\n"; + throw exception("mk_norm found unrecognized combo "); + } + // throw exception("not implemented yet"); +}*/ + +expr norm_num_context::from_pos_num(mpz const & n, expr const & type) { + lean_assert(n > 0); + if (n == 1) + return mk_app({mk_const(get_one_name()), type, mk_has_one(type)}); + if (n % mpz(2) == 1) + return mk_app({mk_const(get_bit1_name()), type, mk_has_one(type), mk_has_add(type), from_pos_num(n/2, type)}); + else + return mk_app({mk_const(get_bit0_name()), type, mk_has_add(type), from_pos_num(n/2, type)}); +} + +expr norm_num_context::from_num(mpz const & n, expr const & type) { + expr r; + lean_assert(n >= 0); + if (n == 0) + r = mk_app(mk_const(get_zero_name()), type, mk_has_zero(type)); + else + r = from_pos_num(n, type); + lean_assert(*to_num(r) == n); + return r; +} + +expr norm_num_context::mk_neg(expr const & type, expr const & e) { + auto has_neg = mk_has_neg(type); + return mk_app({mk_const(*g_neg), type, has_neg, e}); +} + +expr norm_num_context::mk_add(expr const & type, expr const & e1, expr const & e2) { + auto has_add = mk_has_add(type); + return mk_app({mk_const(*g_add), type, has_add, e1, e2}); +} + +pair norm_num_context::mk_norm(expr const & e) { + std::cout << "mk_norm\n"; + buffer args; + expr f = get_app_args(e, args); + if (!is_constant(f) || args.size() == 0) { + throw exception("malformed argument to mk_norm_expr"); + } + m_lvls = const_levels(f); + expr type = args[0]; + auto val = num_of_expr(e); + expr nval; // e = nval + if (val >= 0) { + nval = from_num(val, type); + } else { + nval = mk_neg(type, from_num(neg(val), type)); + } + if (const_name(f) == *g_add && args.size() == 4) { + auto lhs_p = mk_norm(args[2]); + auto rhs_p = mk_norm(args[3]); + expr prf; + if (is_neg(lhs_p.first)) { + if (is_neg(rhs_p.first)) { + prf = mk_norm_eq_neg_add_neg(lhs_p.first, rhs_p.first, nval); + } else { + prf = mk_norm_eq_neg_add_pos(lhs_p.first, rhs_p.first, nval); + } + } else { + if (is_neg(rhs_p.first)) { + prf = mk_norm_eq_pos_add_neg(lhs_p.first, rhs_p.first, nval); + } else { + prf = mk_norm_eq_pos_add_pos(lhs_p.first, rhs_p.first, nval); + } + } + expr rprf = mk_app({mk_const(*g_subst_sum), type, mk_has_add(type), args[2], args[3], + lhs_p.first, rhs_p.first, nval, lhs_p.second, rhs_p.second, prf}); + return pair(nval, rprf); + + } else if (const_name(f) == *g_sub && args.size() == 4) { + expr sum = mk_add(args[0], args[2], mk_neg(args[0], args[3])); + auto anprf = mk_norm(sum); + expr rprf = mk_app({mk_const(*g_subst_subtr), type, mk_add_group(type), args[2], args[3], anprf.first, anprf.second}); + return pair(nval, rprf); + } else if (const_name(f) == *g_neg && args.size() == 3) { + auto prf = mk_norm(args[2]); + lean_assert(num_of_expr(prf.first) == neg(val)); + if (is_neg(nval)) { + buffer nval_args; + get_app_args(nval, nval_args); + auto rprf = mk_cong(mk_app(f, args[0], args[1]), type, args[2], nval_args[2], prf.second); + return pair(nval, rprf); + } else { + auto rprf = mk_app({mk_const(*g_neg_neg), type, mk_add_group(type), args[2], nval, prf.second}); + return pair(nval, rprf); + } + } else if (const_name(f) == get_bit0_name() && args.size() == 3) { + lean_assert(is_bit0(nval)); + buffer nval_args; + get_app_args(nval, nval_args); + auto prf = mk_norm(args[2]); + auto rprf = mk_cong(mk_app(f, args[0], args[1]), type, args[2], nval_args[2], prf.second); + return pair(nval, rprf); + } else if (const_name(f) == get_bit1_name() && args.size() == 4) { + lean_assert(is_bit1(nval)); + buffer nval_args; + get_app_args(nval, nval_args); + auto prf = mk_norm(args[3]); + auto rprf = mk_cong(mk_app(f, args[0], args[1], args[2]), type, args[3], nval_args[3], prf.second); + return pair(nval, rprf); + } else if ((const_name(f) == get_zero_name() || const_name(f) == get_one_name()) && args.size() == 2) { + return pair(e, mk_app({mk_const(*g_mk_eq), args[0], e})); + } else { + std::cout << "error with name " << const_name(f) << " and size " << args.size() << ".\n"; + throw exception("mk_norm found unrecognized combo "); + } +} + void initialize_norm_num() { g_add = new name("add"); g_add1 = new name("add1"); g_mul = new name("mul"); g_sub = new name("sub"); + g_neg = new name("neg"); g_bit0_add_bit0 = new name("bit0_add_bit0_helper"); g_bit1_add_bit0 = new name("bit1_add_bit0_helper"); g_bit0_add_bit1 = new name("bit0_add_bit1_helper"); @@ -386,6 +689,7 @@ void initialize_norm_num() { g_add1_zero = new name("add1_zero"); g_add1_one = new name("add1_one"); g_subst_sum = new name("subst_into_sum"); + g_subst_subtr = new name("subst_into_subtr"); g_subst_prod = new name("subst_into_prod"); g_mk_cong = new name("mk_cong"); g_mk_eq = new name("mk_eq"); @@ -398,9 +702,21 @@ void initialize_norm_num() { g_add_monoid = new name("algebra", "add_monoid"); g_monoid = new name("algebra", "monoid"); g_add_comm = new name("algebra", "add_comm_semigroup"); + g_add_group = new name("algebra", "add_group"); g_mul_zero_class = new name("algebra", "mul_zero_class"); g_distrib = new name("algebra", "distrib"); + g_has_neg = new name("has_neg"); //"algebra", + g_has_sub = new name("algebra", "has_sub"); g_semiring = new name("algebra", "semiring"); + g_eq_neg_of_add_eq_zero = new name("algebra", "eq_neg_of_add_eq_zero"); + g_neg_add_neg_eq = new name("neg_add_neg_helper"); + g_neg_add_pos1 = new name("neg_add_pos_helper1"); + g_neg_add_pos2 = new name("neg_add_pos_helper2"); + g_pos_add_neg = new name("pos_add_neg_helper"); + g_sub_eq_add_neg = new name("sub_eq_add_neg_helper"); + g_pos_add_pos = new name("pos_add_pos_helper"); + g_neg_neg = new name("neg_neg_helper"); + g_add_comm_group = new name("algebra", "add_comm_group"); } void finalize_norm_num() { @@ -408,6 +724,7 @@ void finalize_norm_num() { delete g_add1; delete g_mul; delete g_sub; + delete g_neg; delete g_bit0_add_bit0; delete g_bit1_add_bit0; delete g_bit0_add_bit1; @@ -425,6 +742,7 @@ void finalize_norm_num() { delete g_add1_zero; delete g_add1_one; delete g_subst_sum; + delete g_subst_subtr; delete g_subst_prod; delete g_mk_cong; delete g_mk_eq; @@ -437,8 +755,20 @@ void finalize_norm_num() { delete g_add_monoid; delete g_monoid; delete g_add_comm; + delete g_add_group; delete g_mul_zero_class; delete g_distrib; + delete g_has_neg; + delete g_has_sub; delete g_semiring; + delete g_eq_neg_of_add_eq_zero; + delete g_neg_add_neg_eq; + delete g_neg_add_pos1; + delete g_neg_add_pos2; + delete g_pos_add_neg; + delete g_pos_add_pos; + delete g_sub_eq_add_neg; + delete g_neg_neg; + delete g_add_comm_group; } } diff --git a/src/library/norm_num.h b/src/library/norm_num.h index 5a279eb4f..b528eaf1b 100644 --- a/src/library/norm_num.h +++ b/src/library/norm_num.h @@ -23,21 +23,39 @@ class norm_num_context { expr mk_cong(expr const &, expr const &, expr const &, expr const &, expr const &); expr mk_has_add(expr const &); expr mk_has_mul(expr const &); + expr mk_has_zero(expr const &); expr mk_has_one(expr const &); expr mk_add_monoid(expr const &); expr mk_monoid(expr const &); expr mk_has_distrib(expr const &); expr mk_add_comm(expr const &); + expr mk_add_group(expr const &); expr mk_mul_zero_class(expr const &); expr mk_semiring(expr const &); + expr mk_has_neg(expr const &); + expr mk_has_sub(expr const &); + expr mk_add(expr const &, expr const &, expr const &); + expr mk_neg(expr const &, expr const &); + expr mk_add_comm_group(expr const &); + expr mk_norm_eq_neg_add_neg(expr &,expr &,expr &); + expr mk_norm_eq_neg_add_pos(expr &, expr &, expr &); + expr mk_norm_eq_pos_add_neg(expr &, expr &, expr &); + expr mk_norm_eq_pos_add_pos(expr &, expr &, expr &); public: norm_num_context(environment const & env, local_context const & ctx):m_env(env), m_ctx(ctx) {} bool is_numeral(expr const & e) const; pair mk_norm(expr const & e); + //pair mk_norm_expr(expr const & e); + expr mk_norm_eq(expr const &, expr const &); + mpz num_of_expr(expr const & e); + expr from_pos_num(mpz const &, expr const &); + expr from_num(mpz const &, expr const &); }; +inline bool is_neg(expr const & e); + inline bool is_numeral(environment const & env, expr const & e) { return norm_num_context(env, local_context()).is_numeral(e); } @@ -45,6 +63,11 @@ inline bool is_numeral(environment const & env, expr const & e) { inline pair mk_norm_num(environment const & env, local_context const & ctx, expr const & e) { return norm_num_context(env, ctx).mk_norm(e); } + +inline mpz num_of_expr(environment const & env, local_context const & ctx, expr const & e) { + return norm_num_context(env, ctx).num_of_expr(e); +} + void initialize_norm_num(); void finalize_norm_num(); } diff --git a/src/library/num.cpp b/src/library/num.cpp index 5822b5766..bbd8656fa 100644 --- a/src/library/num.cpp +++ b/src/library/num.cpp @@ -18,7 +18,7 @@ bool has_num_decls(environment const & env) { env.find(get_bit1_name()); } -static bool is_const_app(expr const & e, name const & n, unsigned nargs) { +bool is_const_app(expr const & e, name const & n, unsigned nargs) { expr const & f = get_app_fn(e); return is_constant(f) && const_name(f) == n && get_app_num_args(e) == nargs; } @@ -43,6 +43,12 @@ optional is_bit1(expr const & e) { return some_expr(app_arg(e)); } +optional is_neg(expr const & e) { + if (!is_const_app(e, *new name("neg"), 3)) + return none_expr(); + return some_expr(app_arg(e)); +} + optional unfold_num_app(environment const & env, expr const & e) { if (is_zero(e) || is_one(e) || is_bit0(e) || is_bit1(e)) { return unfold_app(env, e); @@ -83,6 +89,9 @@ static optional to_num(expr const & e, bool first) { } else if (auto a = is_bit1(e)) { if (auto r = to_num(*a, false)) return some(2*(*r)+1); + } else if (auto a = is_neg(e)) { + if (auto r = to_num(*a, false)) + return some(neg(*r)); } return optional(); } diff --git a/src/library/num.h b/src/library/num.h index cc444829a..4b347297e 100644 --- a/src/library/num.h +++ b/src/library/num.h @@ -12,6 +12,8 @@ namespace lean { zero, one, bit0, bit1 */ bool has_num_decls(environment const & env); +bool is_const_app(expr const &, name const &, unsigned); + /** \brief Return true iff the given expression encodes a numeral. */ bool is_num(expr const & e); diff --git a/src/library/tactic/norm_num_tactic.cpp b/src/library/tactic/norm_num_tactic.cpp index 7a1d84376..461b515e3 100644 --- a/src/library/tactic/norm_num_tactic.cpp +++ b/src/library/tactic/norm_num_tactic.cpp @@ -28,10 +28,10 @@ tactic norm_num_tactic() { type_checker_ptr rtc = mk_type_checker(env, UnfoldReducible); lhs = normalize(*rtc, lhs); rhs = normalize(*rtc, rhs); - buffer hyps; g.get_hyps(hyps); local_context ctx(to_list(hyps)); +// std::cout << "num of lhs: " << num_of_expr(env, ctx, lhs) << "\n"; try { pair p = mk_norm_num(env, ctx, lhs); expr new_lhs = p.first; @@ -53,6 +53,7 @@ tactic norm_num_tactic() { return none_proof_state(); } } else { + std::cout << "lhs: " << new_lhs << ", rhs: " << new_rhs << "\n"; throw_tactic_exception_if_enabled(s, "norm_num tactic failed, one side is not a numeral"); return none_proof_state(); } diff --git a/tests/lean/extra/num_norm1.lean b/tests/lean/extra/num_norm1.lean index 11bcd6c7c..6942f37b2 100644 --- a/tests/lean/extra/num_norm1.lean +++ b/tests/lean/extra/num_norm1.lean @@ -1,4 +1,4 @@ -import algebra.numeral algebra.field +import algebra.numeral algebra.field data.nat open algebra variable {A : Type} @@ -29,6 +29,10 @@ example : (12 : A) = 0 + (2 + 3) + 7 := by norm_num example : (105 : A) = 70 + (33 + 2) := by norm_num example : (45000000000 : A) = 23000000000 + 22000000000 := by norm_num +example : (12 : A) - 4 - (5 + -2) = 5 := by norm_num +example : (12 : A) - 4 - (5 + -2) - 20 = -15 := by norm_num +exit + example : (0 : A) * 0 = 0 := by norm_num example : (0 : A) * 1 = 0 := by norm_num example : (0 : A) * 2 = 0 := by norm_num