/* Benchmark program for matrix multiplication

   Three different algorithms
   Loop i contains loop j contains loop k
   Loop i contains loop k contains loop j
   Loop i contains loop j contains loop k, but Btranspose is used.

   Two difference cell accesses methods
   Subscripting a 1-dimensional array using the standard 2-D mapping
   Explicit pointer arithmetic used to eliminate the integer multiplications

          | Scattered B access | Sequential B access | B passed as transpose
----------+--------------------+---------------------+----------------------
Subscripts|      MatMult00     |      MatMult01      |      MatMult02
Pointers  |      MatMult10     |      MatMult11      |      MatMult12
----------+--------------------+---------------------+----------------------

   Author:    Timothy Rolfe
   Language:  ANSI C
   Op.Sys.:   Unix --- developed under the Linux variant

   Required file:  cpuClock.c

   double cpuClock(void) returns the elapsed CPU time in (floating point)
   seconds since some fixed point in the past.
*/

#include <stdio.h>      // printf, scanf, etc.
#include <stdlib.h>     // srand, rand, realloc, etc.
#include <time.h>       // time for "srand(time(NULL));"
#include "cpuClock.c"   // double cpuClock();

// Fill the vector X with N values in the range [-100.0 .. 100.0)
void RandFill (double *X, int N);

// A[N1][N2], B[N2][N3], C[N1][N3]
// Straight computation (i.e., ALL of C[i][j] at once)
void MatMult00(double *A, double *B, double *C,
               int N1, int N2, int N3);
void MatMult10(double *A, double *B, double *C,
               int N1, int N2, int N3);

// A[N1][N2], B[N2][N3], C[N1][N3]
// Scattered computation (i.e., row-wide in B and C)
void MatMult01(double *A, double *B, double *C,
               int N1, int N2, int N3);
void MatMult11(double *A, double *B, double *C,
               int N1, int N2, int N3);

// A[N1][N2], Bt[N3][N2], C[N1][N3]
// Straight computation (i.e., ALL of C[i][j] at once)
// except that B has been transposed for local access
void MatMult02(double *A, double *Bt, double *C,
               int N1, int N2, int N3);
void MatMult12(double *A, double *Bt, double *C,
               int N1, int N2, int N3);

// ?? A[N1][N3] equals B[N1][N3] ??
int  MatCheck(double *A, double *B, int N1, int N3);

void Dump (double *A, int N1, int N2)
{  int row, col;

   for ( row = 0; row < N1; row++ )
   {  for ( col = 0; col < N2; col++ )
         printf ("%8.0f  ", *A++);
      putchar('\n');
   }
}

int main ( int argc, char* argv[] )
{  int     N1, N2, N3;
   int     run, Nruns;
   int     j, k;
   long    Seed;
   double *A, *B, *Bt, *C0, *C1, *C2;
   FILE *fptr;

   double Start, Mid1, Mid2, Finish;

   fputs ("Enter the three dimensions:  ", stdout);
   fflush(stdout);

   if (argc > 1)
   {  N1 = atoi(argv[1]);
      printf ("%d ", N1);
   }
   else
      scanf ("%d", &N1);

   if (argc > 2)
   {  N2 = atoi(argv[2]);
      printf ("%d ", N2);
   }
   else
      scanf ("%d", &N2);

   if (argc > 3)
   {  N3 = atoi(argv[3]);
      printf ("%d\n", N3); fflush(stdout);
   }
   else
      scanf ("%d", &N3);

   fputs ("Enter the number of random matrices to run: ", stdout);
   fflush(stdout);

   if (argc > 4)
   {  Nruns = atoi(argv[4]);
      printf ("%d\n", Nruns); fflush(stdout);
   }
   else
   {  scanf ("%d", &Nruns);
      while ( getchar() != '\n' )
         ;
   }

   fptr = fopen("MatOpt.csv", "a");
   if (fptr == NULL)
   {  perror ("fptr open failed"); exit(-1);  }

   srand(time(NULL));

   Seed = rand();

   A  = (double*) calloc (N1*N2, sizeof *A);
   B  = (double*) calloc (N2*N3, sizeof *B);
   Bt = (double*) calloc (N2*N3, sizeof *Bt);
   C0 = (double*) calloc (N1*N3, sizeof *C0);
   C1 = (double*) calloc (N1*N3, sizeof *C0);
   C2 = (double*) calloc (N1*N3, sizeof *C0);

   if (C2 == NULL || Bt == NULL || A == NULL)
   {  fprintf(stderr, "Insufficient memory.\n"); exit(-1);  }

   Start = cpuClock();

   srand(Seed);
   for ( run = 0; run < Nruns; run++ )
   {
      RandFill (A, N1*N2);
      RandFill (B, N2*N3);
      MatMult00 (A, B, C0, N1, N2, N3);
   }

   Mid1 = cpuClock();

   srand(Seed);
   for ( run = 0; run < Nruns; run++ )
   {
      RandFill (A, N1*N2);
      RandFill (B, N2*N3);
      MatMult01 (A, B, C1, N1, N2, N3);
   }

   Mid2 = cpuClock();

   srand(Seed);
   for ( run = 0; run < Nruns; run++ )
   {
      RandFill (A, N1*N2);
      RandFill (B, N2*N3);
//    B[N2][N3] and Bt[N3][N2]
      for ( j = 0; j < N2; j++ )
      {  double *pB  = B  + j*N3,
           *pBt = Bt + j;

         for ( k = 0; k < N3; k++, pB++, pBt += N2 )
            *pBt = *pB;
      }

      MatMult02 (A, Bt, C2, N1, N2, N3);
   }

   Finish = cpuClock();

   if ( MatCheck (C0, C1, N1, N3) == 1 )
      puts ("C0 equals C1");
   else
   {  puts ("C0 does NOT equal C1\nC0:");
      Dump (C0, N1, N3);
      puts ("C1:");
      Dump (C1, N1, N3);
   }

   if ( MatCheck (C0, C2, N1, N3) == 1 )
      puts ("C0 equals C2");
   else
   {  puts ("C0 does NOT equal C2\nC0:");
      Dump (C0, N1, N3);
      puts ("C2:");
      Dump (C2, N1, N3);
   }

   if ( MatCheck (C1, C2, N1, N3) == 1 )
      puts ("C1 equals C2");
   else
   {  puts ("C1 does NOT equal C2\nC1:");
      Dump (C1, N1, N3);
      puts ("C2:");
      Dump (C2, N1, N3);
   }

   printf ("Time for MatMult0:  %3.3f\n", Mid1-Start);
   printf ("Time for MatMult1:  %3.3f\n", Mid2-Mid1);
   printf ("Time for MatMult2:  %3.3f\n", Finish-Mid2);
   fprintf (fptr, "%d, %3.3f, %3.3f, %3.3f,", N3,
            Mid1-Start, Mid2-Mid1, Finish-Mid2);
   fflush(fptr);

   Start = cpuClock();
   srand(Seed);
   for ( run = 0; run < Nruns; run++ )
   {
      RandFill (A, N1*N2);
      RandFill (B, N2*N3);
      MatMult10 (A, B, C0, N1, N2, N3);
   }
   Finish = cpuClock();

   puts ("Pointer-based MatMult0");
   if ( MatCheck (C0, C1, N1, N3) == 1 )
      puts ("C0 equals C1");
   else if (N1 < 5 && N2 < 5)
   {  puts ("C0 does NOT equal C1\nC0:");
      Dump (C0, N1, N3);
      puts ("C1:");
      Dump (C1, N1, N3);
   }

   printf ("Time required:  %3.3f\n", Finish-Start);
   fprintf (fptr, "%3.3f, ", Finish-Start);
   fflush(fptr);

   Start = cpuClock();
   srand(Seed);
   for ( run = 0; run < Nruns; run++ )
   {
      RandFill (A, N1*N2);
      RandFill (B, N2*N3);
      MatMult11 (A, B, C1, N1, N2, N3);
   }
   Finish = cpuClock();

   puts ("Pointer-based MatMult1");
   if ( MatCheck (C0, C1, N1, N3) == 1 )
      puts ("C0 equals C1");
   else if (N1 < 5 && N2 < 5)
   {  puts ("C0 does NOT equal C1\nC0:");
      Dump (C0, N1, N3);
      puts ("C1:");
      Dump (C1, N1, N3);
   }

   printf ("Time required:  %3.3f\n", Finish-Start);
   fprintf (fptr, "%3.3f, ", Finish-Start);
   fflush(fptr);

   Start = cpuClock();
   srand(Seed);
   for ( run = 0; run < Nruns; run++ )
   {
      RandFill (A, N1*N2);
      RandFill (B, N2*N3);
//    B[N2][N3] and Bt[N3][N2] --- Keep the transpose time
      { double *pB = B;
         for ( j = 0; j < N2; j++ )
         {  double *pBt = Bt + j;

            for ( k = 0; k < N3; k++, pB++, pBt += N2 )
               *pBt = *pB;
         }
      }

      MatMult12 (A, Bt, C1, N1, N2, N3);
   }
   Finish = cpuClock();

   puts ("Pointer-based MatMult2");
   if ( MatCheck (C1, C2, N1, N3) == 1 )
      puts ("C1 equals C2");
   else if (N1 < 5 && N2 < 5)
   {  puts ("C1 does NOT equal C2\nC1:");
      Dump (C1, N1, N3);
      puts ("C2:");
      Dump (C2, N1, N3);
   }

   printf ("Time required:  %3.3f\n", Finish-Start);
   fprintf (fptr, "%3.3f\n", Finish-Start);
   fclose(fptr);

   return 0;
}

// Fill the vector X with N values in the range [-100.0 .. 100.0)
void RandFill (double *X, int N)
{  while ( N-- > 0)
      *X++ = (double)rand() * 200 / (RAND_MAX + 1.0) - 100.0;
}

// A[n1][n2], B[n2][n3], C[n1][n3]
// Straight computation (i.e., ALL of C[i][j] at once)
// Explicit subscript calculation version
void MatMult00(double *A, double *B, double *C,
               int n1, int n2, int n3)
{  int i, k, j;

   for ( i = 0; i < n1; i++ )
      for ( j = 0; j < n3; j++ )
      {  double sigma = 0;

         for ( k = 0; k < n2; k++ )
         {  sigma += A[i*n2 + k] * B[k*n3 + j]; }
         C[i*n3 + j] = sigma;
      }
}

// Pure pointer access version
void MatMult10(double *A, double *B, double *C,
               int n1, int n2, int n3)
{  int i, k, j;
   double *pC = C;

   for ( i = 0; i < n1; i++ )
      for ( j = 0; j < n3; j++ )
      {  double sigma = 0;
         double *pA = A + i*n2,
                *pB = B + j;

         for ( k = 0; k < n2; k++ )
         {  sigma += *pA * *pB;
            pA++; pB += n3;
         }
         *pC++ = sigma;
      }
}

// A[n1][n2], B[n2][n3], C[n1][n3]
// Scattered computation (i.e., row-wide in B and C)
// Explicit subscript calculation version
void MatMult01(double *A, double *B, double *C,
               int n1, int n2, int n3)
{  int i, k, j;

   for ( i = 0; i < n1; i++ )
   {  double Aik = A[i*n2];
      for ( j = 0; j < n3; j++ )
         C[i*n3 + j] = Aik * B[j];
      for ( k = 1; k < n2; k++ )
      {  Aik = A[i*n2 + k];
         for ( j = 0; j < n3; j++ )
            C[i*n3 + j] += Aik * B[k*n3 + j];
      }
   }
}

// Pure pointer access version
void MatMult11(double *A, double *B, double *C,
               int n1, int n2, int n3)
{  int i, k, j;
   double *pA = A;

   for ( i = 0; i < n1; i++ )
   {  double *pC = C + i*n3,
             *pB = B,
             Aik = *pA++;
      for ( j = 0; j < n3; j++ )
         *pC++ = Aik * *pB++;
      for ( k = 1; k < n2; k++ )
      {  pB  = B + k*n3;
         pC  = C + i*n3;
         Aik = *pA++;

         for ( j = 0; j < n3; j++ )
         {  *pC++ += Aik * *pB++;  }
      }
   }
}

// A[n1][n2], Bt[n3][n2], C[n1][n3]
// Straight computation (i.e., ALL of C[i][j] at once)
// except that B has been transposed for local access
// Explicit subscript calculation version
void MatMult02(double *A, double *Bt, double *C,
               int n1, int n2, int n3)
{  int i, k, j;

   for ( i = 0; i < n1; i++ )
      for ( j = 0; j < n3; j++ )
      {  double sigma = 0;

         for ( k = 0; k < n2; k++ )
         {  sigma += A[i*n2 + k] * Bt[j*n2 + k]; }
         C[i*n3 + j] = sigma;
      }

}

// Pure pointer access version
void MatMult12(double *A, double *Bt, double *C,
               int n1, int n2, int n3)
{  int i, k, j;
   double *pC = C;

   for ( i = 0; i < n1; i++ )
      for ( j = 0; j < n3; j++ )
      {  double sigma = 0;
         double *pA  = A  + i*n2,
                *pBt = Bt + j*n2;

         for ( k = 0; k < n2; k++ )
         {  sigma += *pA++ * *pBt++;  }
         *pC++ = sigma;
      }
}

// ?? A[N1][N3] equals B[N1][N3] ??
int  MatCheck(double *A, double *B, int N1, int N3)
{  int row, col;

   for ( row = 0; row < N1; row++ )
      for ( col = 0; col < N3; col++ )
         if ( (long)*A++ != (long)*B++ )
         {  printf ("Different in [%d][%d].\n", row, col);
            return 0;
         }
   return 1;
}
