/*
 * Decompiled with CFR 0.152.
 */
package machinelearning.networks;

import java.util.Vector;
import machinelearning.evolution.evolvables.Evolvable;
import machinelearning.networks.TWEANN;
import utopia.Utils;

public class MultiModalTWEANN
extends TWEANN {
    protected int numModes;
    protected int outputWidth;
    public static final double ADD_OUTPUT_MODE_RATE = 0.1;
    public static final double DELETE_OUTPUT_MODE_RATE = 0.1;
    public Vector<Integer> modeHistory = new Vector();

    public int getNumModes() {
        return this.numModes;
    }

    @Override
    public Evolvable getNewInstance() {
        TWEANN net = (TWEANN)super.getNewInstance();
        return new MultiModalTWEANN(net, net.getNumberOfOutputs(), 1);
    }

    private MultiModalTWEANN(TWEANN net, int outputWidth, int numModes) {
        this(net.allnodes, net.genes, net.net_id, outputWidth, numModes);
    }

    public MultiModalTWEANN(int numInputs, int numOutputs, boolean featureSelective) {
        super(numInputs, numOutputs, featureSelective);
        this.outputWidth = numOutputs;
        this.numModes = 1;
    }

    public MultiModalTWEANN(Vector<TWEANN.NNode> in, Vector<TWEANN.NNode> out, Vector<TWEANN.NNode> all, int xnet_id, Vector<TWEANN.Gene> _genes, int outputWidth, int numModes) {
        super(in, out, all, xnet_id, _genes);
        this.outputWidth = outputWidth;
        this.numModes = numModes;
    }

    public MultiModalTWEANN(Vector<TWEANN.NNode> all, Vector<TWEANN.Gene> _genes, int xnet_id, int outputWidth, int numModes) {
        super(all, _genes, xnet_id);
        this.outputWidth = outputWidth;
        this.numModes = numModes;
    }

    @Override
    public int getNumberOfOutputs() {
        return this.outputWidth;
    }

    @Override
    public TWEANN duplicateWithNewID() {
        return new MultiModalTWEANN(super.duplicateWithNewID(), this.outputWidth, this.numModes);
    }

    @Override
    public double[] propagate(double[] doubles) {
        double[] fullOutput = this.processInputsToOutputs(doubles);
        double[] modeOutput = new double[this.outputWidth];
        int bestMode = 0;
        if (fullOutput.length == modeOutput.length) {
            modeOutput = fullOutput;
        } else {
            double maxPreference = -1.7976931348623157E308;
            for (int i = 0; i < this.numModes; ++i) {
                double preferenceOutput = fullOutput[i * (this.outputWidth + 1) + this.outputWidth];
                if (!(preferenceOutput > maxPreference)) continue;
                maxPreference = preferenceOutput;
                bestMode = i;
            }
            for (int x = 0; x < this.outputWidth; ++x) {
                modeOutput[x] = fullOutput[bestMode * (this.outputWidth + 1) + x];
            }
        }
        this.modeHistory.set(bestMode, this.modeHistory.get(bestMode) + 1);
        return modeOutput;
    }

    public void mutate_add_output_mode() {
        System.out.println("NET" + this.net_id + ":mutate_add_output_mode");
        if (super.getNumberOfOutputs() == this.outputWidth) {
            System.out.println("NET" + this.net_id + ":add preference for first mode");
            this.addOutputNode();
        }
        for (int i = 0; i < this.outputWidth + 1; ++i) {
            System.out.println("NET" + this.net_id + ":add node");
            this.addOutputNode();
        }
        ++this.numModes;
    }

    public int addOutputNode() {
        double weight = (double)Utils.randposneg() * Utils.randomFloat();
        int nodenum1 = Utils.randomInt(0, this.allnodes.size() - this.outputs.size() - 1);
        TWEANN.NNode source = (TWEANN.NNode)this.allnodes.elementAt(nodenum1);
        int curnode_id = TWEANN.getCur_node_id_and_increment();
        TWEANN.NNode new_node = new TWEANN.NNode(2, curnode_id, 5);
        double Gene_innov1 = TWEANN.getCurr_innov_num_and_increment();
        TWEANN.Gene newGene1 = new TWEANN.Gene(weight, source, new_node, Gene_innov1, 0.0);
        TWEANN.innovations.add(new TWEANN.Innovation(newGene1));
        this.genes.add(newGene1);
        this.allnodes.add(new_node);
        this.outputs.add(new_node);
        this.rebuild();
        return new_node.node_id;
    }

    public void mutate_delete_least_used_output_mode() {
        System.out.println("NET" + this.net_id + ":mutate_delete_least_used_output_mode");
        if (this.numModes > 1) {
            System.out.println("NET" + this.net_id + ":Enough modes to delete");
            int target = 0;
            int currentLeast = Integer.MAX_VALUE;
            for (int i = 0; i < this.modeHistory.size(); ++i) {
                if (this.modeHistory.get(i) >= currentLeast) continue;
                target = i;
                currentLeast = this.modeHistory.get(i);
            }
            System.out.println("NET" + this.net_id + ":Least used mode = " + currentLeast + " out of " + this.modeHistory.size());
            boolean success = this.deleteOutputMode(target);
            if (success) {
                --this.numModes;
            } else {
                System.out.println("FATAL ERROR: Failure to delete mode");
                System.out.println("modeHistory: " + this.modeHistory);
                System.out.println(this);
                System.exit(1);
            }
        }
    }

    public boolean deleteOutputMode(int mode) {
        int start = this.firstOutputOfMode(mode);
        boolean result = true;
        for (int i = this.outputWidth; i >= 0; --i) {
            System.out.println("NET" + this.net_id + "Removing: " + i + "/" + this.outputs.size());
            int id = ((TWEANN.NNode)this.outputs.get((int)(start + i))).node_id;
            boolean nodeResult = this.deleteOutputNode(id);
            if (!nodeResult) {
                System.out.println("ERROR:");
                System.out.println("start: " + start);
                System.out.println("Failed removing id: " + id);
                System.out.println("Num removed: " + i);
                System.out.println("outputs: " + this.outputs);
                System.out.println("numModes: " + this.numModes);
                System.out.println("outputWidth: " + this.outputWidth);
            }
            result = result && nodeResult;
        }
        return result;
    }

    private int firstOutputOfMode(int mode) {
        return mode * (this.outputWidth + 1);
    }

    public boolean deleteOutputNode(int nodeID) {
        boolean result;
        int i;
        boolean outputRemoved = false;
        boolean allRemoved = false;
        int geneCountRemoved = 0;
        for (i = 0; i < this.outputs.size(); ++i) {
            if (((TWEANN.NNode)this.outputs.get((int)i)).node_id != nodeID) continue;
            this.outputs.remove(i);
            outputRemoved = true;
            break;
        }
        for (i = this.allnodes.size() - 1; i >= 0; --i) {
            if (((TWEANN.NNode)this.allnodes.get((int)i)).node_id != nodeID) continue;
            this.allnodes.remove(i);
            allRemoved = true;
            break;
        }
        for (i = this.genes.size() - 1; i >= 0; --i) {
            if (((TWEANN.Gene)this.genes.get((int)i)).lnk.in_node.node_id != nodeID) continue;
            this.genes.remove(i);
            ++geneCountRemoved;
        }
        for (i = this.genes.size() - 1; i >= 0; --i) {
            if (((TWEANN.Gene)this.genes.get((int)i)).lnk.out_node.node_id != nodeID) continue;
            this.genes.remove(i);
            ++geneCountRemoved;
        }
        boolean bl = result = outputRemoved && allRemoved && geneCountRemoved >= 1;
        if (!result) {
            System.out.println("ERROR:");
            System.out.println("outputRemoved: " + outputRemoved);
            System.out.println("allRemoved: " + allRemoved);
            System.out.println("geneCountRemoved: " + geneCountRemoved);
        }
        return result;
    }

    @Override
    public void mutate() {
        this.mutate(true);
    }

    public void mutate(boolean modeMutate) {
        super.mutate();
        if (modeMutate) {
            if (Utils.randomFloat() < 0.1) {
                this.mutate_delete_least_used_output_mode();
            }
            if (Utils.randomFloat() < 0.1) {
                this.mutate_add_output_mode();
            }
        }
    }

    @Override
    public void flush() {
        super.flush();
        this.modeHistory = new Vector(this.numModes);
        for (int i = 0; i < this.numModes; ++i) {
            this.modeHistory.add(0);
        }
    }
}

