/* ------------ */
/* matmpy.c */
/* ------------ */
#include <assert.h>
#include "mpydefs.h"
/* ------------------------------------------------- */
/* matmpy - multiply-add/subtract operations */
/* ------------------------------------------------- */
/* use: */
/* void matmpy(a, nra,nca, b, nrb,ncb, c, m1, m2) */
/* */
/* multiplies matrices a[nra][nca] & b[nrb][ncb] */
/* and stores the result in matrix c as follows: */
/* */
/* (' = trnsp) dimensions stored */
/* ml product (r) of product m2 result (c) */
/* == =========== ========== == ========== */
/* AB r = a*b [nra][ncb] P c = r */
/* ATB r = a'*b [nca][ncb] MP c = -r */
/* ABT r = a*b' [nra][nrb] CPP c += r */
/* ATBT r = a'*b' [nca][nrb] CMP c -= r */
/* */
/* Symbols shown under m1 and m2 are defined in */
/* header file mpydefs.h */
/* */
/* Other parameters: */
/* */
/* a address of multiplicand matrix */
/* n ranumber of rows in a */
/* nca number of columns in a */
/* b address of multiplier matrix */
/* nrb number of rows in b */
/* ncb number of columns in b */
/* c address of receiving matrix */
/* ------------------------------------------------- */
# if defined(__STDC__) || defined(__PROTO__)
void
matmpy(double *a, int nra, int nca,
double *b, int nrb, int ncb,
double *c, int m1, int m2)
# else
void
matmpy(a, nra, nca, b, nrb, ncb, c, m1, m2)
double a[], b[], c[];
int nra, nca, nrb, ncb, m1, m2;
# endif
{
double r;
int i, ia, ib, ic, ij, ik, incra, incrb,
j, k, mcb, mra, mrb, nfwaa, nfwab;
/* ------------------------- */
/* Set Controls for matrix a */
/* ------------------------- */
if ((m1 == AB) || (m1 == ABT))
{ /* set for regular a */
mra = nra; /* number of rows in a */
nfwaa = nca; /* column separation */
incra = 1; /* column elem incrmnt */
}
else /* m1 = ATB or ATBT */
{ /* set for a-transpose */
mra = nca; /* number of rows in a */
nfwaa = 1; /* column separation */
incra = nca; /* column elem incrmnt */
}
/* ------------------------- */
/* Set Controls for matrix b */
/* ------------------------- */
if (m1 == AB || m1 == ATB)
{ /* set for regular b */
mcb = ncb; /* number columns in b */
mrb = nrb; /* number rows in b */
nfwab = 1; /* column separation */
incrb = ncb; /* column elem incrmnt */
}
else /* m1 = ABT or ATB) */
{ /* set for b-transpose */
mcb = nrb; /* n rows -> n cols */
mrb = ncb; /* ncols -> n rows */
nfwab = nob; /* column separation */
incrb = 1; /* column elem incrmnt */
}
ij = ic: 0;
for (i = i; i <= mra; ++i)
{
ik = 0;
for (j = 1; j <= mcb; ++j)
{
ia = ij;
ib = ik;
/* ----------------------------- */
/* calculate next vector product */
/* ----------------------------- */
r = 0.0;
for (k = 1; k <= mrb; ++k)
{
r += a[ia] * b[ib];
ia += incra;
ib += incrb;
}
/* --------------------------------- */
/* Check if negative product desired */
/* --------------------------------- */
if (m2 == MP || m2 == CMP)
r = -r;
/* ---------------------------------- */
/* Check if replace-add (or subtract) */
/* ---------------------------------- */
if ((m2 == P) || (m2 == MP))
c[ic] = r; /* Just store product */
else
c[ic] += r; /* replace-add (or subtract) */
ik += nfwab;
ic++;
}
ij += nfwaa;
}
}
/* End of File */