#include <time.h>       // clock() etc.
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <mpi.h>
#include "cpuTimes.c"
#include <unistd.h>     // fork stuff
#include <sys/wait.h>   // to allow wait

#include <sys/types.h>  // required for low-level I/O
#include <fcntl.h>      // File control definitions, like O_RDWR
 
// Diagonal permutation sets solution and lowerLimit
#define MAX_SIZE 32
int solution[MAX_SIZE], lowerLimit;
// NOTE:  size of the problem is available globally
int size, benefit[MAX_SIZE][MAX_SIZE];
enum { FALSE, TRUE };
int DEBUG;
int HANDSHAKE;          // Require the FINIS message
int P_TRACE;
int TRUNCATE;

#define INIT     1      // Message to client with data
#define BLOCK    2      // Complete benefit matrix to client
#define DATA     4      // Both directions, data in / out
#define FINIS    8      // Message to client to terminate

// For MPI messaging, this will be treated as a MAX_SIZE+1 vector
typedef int vector_state[MAX_SIZE+3];

void client (int rank,int nProc);
void explore(int index, int *vect, int prefix);
void MPIsplit(int *vect, int nProc);

void dumpState()
{  int row, col;
   char buffer[2048];
   buffer[0] = 0;
   for (row = 0; row < size; row++)
   {  for (col = 0; col < size; col++)
         sprintf(buffer+strlen(buffer), "%3d", benefit[row][col]);
      sprintf(buffer+strlen(buffer), "\n");
   }
   sprintf(buffer+strlen(buffer), "Benefit %d:  ", lowerLimit);
   for (row = 0; row < size; row++)
      sprintf(buffer+strlen(buffer), "%3d", solution[row]);
   printf("%s\n", buffer);
   return;
}

int main (int argc, char** argv)
{
   // Set global variables
   DEBUG = FALSE;
   HANDSHAKE =  TRUE;
   P_TRACE = FALSE;
   TRUNCATE = FALSE;
   lowerLimit = 0;

   int    nProc,                    // Processes in the communicator
          proc;                     // loop variable
   MPI_Status Status;               // Return status from MPI
   int    rc;                       // Status  code from MPI_Xxx() call
   int    myPos;                    // My own position

   char* filename = argc == 1 ? "Asg5.in" : argv[1];
   FILE *input = fopen(filename, "r");
   int antisum, row, col, j, k, *work;
   double start, elapsed;
   int N_out;

   rc = MPI_Init(&argc, &argv);
   if (rc != MPI_SUCCESS)
   {  puts ("MPI_Init failed"); exit(-1);  }

   rc = MPI_Comm_rank (MPI_COMM_WORLD, &myPos);
   rc = MPI_Comm_size (MPI_COMM_WORLD, &nProc);

   if (P_TRACE)
   {  printf ("Process %d of %d started.\n", myPos, nProc); fflush(stdout);
   }
      

   if ( myPos > 0 )
      client(myPos, nProc);
   else
   {
      if (P_TRACE) {  puts("Starting data server"); fflush(stdout);  }
      if (input == NULL)
      {  printf("Failed to open %s\n", filename);
         // Send terminate message to clients
         size = 0;
         for (row = 1; row < nProc; row++)
            MPI_Send (&size, 1, MPI_INT, row, INIT, MPI_COMM_WORLD);
         MPI_Finalize();
         return 0;
      }
      fscanf(input, "%d", &size);
      work = (int*) calloc(size, sizeof *work);
      antisum  = 0;  k = size;
      for (row = lowerLimit = 0; row < size; row++)
      {  solution[row] = row;
         for (col = 0; col < size; col++)
            fscanf (input, "%d", &benefit[row][col]);
         lowerLimit += benefit[row][row];
         antisum  += benefit[row][--k];
      }
      if (antisum > lowerLimit)
      {  lowerLimit = antisum;
         for (row = 0, k = size; row < size; row++)
            solution[row] = --k;
      }
      memcpy(work, solution, size * sizeof *work);
      if (TRUNCATE) dumpState();
      start = cpuClock();
      MPIsplit(work, nProc);
      elapsed = cpuClock() - start;
      printf("Solution found with value %d\n", lowerLimit);
      for (row = 0; row < size; row++)
         printf("%3d", solution[row]);
      printf("\n%3.3f seconds\n", elapsed);
   }
// Everybody resigns from MPI before terminating.
   MPI_Finalize();
   return 0;
}

void swap(int *x, int j, int k)
{  int temp = x[j]; x[j] = x[k]; x[k] = temp;  }

void MPIsplit(int *work, int nProc)
{
   int  proc, j, k;
   int  bound[MAX_SIZE+1] = {0}; // Force bounds[0] to 0
   vector_state result;
   // [0] start range, [1] end range, [2] lowerLimit [3..] solution
   int *initialState = (int*) calloc (3 + size, sizeof *initialState);
   MPI_Status Status;               // Return status from MPI

   initialState[2] = lowerLimit;
   memcpy(&initialState[3], solution, size * sizeof *solution);
   for (k = 1; k <= nProc; k++)
      bound[k] = size * k / nProc;
   // The work server will do the last block.
   for ( proc = 1; proc < nProc; proc++ )
   {  if (P_TRACE) {printf("Sending to %d\n", proc); fflush(stdout);  }
      MPI_Send (&size, 1, MPI_INT, proc, INIT, MPI_COMM_WORLD);
/**/
      // Easier to ignore size and use MAX_SIZE
      if (P_TRACE) {printf("Send benefit to %d\n", proc); fflush(stdout);  }
      MPI_Send (benefit, MAX_SIZE*MAX_SIZE, MPI_INT, proc, BLOCK,
                MPI_COMM_WORLD);
      initialState[0] = bound[proc-1]; // Start position
      initialState[1] = bound[proc];   // work UP TO this one
      if (P_TRACE) { printf("Send %d initial state\n", proc); fflush(stdout); }
      MPI_Send (initialState, size+3, MPI_INT, proc, DATA, MPI_COMM_WORLD);
/**/
   }
   if (TRUNCATE) return;
   // Server itself does one subset of the problem
   for (j = bound[nProc-1]; j < bound[nProc]; j++)
   {
      swap(work, 0, j);
      explore (1, work, benefit[0][work[0]]);
      swap(work, 0, j);
   }
   if (P_TRACE)
   {  char buffer[2048];
      sprintf(buffer, "Rank 0:  %d as ", lowerLimit);
      sprintf(buffer+strlen(buffer), " (size %d) ", size);
      j = 0;
      while ( j < size )
         sprintf(buffer+strlen(buffer), "%3d", solution[j++]);
      printf("%s\n", buffer);
   }
   // Receive results back from the work engines
   for ( proc = 1; proc < nProc; proc++ )
   {  MPI_Recv(&result, size+1, MPI_INT, proc, DATA, MPI_COMM_WORLD,
               &Status);
      if (result[size] > lowerLimit)
      {  if (DEBUG)
            printf("Updating from %d to %d for benefit.\n",
                  lowerLimit, result[size]);
         lowerLimit = result[size];
         memcpy (solution, result, size * sizeof *solution);
      }
   }
   if (HANDSHAKE)
   {  for (proc = 1; proc < nProc; proc++)
         MPI_Send(&k, 0, MPI_INT, proc, FINIS, MPI_COMM_WORLD);
      for (proc = 1; proc < nProc; proc++)
         MPI_Recv(&k, 0, MPI_INT, proc, FINIS, MPI_COMM_WORLD,
                  &Status);
   }
}

void client (int rank, int nProc)
{
   MPI_Status Status;             // Return status from MPI
   int *msgvect, *work;
   int  k;                        // loop variable
   vector_state result;

   if (P_TRACE)
   {  printf("Compute engine %d starting\n", rank); fflush(stdout);  }

   // Receive the global variable size
   MPI_Recv(&size, 1, MPI_INT, 0, INIT, MPI_COMM_WORLD,
            &Status);
   if (size < 1) return;     // Failure message from server
/**/
   msgvect = (int*) calloc (size+3, sizeof *msgvect);
   MPI_Recv(benefit, MAX_SIZE*MAX_SIZE, MPI_INT, 0, BLOCK,
            MPI_COMM_WORLD, &Status);
   MPI_Recv(msgvect, 3+size, MPI_INT, 0, DATA, MPI_COMM_WORLD,
            &Status);
   // Set up the global state
   work = (int*) calloc(size, sizeof *work);
   memcpy(work, msgvect+3, size * sizeof *work);
   lowerLimit = msgvect[2];
   memcpy(solution, work, size * sizeof *work);
   if (P_TRACE)
      printf("Work engine %d working on %d up to %d\n", rank,
             msgvect[0], msgvect[1]);

   if (TRUNCATE)
   {  dumpState();  return;  }
   for (k = msgvect[0]; k < msgvect[1]; k++)
   {
      swap(work, 0, k);
      explore (1, work, benefit[0][work[0]]);
      swap(work, 0, k);
   }
   if (P_TRACE)
   {  int k = 0;
      char buffer[2048] = {'\0'};
      sprintf(buffer, "Rank %d:  %d (size is %d)",
              rank, lowerLimit, size);
      k = 0;
      while ( k < size )
      {
         sprintf(buffer+strlen(buffer), "%3d", solution[k++]);
      }
      printf("%s\n", buffer); fflush(stdout); 
   }
   result[size] = lowerLimit;
   memcpy(result, solution, size * sizeof *solution);
   MPI_Send(&result, size+1, MPI_INT, 0, DATA, MPI_COMM_WORLD);
   if (HANDSHAKE)
   {  MPI_Recv(&k, 0, MPI_INT, 0, FINIS, MPI_COMM_WORLD,
               &Status);
      MPI_Send(&k, 0, MPI_INT, 0, FINIS, MPI_COMM_WORLD);
   }
/**/
}

int colMaxSum(int start, int work[])
{  int sum = 0, k;
   for (k = start; k < size; k++)
   {
      int columnMaximum = benefit[start][work[k]],
          row;
      for (row = start+1; row < size; row++)
         if (columnMaximum < benefit[row][work[k]])
            columnMaximum = benefit[row][work[k]];
      sum += columnMaximum;
   }
   return sum;
}

void explore(int index, int vect[], int prefix)
{
   int k,       // Loop variable for swapping [index]..[size-1]
       j,       // Spare loop variable
       hold;    // Value of fixed portion of a permutation,
                // then used undoing the right rotation
   for ( k = index; k < size; k++ )
   {  int col, upperBound;
      swap(vect, index, k);
   // Add together column maxima
      upperBound = colMaxSum(index+1, vect);

      hold = prefix + benefit[index][vect[index]];

      // Selecting [size-2] also fixes [size-1] ---
      // a permutation has been completed.
      if (index == size-2)
      {
         if ( lowerLimit < hold+upperBound )
         {  // Add in the last piece
            hold += benefit[size-1][vect[size-1]];
            if (DEBUG)
            {  printf("Replacing %d with %d: ", lowerLimit, hold);
               for (j = 0; j < size; j++)
                  printf("%3d", vect[j]);
               putchar('\n');
            }
            lowerLimit = hold;
            memcpy(solution, vect, size * sizeof *vect);
//          System.arraycopy(vect, 0, solution, 0, size);
         }
         else if (DEBUG)
         {  printf("Reject %d: ", hold+upperBound);
            for (j = 0; j < size; j++)
               printf("%3d", vect[j]);
            putchar('\n');
         }
      }
      else if (hold + upperBound > lowerLimit)
         explore(index+1, vect, hold);
      else if (DEBUG)
      {  printf("Prune at %d, upper limit %d: ",
                index, hold + upperBound);
         for (j = 0; j < size; j++)
            printf("%3d", vect[j]);
         putchar('\n');
      }
   }
   // Undo the one-cell rightward rotation done above
   hold = vect[index];
   for (k = index+1; k < size; k++)
      vect[k-1] = vect[k];
   vect[size-1] = hold;
}
