#include "mpi.h"

#define NULL 0

void my_reduce_scatter_x( void *, void *, int *, MPI_Datatype, MPI_Op, 
			  MPI_Comm );

void my_reduce_scatter( void * send_buf, void * recv_buf, int * recv_count,
		        MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
{
  int
    *offsets, i,
    np;
  
  MPI_Comm_size(comm, &np);

  offsets = ( int * ) malloc( (np+1) * sizeof( int ) );
  offsets[ 0 ] = 0;
  for ( i=0; i<np; i++ )
    offsets[ i+1 ] = offsets[ i ] + recv_count[ i ];

  my_reduce_scatter_x( send_buf, recv_buf, offsets, dtype, op, comm );

  free( offsets );

  return;
}

void my_reduce_scatter_x( void * send_buf, void * recv_buf, int * offsets,
			  MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
{
  int
    me, np, left, right, typesize;
  int
    i = 0, j = 0, jj = 0, next = 0;
  char
    *temp_buf;
  MPI_Status
    status;
  MPI_Request
    request;
  
  MPI_Comm_rank(comm, &me);
  MPI_Comm_size(comm, &np);
  MPI_Type_size( dtype, &typesize );

  left = (me + np - 1) % np;
  right = (me + 1) % np;

  if ( offsets[ np ] > 0 )
    temp_buf = ( char * ) malloc( offsets[ np ] * typesize );
  else 
    temp_buf = ( char * ) 1;

  i = left;

  /* This loop cycles through the processors, moving backwards sending data
     to the left node, receiving for the one before it, sending to that one,
     and so on until it hits the current node */
  while (i != me)
    {
      /* Go back a node, and wrap around if at node 0 */
      next = (i + np - 1) % np;

      MPI_Irecv( 
	        ( me == next ? 
		   recv_buf :
		   ( ( char * ) temp_buf ) + offsets[ next ] * typesize ),
	        offsets[ next+1 ] - offsets[ next ], 
	        dtype,
		left,
                next, 
	        comm, 
	        &request );

      MPI_Send( 
	       ( i == left ? 
		( ( char * ) send_buf ) +  offsets[ i ] * typesize:
		( ( char * ) temp_buf ) +  offsets[ i ] * typesize
		),
	        offsets[ i+1 ] - offsets[ i ],
	        dtype, 
	        right, 
	        i, 
	        comm);

      MPI_Wait( &request, &status );

      i = next;

      /* Add the received data into the data to be sent */
      if ( me != i ) 
	for (j = offsets[ i ]; j< offsets[ i+1 ]; j++){
	  ((int *) temp_buf)[j] += ((int *) send_buf)[j];
	}
      else{
	jj = 0;
	for (j = offsets[ i ]; j< offsets[ i+1 ]; j++){
	  ((int *) recv_buf)[jj++] += ((int *) send_buf)[j];
	}
     }
    }

  if ( offsets[ np ] > 0 )
    free( temp_buf );

  return;
}
