/*
  MPImessenger.C
  ------------------------------------------------------------------------
  MPI messenger class.
  ------------------------------------------------------------------------
  @(#) $Id: MPImessenger.C,v 1.33 1998/11/06 05:03:29 emery Exp $
  ------------------------------------------------------------------------
  AUTHOR/CONTACT:
 
  Emery Berger                    | <http://www.cs.utexas.edu/users/emery>
  Parallel Programming Group      |  <http://www.cs.utexas.edu/users/code>
  Department of Computer Sciences |             <http://www.cs.utexas.edu>
  University of Texas at Austin   |                <http://www.utexas.edu>
  ========================================================================
*/

#include <mpi.h>

#include <string.h>
#include <stdlib.h>

#include "atomic.H"
#include "machine.H"
#include "messenger.H"
#include "parms.H"
#include "MPImessenger.H"

static char const rcsid[] = "$Id: MPImessenger.C,v 1.33 1998/11/06 05:03:29 emery Exp $";

MPIMessenger::MPIMessenger (void)
  : _inpos (0),
    _outpos (0),
    _buffer (new char[BUFFSIZE])
{
}


MPIMessenger::~MPIMessenger (void)
{
  delete _buffer;
}


void MPIMessenger::initialize (int& argc, char **& argv)
{
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  MPI_Init (&argc, &argv);

  MPI_Buffer_attach (_buffer, BUFFSIZE);

  // Set the error handler so that MPI errors
  // don't kill us.

  MPI_Errhandler_set (MPI_COMM_WORLD, MPI_ERRORS_RETURN);
}


void MPIMessenger::finalize (void)
{
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif

#if DEBUG_PRINT
  cout << "Finalize on processor " << self() << endl << ::flush;
#endif
  
  // Output any pending messages.
  
#if 0
  while (flush ())
    ;
#endif

  MPI_Barrier (MPI_COMM_WORLD);	// Wait for everyone.

  char *buff;
  int size;
  MPI_Buffer_detach (&buff, &size);
  delete buff;

  MPI_Finalize ();
}


int MPIMessenger::self (void)
{
  Guard m1 (objectlock());
  static int _self = -1;
  if (_self == -1) {
#if !(THREAD_SAFE_MPI)
    Guard classguard (classlock());
#endif
    MPI_Comm_rank (MPI_COMM_WORLD, &_self);
  }
  return _self;
}


int MPIMessenger::numProcessors (void)
{
  Guard m1 (objectlock());
  static int _numProcessors = -1;
  if (_numProcessors == -1) {
#if !(THREAD_SAFE_MPI)
    Guard classguard (classlock());
#endif
    MPI_Comm_size (MPI_COMM_WORLD, &_numProcessors);
  }
  return _numProcessors;
}


void MPIMessenger::send (const int dest, const int tag)
{
  // cout << "Sending to " << dest << " with tag " << tag << endl;
  assert (dest >= 0);
  assert (dest < numProcessors ());
  assert (tag >= 0);
  Guard m (_outboxlock);
  _outbox.push_back (new OutPacket (dest, tag, (void *) _outbuf.array(), _outbuf.size()));
  // Packing after a send resumes at the beginning of the buffer.
  _outpos = 0;
  _outbuf.resize (0);
}


int MPIMessenger::flush (void)
{
  Guard m1 (objectlock());

  _outboxlock.acquire ();
  if (!_outbox.empty ()) {
    OutPacket * packet = _outbox.front ();
    assert (packet);
    _outbox.pop_front ();
    _outboxlock.release ();

    int flag = 0;
    
#if !(THREAD_SAFE_MPI)
    Guard classguard (classlock());
#endif

#if DEBUG_PRINT
    printf ("[%d] Sending a '%s' message to %d (buffer size = %d).\n", self(), CODE_Machine::MessageTagString[packet->getTag()], packet->getDestination(), packet->getLength());
#endif

   
    int res = MPI_Bsend (packet->getBuffer(), packet->getLength(), MPI_PACKED, packet->getDestination(), packet->getTag(), MPI_COMM_WORLD);

    delete packet;
    
    if (res != MPI_SUCCESS) {
      char errstr[255];
      int len;
      MPI_Error_string (res, errstr, &len);
      cerr << "Send error: " << errstr << endl;
    }
    // printf ("[%d] Performed a send.\n", self());
    return 1;
  } else {
    _outboxlock.release ();
    return 0;
  }
}


int MPIMessenger::receive (void)
{
  // Use MPI defines here.
  int src = MPI_ANY_SOURCE;
  int tag = MPI_ANY_TAG;
  
  int flag;
  objectlock().acquire ();
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif

  // Check to see if any messages are waiting.
  MPI_Status status;
  int res = MPI_Iprobe (src, tag, MPI_COMM_WORLD, &flag, &status);
  int len;
  if (res != MPI_SUCCESS) {
    char errstr[255];
    MPI_Error_string (res, errstr, &len);
    cerr << "Receive error: " << errstr << endl;
  }

  if (flag) {
    MPI_Get_count (&status, MPI_PACKED, &len);
    _inbuf.resize (len);
    res = MPI_Recv ((void *) _inbuf.array(), _inbuf.size(), MPI_PACKED, src, tag, MPI_COMM_WORLD, &status);
    if (res != MPI_SUCCESS) {
      char errstr[255];
      MPI_Error_string (res, errstr, &len);
      cerr << "Receive error: " << errstr << endl;
    }
#if DEBUG_PRINT
    cout << "Saving packet from " << status.MPI_SOURCE << " with tag " << status.MPI_TAG << " and length " << _inbuf.size() << endl;
#endif
    _length = _inbuf.size();
    _sender = status.MPI_SOURCE;
    _tag = status.MPI_TAG;
    _inpos = 0;
    return 1;
  } else {
    objectlock().release ();
    return 0;
  }
}


void MPIMessenger::broadcast (const int firstProcessor,
			      const int lastProcessor,
			      const int tag)
{
  assert (firstProcessor >= 0);
  assert (lastProcessor < numProcessors ());
  assert (firstProcessor <= lastProcessor);
  assert (tag >= 0);

  Guard m1 (objectlock());
  int me = self();
  for (int dest = firstProcessor; dest <= lastProcessor; ++dest) {
    if (dest != me) {
      Guard m (_outboxlock);
      _outbox.push_back (new OutPacket (dest, tag, (void *) _outbuf.array(), _outbuf.size()));
    }
  }
  // Packing after a send resumes at the beginning of the buffer.
  _outpos = 0;
  _outbuf.resize (0);
}


// PVM-style packing and unpacking routines,
// but using MPI primitives.

int MPIMessenger::pkint (const int *np, int nitem)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  _outbuf.resize (_outpos + nitem * sizeof(int));

  for (i = 0; i < nitem; i++) {
    MPI_Pack ((void *) &np[s], 1, MPI_INT, (void *) _outbuf.array(), _outbuf.size(), &_outpos, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::pklong (const long *np, int nitem)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  _outbuf.resize (_outpos + nitem * sizeof(long));

  for (i = 0; i < nitem; i++) {
    MPI_Pack ((void *) &np[s], 1, MPI_LONG, (void *) _outbuf.array(), _outbuf.size(), &_outpos, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::pkdouble (const double *np, int nitem)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  _outbuf.resize (_outpos + nitem * sizeof(double));

  for (i = 0; i < nitem; i++) {
    MPI_Pack ((void *) &np[s], 1, MPI_DOUBLE, (void *) _outbuf.array(), _outbuf.size(), &_outpos, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::pkfloat (const float *np, int nitem)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  _outbuf.resize (_outpos + nitem * sizeof(float));

  for (i = 0; i < nitem; i++) {
    MPI_Pack ((void *) &np[s], 1, MPI_FLOAT, (void *) _outbuf.array(), _outbuf.size(), &_outpos, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::pkbyte (const char *cp, int nitem)
{
  Guard m1 (objectlock());
  _outbuf.resize (_outpos + nitem * sizeof(char));

  memcpy (&_outbuf[_outpos], cp, nitem * sizeof(char));
  _outpos += nitem * sizeof(char);

  return 0;
}


int MPIMessenger::pkstr (const char *cp)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i;

  i = strlen (cp);

  _outbuf.resize (_outpos + sizeof(int) + (i + 1) * sizeof(char));

  // Send the length of the string and then the string itself.
  MPI_Pack (&i, 1, MPI_INT, (void *) _outbuf.array(), _outbuf.size(), &_outpos, MPI_COMM_WORLD);
  MPI_Pack ((void *) cp, i, MPI_CHAR, (void *) _outbuf.array(), _outbuf.size(), &_outpos, MPI_COMM_WORLD);
  return 0;
}



// Unpacking routines.


int MPIMessenger::upkint (int *np, int nitems)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  for (i = 0; i < nitems; i++) {
    MPI_Unpack ((void *) _inbuf.array(),_inbuf.size(), &_inpos, (void *) &np[s], 1, MPI_INT, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::upklong (long *np, int nitems)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  for (i = 0; i < nitems; i++) {
    MPI_Unpack ((void *) _inbuf.array(),_inbuf.size(), &_inpos, (void *) &np[s], 1, MPI_LONG, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::upkdouble (double *np, int nitems)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  for (i = 0; i < nitems; i++) {
    MPI_Unpack ((void *) _inbuf.array(),_inbuf.size(), &_inpos, (void *) &np[s], 1, MPI_DOUBLE, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::upkfloat (float *np, int nitems)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int i, s = 0;

  for (i = 0; i < nitems; i++) {
    MPI_Unpack ((void *) _inbuf.array(),_inbuf.size(), &_inpos, (void *) &np[s], 1, MPI_FLOAT, MPI_COMM_WORLD);
    s ++;
  }
  return 0;
}


int MPIMessenger::upkbyte (char *cp, int nitems)
{
  Guard m1 (objectlock());
  memcpy (cp, &_inbuf[_inpos], nitems * sizeof(char));
  _inpos += nitems * sizeof(char);

  return 0;
}


int MPIMessenger::upkstr (char *cp)
{
  Guard m1 (objectlock());
#if !(THREAD_SAFE_MPI)
  Guard classguard (classlock());
#endif
  int n;

  MPI_Unpack ((void *) _inbuf.array(),_inbuf.size(), &_inpos, &n, 1, MPI_INT, MPI_COMM_WORLD);

  assert (n < 256); // FIX ME! - just to prevent major string screw-ups.

  MPI_Unpack ((void *) _inbuf.array(),_inbuf.size(), &_inpos, (void *) cp, n, MPI_CHAR, MPI_COMM_WORLD);
  cp[n] = '\0';		// Terminate the string with a NUL.
  return 0;
}
