#include "mpi.h"
#include "PLA.h"

int main( int argc, char **argv )
{ 
  int            me, m, n;
  PLA_Template   template;
  MPI_Comm       comm;
  PLA_Obj        a_global = NULL, x_global = NULL, y_global = NULL, 
                 one = NULL, zero = NULL;

  MPI_Init( &argc, &argv );                                /* initialize MPI */
  MPI_Comm_rank( MPI_COMM_WORLD, &me );          /* extract this node's rank */
                /* create a communicator with a suggested square 2D topology */
  PLA_Comm_1D_to_2D_ratio( MPI_COMM_WORLD, 1.0, &comm );
  PLA_Init( comm );                                 /*    initialize PLAPACK */
  template_init( &template );                         /* initialize template */
                  /* get global matrix dimensions and broadcast to all nodes */
  if ( 0 == me ) {               
    printf("Enter matrix dimensions (mxn):\n");
    scanf("%d%d", &m, &n );
  }
  MPI_Bcast( &m, 1, MPI_INT, 0, MPI_COMM_WORLD );
  MPI_Bcast( &n, 1, MPI_INT, 0, MPI_COMM_WORLD );
                        /* create global matrix A and global vectors x and y */
  PLA_Matrix_create( MPI_DOUBLE, m, n, template, 0, 0, &a_global );
  PLA_Vector_create( MPI_DOUBLE, n,    template, 0,    &x_global );
  PLA_Vector_create( MPI_DOUBLE, m,    template, 0,    &y_global );

  A_x_fill( a_global, x_global );              /* fill matrix A and vector x */
                                            /* create constants zero and one */
  PLA_Mscalar_create( MPI_DOUBLE, PLA_ALL_ROWS, PLA_ALL_COLS, 1, 1, 
                                                          template, &zero );
  PLA_Obj_set_to_zero( zero );
  PLA_Mscalar_create( MPI_DOUBLE, PLA_ALL_ROWS, PLA_ALL_COLS, 1, 1, 
                                                          template, &one );
  PLA_Obj_set_to_one ( one );
                                  /* perform matrix-vector multiply  y = A x */
  PLA_Gemv( PLA_NOTRANS, one, a_global, x_global, zero, y_global );

  process_vector( y_global );                  /* do something with result y */
                                                         /* free the objects */
  PLA_Obj_free( &a_global );     PLA_Obj_free( &x_global );    
  PLA_Obj_free( &y_global );
  PLA_Obj_free( &one );          PLA_Obj_free( &zero );
          
  PLA_Temp_free( &template );                           /* free the template */
  PLA_Finalize( );                                       /* finalize PLAPACK */
  MPI_Finalize( );                                           /* finalize MPI */
}


