// This is the MEX wrapper for kmeans. 
// 
// Invocation form within Matlab:
// [idx innerproducts] = hkmeans_predict(data, k, hyperplanes)
// 
// input arguments: 
// 		data: the d by n sparse matrix. 
// 		k	: number of clusters.
// 		hyperplanes: dense (d+1) by (2^k-1) matrix.
//
// output arguments: 
// 		idx: cluster membership (n*1)
// 		innerproducts (k*n)

#include "mex.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <vector>

#define MAXDIS 1000000000000000.0
using namespace std;

void exit_with_help()
{
	mexPrintf(
 "[idx centers] = hkmeans_predict(data, k, maxiter)\n"
 "input arguments: \n"
 "      data   : the d by n sparse matrix. \n"
 "      k      : number of clusters. \n"
 "      hyperplanes: dense (d+1) by (2^k-1) matrix. \n"
 "output arguments: \n"
 "      idx: cluster membership (n)\n"
 "      innerproducts (k*n)\n"
	);
}

static void fake_answer(mxArray *plhs[])
{
	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
}


void run_hcluster_predict(int n, int d, mwIndex *ir, mwIndex *jc, double *values, int k, int *idx, double *hyperplanes, double *innerproducts)
{
	int ncluster = pow(2,k);
	vector<double> hyperplane_norm(ncluster-1,0);

	for ( int i=0 ; i<ncluster-1 ; i++ )
		for ( int j=0, gg=i*(d+1) ; j<d ; j++, gg++ )
			hyperplane_norm[i] += hyperplanes[gg]*hyperplanes[gg];

	for ( int i=0 ; i<n ; i++ )
	{
		int nowg = 0;
		double *nowinner = &(innerproducts[i*k]);
		int prev_ncluster = 1;
		double mysq = 0;
		for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
			mysq += values[ptr]*values[ptr];
		for ( int level = 0 ; level < k ; level ++ )
		{
			int nownn = prev_ncluster-1+nowg;
			if ( hyperplane_norm[nownn] == 0 )
			{
				nowinner[level] = 0;
				nowg = nowg<<1;
				prev_ncluster = prev_ncluster <<1;
				continue;
			}
			double *now_hyper = &(hyperplanes[(prev_ncluster-1+nowg)*(d+1)]);
			double v = 0.0;
			for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
				v += values[ptr]*now_hyper[ir[ptr]];
			nowinner[level] = hyperplane_norm[nownn] + mysq - 2*v;
			if ( v > now_hyper[d] )
				nowg = nowg <<1 | 1;
			else
				nowg = nowg <<1;
			prev_ncluster  = prev_ncluster <<1;
		}

		idx[i] = nowg;
	}
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	double tol = 1e-10;
	srand(0);
    if (nrhs !=3 ) {
		exit_with_help();
		fake_answer(plhs);
    } 
	else
	{
		mwIndex *ir, *jc;
		int d = (int)mxGetM(prhs[0]);
		int n = (int)mxGetN(prhs[0]);
		double *values = mxGetPr(prhs[0]);
		long nnz = (long)mxGetNzmax(prhs[0]);
		ir = mxGetIr(prhs[0]);
		jc = mxGetJc(prhs[0]);

		int k = (int)mxGetScalar(prhs[1]);
		int ncluster = pow(2,k);

		int dd = (int)mxGetM(prhs[2]);
		int kk = (int)mxGetN(prhs[2]);
		if ( (dd != d+1) || (kk != ncluster-1)) 
		{
			printf("dimension wrong dd: %d d: %d,,,,,   kk %d ncluster-1 %d!!\n", dd, d+1, kk, ncluster-1);
			exit_with_help();
			fake_answer(plhs);
		}
		double *hyperplanes = mxGetPr(prhs[2]);

		plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
		plhs[1] = mxCreateDoubleMatrix(k, n, mxREAL);
		double  *idx_out = mxGetPr(plhs[0]);
		double *innerproducts = mxGetPr(plhs[1]);
		
		int *idx = (int *)malloc(sizeof(int)*n);
		run_hcluster_predict(n, d, ir, jc, values, k, idx, hyperplanes, innerproducts);

		for ( int i=0 ; i<n ; i++ )
			idx_out[i] = idx[i]+1;
		free(idx);
	}
}
