Add binary operators between interval<T> and T

This commit is contained in:
Soonho Kong 2013-08-13 00:30:57 -07:00
parent cebe7d415a
commit 26d8bd2c12
2 changed files with 111 additions and 13 deletions

View file

@ -139,14 +139,24 @@ public:
void neg();
friend interval<T> neg(interval<T> o) { o.neg(); return o; }
interval & operator+=(interval const & o);
interval & operator-=(interval const & o);
interval & operator*=(interval const & o);
interval & operator/=(interval const & o);
interval & operator+=(interval<T> const & o);
interval & operator-=(interval<T> const & o);
interval & operator*=(interval<T> const & o);
interval & operator/=(interval<T> const & o);
interval & operator+=(T const & o);
interval & operator-=(T const & o);
interval & operator*=(T const & o);
interval & operator/=(T const & o);
void inv();
friend interval<T> inv(interval<T> o) { o.inv(); return o; }
void fmod(interval<T> y);
void fmod(T y);
friend interval<T> inv(interval<T> o, interval<T> y) { o.fmod(y); return o; }
friend interval<T> inv(interval<T> o, T y) { o.fmod(y); return o; }
void power(unsigned n);
void exp ();
void exp2 ();
@ -193,10 +203,21 @@ public:
friend interval<T> acosh(interval<T> o) { o.acosh(); return o; }
friend interval<T> atanh(interval<T> o) { o.atanh(); return o; }
friend interval operator+(interval a, interval const & b) { return a += b; }
friend interval operator-(interval a, interval const & b) { return a -= b; }
friend interval operator*(interval a, interval const & b) { return a *= b; }
friend interval operator/(interval a, interval const & b) { return a /= b; }
friend interval<T> operator+(interval<T> a, interval<T> const & b) { return a += b; }
friend interval<T> operator-(interval<T> a, interval<T> const & b) { return a -= b; }
friend interval<T> operator*(interval<T> a, interval<T> const & b) { return a *= b; }
friend interval<T> operator/(interval<T> a, interval<T> const & b) { return a /= b; }
friend interval<T> operator+(interval<T> a, T const & b) { return a += b; }
friend interval<T> operator-(interval<T> a, T const & b) { return a -= b; }
friend interval<T> operator*(interval<T> a, T const & b) { return a *= b; }
friend interval<T> operator/(interval<T> a, T const & b) { return a /= b; }
friend interval<T> operator+(T const & a, interval<T> b) { return b += a; }
friend interval<T> operator-(T const & a, interval<T> b) { return b += -a; }
friend interval<T> operator*(T const & a, interval<T> b) { return b *= a; }
friend interval<T> operator/(T const & a, interval<T> b) { b = b / a; return b; }
bool check_invariant() const;

View file

@ -560,6 +560,75 @@ interval<T> & interval<T>::operator/=(interval<T> const & o) {
return *this;
}
template<typename T>
interval<T> & interval<T>::operator+=(T const & o) {
xnumeral_kind new_l_kind, new_u_kind;
round_to_minus_inf();
add(m_lower, new_l_kind, m_lower, lower_kind(), o, XN_NUMERAL);
round_to_plus_inf();
add(m_upper, new_u_kind, m_upper, upper_kind(), o, XN_NUMERAL);
m_lower_inf = new_l_kind == XN_MINUS_INFINITY;
m_upper_inf = new_u_kind == XN_PLUS_INFINITY;
lean_assert(check_invariant());
return *this;
}
template<typename T>
interval<T> & interval<T>::operator-=(T const & o) {
xnumeral_kind new_l_kind, new_u_kind;
round_to_minus_inf();
sub(m_lower, new_l_kind, m_lower, lower_kind(), o, XN_NUMERAL);
round_to_plus_inf();
sub(m_upper, new_u_kind, m_upper, upper_kind(), o, XN_NUMERAL);
m_lower_inf = new_l_kind == XN_MINUS_INFINITY;
m_upper_inf = new_u_kind == XN_PLUS_INFINITY;
lean_assert(check_invariant());
return *this;
}
template<typename T>
interval<T> & interval<T>::operator*=(T const & o) {
xnumeral_kind new_l_kind, new_u_kind;
static thread_local T tmp1;
if (this->is_zero()) {
return *this;
}
if (numeric_traits<T>::is_zero(o)) {
numeric_traits<T>::reset(m_lower);
numeric_traits<T>::reset(m_upper);
m_lower_open = m_upper_open = false;
m_lower_inf = m_upper_inf = false;
return *this;
}
if(numeric_traits<T>::is_pos(o)) {
// [a, b] * c = [a*c, b*c] when c > 0
round_to_minus_inf();
mul(m_lower, new_l_kind, m_lower, lower_kind(), o, XN_NUMERAL);
round_to_plus_inf();
mul(m_upper, new_u_kind, m_upper, upper_kind(), o, XN_NUMERAL);
m_lower_inf = new_l_kind == XN_MINUS_INFINITY;
m_upper_inf = new_u_kind == XN_PLUS_INFINITY;
}
else {
// [a, b] * c = [b*c, a*c] when c < 0
round_to_minus_inf();
mul(tmp1, new_l_kind, m_upper, upper_kind(), o, XN_NUMERAL);
round_to_plus_inf();
mul(m_upper, new_u_kind, m_lower, lower_kind(), o, XN_NUMERAL);
m_lower = tmp1;
m_lower_inf = new_l_kind == XN_MINUS_INFINITY;
m_upper_inf = new_u_kind == XN_PLUS_INFINITY;
}
return *this;
}
template<typename T>
interval<T> & interval<T>::operator/=(T const & o) {
return *this;
}
template<typename T>
void interval<T>::inv() {
// If the interval [l,u] does not contain 0, then 1/[l,u] = [1/u, 1/l]
@ -756,7 +825,15 @@ void interval<T>::display(std::ostream & out) const {
out << (m_upper_open ? ")" : "]");
}
template<typename T> void interval<T>::exp () {
template<typename T> void interval<T>::fmod(interval<T> y) {
}
template<typename T> void interval<T>::fmod(T y) {
}
template<typename T> void interval<T>::exp() {
if(is_empty())
return;
if(m_lower_inf) {
@ -774,7 +851,7 @@ template<typename T> void interval<T>::exp () {
lean_assert(check_invariant());
return;
}
template<typename T> void interval<T>::exp2 () {
template<typename T> void interval<T>::exp2() {
if(is_empty())
return;
if(m_lower_inf) {
@ -810,7 +887,7 @@ template<typename T> void interval<T>::exp10() {
lean_assert(check_invariant());
return;
}
template<typename T> void interval<T>::log () {
template<typename T> void interval<T>::log() {
if(is_empty())
return;
if(is_N0()) {
@ -833,7 +910,7 @@ template<typename T> void interval<T>::log () {
lean_assert(check_invariant());
return;
}
template<typename T> void interval<T>::log2 () {
template<typename T> void interval<T>::log2() {
if(is_empty())
return;
if(is_N0()) {
@ -879,7 +956,7 @@ template<typename T> void interval<T>::log10() {
lean_assert(check_invariant());
return;
}
template<typename T> void interval<T>::sin () {
template<typename T> void interval<T>::sin() {
*this -= numeric_traits<T>::pi_half_lower();
cos();
}