From 958add9ef899204ef3502f2b62466333d5866422 Mon Sep 17 00:00:00 2001 From: Rob Lewis Date: Mon, 19 Oct 2015 19:03:32 -0400 Subject: [PATCH] feat(library/norm_num): fix numeral normalization to work on new numeral structure; add support for multiplication --- library/algebra/numeral.lean | 138 +++++++- src/library/init_module.cpp | 3 + src/library/norm_num.cpp | 441 ++++++++++++++++++++++++- src/library/norm_num.h | 25 +- src/library/tactic/norm_num_tactic.cpp | 31 +- tests/lean/extra/num_norm1.lean | 52 ++- 6 files changed, 663 insertions(+), 27 deletions(-) diff --git a/library/algebra/numeral.lean b/library/algebra/numeral.lean index 72deb63a1..e5d7e6e64 100644 --- a/library/algebra/numeral.lean +++ b/library/algebra/numeral.lean @@ -1,8 +1,138 @@ -import algebra.ring -- has_add, has_one, ... will be moved to init in the future +/- +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author: Robert Y. Lewis +-/ + +import algebra.ring open algebra variable {A : Type} --- variables [s : ring A] --- set_option pp.all true --- check bit1 (bit0 (one : A)) +definition add1 [s : has_add A] [s' : has_one A] (a : A) : A := add a one + +theorem add_comm_four [s : add_comm_semigroup A] (a b : A) : a + a + (b + b) = (a + b) + (a + b) := + by rewrite [-add.assoc at {1}, add.comm, {a + b}add.comm at {1}, *add.assoc] + +theorem add_comm_middle [s : add_comm_semigroup A] (a b c : A) : a + b + c = a + c + b := + by rewrite [add.assoc, add.comm b, -add.assoc] + +theorem bit0_add_bit0 [s : add_comm_semigroup A] (a b : A) : bit0 a + bit0 b = bit0 (a + b) := + !add_comm_four + +theorem bit0_add_bit0_helper [s : add_comm_semigroup A] (a b t : A) (H : a + b = t) : + bit0 a + bit0 b = bit0 t := + by rewrite -H; apply bit0_add_bit0 + +theorem bit1_add_bit0 [s : add_comm_semigroup A] [s' : has_one A] (a b : A) : + bit1 a + bit0 b = bit1 (a + b) := + begin + rewrite [↑bit0, ↑bit1, add_comm_middle], congruence, apply add_comm_four + end + +theorem bit1_add_bit0_helper [s : add_comm_semigroup A] [s' : has_one A] (a b t : A) (H : a + b = t) : + bit1 a + bit0 b = bit1 t := + by rewrite -H; apply bit1_add_bit0 + +theorem bit0_add_bit1 [s : add_comm_semigroup A] [s' : has_one A] (a b : A) : + bit0 a + bit1 b = bit1 (a + b) := + by rewrite [{bit0 a + _}add.comm, {a + _}add.comm]; apply bit1_add_bit0 + +theorem bit0_add_bit1_helper [s : add_comm_semigroup A] [s' : has_one A] (a b t : A) (H : a + b = t) : + bit0 a + bit1 b = bit1 t := + by rewrite -H; apply bit0_add_bit1 + +theorem bit1_add_bit1 [s : add_comm_semigroup A] [s' : has_one A] (a b : A) : bit1 a + bit1 b = bit0 (add1 (a + b)) := + begin + rewrite ↑[bit0, bit1, add1], + apply sorry + end + +theorem bit1_add_bit1_helper [s : add_comm_semigroup A] [s' : has_one A] (a b t s: A) + (H : (a + b) = t) (H2 : add1 t = s) : bit1 a + bit1 b = bit0 s := + begin rewrite [-H2, -H], apply bit1_add_bit1 end + +theorem bin_add_zero [s : add_monoid A] (a : A) : a + zero = a := !add_zero + +theorem bin_zero_add [s : add_monoid A] (a : A) : zero + a = a := !zero_add + +theorem one_add_bit0 [s : add_comm_semigroup A] [s' : has_one A] (a : A) : one + bit0 a = bit1 a := + begin rewrite ↑[bit0, bit1], rewrite add.comm end + +theorem bit0_add_one [s : has_add A] [s' : has_one A] (a : A) : bit0 a + one = bit1 a := rfl + +theorem bit1_add_one [s : has_add A] [s' : has_one A] (a : A) : bit1 a + one = add1 (bit1 a) := rfl + +theorem bit1_add_one_helper [s : has_add A] [s' : has_one A] (a t : A) (H : add1 (bit1 a) = t) : + bit1 a + one = t := + by rewrite -H + +theorem one_add_bit1 [s : add_comm_semigroup A] [s' : has_one A] (a : A) : + one + bit1 a = add1 (bit1 a) := !add.comm + +theorem one_add_bit1_helper [s : add_comm_semigroup A] [s' : has_one A] (a t : A) + (H : add1 (bit1 a) = t) : one + bit1 a = t := + by rewrite -H; apply one_add_bit1 + +theorem add1_bit0 [s : has_add A] [s' : has_one A] (a : A) : add1 (bit0 a) = bit1 a := + rfl + +theorem add1_bit1 [s : add_comm_semigroup A] [s' : has_one A] (a : A) : + add1 (bit1 a) = bit0 (add1 a) := + begin + rewrite ↑[add1, bit1, bit0], + rewrite [add.assoc, add_comm_four] + end + +theorem add1_bit1_helper [s : add_comm_semigroup A] [s' : has_one A] (a t : A) (H : add1 a = t) : + add1 (bit1 a) = bit0 t := + by rewrite -H; apply add1_bit1 + +theorem add1_one [s : has_add A] [s' : has_one A] : add1 (one : A) = bit0 one := + rfl + +theorem add1_zero [s : add_monoid A] [s' : has_one A] : add1 (zero : A) = one := + begin + rewrite [↑add1, zero_add] + end + +theorem one_add_one [s : has_add A] [s' : has_one A] : (one : A) + one = bit0 one := + rfl + +theorem subst_into_sum [s : has_add A] (l r tl tr t : A) (prl : l = tl) (prr : r = tr) (prt : tl + tr = t) : + l + r = t := + by rewrite [prl, prr, prt] + +-- multiplication + +theorem mul_zero [s : mul_zero_class A] (a : A) : a * zero = zero := + by rewrite [↑zero, mul_zero] + +theorem zero_mul [s : mul_zero_class A] (a : A) : zero * a = zero := + by rewrite [↑zero, zero_mul] + +theorem mul_one [s : monoid A] (a : A) : a * one = a := + by rewrite [↑one, mul_one] + +theorem mul_bit0 [s : distrib A] (a b : A) : a * (bit0 b) = bit0 (a * b) := + by rewrite [↑bit0, left_distrib] + +theorem mul_bit0_helper [s : distrib A] (a b t : A) (H : a * b = t) : a * (bit0 b) = bit0 t := + by rewrite -H; apply mul_bit0 + +theorem mul_bit1 [s : semiring A] (a b : A) : a * (bit1 b) = bit0 (a * b) + a := + by rewrite [↑bit1, ↑bit0, +left_distrib, ↑one, mul_one] + +theorem mul_bit1_helper [s : semiring A] (a b s t : A) (Hs : a * b = s) (Ht : bit0 s + a = t) : + a * (bit1 b) = t := + begin rewrite [-Ht, -Hs, mul_bit1] end + +theorem subst_into_prod [s : has_mul A] (l r tl tr t : A) (prl : l = tl) (prr : r = tr) + (prt : tl * tr = t) : + l * r = t := + by rewrite [prl, prr, prt] + +theorem mk_cong (op : A → A) (a b : A) (H : a = b) : op a = op b := + by congruence; exact H + +theorem mk_eq (a : A) : a = a := rfl diff --git a/src/library/init_module.cpp b/src/library/init_module.cpp index 1162661d9..73971cf87 100644 --- a/src/library/init_module.cpp +++ b/src/library/init_module.cpp @@ -43,6 +43,7 @@ Author: Leonardo de Moura #include "library/aux_recursors.h" #include "library/decl_stats.h" #include "library/meng_paulson.h" +#include "library/norm_num.h" namespace lean { void initialize_library_module() { @@ -85,6 +86,7 @@ void initialize_library_module() { initialize_aux_recursors(); initialize_decl_stats(); initialize_meng_paulson(); + initialize_norm_num(); } void finalize_library_module() { @@ -127,5 +129,6 @@ void finalize_library_module() { finalize_print(); finalize_fingerprint(); finalize_constants(); + finalize_norm_num(); } } diff --git a/src/library/norm_num.cpp b/src/library/norm_num.cpp index f203a02ab..1088b0b2b 100644 --- a/src/library/norm_num.cpp +++ b/src/library/norm_num.cpp @@ -1,18 +1,447 @@ /* Copyright (c) 2015 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. - -Author: Rob Lewis +Author: Robert Y. Lewis */ + #include "library/norm_num.h" +#include "library/constants.cpp" namespace lean { -bool norm_num_context::is_numeral(expr const &) const { - // TODO(Rob) + static name * g_add = nullptr, + * g_add1 = nullptr, + * g_mul = nullptr, + * g_sub = nullptr, + * g_bit0_add_bit0 = nullptr, + * g_bit1_add_bit0 = nullptr, + * g_bit0_add_bit1 = nullptr, + * g_bit1_add_bit1 = nullptr, + * g_bin_add_0 = nullptr, + * g_bin_0_add = nullptr, + * g_bin_add_1 = nullptr, + * g_1_add_bit0 = nullptr, + * g_bit0_add_1 = nullptr, + * g_bit1_add_1 = nullptr, + * g_1_add_bit1 = nullptr, + * g_one_add_one = nullptr, + * g_add1_bit0 = nullptr, + * g_add1_bit1 = nullptr, + * g_add1_zero = nullptr, + * g_add1_one = nullptr, + * g_subst_sum = nullptr, + * g_subst_prod = nullptr, + * g_mk_cong = nullptr, + * g_mk_eq = nullptr, + * g_mul_zero = nullptr, + * g_zero_mul = nullptr, + * g_mul_one = nullptr, + * g_mul_bit0 = nullptr, + * g_mul_bit1 = nullptr, + * g_has_mul = nullptr, + * g_add_monoid = nullptr, + * g_monoid = nullptr, + * g_add_comm = nullptr, + * g_mul_zero_class= nullptr, + * g_distrib = nullptr, + * g_semiring = nullptr; + + +static bool is_numeral_aux(expr const & e, bool is_first) { + buffer args; + expr const & f = get_app_args(e, args); + if (!is_constant(f)) { + return false; + } + if (const_name(f) == *g_one) { + return args.size() == 2; + } else if (const_name(f) == *g_zero) { + return is_first && args.size() == 2; + } else if (const_name(f) == *g_bit1 || const_name(f) == *g_bit0) { + return args.size() == 3 && is_numeral_aux(args[2], false); + } return false; } -pair norm_num_context::mk_norm(expr const &) { +bool norm_num_context::is_numeral(expr const & e) const { + return is_numeral_aux(e, true); +} + +/* +Takes e : instance A, and tries to synthesize has_add A. +*/ +expr norm_num_context::mk_has_add(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_has_add, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_add instance"); + } +} + +expr norm_num_context::mk_has_mul(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_has_mul, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_mul instance"); + } +} + +expr norm_num_context::mk_has_one(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_has_one, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_one instance"); + } +} + +expr norm_num_context::mk_add_monoid(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_add_monoid, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize add_monoid instance"); + } +} + +expr norm_num_context::mk_monoid(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_monoid, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize monoid instance"); + } +} + +expr norm_num_context::mk_add_comm(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_add_comm, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize add_comm_semigroup instance"); + } +} + +expr norm_num_context::mk_has_distrib(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_distrib, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize has_distrib instance"); + } +} + +expr norm_num_context::mk_mul_zero_class(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_mul_zero_class, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize mul_zero instance"); + } +} + +expr norm_num_context::mk_semiring(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr t = mk_app(mk_constant(*g_semiring, const_levels(f)), args[0]); + optional inst = mk_class_instance(m_env, m_ctx, t); + if (inst) { + return *inst; + } else { + throw exception("failed to synthesize semiring instance"); + } +} + +expr norm_num_context::mk_const(name const & n) { + return mk_constant(n, m_lvls); +} + +expr norm_num_context::mk_cong(expr const & op, expr const & type, expr const & a, expr const & b, expr const & eq) { + return mk_app({mk_const(*g_mk_cong), type, op, a, b, eq}); +} + +pair norm_num_context::mk_norm(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + if (!is_constant(f)) { + throw exception("cannot take norm of nonconstant"); + } + m_lvls = const_levels(f); + if (const_name(f) == *g_add && args.size() == 4) { + auto lhs_p = mk_norm(args[2]); + auto rhs_p = mk_norm(args[3]); + auto add_p = mk_norm_add(lhs_p.first, rhs_p.first); + expr prf = mk_app({mk_const(*g_subst_sum), args[0], mk_has_add(args[1]), args[2], args[3], + lhs_p.first, rhs_p.first, add_p.first, lhs_p.second, rhs_p.second, add_p.second}); + return pair(add_p.first, prf); + } else if (const_name(f) == *g_mul && args.size() == 4) { + auto lhs_p = mk_norm(args[2]); + auto rhs_p = mk_norm(args[3]); + auto mul_p = mk_norm_mul(lhs_p.first, rhs_p.first); + expr prf = mk_app({mk_const(*g_subst_prod), args[0], mk_has_mul(args[1]), args[2], args[3], + lhs_p.first, rhs_p.first, mul_p.first, lhs_p.second, rhs_p.second, mul_p.second}); + return pair(mul_p.first, prf); + } else if (const_name(f) == *g_bit0 && args.size() == 3) { + auto arg = mk_norm(args[2]); + expr rv = mk_app({f, args[0], args[1], arg.first}); + expr prf = mk_cong(mk_app({f, args[0], args[1]}), args[0], args[2], arg.first, arg.second); + return pair(rv, prf); + } else if (const_name(f) == *g_bit1 && args.size() == 4) { + auto arg = mk_norm(args[3]); + expr rv = mk_app({f, args[0], args[1], args[2], arg.first}); + expr prf = mk_cong(mk_app({f, args[0], args[1], args[2]}), args[0], args[3], arg.first, arg.second); + return pair(rv, prf); + } else if ((const_name(f) == *g_zero || const_name(f) == *g_one) && args.size() == 2) { + return pair(e, mk_app({mk_const(*g_mk_eq), args[0], e})); + } else { + std::cout << "error with name " << const_name(f) << " and size " << args.size() << ".\n"; + throw exception("mk_norm found unrecognized combo "); + } + // TODO(Rob): cases for sub, div +} + +// returns such that p is a proof that lhs + rhs = t. +pair norm_num_context::mk_norm_add(expr const & lhs, expr const & rhs) { + buffer args_lhs; + buffer args_rhs; + expr lhs_head = get_app_args (lhs, args_lhs); + expr rhs_head = get_app_args (rhs, args_rhs); + if (!is_constant(lhs_head) || !is_constant(rhs_head)) { + throw exception("cannot take norm_add of nonconstant"); + } + auto type = args_lhs[0]; + auto typec = args_lhs[1]; + expr rv; + expr prf; + if (is_bit0(lhs) && is_bit0(rhs)) { // typec is has_add + auto p = mk_norm_add(args_lhs[2], args_rhs[2]); + rv = mk_app(lhs_head, type, typec, p.first); + prf = mk_app({mk_const(*g_bit0_add_bit0), type, mk_add_comm(typec), args_lhs[2], args_rhs[2], p.first, p.second}); + } else if (is_bit0(lhs) && is_bit1(rhs)) { + auto p = mk_norm_add(args_lhs[2], args_rhs[3]); + rv = mk_app({rhs_head, type, args_rhs[1], args_rhs[2], p.first}); + prf = mk_app({mk_const(*g_bit0_add_bit1), type, mk_add_comm(typec), args_rhs[1], args_lhs[2], args_rhs[3], p.first, p.second}); + } else if (is_bit0(lhs) && is_one(rhs)) { + rv = mk_app({mk_const(*g_bit1), type, args_rhs[1], args_lhs[1], args_lhs[2]}); + prf = mk_app({mk_const(*g_bit0_add_1), type, typec, args_rhs[1], args_lhs[2]}); + } else if (is_bit1(lhs) && is_bit0(rhs)) { // typec is has_one + auto p = mk_norm_add(args_lhs[3], args_rhs[2]); + rv = mk_app(lhs_head, type, typec, args_lhs[2], p.first); + prf = mk_app({mk_const(*g_bit1_add_bit0), type, mk_add_comm(typec), typec, args_lhs[3], args_rhs[2], p.first, p.second}); + } else if (is_bit1(lhs) && is_bit1(rhs)) { // typec is has_one + auto add_ts = mk_norm_add(args_lhs[3], args_rhs[3]); + expr add1 = mk_app({mk_const(*g_add1), type, args_lhs[2], typec, add_ts.first}); + auto p = mk_norm_add1(add1); + rv = mk_app({mk_const(*g_bit0), type, args_lhs[2], p.first}); + prf = mk_app({mk_const(*g_bit1_add_bit1), type, mk_add_comm(typec), typec, args_lhs[3], args_rhs[3], add_ts.first, p.first, add_ts.second, p.second}); + } else if (is_bit1(lhs) && is_one(rhs)) { // typec is has_one + expr add1 = mk_app({mk_const(*g_add1), type, args_lhs[2], typec, lhs}); + auto p = mk_norm_add1(add1); + rv = p.first; + prf = mk_app({mk_const(*g_bit1_add_1), type, args_lhs[2], typec, args_lhs[3], p.first, p.second}); + } else if (is_one(lhs) && is_bit0(rhs)) { // typec is has_one + rv = mk_app({mk_const(*g_bit1), type, typec, args_rhs[1], args_rhs[2]}); + prf = mk_app({mk_const(*g_1_add_bit0), type, mk_add_comm(typec), typec, args_rhs[2]}); + } else if (is_one(lhs) && is_bit1(rhs)) { // typec is has_one + expr add1 = mk_app({mk_const(*g_add1), type, args_rhs[2], args_rhs[1], rhs}); + auto p = mk_norm_add1(add1); + rv = p.first; + prf = mk_app({mk_const(*g_1_add_bit1), type, mk_add_comm(typec), typec, args_rhs[3], p.first, p.second}); + } else if (is_one(lhs) && is_one(rhs)) { + rv = mk_app({mk_const(*g_bit0), type, mk_has_add(typec), lhs}); + prf = mk_app({mk_const(*g_one_add_one), type, mk_has_add(typec), typec}); + } else if (is_zero(lhs)) { + rv = rhs; + prf = mk_app({mk_const(*g_bin_0_add), type, mk_add_monoid(typec), rhs}); + } else if (is_zero(rhs)) { + rv = lhs; + prf = mk_app({mk_const(*g_bin_add_0), type, mk_add_monoid(typec), lhs}); + } + else { + std::cout << "\n\n bad args: " << lhs_head << ", " << rhs_head << "\n"; + throw exception("mk_norm_add got malformed args"); + } + return pair(rv, prf); +} + +pair norm_num_context::mk_norm_add1(expr const & e) { + buffer args; + expr f = get_app_args(e, args); + expr p = args[3]; + buffer ne_args; + expr ne = get_app_args(p, ne_args); + expr rv; + expr prf; + // args[1] = has_add, args[2] = has_one + if (is_bit0(p)) { + auto has_one = args[2]; + rv = mk_app({mk_const(*g_bit1), args[0], args[2], args[1], ne_args[2]}); + prf = mk_app({mk_const(*g_add1_bit0), args[0], args[1], args[2], ne_args[2]}); + } else if (is_bit1(p)) { // ne_args : has_one, has_add + auto np = mk_norm_add1(mk_app({mk_const(*g_add1), args[0], args[1], args[2], ne_args[3]})); + rv = mk_app({mk_const(*g_bit0), args[0], args[1], np.first}); + prf = mk_app({mk_const(*g_add1_bit1), args[0], mk_add_comm(args[1]), args[2], ne_args[3], np.first, np.second}); + } else if (is_zero(p)) { + rv = mk_app({mk_const(*g_one), args[0], args[2]}); + prf = mk_app({mk_const(*g_add1_zero), args[0], mk_add_monoid(args[1]), args[2]}); + } else if (is_one(p)) { + rv = mk_app({mk_const(*g_bit0), args[0], args[1], mk_app({mk_const(*g_one), args[0], args[2]})}); + prf = mk_app({mk_const(*g_add1_one), args[0], args[1], args[2]}); + } else { + std::cout << "malformed add1: " << ne << "\n"; + throw exception("malformed add1"); + } + return pair(rv, prf); +} + +pair norm_num_context::mk_norm_mul(expr const & lhs, expr const & rhs) { + buffer args_lhs; + buffer args_rhs; + expr lhs_head = get_app_args (lhs, args_lhs); + expr rhs_head = get_app_args (rhs, args_rhs); + if (!is_constant(lhs_head) || !is_constant(rhs_head)) { + throw exception("cannot take norm_add of nonconstant"); + } + auto type = args_rhs[0]; + auto typec = args_rhs[1]; + expr rv; + expr prf; + if (is_zero(rhs)) { + rv = rhs; + prf = mk_app({mk_const(*g_mul_zero), type, mk_mul_zero_class(typec), lhs}); + } else if (is_zero(lhs)) { + rv = lhs; + prf = mk_app({mk_const(*g_zero_mul), type, mk_mul_zero_class(typec), rhs}); + } else if (is_one(rhs)) { + rv = lhs; + prf = mk_app({mk_const(*g_mul_one), type, mk_monoid(typec), lhs}); + } else if (is_bit0(rhs)) { + auto mtp = mk_norm_mul(lhs, args_rhs[2]); + rv = mk_app({rhs_head, type, typec, mtp.first}); + prf = mk_app({mk_const(*g_mul_bit0), type, mk_has_distrib(typec), lhs, args_rhs[2], mtp.first, mtp.second}); + } else if (is_bit1(rhs)) { + std::cout << "is_bit1 " << rhs << "\n"; + auto mtp = mk_norm_mul(lhs, args_rhs[3]); + auto atp = mk_norm_add(mk_app({mk_const(*g_bit0), type, args_rhs[2], mtp.first}), lhs); + rv = atp.first; + prf = mk_app({mk_const(*g_mul_bit1), type, mk_semiring(typec), lhs, args_rhs[3], + mtp.first, atp.first, mtp.second, atp.second}); + } else { + std::cout << "bad args to mk_norm_mul: " << rhs << "\n"; + throw exception("mk_norm_mul got malformed args"); + } + return pair(rv, prf); +} + +pair norm_num_context::mk_norm_div(expr const &, expr const &) { // TODO(Rob) - throw exception("not implemented yet - norm_num`"); + throw exception("not implemented yet -- mk_norm_div"); +} + +pair norm_num_context::mk_norm_sub(expr const &, expr const &) { + // TODO(Rob) + throw exception("not implemented yet -- mk_norm_sub"); +} + +void initialize_norm_num() { + g_add = new name("add"); + g_add1 = new name("add1"); + g_mul = new name("mul"); + g_sub = new name("sub"); + g_bit0_add_bit0 = new name("bit0_add_bit0_helper"); + g_bit1_add_bit0 = new name("bit1_add_bit0_helper"); + g_bit0_add_bit1 = new name("bit0_add_bit1_helper"); + g_bit1_add_bit1 = new name("bit1_add_bit1_helper"); + g_bin_add_0 = new name("bin_add_zero"); + g_bin_0_add = new name("bin_zero_add"); + g_bin_add_1 = new name("bin_add_one"); + g_1_add_bit0 = new name("one_add_bit0"); + g_bit0_add_1 = new name("bit0_add_one"); + g_bit1_add_1 = new name("bit1_add_one_helper"); + g_1_add_bit1 = new name("one_add_bit1_helper"); + g_one_add_one = new name("one_add_one"); + g_add1_bit0 = new name("add1_bit0"); + g_add1_bit1 = new name("add1_bit1_helper"); + g_add1_zero = new name("add1_zero"); + g_add1_one = new name("add1_one"); + g_subst_sum = new name("subst_into_sum"); + g_subst_prod = new name("subst_into_prod"); + g_mk_cong = new name("mk_cong"); + g_mk_eq = new name("mk_eq"); + g_zero_mul = new name("zero_mul"); + g_mul_zero = new name("mul_zero"); + g_mul_one = new name("mul_one"); + g_mul_bit0 = new name("mul_bit0_helper"); + g_mul_bit1 = new name("mul_bit1_helper"); + g_has_mul = new name("has_mul"); + g_add_monoid = new name("algebra", "add_monoid"); + g_monoid = new name("algebra", "monoid"); + g_add_comm = new name("algebra", "add_comm_semigroup"); + g_mul_zero_class = new name("algebra", "mul_zero_class"); + g_distrib = new name("algebra", "distrib"); + g_semiring = new name("algebra", "semiring"); +} + +void finalize_norm_num() { + delete g_add; + delete g_add1; + delete g_mul; + delete g_sub; + delete g_bit0_add_bit0; + delete g_bit1_add_bit0; + delete g_bit0_add_bit1; + delete g_bit1_add_bit1; + delete g_bin_add_0; + delete g_bin_0_add; + delete g_bin_add_1; + delete g_1_add_bit0; + delete g_bit0_add_1; + delete g_bit1_add_1; + delete g_1_add_bit1; + delete g_one_add_one; + delete g_add1_bit0; + delete g_add1_bit1; + delete g_add1_zero; + delete g_add1_one; + delete g_subst_sum; + delete g_subst_prod; + delete g_mk_cong; + delete g_mk_eq; + delete g_mul_zero; + delete g_zero_mul; + delete g_mul_one; + delete g_mul_bit0; + delete g_mul_bit1; + delete g_has_mul; + delete g_add_monoid; + delete g_monoid; + delete g_add_comm; + delete g_mul_zero_class; + delete g_distrib; + delete g_semiring; } } diff --git a/src/library/norm_num.h b/src/library/norm_num.h index b6ea28762..491abc80a 100644 --- a/src/library/norm_num.h +++ b/src/library/norm_num.h @@ -1,17 +1,36 @@ /* Copyright (c) 2015 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. - -Author: Rob Lewis +Author: Robert Y. Lewis */ #pragma once #include "kernel/environment.h" #include "library/local_context.h" +#include "library/num.h" +#include "library/class_instance_synth.h" namespace lean { class norm_num_context { environment m_env; local_context m_ctx; + levels m_lvls; + pair mk_norm_add(expr const &, expr const &); + pair mk_norm_add1(expr const &); + pair mk_norm_mul(expr const &, expr const &); + pair mk_norm_div(expr const &, expr const &); + pair mk_norm_sub(expr const &, expr const &); + expr mk_const(name const & n); + expr mk_cong(expr const &, expr const &, expr const &, expr const &, expr const &); + expr mk_has_add(expr const &); + expr mk_has_mul(expr const &); + expr mk_has_one(expr const &); + expr mk_add_monoid(expr const &); + expr mk_monoid(expr const &); + expr mk_has_distrib(expr const &); + expr mk_add_comm(expr const &); + expr mk_mul_zero_class(expr const &); + expr mk_semiring(expr const &); + public: norm_num_context(environment const & env, local_context const & ctx):m_env(env), m_ctx(ctx) {} @@ -26,4 +45,6 @@ inline bool is_numeral(environment const & env, expr const & e) { inline pair mk_norm_num(environment const & env, local_context const & ctx, expr const & e) { return norm_num_context(env, ctx).mk_norm(e); } +void initialize_norm_num(); +void finalize_norm_num(); } diff --git a/src/library/tactic/norm_num_tactic.cpp b/src/library/tactic/norm_num_tactic.cpp index 591d6dfb6..bffcc0947 100644 --- a/src/library/tactic/norm_num_tactic.cpp +++ b/src/library/tactic/norm_num_tactic.cpp @@ -1,10 +1,12 @@ /* Copyright (c) 2015 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. - -Author: Rob Lewis +Author: Robert Y. Lewis */ +#include "kernel/type_checker.h" #include "library/util.h" +#include "library/reducible.h" +#include "library/normalize.h" #include "library/norm_num.h" #include "library/tactic/expr_to_tactic.h" @@ -23,20 +25,31 @@ tactic norm_num_tactic() { throw_tactic_exception_if_enabled(s, "norm_num tactic failed, conclusion is not an equality"); return none_proof_state(); } + type_checker_ptr rtc = mk_type_checker(env, UnfoldReducible); + lhs = normalize(*rtc, lhs); + rhs = normalize(*rtc, rhs); + buffer hyps; g.get_hyps(hyps); local_context ctx(to_list(hyps)); try { + //bool bs = is_numeral(env, lhs); pair p = mk_norm_num(env, ctx, lhs); expr new_lhs = p.first; - expr new_pr = p.second; - if (new_lhs != rhs) { - throw_tactic_exception_if_enabled(s, "norm_num tactic failed, lhs normal form doesn't match rhs"); - return none_proof_state(); - } + expr new_lhs_pr = p.second; + pair p2 = mk_norm_num(env, ctx, rhs); + expr new_rhs = p2.first; + expr new_rhs_pr = p2.second; + //if (new_lhs != new_rhs) { + // std::cout << "lhs: " << new_lhs << ", rhs: " << new_rhs << "\n"; + // throw_tactic_exception_if_enabled(s, "norm_num tactic failed, lhs normal form doesn't match rhs"); + // return none_proof_state(); + //} + type_checker tc(env); + expr g_prf = mk_trans(tc, new_lhs_pr, mk_symm(tc, new_rhs_pr)); substitution new_subst = s.get_subst(); - assign(new_subst, g, new_pr); - return some_proof_state(proof_state(s, tail(gs), new_subst)); + assign(new_subst, g, g_prf); + return some_proof_state(proof_state(s, tail(gs), new_subst)); } catch (exception & ex) { throw_tactic_exception_if_enabled(s, ex.what()); return none_proof_state(); diff --git a/tests/lean/extra/num_norm1.lean b/tests/lean/extra/num_norm1.lean index b4f0838c1..f4c6c9988 100644 --- a/tests/lean/extra/num_norm1.lean +++ b/tests/lean/extra/num_norm1.lean @@ -1,11 +1,51 @@ -import algebra.numeral algebra.ring +import algebra.numeral algebra.field open algebra variable {A : Type} -variable [s : ring A] +variable [s : comm_ring A] include s -example : add (bit0 (one:A)) one = bit1 one := -begin - norm_num -end +example : (1 : A) = 0 + 1 := by norm_num +example : (1 : A) = 1 + 0 := by norm_num +example : (2 : A) = 1 + 1 := by norm_num +example : (2 : A) = 0 + 2 := by norm_num +example : (3 : A) = 1 + 2 := by norm_num +example : (3 : A) = 2 + 1 := by norm_num +example : (4 : A) = 3 + 1 := by norm_num +example : (4 : A) = 2 + 2 := by norm_num +example : (5 : A) = 4 + 1 := by norm_num +example : (5 : A) = 3 + 2 := by norm_num +example : (5 : A) = 2 + 3 := by norm_num +example : (6 : A) = 0 + 6 := by norm_num +example : (6 : A) = 3 + 3 := by norm_num +example : (6 : A) = 4 + 2 := by norm_num +example : (6 : A) = 5 + 1 := by norm_num +example : (7 : A) = 4 + 3 := by norm_num +example : (7 : A) = 1 + 6 := by norm_num +example : (7 : A) = 6 + 1 := by norm_num +example : 33 = 5 + (28 : A) := by norm_num + + +example : (12 : A) = 0 + (2 + 3) + 7 := by norm_num +example : (105 : A) = 70 + (33 + 2) := by norm_num +example : (45000000000 : A) = 23000000000 + 22000000000 := by norm_num + +example : (0 : A) * 0 = 0 := by norm_num +example : (0 : A) * 1 = 0 := by norm_num +example : (0 : A) * 2 = 0 := by norm_num +example : (2 : A) * 0 = 0 := by norm_num +example : (1 : A) * 0 = 0 := by norm_num +example : (1 : A) * 1 = 1 := by norm_num +example : (2 : A) * 1 = 2 := by norm_num +example : (1 : A) * 2 = 2 := by norm_num +example : (2 : A) * 2 = 4 := by norm_num +example : (3 : A) * 2 = 6 := by norm_num +example : (2 : A) * 3 = 6 := by norm_num +example : (4 : A) * 1 = 4 := by norm_num +example : (1 : A) * 4 = 4 := by norm_num +example : (3 : A) * 3 = 9 := by norm_num +example : (3 : A) * 4 = 12 := by norm_num +example : (4 : A) * 4 = 16 := by norm_num +example : (11 : A) * 2 = 22 := by norm_num +example : (15 : A) * 6 = 90 := by norm_num +example : (123456 : A) * 123456 = 15241383936 := by norm_num