#include <string.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <time.h>
#include "RLPlayer.h"

RLPlayer::RLPlayer( ActHandler* act, WorldModel *wm, ServerSettings *ss,
		    PlayerSettings *ps,
		    char* strTeamName, char *loadWeightsFile, char *saveWeightsFile, 
		    bool bLearn, int iNumKeepers, int iNumTakers, 
		    double dVersion, int iReconnect ):
  KeepawayPlayer( act, wm, ss, ps, strTeamName, 
		  iNumKeepers, iNumTakers, dVersion, iReconnect )
{
  if ( strlen( loadWeightsFile ) > 0 )
    loadWeights( loadWeightsFile );
  if ( strlen( saveWeightsFile ) > 0 ) {
    strcpy( weightsFile, saveWeightsFile );
    bSaveWeights = true;
  }
  else {
    bSaveWeights = false;
  }
  bLearning = bLearn;

  epochNum = 0;

  alpha = 0.125;
  gamma = 1.0;
  lambda = 0;
  epsilon = 0.01;

  numActions = WM->getNumKeepers();

  minimumTrace = 0.01;

  initializeTileWidths();

  numNonzeroTraces = 0;
  for ( int i = 0; i < RL_MEMORY_SIZE; i++ ) {
    weights[ i ] = 0;
    traces[ i ] = 0;
  }

  srand( (unsigned int) 0 );
  int tmp[ 2 ];
  float tmpf[ 2 ];
  colTab = new collision_table( RL_MEMORY_SIZE, 1 );
  GetTiles( tmp, 1, 1, tmpf, 0 );  // A dummy call to set the hashing table    
  srand( time( NULL ) );

  if ( strlen( loadWeightsFile ) > 0 )
    loadWeights( loadWeightsFile );
}

int RLPlayer::SMDP_startEpisode( double state[], int numFeatures )
{
  epochNum++;
  decayTraces( 0 );
  loadTiles( state, numFeatures );
  for ( int a = 0; a < numActions; a++ ) {
    Q[ a ] = computeQ( a );
  }
  int action = selectAction();

  for ( int j = 0; j < numTilings; j++ )
    setTrace( tiles[ action ][ j ], 1.0 );
  return action;
}

int RLPlayer::SMDP_step( double reward, double state[], int numFeatures )
{
  double delta = reward - Q[ m_lastAction ];
  loadTiles( state, numFeatures );
  for ( int a = 0; a < numActions; a++ ) {
    Q[ a ] = computeQ( a );
  }

  int action = selectAction();

  if ( !bLearning )
    return action;

  char buffer[128];
  sprintf( buffer, "reward: %.2f", reward ); 
  LogDraw.logText( "reward", VecPosition( 25, 30 ),
		   buffer,
		   1, COLOR_NAVY );

  delta += Q[ action ];
  updateWeights( delta );
  Q[ action ] = computeQ( action ); // need to redo because weights changed
  decayTraces( gamma * lambda );

  for ( int a = 0; a < numActions; a++ ) {  //clear other than F[a]
    if ( a != action ) {
      for ( int j = 0; j < numTilings; j++ )
        clearTrace( tiles[ a ][ j ] );
    }
  }
  for ( int j = 0; j < numTilings; j++ )      //replace/set traces F[a]
    setTrace( tiles[ action ][ j ], 1.0 );

  return action;
}

void RLPlayer::SMDP_endEpisode( double reward )
{
  if ( bLearning && m_lastAction != -1 ) { /* otherwise we never ran on this episode */
    char buffer[128];
    sprintf( buffer, "reward: %.2f", reward ); 
    LogDraw.logText( "reward", VecPosition( 25, 30 ),
		     buffer,
		     1, COLOR_NAVY );

    /* finishing up the last episode */
    /* assuming gamma = 1  -- if not,error*/
    if ( gamma != 1.0)
      cerr << "We're assuming gamma's 1" << endl;
    double delta = reward - Q[ m_lastAction ];
    updateWeights( delta );
  }
  if ( bLearning && bSaveWeights && rand() % 200 == 0 ) {
    saveWeights( weightsFile );
  }
}

int RLPlayer::selectAction()
{
  int action;

  // Epsilon-greedy
  if ( bLearning && drand48() < epsilon ) {     /* explore */
    action = rand() % numActions;
  }
  else{
    action = argmaxQ();
  }

  return action;
}

void RLPlayer::initializeTileWidths()
{
  int numK = WM->getNumKeepers();
  int numT = WM->getNumTakers();
  int j = 0;

  tileWidths[ j++ ] = 2.0; // WB_dist_to_center                      
  for ( int i = 1; i < numK; i++ )       // WB_dist_to_T          
    tileWidths[ j++ ] = 2.0 + ( i - 1 ) / ( numK - 2 );
  for ( int i = 1; i <= numT; i++ )   // WB_dist_to_O
    tileWidths[ j++ ] = 3.0 + ( i - 1 ) / ( numT - 1 );
  for ( int i = 1; i < numK; i++ )       // dist_to_center_T    
    tileWidths[ j++ ] = 2.0 + ( i - 1 ) / ( numK - 2 );
  for ( int i = 1; i <= numT; i++ )   // dist_to_center_O  
    tileWidths[ j++ ] = 3.0;
  for ( int i = 1; i < numK; i++ )       // nearest_Opp_dist_T 
    tileWidths[ j++ ] = 4.0;
  for ( int i = 1; i < numK; i++ )       // nearest_Opp_ang_T  
    tileWidths[ j++ ] = 10.0;
}

bool RLPlayer::loadWeights( char *filename )
{
  cout << "Loading weights from " << filename << endl;
  int file = open( filename, O_RDONLY );
  read( file, (char *) weights, RL_MEMORY_SIZE * sizeof(double) );
  colTab->restore( file );
  close( file );
  cout << "...done" << endl;
  return true;
}

bool RLPlayer::saveWeights( char *filename )
{
  int file = open( filename, O_CREAT | O_WRONLY, 0664 );
  write( file, (char *) weights, RL_MEMORY_SIZE * sizeof(double) );
  colTab->save( file );
  close( file );
  return true;
}

// Compute an action value from current F and theta    
double RLPlayer::computeQ( int a )
{
  double q = 0;
  for ( int j = 0; j < numTilings; j++ ) {
    q += weights[ tiles[ a ][ j ] ];
  }

  return q;
}

// Returns index (action) of largest entry in Q array, breaking ties randomly 
int RLPlayer::argmaxQ()
{
  int bestAction = 0;
  double bestValue = Q[ bestAction ];
  int numTies = 0;
  for ( int a = bestAction + 1; a < numActions; a++ ) {
    double value = Q[ a ];
    if ( value > bestValue ) {
      bestValue = value;
      bestAction = a;
    }
    else if ( value == bestValue ) {
      numTies++;
      if ( rand() % ( numTies + 1 ) == 0 ) {
	bestValue = value;
	bestAction = a;
      }
    }
  }

  return bestAction;
}

void RLPlayer::updateWeights( double delta )
{
  double tmp = delta * alpha / numTilings;
  for ( int i = 0; i < numNonzeroTraces; i++ ) {
    int f = nonzeroTraces[ i ];
    if ( f > RL_MEMORY_SIZE || f < 0 )
      cerr << "f is too big or too small!!" << f << endl;
    weights[ f ] += tmp * traces[ f ];
  }
}

void RLPlayer::loadTiles( double state[], int numFeatures )
{
  int tilingsPerGroup = 32;  /* num tilings per tiling group */
  numTilings = 0;

  /* These are the 'tiling groups'  --  play here with representations */
  /* One tiling for each state variable */
  for ( int v = 0; v < numFeatures; v++ ) {
    for ( int a = 0; a < numActions; a++ ) {
      GetTiles1( &(tiles[ a ][ numTilings ]), tilingsPerGroup, colTab,
		 state[ v ] / tileWidths[ v ], a , v );
    }  
    numTilings += tilingsPerGroup;
  }
  if ( numTilings > RL_MAX_NUM_TILINGS )
    cerr << "TOO MANY TILINGS! " << numTilings << endl;
}


// Clear any trace for feature f      
void RLPlayer::clearTrace( int f)
{
  if ( f > RL_MEMORY_SIZE || f < 0 )
    cerr << "ClearTrace: f out of range " << f << endl;
  if ( traces[ f ] != 0 )
    clearExistentTrace( f, nonzeroTracesInverse[ f ] );
}

// Clear the trace for feature f at location loc in the list of nonzero traces 
void RLPlayer::clearExistentTrace( int f, int loc )
{
  if ( f > RL_MEMORY_SIZE || f < 0 )
    cerr << "ClearExistentTrace: f out of range " << f << endl;
  traces[ f ] = 0.0;
  numNonzeroTraces--;
  nonzeroTraces[ loc ] = nonzeroTraces[ numNonzeroTraces ];
  nonzeroTracesInverse[ nonzeroTraces[ loc ] ] = loc;
}

// Decays all the (nonzero) traces by decay_rate, removing those below minimum_trace 
void RLPlayer::decayTraces( double decayRate )
{
  int f;
  for ( int loc = numNonzeroTraces - 1; loc >= 0; loc-- ) {
    f = nonzeroTraces[ loc ];
    if ( f > RL_MEMORY_SIZE || f < 0 )
      cerr << "DecayTraces: f out of range " << f << endl;
    traces[ f ] *= decayRate;
    if ( traces[ f ] < minimumTrace )
      clearExistentTrace( f, loc );
  }
}

// Set the trace for feature f to the given value, which must be positive   
void RLPlayer::setTrace( int f, float newTraceValue )
{
  if ( f > RL_MEMORY_SIZE || f < 0 )
    cerr << "SetTraces: f out of range " << f << endl;
  if ( traces[ f ] >= minimumTrace )
    traces[ f ] = newTraceValue;         // trace already exists              
  else {
    while ( numNonzeroTraces >= RL_MAX_NONZERO_TRACES )
      increaseMinTrace(); // ensure room for new trace              
    traces[ f ] = newTraceValue;
    nonzeroTraces[ numNonzeroTraces ] = f;
    nonzeroTracesInverse[ f ] = numNonzeroTraces;
    numNonzeroTraces++;
  }
}

// Try to make room for more traces by incrementing minimum_trace by 10%,
// culling any traces that fall below the new minimum                      
void RLPlayer::increaseMinTrace()
{
  minimumTrace *= 1.1;
  cerr << "Changing minimum_trace to " << minimumTrace << endl;
  for ( int loc = numNonzeroTraces - 1; loc >= 0; loc-- ) { // necessary to loop downwards    
    int f = nonzeroTraces[ loc ];
    if ( traces[ f ] < minimumTrace )
      clearExistentTrace( f, loc );
  }
}
