// 
// Invocation form within Matlab:
// [labels dec_values] = comp_predict(X, gamma, landmarks, innerproducts, quad_list, w)
// 
// input arguments: 
// 		X: data (d by n sparse matrix) 
// 	    gamma: gamma for Gaussian kernel 
// 		landmarks: landmark points (d by m1 dense matrix)
// 		innerproducts: precomputed values (m2 by n dense matrix)
// 		quad_list: quadratic expansion terms (2 by m3 dense matrix)
// 		w: weights of linear model (m3 by 1 matrix)
//
// output arguments: 
// 		labels: output labels (n by 1 matrix, +1 or -1)
// 		dec_values: (n by 1 matrix)

#include "mex.h"
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <vector>

#define MAXDIS 1000000000000000.0
using namespace std;

void exit_with_help()
{
	mexPrintf(
 "[labels dec_values] = comp_predict(X, gamma, landmarks, innerproducts, quad_list, w)\n"
 "input arguments: \n"
 "    X: data (d by n sparse matrix)  \n"
 "    gamma: gamma for Gaussian kernel  \n"
 "	  landmarks: landmark points (d by m1 dense matrix) \n"
 "    innerproducts: precomputed values (m2 by n dense matrix) \n"
 "    quad_list: quadratic expansion terms (2 by m3 dense matrix) \n"
 "    w: weights of linear model (m3 by 1 matrix)\n "
 "output arguments:  \n"
 "    labels: output labels (n by 1 matrix, +1 or -1) \n"
 "    dec_values: (n by 1 matrix) \n"
	);
}

static void fake_answer(mxArray *plhs[])
{
	plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
	plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	srand(0);
    if (nrhs !=6 ) {
		exit_with_help();
		fake_answer(plhs);
    } 
	else
	{
		mwIndex *ir, *jc;
		int d = (int)mxGetM(prhs[0]);
		int n = (int)mxGetN(prhs[0]);
		double *values = mxGetPr(prhs[0]);
		long nnz = (long)mxGetNzmax(prhs[0]);
		ir = mxGetIr(prhs[0]);
		jc = mxGetJc(prhs[0]);

// [labels dec_values] = comp_predict(X, gamma, landmarks, innerproducts, quad_list, w)

		double gamma = (double)mxGetScalar(prhs[1]);

		int m = (int)mxGetN(prhs[2]);
		if ( d != (int)mxGetM(prhs[2]) )
		{
			printf("Parameter landmark dimension 1 wrong %d %d!\n", d, (int)mxGetM(prhs[2]));
			exit_with_help();
			fake_answer(plhs);
			return;
		}
//		int d = (int)mxGetM(prhs[2]);
		double *landmark_values = mxGetPr(prhs[2]);

		int m1 = (int)mxGetM(prhs[3]);
		if ( (n != (int)mxGetN(prhs[3])) && (m1!=0) )
		{
			printf("Paramemter innerproducts size not match n\n");
			exit_with_help();
			fake_answer(plhs);
			return;
		}
		double *innerproduct_values = mxGetPr(prhs[3]);

		int m2 = (int)mxGetN(prhs[4]);
/*		if ( 2 != (int)mxGetM(prhs[4]))
		{
			printf("Quadlist should be 2 * n array!\n");
			exit_with_help();
			fake_answer(plhs);
		}*/
		double *quad_list_values = mxGetPr(prhs[4]);

		double *wvalues = mxGetPr(prhs[5]);

		// Output
		plhs[0] = mxCreateDoubleMatrix(n, 1, mxREAL);
		plhs[1] = mxCreateDoubleMatrix(n, 1, mxREAL);
		double *labels_out = mxGetPr(plhs[0]);
		double *dec_vals_out = mxGetPr(plhs[1]);

		// Begin compute
		vector<double> landmark_sq(m,0.0);
		for ( int i=0 ; i<m ; i++ )
			for ( int j=0, id = i*d ; j<d ; j++, id++ )
				landmark_sq[i] += landmark_values[id]*landmark_values[id];
		vector<double> features(m+m1+m2);
		for ( int i=0 ; i<n ; i++ )
		{
			// Compute kernel values iwth landmarks
			double mysq = 0.0;
			for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
				mysq += values[ptr]*values[ptr];
			for ( int j=0 ; j<m ; j++ )
			{
				double dis = 0.0;
				double *now_landmark = &(landmark_values[j*d]);
				for ( int ptr = jc[i] ; ptr<jc[i+1] ; ptr++ )
					dis += now_landmark[ir[ptr]]*values[ptr];
				dis = mysq + landmark_sq[j] - 2*dis;
				features[j] = exp(-gamma*dis);
			}
			for ( int j=0, id=i*m1 ; j<m1 ; j++, id++ )
			{
				features[j+m] = innerproduct_values[id];
			}
			for ( int j=0, mm1=m+m1 ; j<m2 ; j++ ) 
			{
				features[mm1+j] = features[quad_list_values[2*j]-1]*features[quad_list_values[2*j+1]-1];
			}

			int M = m+m1+m2;
			double dec_val=0.0;
			for ( int j=0 ; j<M ; j++ )
				dec_val += features[j]*wvalues[j];

			dec_vals_out[i] = dec_val;
			labels_out[i] = 1;
			if ( dec_val < 0 )
				labels_out[i] = -1;
		}
	}
}
