#include "PLA.h"

int PLA_Chol_blk_var3_opt4( PLA_Obj A, int nb_alg )
{
  PLA_Obj ATL=NULL,   ATR=NULL,      A00=NULL, A01=NULL, A02=NULL, 
          ABL=NULL,   ABR=NULL,      A10=NULL, A11=NULL, A12=NULL,
                                     A20=NULL, A21=NULL, A22=NULL;

  PLA_Obj MINUS_ONE=NULL, ZERO=NULL, ONE=NULL;

  PLA_Obj A11_dmsc=NULL, A21_mv=NULL, A21_dpmv=NULL, A21t_dpmv=NULL;

  int b;

  PLA_Create_constants_conf_to( A, &MINUS_ONE, &ZERO, &ONE );

  PLA_Part_2x2( A,    &ATL, &ATR,
                      &ABL, &ABR,     0, 0, PLA_TL );

  while ( PLA_Obj_length( ATL ) < PLA_Obj_length( A ) ){

    /* Block according to the algorithmic block size */
    b = min( PLA_Obj_length( ABR ), nb_alg );

    PLA_Repart_2x2_to_3x3( ATL, /**/ ATR,       &A00, /**/ &A01, &A02,
                        /* ************* */   /* ******************** */
                                                &A10, /**/ &A11, &A12,
                           ABL, /**/ ABR,       &A20, /**/ &A21, &A22,
                           b, b, PLA_BR );

    /*------------------------------------------------------------*/

    /* This time A11 is copied to ALL nodes and the factorization of A11 
       proceeds independently on all nodes */
    PLA_Mscalar_create_conf_to( A11, PLA_ALL_ROWS, PLA_ALL_COLS, &A11_dmsc );
    PLA_Copy( A11, A11_dmsc );
    PLA_Local_chol( PLA_LOWER_TRIANGULAR, A11_dmsc );
    /* Copy A11_dmsc back to A11.  Notice that this involves no communication, 
       since every node has a complete copy of A11_dmsc */
    PLA_Copy( A11_dmsc, A11 );

    /* Redistribute A21 as a multivector */
    PLA_Mvector_create_conf_to( A21, b, &A21_mv );
    PLA_Copy( A21, A21_mv );
    /* This means that all nodes own part of A21 and can update that part 
       independently */
    PLA_Local_trsm( PLA_RIGHT, PLA_LOWER_TRIANGULAR, 
                    PLA_TRANSPOSE, PLA_NONUNIT_DIAG,
                    ONE, A11_dmsc, A21_mv );
    PLA_Obj_free( &A11_dmsc );

    /* Here things get very tricky: The copying into A21_dpmv and A21t_dpmv 
       makes it so that all data in A21 is distributed so that
       A22 - A21 * A21' can proceed completely locally on each node */
    PLA_Pmvector_create_conf_to( A22, PLA_PROJ_ONTO_COLS, PLA_ALL_COLS,
                                 b, &A21_dpmv );
    PLA_Pmvector_create_conf_to( A22, PLA_PROJ_ONTO_ROWS, PLA_ALL_ROWS,
                                 b, &A21t_dpmv );
    PLA_Copy( A21_mv, A21_dpmv );
    PLA_Copy( A21_mv, A21t_dpmv );
    PLA_Obj_free( &A21_mv );
    /* It turns out that A21_dpmv can be copied into A21 without 
       requiring communication, making it better to copy A21_dpmv
       into A21 than to copy A21_mv into A21 */
    PLA_Copy( A21_dpmv, A21 );

    PLA_Syrk_perform_local_part( PLA_LOWER_TRIANGULAR,
                                 MINUS_ONE, A21_dpmv, A21t_dpmv, ONE, A22 );

    PLA_Obj_free( &A21_dpmv );
    PLA_Obj_free( &A21t_dpmv );

    /*------------------------------------------------------------*/

    PLA_Cont_with_3x3_to_2x2( &ATL, /**/ &ATR,       A00, A01, /**/ A02,
                                                     A10, A11, /**/ A12,
                            /* ************** */  /* ****************** */
                              &ABL, /**/ &ABR,       A20, A21, /**/ A22,
                              PLA_TL );

  }

  PLA_Obj_free( &ATL ); PLA_Obj_free( &ATR );
  PLA_Obj_free( &ABL ); PLA_Obj_free( &ABR );
  PLA_Obj_free( &A00 ); PLA_Obj_free( &A01 ); PLA_Obj_free( &A02 );
  PLA_Obj_free( &A10 ); PLA_Obj_free( &A11 ); PLA_Obj_free( &A12 );
  PLA_Obj_free( &A20 ); PLA_Obj_free( &A21 ); PLA_Obj_free( &A22 );

  PLA_Obj_free( &MINUS_ONE ); PLA_Obj_free( &ZERO ); PLA_Obj_free( &ONE );

  return PLA_SUCCESS;
}