// This is the MEX wrapper for kmeans. 
// 
// Invocation form within Matlab:
// [idx] = hkmeans_predict_dense(data, k, hyperplanes)
// 
// input arguments: 
// 		data: the d by n sparse matrix. 
// 		k	: number of clusters.
// 		hyperplanes: dense (d+1) by (2^k-1) matrix.
//
// output arguments: 
// 		idx: cluster membership (n*1)

#include "mex.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <vector>
#include <time.h>
#include <sys/time.h>
#include <stdlib.h>
#include <emmintrin.h>
#include <memory.h>

#define MIE_ALIGN(x) __attribute__((aligned(x)))
static inline void *_aligned_malloc(size_t size, size_t alignment)
    {
        void *p;
        int ret = posix_memalign(&p, alignment, size);
        return (ret == 0) ? p : 0;
    }
#define CONST_128D(var, val) \
    MIE_ALIGN(16) static const double var[2] = {(val), (val)}
#define CONST_128I(var, v1, v2, v3, v4) \
    MIE_ALIGN(16) static const int var[4] = {(v1), (v2), (v3), (v4)}

static const double MAXLOG =  7.08396418532264106224E2;     /* log 2**1022 */
static const double MINLOG = -7.08396418532264106224E2;     /* log 2**-1022 */
static const double LOG2E  =  1.4426950408889634073599;     /* 1/log(2) */
//static const double INFINITY = 1.79769313486231570815E308;
static const double C1 = 6.93145751953125E-1;
static const double C2 = 1.42860682030941723212E-6;


#define MAXDIS 1000000000000000.0
using namespace std;



void remez9_0_log2_sse(double *values, int num)
{
    int i;
    CONST_128D(one, 1.);
    CONST_128D(log2e, 1.4426950408889634073599);
    CONST_128D(maxlog, 7.09782712893383996843e2);   // log(2**1024)
    CONST_128D(minlog, -7.08396418532264106224e2);  // log(2**-1022)
    CONST_128D(c1, 6.93145751953125E-1);
    CONST_128D(c2, 1.42860682030941723212E-6);
    CONST_128D(w9, 3.9099787920346160288874633639268318097077213911751e-6);
    CONST_128D(w8, 2.299608440919942766555719515783308016700833740918e-5);
    CONST_128D(w7, 1.99930498409474044486498978862963995247838069436646e-4);
    CONST_128D(w6, 1.38812674551586429265054343505879910146775323730237e-3);
    CONST_128D(w5, 8.3335688409829575034112982839739473866857586300664e-3);
    CONST_128D(w4, 4.1666622504201078708502686068113075402683415962893e-2);
    CONST_128D(w3, 0.166666671414320541875332123507829990378055646330574);
    CONST_128D(w2, 0.49999999974109940909767965915362308135415179642286);
    CONST_128D(w1, 1.0000000000054730504284163017295863259125942049362);
    CONST_128D(w0, 0.99999999999998091336479463057053516986466888462081);
    const __m128i offset = _mm_setr_epi32(1023, 1023, 0, 0);

    for (i = 0;i < num;i += 4) {
        __m128i k1, k2;
        __m128d p1, p2;
        __m128d a1, a2;
        __m128d xmm0, xmm1;
        __m128d x1, x2;

        /* Load four double values. */
        xmm0 = _mm_load_pd(maxlog);
        xmm1 = _mm_load_pd(minlog);
        x1 = _mm_load_pd(values+i);
        x2 = _mm_load_pd(values+i+2);
        x1 = _mm_min_pd(x1, xmm0);
        x2 = _mm_min_pd(x2, xmm0);
        x1 = _mm_max_pd(x1, xmm1);
        x2 = _mm_max_pd(x2, xmm1);

        /* a = x / log2; */
        xmm0 = _mm_load_pd(log2e);
        xmm1 = _mm_setzero_pd();
        a1 = _mm_mul_pd(x1, xmm0);
        a2 = _mm_mul_pd(x2, xmm0);

        /* k = (int)floor(a); p = (float)k; */
        p1 = _mm_cmplt_pd(a1, xmm1);
        p2 = _mm_cmplt_pd(a2, xmm1);
        xmm0 = _mm_load_pd(one);
        p1 = _mm_and_pd(p1, xmm0);
        p2 = _mm_and_pd(p2, xmm0);
        a1 = _mm_sub_pd(a1, p1);
        a2 = _mm_sub_pd(a2, p2);
        k1 = _mm_cvttpd_epi32(a1);
        k2 = _mm_cvttpd_epi32(a2);
        p1 = _mm_cvtepi32_pd(k1);
        p2 = _mm_cvtepi32_pd(k2);

        /* x -= p * log2; */
        xmm0 = _mm_load_pd(c1);
        xmm1 = _mm_load_pd(c2);
        a1 = _mm_mul_pd(p1, xmm0);
        a2 = _mm_mul_pd(p2, xmm0);
        x1 = _mm_sub_pd(x1, a1);
        x2 = _mm_sub_pd(x2, a2);
        a1 = _mm_mul_pd(p1, xmm1);
        a2 = _mm_mul_pd(p2, xmm1);
        x1 = _mm_sub_pd(x1, a1);
        x2 = _mm_sub_pd(x2, a2);

        /* Compute e^x using a polynomial approximation. */
        xmm0 = _mm_load_pd(w9);
        xmm1 = _mm_load_pd(w8);
        a1 = _mm_mul_pd(x1, xmm0);
        a2 = _mm_mul_pd(x2, xmm0);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w7);
        xmm1 = _mm_load_pd(w6);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w5);
        xmm1 = _mm_load_pd(w4);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w3);
        xmm1 = _mm_load_pd(w2);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        xmm0 = _mm_load_pd(w1);
        xmm1 = _mm_load_pd(w0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm0);
        a2 = _mm_add_pd(a2, xmm0);
        a1 = _mm_mul_pd(a1, x1);
        a2 = _mm_mul_pd(a2, x2);
        a1 = _mm_add_pd(a1, xmm1);
        a2 = _mm_add_pd(a2, xmm1);

        /* p = 2^k; */
        k1 = _mm_add_epi32(k1, offset);
        k2 = _mm_add_epi32(k2, offset);
        k1 = _mm_slli_epi32(k1, 20);
        k2 = _mm_slli_epi32(k2, 20);
        k1 = _mm_shuffle_epi32(k1, _MM_SHUFFLE(1,3,0,2));
        k2 = _mm_shuffle_epi32(k2, _MM_SHUFFLE(1,3,0,2));
        p1 = _mm_castsi128_pd(k1);
        p2 = _mm_castsi128_pd(k2);

        /* a *= 2^k. */
        a1 = _mm_mul_pd(a1, p1);
        a2 = _mm_mul_pd(a2, p2);

        /* Store the results. */
        _mm_store_pd(values+i, a1);
        _mm_store_pd(values+i+2, a2);
    }
}


extern "C" {
void dgemm_(char *transa, char *transb, ptrdiff_t *m, ptrdiff_t *n, ptrdiff_t *k, double *alpha,
		double *A, ptrdiff_t *lda, double *B, ptrdiff_t *ldb, double *beta, double *C, ptrdiff_t *ldc);
}

class Timer {
     timeval timer[2];
	   public:
		        timeval StartTimer(void) {
			             gettimeofday(&this->timer[0], NULL);
				          return this->timer[0];
				       }
	        timeval StopTimer(void) {
		             gettimeofday(&this->timer[1], NULL);
			          return this->timer[1];
			       }
	         double ElapsedTime(void) const {
			          double secs(this->timer[1].tv_sec - this->timer[0].tv_sec);
       double usecs(this->timer[1].tv_usec - this->timer[0].tv_usec);
	             if(usecs < 0) {
			              --secs;
			               usecs += 1000000;
			            }
		          return (secs * 1000 + usecs / 1000.0);
		       }
			   };



void exit_with_help()
{
	mexPrintf(
 "[idx centers] = hkmeans_predict(data, k, maxiter)\n"
 "input arguments: \n"
 "      data   : the d by n sparse matrix. \n"
 "      k      : number of clusters. \n"
 "      hyperplanes: dense (d+1) by (2^k-1) matrix. \n"
 "output arguments: \n"
 "      idx: cluster membership (n)\n"
 "      innerproducts (k*n)\n"
	);
}

static void fake_answer(mxArray *plhs[])
{
	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
}


void run_cluster_predict(int n, int d, double *values, int ncluster, int *idx, double *centers)
{
	double alpha=1, beta=0;
	ptrdiff_t mmm = ncluster;
	ptrdiff_t ddd = d;
	ptrdiff_t nnn = n;
	double *result = (double *)malloc(sizeof(double)*n*ncluster);
	dgemm_((char *)"T", (char *)"N", &mmm, &nnn, &ddd, &alpha, centers, &ddd, values, &ddd, &beta, result, &mmm);

	vector<double> center_norm(ncluster,0);
	int jj=0;
	for ( int i=0 ; i<ncluster ; i++ )
		for ( int j=0 ; j<d ; j++, jj++ )
			center_norm[i] += centers[jj]*centers[jj];
	jj = 0;
	for (int i=0 ; i<n ; i++ )
	{
		double *nowx =  &(values[i*d]);
		double mindis = 100000000000000000000.0;
		int minidx = -1;

		for ( int j=0 ; j<ncluster ; j++ )
		{
			double nowv = center_norm[j]-2*result[jj];
			if ( nowv < mindis)
			{
				mindis = nowv;
				minidx = j;
			}
			jj++;
		}
		idx[i] = minidx;
	}
	free(result);
}

void run_hcluster_predict(int n, int d, double *values, int k, int *idx, double *hyperplanes)
{
	int ncluster = pow(2,k);
	vector<double> hyperplane_norm(ncluster-1,0);

	for ( int i=0 ; i<ncluster-1 ; i++ )
		for ( int j=0, gg=i*(d+1) ; j<d ; j++, gg++ )
			hyperplane_norm[i] += hyperplanes[gg]*hyperplanes[gg];

	for ( int i=0 ; i<n ; i++ )
	{
		int nowg = 0;
//		double *nowinner = &(innerproducts[i*k]);
		int prev_ncluster = 1;
		double mysq = 0;
		double *nowx = &(values[i*d]); 

		for ( int level = 0 ; level < k ; level ++ )
		{
			int nownn = prev_ncluster-1+nowg;
			if ( hyperplane_norm[nownn] == 0 )
			{
				nowg = nowg<<1;
				prev_ncluster = prev_ncluster <<1;
				continue;
			}
			double *now_hyper = &(hyperplanes[(prev_ncluster-1+nowg)*(d+1)]);
			double v = 0.0;
			for ( int j=0 ; j<d ; j++)
				v+= nowx[j]*now_hyper[j];
			if ( v > now_hyper[d] )
				nowg = nowg <<1 | 1;
			else
				nowg = nowg <<1;
			prev_ncluster  = prev_ncluster <<1;
		}

		idx[i] = nowg;
	}
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	double tol = 1e-10;
	srand(0);
    if (nrhs !=7 ) {
		exit_with_help();
		fake_answer(plhs);
    } 
	else
	{
		mwIndex *ir, *jc;
		int d = (int)mxGetM(prhs[0]);
		int n = (int)mxGetN(prhs[0]);
		double *values = mxGetPr(prhs[0]);

		int k = (int)mxGetScalar(prhs[1]);
		int ncluster = (int)pow(2,k);

		int dd = (int)mxGetM(prhs[2]);
		int kk = (int)mxGetN(prhs[2]);
/*		if ( (dd != d+1) || (kk != ncluster-1)) 
		{
			printf("dimension wrong dd: %d d: %d,,,,,   kk %d ncluster-1 %d!!\n", dd, d+1, kk, ncluster-1);
			exit_with_help();
			fake_answer(plhs);
		}*/
		double *hyperplanes = mxGetPr(prhs[2]);
		int mode = 0;
		if ( dd==(d+1) )
			mode = 0;
		else
			mode = 1;

		double **landmarks = (double **)malloc(sizeof(double *)*ncluster);
		int *mm = (int *)malloc(sizeof(int)*ncluster);
		double **landmarks_norm = (double **)malloc(sizeof(double *)*ncluster);
		for ( int i=0 ; i<ncluster ; i++ )
		{
			mxArray *aa = mxGetCell(prhs[3], i);
			mm[i] = (int)mxGetN(aa);
			if ( mm[i] > 0)
			{

//				printf("i %d mm %d\n", i, mm[i]);
				landmarks_norm[i] = (double *)malloc(sizeof(double)*mm[i]);
				landmarks[i] = mxGetPr(aa);
				for ( int j=0 ; j<mm[i] ; j++)
				{
					double nowsq = 0;
					for ( int p=0 ; p<d ; p++ )
						nowsq += landmarks[i][j*d+p]*landmarks[i][j*d+p];
					landmarks_norm[i][j] = nowsq;
				}
			}
		}

		double **wlist = (double **)malloc(sizeof(double *)*ncluster);
		for ( int i=0 ; i<ncluster ; i++ )
		{
			mxArray *aa = mxGetCell(prhs[4], i);
			if ((int)mxGetN(aa) > 0 )
			{
				wlist[i] = mxGetPr(aa);
			}
		}

		double gamma = (double)mxGetScalar(prhs[5]);

		double **isonelist = (double **)malloc(sizeof(double *)*ncluster);
		for ( int i=0 ; i<ncluster ; i++ )
		{
			mxArray *aa = mxGetCell(prhs[6], i);
			if ((int)mxGetN(aa) > 0 )
			{
				isonelist[i] = mxGetPr(aa);
			}
		}


		plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
		plhs[1] = mxCreateDoubleMatrix(n, 1, mxREAL);
		double  *idx_out = mxGetPr(plhs[0]);
		double *pred_out = mxGetPr(plhs[1]);
		
		int *idx = (int *)malloc(sizeof(int)*n);

		double timebegin = clock();
		Timer tm;
		tm.StartTimer();

		if (mode == 0)
			run_hcluster_predict(n, d, values, k, idx, hyperplanes);
		else
			run_cluster_predict(n, d, values, ncluster, idx, hyperplanes);
//		tm.StopTimer();
//		printf("time %lf\n", tm.ElapsedTime());

//		tm.StartTimer();
		double *tmpx = (double *)malloc(sizeof(double)*d*n);
		int *nowid = (int *)malloc(sizeof(int)*n);
		double *result = (double *)malloc(sizeof(double)*d*n*10);


		int *ni = (int *)malloc(sizeof(int)*(ncluster+1));
		int *numlist = (int *)malloc(sizeof(int)*(n+1));

		for ( int i=0 ; i<ncluster+1 ; i++ )
			ni[i] = 0;

		for ( int i=0 ; i<n ; i++ )
		{
			ni[idx[i]+1]++;
		}


		for ( int i=1 ; i<ncluster+1 ; i++ )
			ni[i] = ni[i] + ni[i-1];

		for ( int i=0 ; i<n ; i++ )
		{
			int nowidx = idx[i];
			numlist[ni[nowidx]] = i;
			ni[nowidx]++;
		}

		for ( int i=ncluster-1 ; i>=1 ; i-- )
			ni[i] = ni[i-1];
		ni[0] =0;

		
		for ( int c=0 ; c<ncluster ; c++)
			if ( ni[c+1]>ni[c])
			{
				if ( isonelist[c][0] != 1000 )	
				{
					for ( int j=ni[c] ; j<ni[c+1] ; j++ )
						pred_out[numlist[j]] = isonelist[c][0];
					continue;
				} 
				if ( (mm[c] == 0) )
				{
					for ( int j=ni[c] ; j<ni[c+1] ; j++ )
						pred_out[numlist[j]] = -1;
					continue;
				}
				int nownum = 0;
				for ( int j=ni[c] ; j<ni[c+1] ; j++ )
				{
					int nowbegin = nownum*d;
					double *nowx =  &(values[numlist[j]*d]);
					for ( int jj=0 ; jj<d ; jj++ )
					{
						tmpx[nowbegin+jj] = nowx[jj];
					}
					nowid[nownum++] = numlist[j];
				}
				
				double *nowl = landmarks[c];
				double alpha=1, beta=0;
				ptrdiff_t mmm = mm[c];
				ptrdiff_t ddd = d;
				ptrdiff_t nnn = nownum;
				double *noww = wlist[c];
				dgemm_((char *)"T", (char *)"N", &mmm, &nnn, &ddd, &alpha, nowl, &ddd, tmpx, &ddd, &beta, result, &mmm);
	
				int jjj=0;
				for (int ii=0 ; ii<nownum ; ii++ )
				{
					int i = nowid[ii];
					double *nowx =  &(values[i*d]);
					double xnorm = 0.0;
					for ( int j=0 ; j<d ; j++ )
						xnorm += nowx[j]*nowx[j];
	
					for ( int j=0 ; j<mmm ; j++ )
					{
						result[jjj] = gamma*(2*result[jjj]-xnorm-landmarks_norm[c][j]);
						jjj++;
					}
				}
	
				remez9_0_log2_sse(result, nownum*mmm);
				int jj=0;
				for (int ii=0 ; ii<nownum ; ii++ )
				{
					int i = nowid[ii];
					double pred=0;
					for ( int j=0 ; j<mmm ; j++ )
					{
						pred += result[jj++]*noww[j];
					}
					pred_out[i] = pred;
					if ( pred > 0)
						pred_out[i] = 1;
					else 
						pred_out[i] = -1;
				}
	
			}

		tm.StopTimer();
		double dc_pred_time = tm.ElapsedTime();
//		printf("time %lf\n", tm.ElapsedTime());
		printf("time %lf\n", dc_pred_time);
	
		plhs[2] = mxCreateDoubleScalar(dc_pred_time);
		for ( int i=0 ; i<n ; i++ )
			idx_out[i] = idx[i]+1;
		free(idx);
		free(tmpx);
		free(nowid);
		free(result);
		free(ni);
		free(numlist);
	}
}
