// SOM class : a very simple SOM network
// by Yoonsuck Choe <yschoe@cs.utexas.edu>
// Fri Apr  4 05:01:50 CST 1997

import java.applet.*;
import java.awt.*;

public class SOM {
  public  Neuron map[][]; // map - no initial value
  public  int size; 	  // map size : size x size
  public  int input_dimension; 
  public  int time = 0;		
  public  int end_time = 30000;
  public  int neighborhood_radius = 8;
  public  int neighborhood_schedule[] = { 5, 4, 3, 2, 1 }; 
  private int number_of_phase = 5;
  private int phase_count = 0;
  public  int reduce_phase[] = { 200, 500, 1000, 2000, 3000 }; 	
  public  float update_rate = (float) 0.1;
  public  float update_rate_schedule[] = 
	{ (float) 0.1, (float) 0.01, (float) 0.005, 
	  (float) 0.001, (float) 0.0007 };
  public  float input[];
  
  // local variables
  private int min_i, min_j;
  private float min=99999, max=-99999;
  private int i1, i2, j1, j2;
  private Colormap cmap = new Colormap("0ry1",128);

  // constructor   n: input dim, size: size
  public SOM(int n, int size) { 
	this.input_dimension = n;
	this.size = size;
 	//debug this.neighborhood_radius = (int)(size/2);
	// allocate map 
	map = new Neuron[size][size];
	for (int i=0; i<size; ++i)
	   for (int j=0; j<size; ++j) {
		map[i][j] = new Neuron();
		map[i][j].setup_afferents(input_dimension,size,i,j);
	    }
	// allocate input
	input = new float[n];
  }

  // generate input
  private void generate_input() {
 	// 	- multiply by map size for easier visualization
	for (int i=0; i<input_dimension; ++i) {
 		input[i] = (float) Math.random() * size;
		// System.out.print("input = ["+input[i]+"]");
	}
	// System.out.println("");
  }

  // feed forward
  private void feed_forward() {
	min = 99999;
	max = -99999;
	// generate input in 'input[]' 
	this.generate_input();
	// calculate activity for the same input
	for (int i=0; i<size; ++i) {
	  for (int j=0; j<size; ++j) {
	    // System.out.print(map[i][j].euclidian_distance(input)+" ");
	    map[i][j].euclidian_distance(input);
	    if (map[i][j].activity>max) max = map[i][j].activity;
	    if (map[i][j].activity<min) min = map[i][j].activity;
	  }
	  // System.out.println(" ");
	}
  }

  // update weight : within a fixed radius
  private void update_weight() {

	// find the winner
	min = (float) 9999999;
	for (int i=0; i<size; ++i) {
	  for (int j=0; j<size; ++j) {
	    if (min > map[i][j].activity) {
		min_i = i; min_j = j; min = map[i][j].activity;
	        // System.out.println ("min"+min+" comp "+map[i][j].activity+"  ");
	    }
          }
	}

	// System.out.println("winner unit="+min_i+","+min_j);
	// System.out.println("winner aff ="+map[min_i][min_j].afferent[0]+","+map[min_i][min_j].afferent[1]);

	// find boundary of weight update
	int diff_i = min_i - neighborhood_radius ;
	i1 = (diff_i < 0) ? 0 : diff_i;
	diff_i = min_i + neighborhood_radius ;
	i2 = (diff_i > size-1) ? size-1 : diff_i;
	int diff_j = min_j - neighborhood_radius ;
	j1 = (diff_j < 0) ? 0 : diff_j;
	diff_j = min_j + neighborhood_radius ;
	j2 = (diff_j > size-1) ? size-1 : diff_j;

 	// update weights within the boundary
	// System.out.println("Update rad: "+i1+","+i2+"  "+j1+","+j2);
	for (int i=i1; i<=i2; ++i) {
	  for (int j=j1; j<=j2; ++j) {
	    map[i][j].delta_rule(update_rate, input);
	  }
	}
  }

  // train for number of steps
  public void train(int steps) {
	// Do not check the radius and learning rate schedule
  	// train for number of steps
	for (int k=0; k<steps; ++k) {
	  this.feed_forward();
 	  this.update_weight();
	  time ++;
	  if ((number_of_phase!=phase_count)&&time==reduce_phase[phase_count]){ 
		update_rate = update_rate_schedule[phase_count];
		neighborhood_radius = neighborhood_schedule[phase_count];
		phase_count ++;
	  }	     
	} 
  }

 // display map on graphics g.
 public void display_map(Graphics g, int width, int height){
	// ignore org_x, etc... as for now
	
	Integer ti = new Integer(time);
	Integer rad = new Integer(neighborhood_radius);
	g.drawString("iteration = "+ti.toString(time)+" radius = "+ti.toString(neighborhood_radius),50,50);	

	// draw the nodes and display map activity
	for (int i=0; i<size; ++i) {
	  // System.out.println("display i="+i);
	  for (int j=0; j<size; ++j){ 
	    // draw map activity
	    g.setColor(cmap.table[(int)((1.0-(map[i][j].activity-min)/(max-min))
					 * 0.8 * 127)]);
	    g.fillRect(i*10,50+j*10,10,10);
	    // draw node
	    g.setColor(Color.black);
	    g.drawRect(	
			(int)((size-map[i][j].afferent[0])*width/size), 
			(int)((size-map[i][j].afferent[1])*height/size),
			4,4
		      );
	    // String s = new String(i+","+j);
	    // g.drawString(s,(int)((size-map[i][j].afferent[0])*width/size),
	    // 		(int)((size-map[i][j].afferent[1])*height/size));
	    // System.out.print(map[i][j].afferent[0]+" ");
	  }
	  // System.out.println("");
	}	  
	// draw the edges
	for (int i=0; i<size; ++i) {
	  // System.out.println("display i="+i);
	  for (int j=0; j<size; ++j){ 
	    g.drawLine(	
			(int)((size-map[i][j].afferent[0])*width/size), 
			(int)((size-map[i][j].afferent[1])*height/size),
			(int)((size-map[(i+1>size-1)?size-1:i+1][j].afferent[0])
				*width/size), 
			(int)((size-map[(i+1>size-1)?size-1:i+1][j].afferent[1])
				*height/size)
		      );
	    g.drawLine(	
			(int)((size-map[i][j].afferent[0])*width/size), 
			(int)((size-map[i][j].afferent[1])*height/size),
			(int)((size-map[i][(j+1>size-1)?size-1:j+1].afferent[0])
				*width/size), 
			(int)((size-map[i][(j+1>size-1)?size-1:j+1].afferent[1])				*height/size)
		      );
	  }
	}	  
 }
}


