Listing 2: The Matrix class


// Stephen O. Schulist - 1 September 1995

#ifndef _SOS_MATRIX_H
#define _SOS_MATRIX_H

#include "array.h"

template<class T> class Matrix
{
public:

   Matrix(int nRows = 0, int nCols = 0);
   Matrix(const Matrix<T>& matrix);
  ~Matrix();

  Matrix<T>& operator=(const Matrix<T>& matrix);

  int operator==(const Matrix<T>& matrix) const;
  int operator!=(const Matrix<T>& matrix) const;

  int    Rows() const { return m_nRows; }
  int Columns() const { return m_nCols; }

  Matrix<T> operator-() const;
  Matrix<T> operator~() const;
  Matrix<T> operator+(const Matrix<T>& m) const;
  Matrix<T> operator-(const Matrix<T>& m) const { return *this + -m; }
  Matrix<T> operator*(const T& t) const;
  Matrix<T> operator/(const T& t) const { return *this * (1/t); }

  Matrix<T>& operator+=(const Matrix<T>& m) { return *this = *this + m; }
  Matrix<T>& operator-=(const Matrix<T>& m) { return *this = *this - m; }
  Matrix<T>& operator*=(const T& t)         { return *this = *this * t; }
  Matrix<T>& operator/=(const T& t)         { return *this = *this / t; }

  friend Matrix<T> operator*(const T& t, const Matrix<T>& m) { return m * t; }

  Matrix<T> operator*(const Matrix<T>& m) const;

  Matrix<T>    Row(int nRow) const;
  Matrix<T> Column(int nCol) const;

  void SetRow   (int nRow, const Matrix<T>& matrix);
  void SetColumn(int nCol, const Matrix<T>& matrix);

  const T& operator()(int nRow, int nCol) const;
        T& operator()(int nRow, int nCol);

private:

  int m_nRows;
  int m_nCols;
  Array<T> m_data;
};

template<class T> inline Matrix<T>::Matrix(int nRows, int nCols) :
  m_nRows(nRows), m_nCols(nCols), m_data(nRows*nCols)
{
}

template<class T> inline Matrix<T>::Matrix(const Matrix<T>& matrix) :
  m_nRows(matrix.m_nRows), m_nCols(matrix.m_nCols), m_data(matrix.m_data)
{
}

template<class T> inline Matrix<T>::~Matrix()
{
}

template<class T> inline Matrix<T>& Matrix<T>::operator=(
  const Matrix<T>& matrix)
{
  if ( this != &matrix )
  {
    m_nRows = matrix.m_nRows;
    m_nCols = matrix.m_nCols;
    m_data  = matrix.m_data;
  }
  return *this;
}

template<class T> inline int Matrix<T>::operator==(
  const Matrix<T>& matrix) const
{
  int equal = ( m_nRows == matrix.m_nRows ) && ( m_nCols == matrix.m_nCols );
  for ( int m = 0; equal && ( m < m_nRows ); m++ )
  for ( int n = 0; equal && ( n < m_nCols ); n++ )
    equal = ( (*this)(m,n) == matrix(m,n) );
  return equal;
}

template<class T> inline int Matrix<T>::operator!=(
  const Matrix<T>& matrix) const
{
  return !operator==(matrix);
}

template<class T> inline Matrix<T> Matrix<T>::operator-() const
{
  Matrix<T> matrix(m_nRows, m_nCols);
  for ( int m = 0; m < m_nRows; m++ )
  for ( int n = 0; n < m_nCols; n++ )
    matrix(m, n) = -(*this)(m, n);
  return matrix;
}

template<class T> inline Matrix<T> Matrix<T>::operator~() const
{
  Matrix<T> matrix(m_nCols, m_nRows);
  for ( int m = 0; m < m_nRows; m++ )
  for ( int n = 0; n < m_nCols; n++ )
    matrix(n, m) = (*this)(m, n);
  return matrix;
}

template<class T> inline Matrix<T> Matrix<T>::operator+(const Matrix<T>& matrix) const
{
  ASSERT(m_nRows == matrix.m_nRows);
  ASSERT(m_nCols == matrix.m_nCols);
  Matrix<T> tmp(m_nRows, m_nCols);
  for ( int m = 0; m < m_nRows; m++ )
  for ( int n = 0; n < m_nCols; n++ )
    tmp(m, n) = (*this)(m, n) + matrix(m, n);
  return tmp;
}

template<class T> inline const T& Matrix<T>::operator()(
  int nRow, int nCol) const
{
  ASSERT(nRow < m_nRows);
  ASSERT(nCol < m_nCols);

  return m_data[nRow*m_nCols + nCol];
}

template<class T> inline T& Matrix<T>::operator()(int nRow, int nCol)
{
  ASSERT(nRow < m_nRows);
  ASSERT(nCol < m_nCols);

  return m_data[nRow*m_nCols + nCol];
}

template<class T> inline Matrix<T> Matrix<T>::operator*(const T& t) const
{
  Matrix<T> matrix(m_nRows, m_nCols);
  for ( int m = 0; m < m_nRows; m++ )
  for ( int n = 0; n < m_nCols; n++ )
    matrix(m, n) = (*this)(m, n) * t;
  return matrix;
}

template<class T> inline Matrix<T> Matrix<T>::operator*(const Matrix<T>& matrix) const
{
  ASSERT(m_nCols == matrix.m_nRows);

  Matrix tmp(m_nRows, matrix.m_nCols);

  for ( int m = 0; m <        m_nRows; m++ )
  for ( int n = 0; n < matrix.m_nCols; n++ )
  {
    tmp(m, n) = 0;
    for ( int i = 0; i < m_nCols; i++ )
      tmp(m, n) += (*this)(m, i) * matrix(i, n);
  }

  return tmp;
}

template<class T> inline Matrix<T> Matrix<T>::Row(int nRow) const
{
  ASSERT(nRow < m_nRows);
  Matrix<T> matrix(1, m_nCols);
  for ( int n = 0; n < m_nCols; n++ )
    matrix(0, n) = (*this)(nRow, n);
  return matrix;
}

template<class T> inline Matrix<T> Matrix<T>::Column(int nCol) const
{
  ASSERT(nCol < m_nCols);
  Matrix<T> matrix(m_nRows, 1);
  for ( int m = 0; m < m_nRows; m++ )
    matrix(m, 0) = (*this)(m, nCol);
  return matrix;
}

template<class T> inline void Matrix<T>::SetRow(
  int nRow, const Matrix<T>& matrix)
{
  ASSERT(nRow < m_nRows);
  ASSERT(matrix.Rows() == 1);
  ASSERT(matrix.Columns() == m_nCols);
  for ( int n = 0; n < m_nCols; n++ )
    (*this)(nRow, n) = matrix(0, n);
}

template<class T> inline void Matrix<T>::SetColumn(
  int nCol, const Matrix<T>& matrix)
{
  ASSERT(nCol < m_nCols);
  ASSERT(matrix.Rows() == m_nRows);
  ASSERT(matrix.Columns() == 1);
  for ( int m = 0; m < m_nRows; m++ )
    (*this)(m, nCol) = matrix(m, 0);
}

template<class T> inline ostream& operator<<(
  ostream& stream, const Matrix<T>& matrix)
{
  for ( int m = 0; m < matrix.Rows(); m++ )
  {
    for ( int n = 0; n < matrix.Columns(); n++ )
    {
      stream << matrix(m,n) << " ";
    }
    stream << endl;
  }
  return stream;
}

#endif // MATRIX_H