// This is the MEX wrapper for kmeans. 
// 
// Invocation form within Matlab:
// [idx centers maps] = hkmeans(data, k, maxiter, ratio)
// 
// input arguments: 
// 		data: the d by n sparse matrix. 
// 		k	: number of clusters. 
// 		maxiter: number of iterations. 
// 		ratio: stop when #samples < n*ratio
//
// output arguments: 
// 		idx: cluster membership
// 		centers: centers
//       maps: 2^k-1 dim array. 

#include "mex.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <vector>

#define MAXDIS 1000000000000000.0
using namespace std;

void exit_with_help()
{
	mexPrintf(
 "[idx centers maps] = hkmeans(data, k, maxiter)\n"
 "input arguments: \n"
 "      data   : the d by n sparse matrix. \n"
 "      k      : number of clusters. \n"
 "      maxiter: number of iterations. \n"
 " 		ratio: stop when #samples < n*ratio. \n"
 "output arguments: \n"
 "      idx: cluster membership (n)\n"
 "      hyperplanes: centers (d+1)*(2^k-1)\n"
 "      maps: 2^k-1 dim array. \n"
	);
}

static void fake_answer(mxArray *plhs[])
{
	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
}


void kmeans(int subn, int d, mwIndex *ir, mwIndex *jc,  double *values, int *idlist, int k, int *subidx, double *centers, int maxiter)
{
	double tol = 1e-4;
	vector<double> xnorm(subn, 0);
	for ( int ii=0 ; ii<subn ; ii++ )
	{
		int i = idlist[ii];
		for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
			xnorm[ii] += values[ptr]*values[ptr];
	}

	for ( int i=0 ; i<subn ; i++ )
		subidx[i] = rand()%k;

	for ( int iter = 0 ; iter<maxiter ; iter++ )
	{
		double change = 0;
		// Compute centers
		for ( int i=0 ; i<k*d ; i++ )
			centers[i] = 0;
		vector<double> count(k, 0);
		for ( int ii=0 ; ii<subn ; ii++ )
		{
			double *nowcenter = &(centers[subidx[ii]*d]);
			int i = idlist[ii];
			for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
				nowcenter[ir[ptr]] += values[ptr];
			count[subidx[ii]] += 1;
		}
		for ( int i=0 ; i<k ; i++ )
		{
			double *nowcenter = &(centers[i*d]);
			for ( int j=0 ; j<d ; j++)
				nowcenter[j]/=count[i];
		}

		// Compute idx
		vector<double> center_norm(k,0);
		for ( int i=0 ; i<k ; i++ )
		{
			double *nowcenter = &(centers[i*d]);
			for (int j=0 ; j<d ; j++ )
				center_norm[i] += nowcenter[j]*nowcenter[j];
		} 
		double loss = 0;
		for ( int ii=0 ; ii<subn ; ii++ )
		{
			int i = idlist[ii];
			double dis = MAXDIS;
			int minidx = -1;
			for ( int j=0 ; j<k ; j++ )
			{
				double *nowcenter = &(centers[j*d]);
				double nowdis = 0;
				for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
					nowdis += values[ptr]*nowcenter[ir[ptr]];
				nowdis = xnorm[ii] - 2*nowdis + center_norm[j];
				if ( nowdis < dis)
				{
					dis = nowdis;
					minidx = j;
				}
			}
			if ( subidx[ii] != minidx )
			{
				subidx[ii] = minidx;
				change += 1;
			}
			loss += dis;
		}

		vector<int> center_count(k,0);
		for ( int ii=0 ; ii<subn ; ii++ )
			center_count[subidx[ii]] += 1;
		for ( int i=0 ; i<k ; i++ )
			if ( center_count[i] == 0)
			{
				while(1)
				{
					int ii = rand()%subn;
					if ( center_count[subidx[ii]] > 1)
					{
						center_count[subidx[ii]]--;
						center_count[i]++;
						subidx[ii] = i;
						break;
					}
				}
			}


//		if (change < subn*tol)
//			break;
	}
/*
	// Finalize, compute centers
	for ( int i=0 ; i<k*d ; i++ )
		centers[i] = 0;
	vector<double> count(k, 0);
	for ( int ii=0 ; ii<subn ; ii++ )
	{
		double *nowcenter = &(centers[subidx[ii]*d]);
		int i = idlist[ii];
		for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
			nowcenter[ir[ptr]] += values[ptr];
		count[subidx[ii]] += 1;
	}
	for ( int i=0 ; i<k ; i++ )
	{
		double *nowcenter = &(centers[i*d]);
		for ( int j=0 ; j<d ; j++)
			nowcenter[j]/=count[i];
	}
	*/
}


void run_hcluster(int n, int d, mwIndex *ir, mwIndex *jc, double *values, int k, int *idx, double *hyperplanes, int *num_map, int maxiter, double ratio)
{
	vector<int> idx_now(n,0), idx_next(n,0);
	int *idlist = (int *)malloc(sizeof(int)*n);
	double *tmpcenters = (double *)malloc(sizeof(double)*2*d);
	int *subidx = (int *)malloc(sizeof(int)*n);
	int ncluster = pow(2,k);
	vector<int> now_num(ncluster, 0), next_num(ncluster,0);
	for (int level =0 ; level < k ; level++ )
	{
		int prev_ncluster = pow(2, level);
		for ( int nowcluster = 0; nowcluster < prev_ncluster ; nowcluster++ )
		{
			int subn = 0;
			for ( int i=0 ; i<n ; i++ )
				if ( idx_now[i] == nowcluster)
				{
					idlist[subn] = i;
					subn++;
				}

			if ( subn < n*ratio )
			{
				for ( int ii=0 ; ii<subn ; ii++ )
					idx_next[idlist[ii]] = nowcluster*2;
				int nowg = pow(2, level)-1+nowcluster;
				double *nowhyper = &(hyperplanes[nowg*(d+1)]);
				for ( int i=0 ; i<d+1 ; i++ )
					nowhyper[i] = 0;

				for ( int i=0 ; i<2 ; i++ )
					next_num[nowcluster*2+i] = now_num[nowcluster];

				continue;
			}
			kmeans(subn, d, ir, jc, values, idlist, 2, subidx, tmpcenters, maxiter);
	
			for ( int i= 0; i<subn ; i++ )
				idx_next[idlist[i]] = nowcluster*2+subidx[i];
	
			// hyperlplanes: size (d+1)*(2^k-1)
			int nowg = pow(2, level)-1+nowcluster;
			double *nowhyper = &(hyperplanes[nowg*(d+1)]);
			for ( int i=0 ; i<d ; i++ )
				nowhyper[i] = tmpcenters[d+i] - tmpcenters[i];
			double num0 = 0.0;
			for ( int i=0 ; i<d ; i++ )
				num0 += nowhyper[i]*tmpcenters[i];
			double num1 = 0.0;
			for ( int i=0 ; i<d ; i++ )
				num1 += nowhyper[i]*tmpcenters[i+d];
			if ( num0 > num1 )
			{
				for ( int i=0 ; i<d ; i++ )
					nowhyper[i] *=(-1);
			}
			nowhyper[d] = (num0 + num1)/2; 

			for ( int i=0 ; i<2 ; i++ )
				next_num[nowcluster*2+i] = now_num[nowcluster]+1;
		}
		for ( int i=0 ; i<n ; i++ )
			idx_now[i] = idx_next[i];
		for ( int i=0 ; i<ncluster ; i++ )
			now_num[i] = next_num[i];
	}

	for ( int i=0 ; i<n ; i++ )
		idx[i] = idx_next[i];
	for ( int i=0 ; i<ncluster ; i++ )
		num_map[i] = now_num[i];
	free(idlist);
	free(tmpcenters);
	free(subidx);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	double tol = 1e-10;
	srand(0);
    if (nrhs !=4 ) {
		exit_with_help();
		fake_answer(plhs);
    } 
	else
	{
		mwIndex *ir, *jc;
		int d = (int)mxGetM(prhs[0]);
		int n = (int)mxGetN(prhs[0]);
		double *values = mxGetPr(prhs[0]);
		long nnz = (long)mxGetNzmax(prhs[0]);
		ir = mxGetIr(prhs[0]);
		jc = mxGetJc(prhs[0]);

		int k = (int)mxGetScalar(prhs[1]);
		int maxiter = (int)mxGetScalar(prhs[2]);
		
		int ncluster = pow(2,k);
		plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
		plhs[1] = mxCreateDoubleMatrix(d+1, ncluster-1, mxREAL);
		plhs[2] = mxCreateDoubleMatrix(ncluster, 1, mxREAL);
		double  *idx_out = mxGetPr(plhs[0]);
		double *hyperplanes = mxGetPr(plhs[1]);
		double *num_map = mxGetPr(plhs[2]);

		double ratio = (double)mxGetScalar(prhs[3]);
		
		int *idx = (int *)malloc(sizeof(int)*n);
		int *tmp_num_map = (int *)malloc(sizeof(int)*ncluster);
		run_hcluster(n, d, ir, jc, values, k, idx, hyperplanes, tmp_num_map, maxiter, ratio);

		for ( int i=0 ; i<n ; i++ )
			idx_out[i] = idx[i]+1;
		for ( int i=0 ; i<ncluster ; i++ )
			num_map[i] = tmp_num_map[i];
		free(idx);
		free(tmp_num_map);
	}
}
