/*
 * Decompiled with CFR 0.152.
 */
package org.powertac.samplebroker;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.log4j.Logger;
import org.powertac.common.CustomerInfo;
import org.powertac.common.TariffSpecification;
import org.powertac.common.enumerations.PowerType;
import org.powertac.samplebroker.BrokerUtils;
import org.powertac.samplebroker.ConfiguratorFactoryService;
import org.powertac.samplebroker.interfaces.BrokerContext;
import org.powertac.samplebroker.interfaces.CustomerPredictionManager;
import org.powertac.samplebroker.interfaces.Initializable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class CustomerPredictionManagerService
implements CustomerPredictionManager,
Initializable {
    private static Logger log = Logger.getLogger(CustomerPredictionManagerService.class);
    private static final double SQUEEZE = 0.8;
    private static final double OFFSET = 0.1;
    @Autowired
    private ConfiguratorFactoryService configuratorFactoryService;

    @Override
    public void initialize(BrokerContext broker) {
    }

    @Override
    public HashMap<TariffSpecification, HashMap<CustomerInfo, Integer>> predictCustomerMigration(TariffSpecification candidateSpec, HashMap<CustomerInfo, HashMap<TariffSpecification, Double>> customer2tariffEvaluations, HashMap<TariffSpecification, HashMap<CustomerInfo, Integer>> tariff2customerSubscriptions, List<TariffSpecification> competingTariffs) {
        HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>> customer2tariffSubscriptions = BrokerUtils.revertKeyMapping(tariff2customerSubscriptions);
        HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>> predictedSubscriptions = this.initializePredictedFromCurrentSubscriptions(customer2tariffSubscriptions);
        for (CustomerInfo customer : customer2tariffSubscriptions.keySet()) {
            int customerPopulation = customer.getPopulation();
            HashMap<TariffSpecification, Integer> myTariff2subscriptions = customer2tariffSubscriptions.get(customer);
            HashMap<TariffSpecification, Double> allTariff2Evaluations = customer2tariffEvaluations.get(customer);
            TreeMap<Double, Integer> e2n = this.initializeEvaluations2NumSubscribed(customer, customerPopulation, myTariff2subscriptions, competingTariffs, allTariff2Evaluations);
            Double candidateEvaluation = allTariff2Evaluations.get(candidateSpec);
            this.updatePredictionWithCandidateSpec(predictedSubscriptions, candidateSpec, candidateEvaluation, e2n, customer, customerPopulation);
        }
        return BrokerUtils.revertKeyMapping(predictedSubscriptions);
    }

    TreeMap<Double, Integer> initializeEvaluations2NumSubscribed(CustomerInfo customer, int customerPopulation, HashMap<TariffSpecification, Integer> myTariff2subscriptions, List<TariffSpecification> competingTariffs, HashMap<TariffSpecification, Double> allTariff2Evaluations) {
        TreeMap<Double, Integer> e2n = new TreeMap<Double, Integer>();
        int totalSubscriptions = 0;
        for (TariffSpecification spec : myTariff2subscriptions.keySet()) {
            int numSubscriptions = myTariff2subscriptions.get(spec);
            double evaluation = allTariff2Evaluations.get(spec);
            if (numSubscriptions <= 0) continue;
            e2n.put(evaluation, numSubscriptions);
            totalSubscriptions += numSubscriptions;
        }
        log.debug("assuming customers I don't have are with the best competing tariff");
        int numNonSubscribed = customerPopulation - totalSubscriptions;
        if (numNonSubscribed > 0) {
            double evaluation = this.findPreferedCompetingTariff(customer.getPowerType().getGenericType(), competingTariffs, allTariff2Evaluations);
            e2n.put(evaluation, numNonSubscribed);
        }
        return e2n;
    }

    void updatePredictionWithCandidateSpec(HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>> predictedSubscriptions, TariffSpecification candidateSpec, double candidateEvaluation, TreeMap<Double, Integer> e2n, CustomerInfo customer, int customerPopulation) {
        int hypotheticalNumSubscriptions = this.predictNumSubscriptions(candidateEvaluation, e2n);
        predictedSubscriptions.get(customer).put(candidateSpec, hypotheticalNumSubscriptions);
        double normalizeConst = (double)customerPopulation / (double)(customerPopulation + hypotheticalNumSubscriptions);
        this.normalizeSubscriptions(predictedSubscriptions.get(customer), normalizeConst);
    }

    private int predictNumSubscriptions(double candidateEval, TreeMap<Double, Integer> e2n) {
        if (e2n.size() > 2 && this.configuratorFactoryService.isUseLWR()) {
            return this.predictWithLWR(candidateEval, e2n);
        }
        return this.interpolateOrNN(candidateEval, e2n);
    }

    private int predictWithLWR(double candidateEval, TreeMap<Double, Integer> e2n) {
        double min = e2n.firstKey();
        double max = e2n.lastKey();
        ArrayRealVector xVec = this.createNormalizedXVector(e2n.keySet(), min, max);
        ArrayRealVector yVec = this.createYVector(e2n.values());
        double bestTau = Double.MAX_VALUE;
        double bestMSE = Double.MAX_VALUE;
        ArrayList<Double> candidateTaus = new ArrayList<Double>();
        candidateTaus.add(0.05);
        candidateTaus.add(0.1);
        candidateTaus.add(0.2);
        candidateTaus.add(0.3);
        candidateTaus.add(0.4);
        candidateTaus.add(0.5);
        candidateTaus.add(0.6);
        candidateTaus.add(0.7);
        candidateTaus.add(0.8);
        candidateTaus.add(0.9);
        candidateTaus.add(1.0);
        for (Double tau : candidateTaus) {
            Double mse = this.CrossValidationError(tau, xVec, yVec);
            if (null == mse) {
                log.info(" cp falling back to interpolateOrNN()");
                return this.interpolateOrNN(candidateEval, e2n);
            }
            if (!(mse < bestMSE)) continue;
            bestMSE = mse;
            bestTau = tau;
        }
        log.info(" cp LWR bestTau " + bestTau);
        double x0 = candidateEval;
        Double prediction = this.LWRPredict(xVec, yVec, this.normalizeX(x0, min, max), bestTau);
        if (null == prediction) {
            log.error("LWR passed CV but cannot predict on new point. falling back to interpolateOrNN()");
            return this.interpolateOrNN(candidateEval, e2n);
        }
        return Math.max(0, (int)prediction.doubleValue());
    }

    private Double CrossValidationError(Double tau, ArrayRealVector x, ArrayRealVector y) {
        int n = x.getDimension();
        double totalError = 0.0;
        for (int i = 0; i < n; ++i) {
            ArrayRealVector Ycv;
            double x_i = x.getEntry(i);
            double y_i = y.getEntry(i);
            ArrayRealVector Xcv = new ArrayRealVector((ArrayRealVector)x.getSubVector(0, i), x.getSubVector(i + 1, n - (i + 1)));
            Double y_predicted = this.LWRPredict(Xcv, Ycv = new ArrayRealVector((ArrayRealVector)y.getSubVector(0, i), y.getSubVector(i + 1, n - (i + 1))), x_i, tau);
            if (null == y_predicted) {
                log.error(" cp LWR cannot predict - returning NULL");
                return null;
            }
            double predictionError = y_predicted - y_i;
            totalError += predictionError * predictionError;
        }
        return totalError;
    }

    private Double LWRPredict(ArrayRealVector X, ArrayRealVector Y, double x0, final double tau) {
        UnivariateFunction expTau;
        ArrayRealVector X0 = new ArrayRealVector(X.getDimension(), x0);
        ArrayRealVector delta = X.subtract(X0);
        ArrayRealVector sqDists = delta.ebeMultiply(delta);
        ArrayRealVector W = sqDists.map(expTau = new UnivariateFunction(){

            @Override
            public double value(double arg0) {
                return Math.pow(Math.E, -arg0 / (2.0 * tau));
            }
        });
        double Xt_W_X = X.dotProduct(W.ebeMultiply(X));
        if (Xt_W_X == 0.0) {
            log.error(" cp LWR cannot predict - 0 denominator returning NULL");
            log.error("Xcv is " + X.toString());
            log.error("Ycv is " + Y.toString());
            log.error("x0 is " + x0);
            return null;
        }
        double theta = 1.0 / Xt_W_X * X.ebeMultiply(W).dotProduct(Y);
        return theta * x0;
    }

    private int interpolateOrNN(double candidateEval, TreeMap<Double, Integer> e2n) {
        log.debug("We interpolate/extrapolate to predictNumSubscriptions() - improve? LWR?");
        Map.Entry<Double, Integer> highNeighbor = e2n.ceilingEntry(candidateEval);
        Map.Entry<Double, Integer> lowNeighbor = e2n.floorEntry(candidateEval);
        if (null == highNeighbor && null == lowNeighbor) {
            log.error("predictNumSubscriptions() no entries in evaluation map");
            return 0;
        }
        if (null == highNeighbor) {
            return lowNeighbor.getValue() + 1;
        }
        if (null == lowNeighbor) {
            return Math.max(0, highNeighbor.getValue() - 1);
        }
        double x1 = lowNeighbor.getKey();
        double x2 = highNeighbor.getKey();
        double y1 = lowNeighbor.getValue().intValue();
        double y2 = highNeighbor.getValue().intValue();
        double predicted = y1 + (y2 - y1) * (candidateEval - x1) / (x2 - x1);
        return (int)predicted;
    }

    private void normalizeSubscriptions(HashMap<TariffSpecification, Integer> subscriptions, double normalizeConst) {
        TariffSpecification spec;
        double accumulatedFractional = 0.0;
        TreeMap<Integer, TariffSpecification> subscriptions2spec = new TreeMap<Integer, TariffSpecification>();
        for (Map.Entry<TariffSpecification, Integer> spec2subs : subscriptions.entrySet()) {
            spec = spec2subs.getKey();
            Integer numSubscriptions = spec2subs.getValue();
            subscriptions2spec.put(numSubscriptions, spec);
        }
        for (Integer oldNumSubscriptions : subscriptions2spec.keySet()) {
            spec = (TariffSpecification)subscriptions2spec.get(oldNumSubscriptions);
            double newNumFractional = (double)oldNumSubscriptions.intValue() * normalizeConst;
            accumulatedFractional += newNumFractional - (double)((int)newNumFractional);
            int newSubscriptions = (int)newNumFractional;
            if (accumulatedFractional > 1.0) {
                ++newSubscriptions;
                accumulatedFractional -= 1.0;
            }
            subscriptions.put(spec, newSubscriptions);
        }
    }

    private double findPreferedCompetingTariff(PowerType genericPowerType, List<TariffSpecification> competingTariffs, HashMap<TariffSpecification, Double> tariff2evaluation) {
        log.debug("currently comparing competing tariffs based on generic powertypes");
        double bestEval = -1.7976931348623157E308;
        for (TariffSpecification spec : competingTariffs) {
            double currentEval;
            if (spec.getPowerType().getGenericType() != genericPowerType || !((currentEval = tariff2evaluation.get(spec).doubleValue()) > bestEval)) continue;
            bestEval = currentEval;
        }
        return bestEval;
    }

    private HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>> initializePredictedFromCurrentSubscriptions(HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>> customer2myTariff2subscriptions) {
        HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>> predicted = new HashMap<CustomerInfo, HashMap<TariffSpecification, Integer>>();
        for (CustomerInfo customer : customer2myTariff2subscriptions.keySet()) {
            predicted.put(customer, new HashMap(customer2myTariff2subscriptions.get(customer)));
        }
        return predicted;
    }

    private ArrayRealVector createNormalizedXVector(Set<Double> xValues, double min, double max) {
        Double[] dummy1 = new Double[1];
        ArrayRealVector xVector = new ArrayRealVector(xValues.toArray(dummy1));
        xVector.mapSubtractToSelf(min);
        xVector.mapDivideToSelf(max - min);
        xVector.mapMultiplyToSelf(0.8);
        xVector.mapAddToSelf(0.1);
        return xVector;
    }

    private ArrayRealVector createYVector(Collection<Integer> yValues) {
        double[] doubleArray = new double[yValues.size()];
        int i = 0;
        Iterator<Integer> it = yValues.iterator();
        while (it.hasNext()) {
            doubleArray[i] = it.next().intValue();
            ++i;
        }
        return new ArrayRealVector(doubleArray);
    }

    private double normalizeX(double xVal, double min, double max) {
        return (xVal - min) / (max - min) * 0.8 + 0.1;
    }
}

