(index<- )        ./libnum/complex.rs

    git branch:    * master           5200215 auto merge of #14035 : alexcrichton/rust/experimental, r=huonw
    modified:    Fri Apr 25 22:40:04 2014
   1  // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
   2  // file at the top-level directory of this distribution and at
   3  // http://rust-lang.org/COPYRIGHT.
   4  //
   5  // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
   6  // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
   7  // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
   8  // option. This file may not be copied, modified, or distributed
   9  // except according to those terms.
  10  
  11  
  12  //! Complex numbers.
  13  
  14  use std::fmt;
  15  use std::num::{Zero,One,ToStrRadix};
  16  
  17  // FIXME #1284: handle complex NaN & infinity etc. This
  18  // probably doesn't map to C's _Complex correctly.
  19  
  20  // FIXME #5734:: Need generic sin/cos for .to/from_polar().
  21  // FIXME #5735: Need generic sqrt to implement .norm().
  22  
  23  
  24  /// A complex number in Cartesian form.
  25  #[deriving(Eq,Clone)]
  26  pub struct Cmplx<T> {
  27      /// Real portion of the complex number
  28      pub re: T,
  29      /// Imaginary portion of the complex number
  30      pub im: T
  31  }
  32  
  33  pub type Complex32 = Cmplx<f32>;
  34  pub type Complex64 = Cmplx<f64>;
  35  
  36  impl<T: Clone + Num> Cmplx<T> {
  37      /// Create a new Cmplx
  38      #[inline]
  39      pub fn new(reT, imT) -> Cmplx<T> {
  40          Cmplx { re: re, im: im }
  41      }
  42  
  43      /**
  44      Returns the square of the norm (since `T` doesn't necessarily
  45      have a sqrt function), i.e. `re^2 + im^2`.
  46      */
  47      #[inline]
  48      pub fn norm_sqr(&self) -> T {
  49          self.re * self.re + self.im * self.im
  50      }
  51  
  52  
  53      /// Returns the complex conjugate. i.e. `re - i im`
  54      #[inline]
  55      pub fn conj(&self) -> Cmplx<T> {
  56          Cmplx::new(self.re.clone(), -self.im)
  57      }
  58  
  59  
  60      /// Multiplies `self` by the scalar `t`.
  61      #[inline]
  62      pub fn scale(&self, tT) -> Cmplx<T> {
  63          Cmplx::new(self.re * t, self.im * t)
  64      }
  65  
  66      /// Divides `self` by the scalar `t`.
  67      #[inline]
  68      pub fn unscale(&self, tT) -> Cmplx<T> {
  69          Cmplx::new(self.re / t, self.im / t)
  70      }
  71  
  72      /// Returns `1/self`
  73      #[inline]
  74      pub fn inv(&self) -> Cmplx<T> {
  75          let norm_sqr = self.norm_sqr();
  76          Cmplx::new(self.re / norm_sqr,
  77                      -self.im / norm_sqr)
  78      }
  79  }
  80  
  81  impl<T: Clone + Float> Cmplx<T> {
  82      /// Calculate |self|
  83      #[inline]
  84      pub fn norm(&self) -> T {
  85          self.re.hypot(self.im)
  86      }
  87  }
  88  
  89  impl<T: Clone + Float> Cmplx<T> {
  90      /// Calculate the principal Arg of self.
  91      #[inline]
  92      pub fn arg(&self) -> T {
  93          self.im.atan2(self.re)
  94      }
  95      /// Convert to polar form (r, theta), such that `self = r * exp(i
  96      /// * theta)`
  97      #[inline]
  98      pub fn to_polar(&self) -> (T, T) {
  99          (self.norm(), self.arg())
 100      }
 101      /// Convert a polar representation into a complex number.
 102      #[inline]
 103      pub fn from_polar(r&T, theta&T) -> Cmplx<T> {
 104          Cmplx::new(*r * theta.cos(), *r * theta.sin())
 105      }
 106  }
 107  
 108  /* arithmetic */
 109  // (a + i b) + (c + i d) == (a + c) + i (b + d)
 110  impl<T: Clone + Num> Add<Cmplx<T>, Cmplx<T>> for Cmplx<T> {
 111      #[inline]
 112      fn add(&self, other&Cmplx<T>) -> Cmplx<T> {
 113          Cmplx::new(self.re + other.re, self.im + other.im)
 114      }
 115  }
 116  // (a + i b) - (c + i d) == (a - c) + i (b - d)
 117  impl<T: Clone + Num> Sub<Cmplx<T>, Cmplx<T>> for Cmplx<T> {
 118      #[inline]
 119      fn sub(&self, other&Cmplx<T>) -> Cmplx<T> {
 120          Cmplx::new(self.re - other.re, self.im - other.im)
 121      }
 122  }
 123  // (a + i b) * (c + i d) == (a*c - b*d) + i (a*d + b*c)
 124  impl<T: Clone + Num> Mul<Cmplx<T>, Cmplx<T>> for Cmplx<T> {
 125      #[inline]
 126      fn mul(&self, other&Cmplx<T>) -> Cmplx<T> {
 127          Cmplx::new(self.re*other.re - self.im*other.im,
 128                     self.re*other.im + self.im*other.re)
 129      }
 130  }
 131  
 132  // (a + i b) / (c + i d) == [(a + i b) * (c - i d)] / (c*c + d*d)
 133  //   == [(a*c + b*d) / (c*c + d*d)] + i [(b*c - a*d) / (c*c + d*d)]
 134  impl<T: Clone + Num> Div<Cmplx<T>, Cmplx<T>> for Cmplx<T> {
 135      #[inline]
 136      fn div(&self, other&Cmplx<T>) -> Cmplx<T> {
 137          let norm_sqr = other.norm_sqr();
 138          Cmplx::new((self.re*other.re + self.im*other.im) / norm_sqr,
 139                     (self.im*other.re - self.re*other.im) / norm_sqr)
 140      }
 141  }
 142  
 143  impl<T: Clone + Num> Neg<Cmplx<T>> for Cmplx<T> {
 144      #[inline]
 145      fn neg(&self) -> Cmplx<T> {
 146          Cmplx::new(-self.re, -self.im)
 147      }
 148  }
 149  
 150  /* constants */
 151  impl<T: Clone + Num> Zero for Cmplx<T> {
 152      #[inline]
 153      fn zero() -> Cmplx<T> {
 154          Cmplx::new(Zero::zero(), Zero::zero())
 155      }
 156  
 157      #[inline]
 158      fn is_zero(&self) -> bool {
 159          self.re.is_zero() && self.im.is_zero()
 160      }
 161  }
 162  
 163  impl<T: Clone + Num> One for Cmplx<T> {
 164      #[inline]
 165      fn one() -> Cmplx<T> {
 166          Cmplx::new(One::one(), Zero::zero())
 167      }
 168  }
 169  
 170  /* string conversions */
 171  impl<T: fmt::Show + Num + Ord> fmt::Show for Cmplx<T> {
 172      fn fmt(&self, f&mut fmt::Formatter) -> fmt::Result {
 173          if self.im < Zero::zero() {
 174              write!(f.buf, "{}-{}i", self.re, -self.im)
 175          } else {
 176              write!(f.buf, "{}+{}i", self.re, self.im)
 177          }
 178      }
 179  }
 180  
 181  impl<T: ToStrRadix + Num + Ord> ToStrRadix for Cmplx<T> {
 182      fn to_str_radix(&self, radixuint) -> ~str {
 183          if self.im < Zero::zero() {
 184              format!("{}-{}i", self.re.to_str_radix(radix), (-self.im).to_str_radix(radix))
 185          } else {
 186              format!("{}+{}i", self.re.to_str_radix(radix), self.im.to_str_radix(radix))
 187          }
 188      }
 189  }
 190  
 191  #[cfg(test)]
 192  mod test {
 193      #![allow(non_uppercase_statics)]
 194  
 195      use super::{Complex64, Cmplx};
 196      use std::num::{Zero,One,Float};
 197  
 198      pub static _0_0i : Complex64 = Cmplx { re: 0.0, im: 0.0 };
 199      pub static _1_0i : Complex64 = Cmplx { re: 1.0, im: 0.0 };
 200      pub static _1_1i : Complex64 = Cmplx { re: 1.0, im: 1.0 };
 201      pub static _0_1i : Complex64 = Cmplx { re: 0.0, im: 1.0 };
 202      pub static _neg1_1i : Complex64 = Cmplx { re: -1.0, im: 1.0 };
 203      pub static _05_05i : Complex64 = Cmplx { re: 0.5, im: 0.5 };
 204      pub static all_consts : [Complex64, .. 5] = [_0_0i, _1_0i, _1_1i, _neg1_1i, _05_05i];
 205  
 206      #[test]
 207      fn test_consts() {
 208          // check our constants are what Cmplx::new creates
 209          fn test(c : Complex64, r : f64, i: f64) {
 210              assert_eq!(c, Cmplx::new(r,i));
 211          }
 212          test(_0_0i, 0.0, 0.0);
 213          test(_1_0i, 1.0, 0.0);
 214          test(_1_1i, 1.0, 1.0);
 215          test(_neg1_1i, -1.0, 1.0);
 216          test(_05_05i, 0.5, 0.5);
 217  
 218          assert_eq!(_0_0i, Zero::zero());
 219          assert_eq!(_1_0i, One::one());
 220      }
 221  
 222      #[test]
 223      #[ignore(cfg(target_arch = "x86"))]
 224      // FIXME #7158: (maybe?) currently failing on x86.
 225      fn test_norm() {
 226          fn test(c: Complex64, ns: f64) {
 227              assert_eq!(c.norm_sqr(), ns);
 228              assert_eq!(c.norm(), ns.sqrt())
 229          }
 230          test(_0_0i, 0.0);
 231          test(_1_0i, 1.0);
 232          test(_1_1i, 2.0);
 233          test(_neg1_1i, 2.0);
 234          test(_05_05i, 0.5);
 235      }
 236  
 237      #[test]
 238      fn test_scale_unscale() {
 239          assert_eq!(_05_05i.scale(2.0), _1_1i);
 240          assert_eq!(_1_1i.unscale(2.0), _05_05i);
 241          for &c in all_consts.iter() {
 242              assert_eq!(c.scale(2.0).unscale(2.0), c);
 243          }
 244      }
 245  
 246      #[test]
 247      fn test_conj() {
 248          for &c in all_consts.iter() {
 249              assert_eq!(c.conj(), Cmplx::new(c.re, -c.im));
 250              assert_eq!(c.conj().conj(), c);
 251          }
 252      }
 253  
 254      #[test]
 255      fn test_inv() {
 256          assert_eq!(_1_1i.inv(), _05_05i.conj());
 257          assert_eq!(_1_0i.inv(), _1_0i.inv());
 258      }
 259  
 260      #[test]
 261      #[should_fail]
 262      #[ignore]
 263      fn test_inv_zero() {
 264          // FIXME #5736: should this really fail, or just NaN?
 265          _0_0i.inv();
 266      }
 267  
 268      #[test]
 269      fn test_arg() {
 270          fn test(c: Complex64, arg: f64) {
 271              assert!((c.arg() - arg).abs() < 1.0e-6)
 272          }
 273          test(_1_0i, 0.0);
 274          test(_1_1i, 0.25 * Float::pi());
 275          test(_neg1_1i, 0.75 * Float::pi());
 276          test(_05_05i, 0.25 * Float::pi());
 277      }
 278  
 279      #[test]
 280      fn test_polar_conv() {
 281          fn test(c: Complex64) {
 282              let (r, theta) = c.to_polar();
 283              assert!((c - Cmplx::from_polar(&r, &theta)).norm() < 1e-6);
 284          }
 285          for &c in all_consts.iter() { test(c); }
 286      }
 287  
 288      mod arith {
 289          use super::{_0_0i, _1_0i, _1_1i, _0_1i, _neg1_1i, _05_05i, all_consts};
 290          use std::num::Zero;
 291  
 292          #[test]
 293          fn test_add() {
 294              assert_eq!(_05_05i + _05_05i, _1_1i);
 295              assert_eq!(_0_1i + _1_0i, _1_1i);
 296              assert_eq!(_1_0i + _neg1_1i, _0_1i);
 297  
 298              for &c in all_consts.iter() {
 299                  assert_eq!(_0_0i + c, c);
 300                  assert_eq!(c + _0_0i, c);
 301              }
 302          }
 303  
 304          #[test]
 305          fn test_sub() {
 306              assert_eq!(_05_05i - _05_05i, _0_0i);
 307              assert_eq!(_0_1i - _1_0i, _neg1_1i);
 308              assert_eq!(_0_1i - _neg1_1i, _1_0i);
 309  
 310              for &c in all_consts.iter() {
 311                  assert_eq!(c - _0_0i, c);
 312                  assert_eq!(c - c, _0_0i);
 313              }
 314          }
 315  
 316          #[test]
 317          fn test_mul() {
 318              assert_eq!(_05_05i * _05_05i, _0_1i.unscale(2.0));
 319              assert_eq!(_1_1i * _0_1i, _neg1_1i);
 320  
 321              // i^2 & i^4
 322              assert_eq!(_0_1i * _0_1i, -_1_0i);
 323              assert_eq!(_0_1i * _0_1i * _0_1i * _0_1i, _1_0i);
 324  
 325              for &c in all_consts.iter() {
 326                  assert_eq!(c * _1_0i, c);
 327                  assert_eq!(_1_0i * c, c);
 328              }
 329          }
 330          #[test]
 331          fn test_div() {
 332              assert_eq!(_neg1_1i / _0_1i, _1_1i);
 333              for &c in all_consts.iter() {
 334                  if c != Zero::zero() {
 335                      assert_eq!(c / c, _1_0i);
 336                  }
 337              }
 338          }
 339          #[test]
 340          fn test_neg() {
 341              assert_eq!(-_1_0i + _0_1i, _neg1_1i);
 342              assert_eq!((-_0_1i) * _0_1i, _1_0i);
 343              for &c in all_consts.iter() {
 344                  assert_eq!(-(-c), c);
 345              }
 346          }
 347      }
 348  
 349      #[test]
 350      fn test_to_str() {
 351          fn test(c : Complex64, s: ~str) {
 352              assert_eq!(c.to_str(), s);
 353          }
 354          test(_0_0i, "0+0i".to_owned());
 355          test(_1_0i, "1+0i".to_owned());
 356          test(_0_1i, "0+1i".to_owned());
 357          test(_1_1i, "1+1i".to_owned());
 358          test(_neg1_1i, "-1+1i".to_owned());
 359          test(-_neg1_1i, "1-1i".to_owned());
 360          test(_05_05i, "0.5+0.5i".to_owned());
 361      }
 362  }