#include <time.h>       // clock() etc.
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <mpi.h>
#include "cpuTimes.c"
#include <unistd.h>     // fork stuff
#include <sys/wait.h>   // to allow wait

#include <sys/types.h>  // required for low-level I/O
#include <fcntl.h>      // File control definitions, like O_RDWR

// Diagonal permutation sets solution and lowerLimit
#define MAX_SIZE 32
int solution[MAX_SIZE] = {0}, lowerLimit = -1;
// NOTE:  size of the problem is available globally
int size, benefit[MAX_SIZE][MAX_SIZE];
enum { FALSE, TRUE };
int DEBUG;
int HANDSHAKE;          // Require the FINIS message
int P_TRACE;
int TRUNCATE;

#define INIT     1      // Message to client with data
#define BLOCK    2      // Complete benefit matrix to client
#define DATA     4      // Both directions, data in / out
#define FINIS    8      // Message to client to terminate

// For MPI messaging, this will be treated as a MAX_SIZE+1 vector
typedef int vector_state[MAX_SIZE+3];

void client (int rank,int nProc);
void explore(int index, int *vect, int prefix);
void MPIsplit(int *vect, int nProc);

void dumpState()
{  int row, col;
   char buffer[2048];
   buffer[0] = 0;
   for (row = 0; row < size; row++)
   {  for (col = 0; col < size; col++)
         sprintf(buffer+strlen(buffer), "%3d", benefit[row][col]);
      sprintf(buffer+strlen(buffer), "\n");
   }
   sprintf(buffer+strlen(buffer), "Benefit %d:  ", lowerLimit);
   for (row = 0; row < size; row++)
      sprintf(buffer+strlen(buffer), "%3d", solution[row]);
   printf("%s\n", buffer);
   return;
}

int main (int argc, char** argv)
{
   // Set global variables
   DEBUG = FALSE;
   HANDSHAKE =  TRUE;
   P_TRACE =  TRUE;
   TRUNCATE = FALSE;
   lowerLimit = 0;

   int    nProc,                    // Processes in the communicator
          proc;                     // loop variable
   MPI_Status Status;               // Return status from MPI
   int    rc;                       // Status  code from MPI_Xxx() call
   int    myPos;                    // My own position

   char* filename = argc == 1 ? "Asg5.in" : argv[1];
   FILE *input = fopen(filename, "r");
   int antisum, row, col, j, k, *work;
   double start, elapsed;
   int N_out;

   rc = MPI_Init(&argc, &argv);
   if (rc != MPI_SUCCESS)
   {  puts ("MPI_Init failed"); exit(-1);  }

   rc = MPI_Comm_rank (MPI_COMM_WORLD, &myPos);
   rc = MPI_Comm_size (MPI_COMM_WORLD, &nProc);

   if (P_TRACE)
   {  printf ("Process %d of %d started.\n", myPos, nProc); fflush(stdout);
   }


   if ( myPos > 0 )
      client(myPos, nProc);
   else
   {
      if (P_TRACE) {  puts("Starting data server"); fflush(stdout);  }
      if (input == NULL)
      {  printf("Failed to open %s\n", filename);
         // Send terminate message to clients
         size = 0;
         for (row = 1; row < nProc; row++)
            MPI_Send (&size, 1, MPI_INT, row, INIT, MPI_COMM_WORLD);
         MPI_Finalize();
         return 0;
      }
      fscanf(input, "%d", &size);
      work = (int*) calloc(size, sizeof *work);
      antisum  = 0;  k = size;
      for (row = lowerLimit = 0; row < size; row++)
      {  solution[row] = row;
         for (col = 0; col < size; col++)
            fscanf (input, "%d", &benefit[row][col]);
         lowerLimit += benefit[row][row];
         antisum  += benefit[row][--k];
      }
      if (antisum > lowerLimit)
      {  lowerLimit = antisum;
         for (row = 0, k = size; row < size; row++)
            solution[row] = --k;
      }
      memcpy(work, solution, size * sizeof *work);
      if (TRUNCATE) dumpState();
      start = wallClock();
      MPIsplit(work, nProc);
      elapsed = wallClock() - start;
      printf("Solution found with value %d\n", lowerLimit);
      for (row = 0; row < size; row++)
         printf("%3d", solution[row]);
      printf("\n%3.3f seconds\n", elapsed);
   }
// Everybody resigns from MPI before terminating.
   MPI_Finalize();
   return 0;
}

void swap(int *x, int j, int k)
{  int temp = x[j]; x[j] = x[k]; x[k] = temp;  }

void MPIsplit(int *work, int nProc)
{
   int  proc, j, k, index,
        nActive;       // Number of active processes
   vector_state result;
   int commBuffer[2]; // Communication buffer -- [0], lowerLimit
   MPI_Status Status;             // Return status from MPI

   commBuffer[1] = lowerLimit;    // Common to all in the first pass
// Send initial configurations to all client processes --- or to those
// needed in case not all are required.
   for ( index = 0, proc = 1; proc < nProc && index < size; proc++, index++ )
   {  commBuffer[0] = index;
      MPI_Send (&size, 1, MPI_INT, proc, INIT, MPI_COMM_WORLD);
      // Easier to ignore size and use MAX_SIZE
      if (P_TRACE) {printf("Send benefit to %d\n", proc); fflush(stdout);  }
      MPI_Send (benefit, MAX_SIZE*MAX_SIZE, MPI_INT, proc, BLOCK,
                MPI_COMM_WORLD);
      if (P_TRACE) printf ("Sending client %d job %d, bound %d\n", proc,
                         commBuffer[0], commBuffer[1]);
      MPI_Send (commBuffer, 2, MPI_INT, proc, DATA, MPI_COMM_WORLD);
   }
   if (index < size)
      nActive = nProc - 1;
   else                      // I.e., may have unused processes
   {  commBuffer[0] = 0;     // "size" zero causes termination
      for ( ; proc < nProc ; proc++ )
      {//These processes are waiting for the size message
         if (P_TRACE) printf ("Sending client %d termination message\n", proc);
         MPI_Send (commBuffer, 1, MPI_INT, proc, INIT, MPI_COMM_WORLD);
      }
      nActive = size;
   }
// Receive back results and send out new problems
   while ( index < size )
   {  MPI_Recv(&result, size+1, MPI_INT, MPI_ANY_SOURCE, DATA,
               MPI_COMM_WORLD, &Status);
      proc = Status.MPI_SOURCE;
      if (P_TRACE) printf ("Received results from client %d (%d vs %d)\n",
                 proc, result[size], lowerLimit);
      if (result[size] > lowerLimit)
      {  if (DEBUG)
            printf("Updating from %d to %d for benefit.\n",
                  lowerLimit, result[size]);
         lowerLimit = result[size];
         memcpy (solution, result, size * sizeof *solution);
      }
      commBuffer[0] = index++;
      commBuffer[1] = lowerLimit;
      if (P_TRACE) printf ("Sending client %d job %d, bound %d\n", proc,
                         commBuffer[0], commBuffer[1]);
      MPI_Send (commBuffer, 2, MPI_INT, proc, DATA, MPI_COMM_WORLD);
   }
// Finally, receive back pending results and send termination
// indication (message with size of zero).
   commBuffer[0] = -1;
   while (nActive > 0)
   {
      if (P_TRACE) printf ("%d pending\n", nActive);
      MPI_Recv(&result, size+1, MPI_INT, MPI_ANY_SOURCE, DATA,
               MPI_COMM_WORLD, &Status);
      --nActive;
      proc = Status.MPI_SOURCE;
      if (P_TRACE) printf ("Received results from client %d (%d vs %d)\n",
                 proc, result[size], lowerLimit);
      if (result[size] > lowerLimit)
      {  if (DEBUG)
            printf("Updating from %d to %d for benefit.\n",
                  lowerLimit, result[size]);
         lowerLimit = result[size];
         memcpy (solution, result, size * sizeof *solution);
      }
      if (P_TRACE) printf ("Sending client %d termination message\n", proc);
      MPI_Send (commBuffer, 2, MPI_INT, proc, DATA, MPI_COMM_WORLD);
   }
   if (HANDSHAKE)
   {  for (proc = 1; proc < nProc; proc++)
         MPI_Send(&k, 0, MPI_INT, proc, FINIS, MPI_COMM_WORLD);
      for (proc = 1; proc < nProc; proc++)
         MPI_Recv(&k, 0, MPI_INT, proc, FINIS, MPI_COMM_WORLD,
                  &Status);
   }
}

void client (int rank, int nProc)
{
   MPI_Status Status;             // Return status from MPI
   int *work, msgvect[2];
   int  k;                        // loop variable
   vector_state result;

   if (P_TRACE)
   {  printf("Compute engine %d starting\n", rank); fflush(stdout);  }

   // Receive the global variable size
   MPI_Recv(&size, 1, MPI_INT, 0, INIT, MPI_COMM_WORLD,
            &Status);
   if (size < 1)        // I.e., no work to do.
      msgvect[0] = -1;  // Avoid entering the while loop
   else                 // Get the benefit matrix and first job
   {
      MPI_Recv(benefit, MAX_SIZE*MAX_SIZE, MPI_INT, 0, BLOCK,
               MPI_COMM_WORLD, &Status);
      MPI_Recv(msgvect, 2, MPI_INT, 0, DATA, MPI_COMM_WORLD,
               &Status);
      // Allocate the working vector and fill it
      work = (int*) calloc(size, sizeof *work);
      for (k = 0; k < size; k++)
         work[k] = k;
   }
   // msgvect[0] is the index to swap with [0]
   // msgvect[1] is the latest lowerLimit
   while (msgvect[0] >= 0)     // Exit on k < 0
   {
      k = msgvect[0];
      lowerLimit = msgvect[1];
      if (P_TRACE)
         printf("Work engine %d working on %d from %d\n", rank, k, lowerLimit);

      swap(work, 0, k);
      explore (1, work, benefit[0][work[0]]);
      swap(work, 0, k);
      if (P_TRACE)
      {  int k = 0;
         char buffer[2048] = {'\0'};
         sprintf(buffer, "Rank %d:  %d  ",
                 rank, lowerLimit);
         k = 0;
         while ( k < size )
         {
            sprintf(buffer+strlen(buffer), "%3d", solution[k++]);
         }
         printf("%s\n", buffer); fflush(stdout);
      }
      result[size] = lowerLimit;
      memcpy(result, solution, size * sizeof *solution);
      MPI_Send(&result, size+1, MPI_INT, 0, DATA, MPI_COMM_WORLD);
      // Get the next problem
      if (P_TRACE)
         printf("Rank %d waiting for job\n", rank);
      MPI_Recv(msgvect, 2, MPI_INT, 0, DATA, MPI_COMM_WORLD,
               &Status);
   }
   if (P_TRACE)
      printf("Rank %d entering hand-shaking\n", rank);  fflush(stdout);
   if (HANDSHAKE)
   {  MPI_Recv(&k, 0, MPI_INT, 0, FINIS, MPI_COMM_WORLD,
               &Status);
      MPI_Send(&k, 0, MPI_INT, 0, FINIS, MPI_COMM_WORLD);
   }
}

int colMaxSum(int start, int work[])
{  int sum = 0, k;
   for (k = start; k < size; k++)
   {
      int columnMaximum = benefit[start][work[k]],
          row;
      for (row = start+1; row < size; row++)
         if (columnMaximum < benefit[row][work[k]])
            columnMaximum = benefit[row][work[k]];
      sum += columnMaximum;
   }
   return sum;
}

void explore(int index, int vect[], int prefix)
{
   int k,       // Loop variable for swapping [index]..[size-1]
       j,       // Spare loop variable
       hold;    // Value of fixed portion of a permutation,
                // then used undoing the right rotation
   for ( k = index; k < size; k++ )
   {  int col, upperBound;
      swap(vect, index, k);
   // Add together column maxima
      upperBound = colMaxSum(index+1, vect);

      hold = prefix + benefit[index][vect[index]];

      // Selecting [size-2] also fixes [size-1] ---
      // a permutation has been completed.
      if (index == size-2)
      {
         if ( lowerLimit < hold+upperBound )
         {  // Add in the last piece
            hold += benefit[size-1][vect[size-1]];
            if (DEBUG)
            {  printf("Replacing %d with %d: ", lowerLimit, hold);
               for (j = 0; j < size; j++)
                  printf("%3d", vect[j]);
               putchar('\n');
            }
            lowerLimit = hold;
            memcpy(solution, vect, size * sizeof *vect);
//          System.arraycopy(vect, 0, solution, 0, size);
         }
         else if (DEBUG)
         {  printf("Reject %d: ", hold+upperBound);
            for (j = 0; j < size; j++)
               printf("%3d", vect[j]);
            putchar('\n');
         }
      }
      else if (hold + upperBound > lowerLimit)
         explore(index+1, vect, hold);
      else if (DEBUG)
      {  printf("Prune at %d, upper limit %d: ",
                index, hold + upperBound);
         for (j = 0; j < size; j++)
            printf("%3d", vect[j]);
         putchar('\n');
      }
   }
   // Undo the one-cell rightward rotation done above
   hold = vect[index];
   for (k = index+1; k < size; k++)
      vect[k-1] = vect[k];
   vect[size-1] = hold;
}
