#include <stdio.h>
#include "mpi.h"

#ifdef CRAY
#include <mpp/rastream.h>
#endif

#define TRUE 1

#define COLLMARK_PATH "/tmp"
#define OUTPUT_FILE "CollMark_output"

int main(int argc, char * argv[] )
{
  int 
    me, nprocs, np,
    itype,
    typesize,
    i, 
    ntrials, nrepeats, root,
    *lengths;
  double 
    *timings_bounce_0,
    *timings_bounce_1,
    *timings_bounce_min,
    *timings_bcast_0, 
    *timings_bcast_1, 
    *timings_bcast_2, 
    *timings_bcast_newmpi,
    *timings_scatter_0, 
    *timings_scatter_1,  
    *timings_gather_0,  
    *timings_gather_1,
    *timings_allgather_0,    
    *timings_allgather_1,
    *timings_allgather_2,
    *timings_allgather_3,
    *timings_reduce_0,
    *timings_reduce_1,
    *timings_reduce_newmpi,
    *timings_allreduce_0,
    *timings_allreduce_1,
    *timings_allreduce_newmpi,
    *timings_reduce_scatter_0,
    *timings_reduce_scatter_1;
  MPI_Datatype
    datatype;
  MPI_Comm
    comm = MPI_COMM_NULL;
  MPI_Status
    status;
  char
    output_file_m[ 100 ];
  FILE 
    *fp;

  /* Initialize the Message-Passing Interface */
  MPI_Init(&argc, &argv);

#ifdef CRAY
  set_d_stream( 1 );
#endif

  /* me = this node's index in the communicator */
  MPI_Comm_rank( MPI_COMM_WORLD, &me );          

  /* nprocs = number of nodes in communicator */
  MPI_Comm_size( MPI_COMM_WORLD, &nprocs );          

  /* create an output file */

  if ( me == 0 ){
    sprintf( output_file_m, "%s/%s_%d.m", COLLMARK_PATH, OUTPUT_FILE, nprocs );
//    sprintf( string, "cp %s/version_info %s", COLLMARK_PATH,  output_file_m );
//    system( string );
    fp = fopen( output_file_m, "a" );
  }  

  /* read in the datatype to time */
  if ( me == 0 ){
/*    printf( "enter datatype: (0=char, 1=int, 2=float, 3=double)\n" );
    scanf( "%d", &itype ); */
    itype = 1;  /* only MPI_INT supported right now. */
  }  

  MPI_Bcast( &itype, 1, MPI_INT, 0, MPI_COMM_WORLD );

  switch ( itype ){
  case 0: 
    datatype = MPI_CHAR;
    if ( me == 0 ) fprintf( fp, "%% datatype MPI_CHAR\n" );
    break;
  case 1: 
    datatype = MPI_INT;
    if ( me == 0 ) fprintf( fp, "%% datatype MPI_INT\n" );
    break;
  case 2: 
    datatype = MPI_FLOAT;
    if ( me == 0 ) fprintf( fp, "%% datatype MPI_FLOAT\n" );
    break;
  case 3: 
    datatype = MPI_DOUBLE;
    if ( me == 0 ) fprintf( fp, "%% datatype MPI_DOUBLE\n" );
    break;
  default:
    printf( "unrecognized datatype\n" );
    exit( 0 );
  }

  MPI_Type_size( datatype, &typesize );

  /* read in the number of message lengths to time */
  if ( me == 0 ){
    printf( "number of message lengths:\n" );
    scanf( "%d", &ntrials );
    fprintf( fp, "%% number of message lengths: %d\n", ntrials );
  }

  MPI_Bcast( &ntrials, 1, MPI_INT, 0, MPI_COMM_WORLD );

  lengths = ( int * ) malloc( ntrials * sizeof( int ) );

  /* read in the message lengths to time */
  if ( me == 0 ){
    printf( "enter message lengths:\n" );
    for ( i=0; i<ntrials; i++ ){
      scanf( "%d", &lengths[ i ] );
    }
  }

  MPI_Bcast( lengths, ntrials, MPI_INT, 0, MPI_COMM_WORLD );

  /* Perform bounce test */

  /* read number of times to bounce */
  if ( me == 0 ){
    printf( "number of bounces:\n" );
    scanf( "%d", &nrepeats );
    fprintf( fp, "%% number of bounces = %d\n", nrepeats );
  }

  MPI_Bcast( &nrepeats, 1, MPI_INT, 0, MPI_COMM_WORLD ); 
  nrepeats = 10;

  timings_bounce_0         = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_bounce_1         = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_bounce_min       = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_bcast_0          = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_bcast_1          = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_bcast_2          = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_bcast_newmpi     = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_scatter_0        = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_scatter_1        = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_gather_0         = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_gather_1         = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allgather_0      = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allgather_1      = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allgather_2      = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allgather_3      = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_reduce_0         = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_reduce_1         = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_reduce_newmpi    = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allreduce_0      = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allreduce_1      = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_allreduce_newmpi = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_reduce_scatter_0 = ( double * ) malloc( ntrials * sizeof( double ) );
  timings_reduce_scatter_1 = ( double * ) malloc( ntrials * sizeof( double ) );

  time_bounce( datatype, ntrials, lengths, nrepeats, timings_bounce_0, 
      MPI_COMM_WORLD, 0 );
  time_bounce( datatype, ntrials, lengths, nrepeats, timings_bounce_1, 
      MPI_COMM_WORLD, 1 );

  if ( me == 0 ){
    fprintf( fp, "%%\n");
    fprintf( fp, "%% Message Size | p2p_0   | p2p_1   | p2p_min \n" );
    fprintf( fp, "%% (in bytes)   |         time (in sec.)  \n" );
    fprintf( fp, "%% -------------------------------------------\n" );
    fprintf( fp, "timings_bounce = [\n");

    for ( i=0; i<ntrials; i++) {
      timings_bounce_min[ i ] = ( ( timings_bounce_0[ i ] <
				    timings_bounce_1[ i ] ) ?
				  timings_bounce_0[ i ] :
				  timings_bounce_1[ i ]);

      fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le\n",
	       lengths[ i ] * typesize,
	       timings_bounce_0[ i ],
	       timings_bounce_1[ i ],
	       timings_bounce_min[ i ]); 
    }
    fprintf( fp, "];\n");
  }

  for ( np = nprocs; np>=2; np = np / 2 ){
    if ( me == 0 ) printf("timing %d processors\n", np );

    MPI_Comm_split( MPI_COMM_WORLD, ( ( me < np ) ? 1 : MPI_UNDEFINED ),  
	  me, &comm );   

    root = 0;

    if ( me < np ){

      if ( me == 0 ) printf( "timing bcast version 0\n" );
      time_bcast( datatype, ntrials, lengths, root, timings_bcast_0,
	          comm, 0 );

      if ( me == 0 ) printf( "timing bcast version 1\n" );
      time_bcast( datatype, ntrials, lengths, root, timings_bcast_1,
	          comm, 1 );

      if ( me == 0 ) printf( "timing bcast version 2\n" );
      time_bcast( datatype, ntrials, lengths, root, timings_bcast_2,
	          comm, 2 );

      if ( me == 0 ) printf( "timing new mpi bcast\n" );
      time_bcast( datatype, ntrials, lengths, root, timings_bcast_newmpi,
	          comm, 3 );

      if ( me == 0 ){
	printf("length = %d my bcast 1 is %5.2lf times as fast as MPI_Bcast\n",
	       lengths[ ntrials-1 ],
	       timings_bcast_0[ ntrials-1 ] / timings_bcast_1[ ntrials-1 ] );
	printf("length = %d my bcast 2 is %5.2lf times as fast as MPI_Bcast\n",
	       lengths[ ntrials-1 ],
	       timings_bcast_0[ ntrials-1 ] / timings_bcast_2[ ntrials-1 ] );
	printf("length = %d new mpi bcast is %5.2lf times as fast as MPI_Bcast\n",
	       lengths[ ntrials-1 ],
	       timings_bcast_0[ ntrials-1 ] / timings_bcast_newmpi[ ntrials-1 ] );
      }

      if ( me == 0 ) printf( "timing reduce version 0\n" );
      time_reduce( datatype, ntrials, lengths, root, timings_reduce_0,
	           comm , 0 );

      if ( me == 0 ) printf( "timing reduce version 1\n" );
      time_reduce( datatype, ntrials, lengths, root, timings_reduce_1,
		   comm , 1 ); 

      if ( me == 0 ) printf( "timing new mpi reduce\n" );
      time_reduce( datatype, ntrials, lengths, root, timings_reduce_newmpi,
		  comm , 2 ); 

      if ( me == 0 ){
	printf("for length = %d my reduce is %5.2lf times as fast as MPI_Reduce\n",
	       lengths[ ntrials-1 ],
	       timings_reduce_0[ ntrials-1 ] / timings_reduce_1[ ntrials-1 ] );

	printf("for length = %d new mpi reduce is %5.2lf times as fast as MPI_Reduce\n",
	       lengths[ ntrials-1 ],
	       timings_reduce_0[ ntrials-1 ] / timings_reduce_newmpi[ ntrials-1 ] );
      }

      if ( me == 0 ) printf( "timing scatter version 0\n" );
      time_scatter( datatype, ntrials, lengths, root, timings_scatter_0,
	            comm, 0 );

      if ( me == 0 ) printf( "timing scatter version 1\n" );
      time_scatter( datatype, ntrials, lengths, root, timings_scatter_1,
	            comm, 1 );

      if ( me == 0 )
	printf("for length = %d my scatter is %5.2lf times as fast as MPI_Scatter\n",
	       lengths[ ntrials-1 ],
	       timings_scatter_0[ ntrials-1 ] / timings_scatter_1[ ntrials-1 ] );

      if ( me == 0 ) printf( "timing gather version 0 \n" );
      time_gather( datatype, ntrials, lengths, root, timings_gather_0,
	            comm, 0 );

      if ( me == 0 ) printf( "timing gather version 1 \n" );
      time_gather( datatype, ntrials, lengths, root, timings_gather_1,
	            comm, 1 );

      if ( me == 0 )
	printf("for length = %d my gather is %5.2lf times as fast as MPI_Gather\n",
	       lengths[ ntrials-1 ],
	       timings_gather_0[ ntrials-1 ] / timings_gather_1[ ntrials-1 ] );

      if ( me == 0 ) printf( "timing allgather version 0 \n" );
      
      time_allgather( datatype, ntrials, lengths, timings_allgather_0,
	            comm, 0 );

      if ( me == 0 ) printf( "timing allgather version 1 \n" );

      time_allgather( datatype, ntrials, lengths, timings_allgather_1,
		     comm, 1 );

      if ( me == 0 ) printf( "timing allgather version 2 \n" );

      time_allgather( datatype, ntrials, lengths, timings_allgather_2,
		     comm, 2 );

      if ( me == 0 ) printf( "timing allgather version 3 \n" );

      time_allgather( datatype, ntrials, lengths, timings_allgather_3,
		     comm, 3 );

      if ( me == 0 ){
	printf("length = %d my allgather 1 is %5.2lf times as fast as MPI_Allgather\n",
	       lengths[ ntrials-1 ],
	       timings_allgather_0[ ntrials-1 ] / timings_allgather_1[ ntrials-1 ] );
	printf("length = %d my allgather 2 is %5.2lf times as fast as MPI_Allgather\n",
	       lengths[ ntrials-1 ],
	       timings_allgather_0[ ntrials-1 ] / timings_allgather_2[ ntrials-1 ] );
	printf("length = %d my allgather 3 is %5.2lf times as fast as MPI_Allgather\n",
	       lengths[ ntrials-1 ],
	       timings_allgather_0[ ntrials-1 ] / timings_allgather_3[ ntrials-1 ] );
      }

      if ( me == 0 ) printf( "timing allreduce version 0 \n" );
      time_allreduce( datatype, ntrials, lengths, timings_allreduce_0,
		      comm, 0 );

      if ( me == 0 ) printf( "timing allreduce version 1 \n" );
      time_allreduce( datatype, ntrials, lengths, timings_allreduce_1,
		      comm, 1 );

      if ( me == 0 ) printf( "timing new mpi allreduce \n" );
      time_allreduce( datatype, ntrials, lengths, timings_allreduce_newmpi,
		      comm, 2 );


      if ( me == 0 ){
	printf("length = %d my allreduce is %5.2lf times as fast as MPI_Allreduce\n",
	       lengths[ ntrials-1 ],
	       timings_allreduce_0[ ntrials-1 ] / timings_allreduce_1[ ntrials-1 ] );
	printf("length = %d new mpi allreduce is %5.2lf times as fast as MPI_Allreduce\n",
	       lengths[ ntrials-1 ],
	       timings_allreduce_0[ ntrials-1 ] / timings_allreduce_newmpi[ ntrials-1 ] );
      }

      if ( me == 0 ) printf( "timing reduce_scatter version 0 \n" );
      time_reduce_scatter( datatype, ntrials, lengths, timings_reduce_scatter_0,
			   comm, 0 );

      if ( me == 0 ) printf( "timing reduce_scatter version 1 \n" );
      time_reduce_scatter( datatype, ntrials, lengths, timings_reduce_scatter_1,
			   comm, 1 );

      if ( me == 0 ){
	printf("length = %d my reduce_scatter is %5.2lf times as fast as MPI_Reduce_scatter\n",
	       lengths[ ntrials-1 ],
	       timings_reduce_scatter_0[ ntrials-1 ] / timings_reduce_scatter_1[ ntrials-1 ] );
      }

      MPI_Comm_free( &comm ); 
    }

    if ( me == 0 ) {
      fprintf( fp, "%%\n");
      fprintf( fp, "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n");
      fprintf( fp, "%% Number of nodes   = %3d                 %%\n", np );
      fprintf( fp, "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n");

      // Print timing values for Broadcast
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | bcast0  | bcast1  | bcast2  | bcast_newmpi\n");
      fprintf( fp, "%% (in bytes)   |                  time (in sec.)             \n");
      fprintf( fp, "%% -------------------------------------------------------------------\n");
      fprintf( fp, "timings_bcast_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_bcast_0[ i ],
		 timings_bcast_1[ i ],
		 timings_bcast_2[ i ],
		 timings_bcast_newmpi[ i ]);
      fprintf( fp, "];\n");

      // Print timing values for Scatter
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | scatr0  | scatr1\n");
      fprintf( fp, "%% (in bytes)   |      time (in sec.)       \n");
      fprintf( fp, "%% ------------------------------------------\n");
      fprintf( fp, "timings_scatter_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_scatter_0[ i ],
		 timings_scatter_1[ i ]);
      fprintf( fp, "];\n");

      // Print timing values for Gather
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | gathr0  | gathr1\n");
      fprintf( fp, "%% (in bytes)   |      time (in sec.)       \n");
      fprintf( fp, "%% ------------------------------------------\n");
      fprintf( fp, "timings_gather_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_gather_0[ i ],
		 timings_gather_1[ i ]);
      fprintf( fp, "];\n");

      // Print timing values for AllGather
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | agath0  | agath1  | agath2  | agath3\n");
      fprintf( fp, "%% (in bytes)   |                  time (in sec.)               \n");
      fprintf( fp, "%% --------------------------------------------------------------\n");
      fprintf( fp, "timings_allgather_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_allgather_0[ i ],
		 timings_allgather_1[ i ],
		 timings_allgather_2[ i ],
		 timings_allgather_3[ i ]);
      fprintf( fp, "];\n");

      // Print timing values for Reduce_Scatter
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | rscat0  | rscat1\n");
      fprintf( fp, "%% (in bytes)   |       time (in sec.)      \n");
      fprintf( fp, "%% ------------------------------------------\n");
      fprintf( fp, "timings_reduce_scatter_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_reduce_scatter_0[ i ],
		 timings_reduce_scatter_1[ i ]);
      fprintf( fp, "];\n");

      // Print timing values for Reduce
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | red0    | red1    | red_newmpi\n");
      fprintf( fp, "%% (in bytes)   |              time (in sec.)             \n");
      fprintf( fp, "%% -------------------------------------------------------\n");
      fprintf( fp, "timings_reduce_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_reduce_0[ i ],
		 timings_reduce_1[ i ],
		 timings_reduce_newmpi[ i ]);
      fprintf( fp, "];\n");

      // Print timing values for AllReduce
      fprintf( fp, "%%\n");
      fprintf( fp, "%% Message Size | p2p     | allred0 | allred1 | allred_newmpi\n");
      fprintf( fp, "%% (in bytes)   |                time (in sec.)              \n");
      fprintf( fp, "%% ----------------------------------------------------------\n");
      fprintf( fp, "timings_allreduce_%d = [\n", np);
      for ( i=0; i<ntrials; i++)
	fprintf( fp, "     %7d     %3.1le   %3.1le   %3.1le   %3.1le\n",
		 lengths[ i ] * typesize,
		 timings_bounce_min[ i ],
		 timings_allreduce_0[ i ],
		 timings_allreduce_1[ i ],
		 timings_allreduce_newmpi[ i ]);
      fprintf( fp, "];\n");

      fflush( fp );
    }
  }

  free( lengths );
  free( timings_bounce_0 );
  free( timings_bounce_1 );
  free( timings_bounce_min );
  free( timings_bcast_0 );
  free( timings_bcast_1 );
  free( timings_bcast_2 );
  free( timings_bcast_newmpi );
  free( timings_scatter_0 );
  free( timings_scatter_1 );
  free( timings_gather_0 );
  free( timings_gather_1 );
  free( timings_allgather_0 );
  free( timings_allgather_1 );
  free( timings_allgather_2 );
  free( timings_allgather_3 );
  free( timings_reduce_0 );
  free( timings_reduce_1 );
  free( timings_reduce_newmpi );
  free( timings_allreduce_0 );
  free( timings_allreduce_1 );
  free( timings_allreduce_newmpi );
  free( timings_reduce_scatter_0 );
  free( timings_reduce_scatter_1 );

  fclose( fp ); 

  MPI_Finalize();
}
