#ifndef LUDECOMP_H
#define LUDECOMP_H
#include "smatrix.h"
namespace sm {
template <class t> class TLUDecomp : public smatrix<t>
{
protected:
smatrix<t> pivots;
bool bDecomposed;
public:
TLUDecomp(const smatrix<t> &m, bool bWithPartialPivot) :
smatrix<t>(m)
{
bDecomposed = false;
if(Decompose(bWithPartialPivot) == false)
{
pivots = smatrix<t>();
*this = TLUDecomp<t>();
}
}
TLUDecomp(const t* d, size_t iDim, size_t jDim,
bool bWithPartialPivot) : smatrix<t> (d, iDim, jDim)
{
bDecomposed = false;
if(Decompose(bWithPartialPivot) == false)
{
pivots = smatrix<t>();
*this = TLUDecomp<t>();
}
}
TLUDecomp() { bDecomposed = false; }
bool Decompose(const smatrix<t> &m, bool bWithPartialPivot)
{
*this = m;
bDecomposed = false;
if(Decompose(bWithPartialPivot) == false)
{
pivots = smatrix<t>();
*this = TLUDecomp<t>();
return(false);
}
return(true);
}
void FwdElim(smatrix<t> &c) const
{
for(size_t k = 1; k < iDim; k++)
{
const_iterator i = IterForCol(k, k + 1);
while(i != EndOfCol(k))
{
t Scaler = Coef(i);
size_t r = Index(i++);
c.row(r) += c.row(k) * Scaler;
}
}
}
void BackElim(smatrix<t> &c) const
{
for(size_t k = iDim; k >= 1; k--)
{
t Diag = t(-1.0) / Coef(k, k);
const_iterator i = IterForCol(k);
while(Index(i) < k)
{
t Scaler = Coef(i) * Diag;
size_t r = Index(i++);
c.row(r) += c.row(k) * Scaler;
}
c.row(k) *= -Diag;
}
}
bool Decompose(bool bWithPartialPivot)
{
if(bDecomposed) return(true);
if(iDim != jDim) return(false);
pivots = smIdentity<t>(iDim);
for(size_t k = 1; k < iDim; k++)
{
t Diag;
if(!GetDiag(k, Diag, bWithPartialPivot, &pivots))
return(false);
Diag = t(-1.0) / Diag;
iterator i = IterForCol(k, k + 1);
while(i != EndOfCol(k))
{
t Scaler = Coef(i) * Diag;
size_t r = Index(i++);
row(r) += row(k, rng(k + 1, jDim)) * Scaler;
Coef(r, k) = Scaler;
}
}
t Diag = Coef(k, k);
if(abs(Diag) < epsilon)
return(false);
bDecomposed = true;
return(true);
}
smatrix<t> Solve(const smatrix<t> &m) const
{
smatrix<t> c(m.iDim, m.jDim, m.NZ);
if(bDecomposed)
{
for(size_t r = 1; r <= iDim; r++)
{
const_svector_ref &v = m.row(r);
const_iterator i = pivots.IterForCol(r);
size_t x = Index(i);
c.row(x) = v;
}
FwdElim(c);
BackElim(c);
}
return(c);
}
smatrix<t> Inverse() const
{
smatrix<t> c(pivots);
if(bDecomposed)
{
FwdElim(c);
BackElim(c);
}
return(c);
}
smatrix<t> L() const
{
smatrix<t> c(iDim, jDim, NZ);
for(size_t r = 1; r <= iDim; r++)
{
c.row(r) = row(r, rng(1, r - 1)) * t(-1);
c.Coef(r, r) = t(1);
}
return(c);
}
smatrix<t> U() const
{
smatrix<t> c(iDim, jDim, NZ);
for(size_t r = 1; r <= iDim; r++)
c.row(r) = row(r, rng(r, jDim));
return(c);
}
smatrix<t> P() const
{
return(pivots);
}
};
}
#endif