/*
 * Decompiled with CFR 0.152.
 */
package machinelearning.networks;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;
import machinelearning.Tools;
import machinelearning.evolution.evolvables.Evolvable;
import machinelearning.networks.Network;
import utopia.Utils;

public class LSTM
implements Network,
Serializable {
    private int time;
    private double[] outputs;
    private double[] inputs;
    private double[] forgetx;
    private double[] forgety;
    private double[] ingatex;
    private double[] ingatey;
    private double[] outgatex;
    private double[] outgatey;
    private double[] cellx;
    private double[] celly;
    private double[] state;
    private double[] prevstate;
    private double[] prevcelly;
    private int nrCells;
    private int nrInputs;
    private int nrOutputs;
    private int nrOfWeights;
    private final double biasValue;
    private int numberOfBias = 1;
    private double[][] win;
    private double[][] derivsin;
    private double[][] wrecin;
    private double[][] derivsrecin;
    private double[] wpeepin;
    private double[] derivspeepin;
    private double[][] wforget;
    private double[][] derivsforget;
    private double[][] wrecforget;
    private double[][] derivsrecforget;
    private double[] wpeepforget;
    private double[] derivspeepforget;
    private double[][] wcell;
    private double[][] derivscell;
    private double[][] wreccell;
    private double[][] derivsreccell;
    private double[][] woutgate;
    private double[][] derivsoutgate;
    private double[][] wrecoutgate;
    private double[][] derivsrecoutgate;
    private double[] wpeepout;
    private double[] derivspeepout;
    private double[][] woutputs;
    private double[][] derivsoutputs;
    protected double initialWeightRange = 0.1;
    protected double mutationMagnitude = this.initialWeightRange / 10.0;
    private final Random random = new Random();

    public static double f(double x) {
        return Tools.sigmoid(x);
    }

    public static double fprime(double x) {
        return Tools.sigmoidprime(x);
    }

    public static double g(double x) {
        return Tools.tanh(x);
    }

    public static double gprime(double x) {
        return Tools.tanhprime(x);
    }

    public LSTM(int numberOfExternalInputs, int numberOfHidden, int numberOfOutputs, double weightRange, double mutation) {
        this(numberOfExternalInputs, numberOfHidden, numberOfOutputs);
        this.initialWeightRange = weightRange;
        this.mutationMagnitude = mutation;
    }

    public LSTM(int numberOfExternalInputs, int numberOfHidden, int numberOfOutputs) {
        this.nrCells = numberOfHidden;
        this.nrInputs = numberOfExternalInputs + this.numberOfBias;
        this.nrOutputs = numberOfOutputs;
        this.biasValue = 1.0;
        this.nrOfWeights = this.nrInputs * this.nrCells * 4 + this.nrCells * this.nrCells * 4 + this.nrCells * this.nrOutputs + this.nrCells * 3;
        this.reset();
        this.win = new double[this.nrCells][this.nrInputs];
        this.wrecin = new double[this.nrCells][this.nrCells];
        this.wpeepin = new double[this.nrCells];
        this.wforget = new double[this.nrCells][this.nrInputs];
        this.wrecforget = new double[this.nrCells][this.nrCells];
        this.wpeepforget = new double[this.nrCells];
        this.wcell = new double[this.nrCells][this.nrInputs];
        this.wreccell = new double[this.nrCells][this.nrCells];
        this.woutgate = new double[this.nrCells][this.nrInputs];
        this.wrecoutgate = new double[this.nrCells][this.nrCells];
        this.wpeepout = new double[this.nrCells];
        this.woutputs = new double[this.nrCells][this.nrOutputs];
    }

    @Override
    public void reset() {
        this.time = 0;
        this.inputs = new double[this.nrInputs];
        this.outputs = new double[this.nrOutputs];
        this.ingatex = new double[this.nrCells];
        this.outgatex = new double[this.nrCells];
        this.forgetx = new double[this.nrCells];
        this.cellx = new double[this.nrCells];
        this.ingatey = new double[this.nrCells];
        this.outgatey = new double[this.nrCells];
        this.forgety = new double[this.nrCells];
        this.celly = new double[this.nrCells];
        this.state = new double[this.nrCells];
        this.prevstate = new double[this.nrCells];
        this.prevcelly = new double[this.nrCells];
    }

    @Override
    public void changeWeights(double[] weightChanges) {
        if (weightChanges.length != this.nrOfWeights) {
            System.out.println("argument of changeWeights has not the same length (" + weightChanges.length + ") as the number of weights (" + this.nrOfWeights + ")");
            throw new IllegalArgumentException();
        }
        double[] newWeights = this.getWeightsArray();
        for (int i = 0; i < newWeights.length; ++i) {
            int n = i;
            newWeights[n] = newWeights[n] + weightChanges[i];
        }
        this.setWeightsArray(newWeights);
    }

    @Override
    public String toString() {
        return Arrays.toString(this.getWeightsArray());
    }

    @Override
    public double[] getWeightsArray() {
        double[] weightsArray = new double[this.nrOfWeights];
        int index = 0;
        index = this.copyFlat(this.win, index, weightsArray);
        index = this.copyFlat(this.wrecin, index, weightsArray);
        index = this.copyFlat(this.wpeepin, index, weightsArray);
        index = this.copyFlat(this.wforget, index, weightsArray);
        index = this.copyFlat(this.wrecforget, index, weightsArray);
        index = this.copyFlat(this.wpeepforget, index, weightsArray);
        index = this.copyFlat(this.wcell, index, weightsArray);
        index = this.copyFlat(this.wreccell, index, weightsArray);
        index = this.copyFlat(this.woutgate, index, weightsArray);
        index = this.copyFlat(this.wrecoutgate, index, weightsArray);
        index = this.copyFlat(this.wpeepout, index, weightsArray);
        index = this.copyFlat(this.woutputs, index, weightsArray);
        return weightsArray;
    }

    @Override
    public void setWeightsArray(double[] weightsArray) {
        if (weightsArray.length != this.nrOfWeights) {
            System.out.println("argument of setWeightsArray has not the same length (" + weightsArray.length + ") as the number of weights (" + this.nrOfWeights + ")");
            throw new IllegalArgumentException();
        }
        int index = 0;
        index = this.copyExpand(weightsArray, index, this.win);
        index = this.copyExpand(weightsArray, index, this.wrecin);
        index = this.copyExpand(weightsArray, index, this.wpeepin);
        index = this.copyExpand(weightsArray, index, this.wforget);
        index = this.copyExpand(weightsArray, index, this.wrecforget);
        index = this.copyExpand(weightsArray, index, this.wpeepforget);
        index = this.copyExpand(weightsArray, index, this.wcell);
        index = this.copyExpand(weightsArray, index, this.wreccell);
        index = this.copyExpand(weightsArray, index, this.woutgate);
        index = this.copyExpand(weightsArray, index, this.wrecoutgate);
        index = this.copyExpand(weightsArray, index, this.wpeepout);
        index = this.copyExpand(weightsArray, index, this.woutputs);
    }

    @Override
    public void randomise() {
    }

    @Override
    public int getNumberOfInputs() {
        return this.nrInputs;
    }

    public int getNumberOfCells() {
        return this.nrCells;
    }

    @Override
    public int getNumberOfWeights() {
        return this.nrOfWeights;
    }

    @Override
    public int getNumberOfOutputs() {
        return this.nrOutputs;
    }

    public double getMutationMagnitude() {
        return this.mutationMagnitude;
    }

    public void setMutationMagnitude(double mutationMagnitude) {
        this.mutationMagnitude = mutationMagnitude;
    }

    @Override
    public Evolvable getNewInstance() {
        LSTM newLSTM = new LSTM(this.getNumberOfInputs() - this.numberOfBias, this.getNumberOfCells(), this.getNumberOfOutputs());
        double[] weights = newLSTM.getWeightsArray();
        for (int i = 0; i < weights.length; ++i) {
            weights[i] = this.random.nextDouble() * (this.initialWeightRange * 2.0) - this.initialWeightRange;
        }
        newLSTM.setWeightsArray(weights);
        return newLSTM;
    }

    @Override
    public Evolvable copy() {
        double[] weights = this.getWeightsArray();
        LSTM copy = new LSTM(this.getNumberOfInputs() - this.numberOfBias, this.getNumberOfCells(), this.getNumberOfOutputs());
        copy.setWeightsArray(weights);
        return copy;
    }

    @Override
    public void mutate() {
        double[] mutated = this.getWeightsArray();
        this.mutate(mutated);
        this.setWeightsArray(mutated);
    }

    protected void mutate(double[] array) {
        int i = 0;
        while (i < array.length) {
            int n = i++;
            array[n] = array[n] + Utils.randomCauchy(this.mutationMagnitude);
        }
    }

    protected void mutate(double[][] array) {
        for (int i = 0; i < array.length; ++i) {
            this.mutate(array[i]);
        }
    }

    @Override
    public double[] propagate(double[] doubles) {
        int i;
        for (i = 0; i < this.inputs.length - 1; ++i) {
            this.inputs[i] = doubles[i];
        }
        this.inputs[this.inputs.length - 1] = this.biasValue;
        for (i = 0; i < this.outputs.length; ++i) {
            this.outputs[i] = 0.0;
        }
        for (int cell = 0; cell < this.nrCells; ++cell) {
            this.ingatex[cell] = 0.0;
            this.cellx[cell] = 0.0;
            this.outgatex[cell] = 0.0;
            this.forgetx[cell] = 0.0;
            for (int input = 0; input < this.nrInputs; ++input) {
                int n = cell;
                this.ingatex[n] = this.ingatex[n] + this.inputs[input] * this.win[cell][input];
                int n2 = cell;
                this.outgatex[n2] = this.outgatex[n2] + this.inputs[input] * this.woutgate[cell][input];
                int n3 = cell;
                this.forgetx[n3] = this.forgetx[n3] + this.inputs[input] * this.wforget[cell][input];
                int n4 = cell;
                this.cellx[n4] = this.cellx[n4] + this.inputs[input] * this.wcell[cell][input];
            }
            if (this.time > 0) {
                for (int i2 = 0; i2 < this.nrCells; ++i2) {
                    int n = cell;
                    this.ingatex[n] = this.ingatex[n] + this.prevcelly[i2] * this.wrecin[cell][i2];
                    int n5 = cell;
                    this.cellx[n5] = this.cellx[n5] + this.prevcelly[i2] * this.wreccell[cell][i2];
                    int n6 = cell;
                    this.outgatex[n6] = this.outgatex[n6] + this.prevcelly[i2] * this.wrecoutgate[cell][i2];
                    int n7 = cell;
                    this.forgetx[n7] = this.forgetx[n7] + this.prevcelly[i2] * this.wrecforget[cell][i2];
                }
            }
            if (this.time > 0) {
                int n = cell;
                this.ingatex[n] = this.ingatex[n] + this.wpeepin[cell] * this.prevstate[cell];
            }
            this.ingatey[cell] = LSTM.f(this.ingatex[cell]);
            if (this.time > 0) {
                int n = cell;
                this.forgetx[n] = this.forgetx[n] + this.wpeepforget[cell] * this.prevstate[cell];
            }
            this.forgety[cell] = LSTM.f(this.forgetx[cell]);
            this.state[cell] = this.ingatey[cell] * LSTM.g(this.cellx[cell]);
            if (this.time > 0) {
                int n = cell;
                this.state[n] = this.state[n] + this.forgety[cell] * this.prevstate[cell];
            }
            int n = cell;
            this.outgatex[n] = this.outgatex[n] + this.wpeepout[cell] * this.state[cell];
            this.outgatey[cell] = LSTM.f(this.outgatex[cell]);
            this.celly[cell] = this.outgatey[cell] * this.state[cell];
        }
        for (i = 0; i < this.nrOutputs; ++i) {
            for (int cell = 0; cell < this.nrCells; ++cell) {
                int n = i;
                this.outputs[n] = this.outputs[n] + this.woutputs[cell][i] * this.celly[cell];
            }
        }
        ++this.time;
        return this.outputs;
    }

    private int copyFlat(double[][] source, int position, double[] target) {
        for (int i = 0; i < source.length; ++i) {
            position = this.copyFlat(source[i], position, target);
        }
        return position;
    }

    private int copyFlat(double[] source, int position, double[] target) {
        System.arraycopy(source, 0, target, position, source.length);
        return position + source.length;
    }

    private int copyExpand(double[] source, int position, double[][] target) {
        for (int i = 0; i < target.length; ++i) {
            position = this.copyExpand(source, position, target[i]);
        }
        return position;
    }

    private int copyExpand(double[] source, int position, double[] target) {
        System.arraycopy(source, position, target, 0, target.length);
        return position + target.length;
    }
}

