/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.ann;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import org.ojalgo.ann.CalculationLayer;
import org.ojalgo.ann.FileFormat;
import org.ojalgo.ann.LayerTemplate;
import org.ojalgo.ann.NetworkBuilder;
import org.ojalgo.ann.NetworkInvoker;
import org.ojalgo.ann.NetworkTrainer;
import org.ojalgo.ann.NodeDropper;
import org.ojalgo.ann.TrainingConfiguration;
import org.ojalgo.data.DataBatch;
import org.ojalgo.function.BinaryFunction;
import org.ojalgo.function.PrimitiveFunction;
import org.ojalgo.function.aggregator.Aggregator;
import org.ojalgo.function.constant.PrimitiveMath;
import org.ojalgo.function.special.MissingMath;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.matrix.store.Primitive32Store;
import org.ojalgo.matrix.store.Primitive64Store;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Structure2D;

public final class ArtificialNeuralNetwork {
    private transient TrainingConfiguration myConfiguration = null;
    private final PhysicalStore.Factory<Double, ?> myFactory;
    private final CalculationLayer[] myLayers;

    public static NetworkBuilder builder(int numberOfNetworkInputNodes) {
        return ArtificialNeuralNetwork.builder(Primitive64Store.FACTORY, numberOfNetworkInputNodes);
    }

    @Deprecated
    public static NetworkTrainer builder(int numberOfInputNodes, int ... nodesPerCalculationLayer) {
        return ArtificialNeuralNetwork.builder(Primitive64Store.FACTORY, numberOfInputNodes, nodesPerCalculationLayer);
    }

    public static NetworkBuilder builder(PhysicalStore.Factory<Double, ?> factory, int numberOfNetworkInputNodes) {
        return new NetworkBuilder(factory, numberOfNetworkInputNodes);
    }

    @Deprecated
    public static NetworkTrainer builder(PhysicalStore.Factory<Double, ?> factory, int numberOfInputNodes, int ... nodesPerCalculationLayer) {
        NetworkBuilder builder = ArtificialNeuralNetwork.builder(factory, numberOfInputNodes);
        for (int i = 0; i < nodesPerCalculationLayer.length; ++i) {
            builder.layer(nodesPerCalculationLayer[i]);
        }
        return builder.get().newTrainer();
    }

    public static ArtificialNeuralNetwork from(DataInput input) throws IOException {
        return FileFormat.read(null, input);
    }

    public static ArtificialNeuralNetwork from(File file) {
        return ArtificialNeuralNetwork.from(null, file);
    }

    public static ArtificialNeuralNetwork from(Path path, OpenOption ... options) {
        return ArtificialNeuralNetwork.from(null, path, options);
    }

    public static ArtificialNeuralNetwork from(PhysicalStore.Factory<Double, ?> factory, DataInput input) throws IOException {
        return FileFormat.read(factory, input);
    }

    public static ArtificialNeuralNetwork from(PhysicalStore.Factory<Double, ?> factory, File file) {
        ArtificialNeuralNetwork artificialNeuralNetwork;
        DataInputStream input = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
        try {
            artificialNeuralNetwork = ArtificialNeuralNetwork.from(factory, input);
        }
        catch (Throwable throwable) {
            try {
                try {
                    input.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException cause) {
                throw new RuntimeException(cause);
            }
        }
        input.close();
        return artificialNeuralNetwork;
    }

    public static ArtificialNeuralNetwork from(PhysicalStore.Factory<Double, ?> factory, Path path, OpenOption ... options) {
        ArtificialNeuralNetwork artificialNeuralNetwork;
        DataInputStream input = new DataInputStream(new BufferedInputStream(Files.newInputStream(path, options)));
        try {
            artificialNeuralNetwork = ArtificialNeuralNetwork.from(factory, input);
        }
        catch (Throwable throwable) {
            try {
                try {
                    input.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException cause) {
                throw new RuntimeException(cause);
            }
        }
        input.close();
        return artificialNeuralNetwork;
    }

    static void doIdentity(PhysicalStore<Double> output) {
    }

    static void doReLU(PhysicalStore<Double> output) {
        output.modifyAll(PrimitiveMath.MAX.second(PrimitiveMath.ZERO));
    }

    static void doSigmoid(PhysicalStore<Double> output) {
        output.modifyAll(PrimitiveMath.LOGISTIC);
    }

    static void doSoftMax(PhysicalStore<Double> output) {
        output.modifyAll(PrimitiveMath.EXP);
        Primitive64Store totals = (Primitive64Store)output.reduceRows(Aggregator.SUM).collect(Primitive64Store.FACTORY);
        output.onRows((BinaryFunction)PrimitiveMath.DIVIDE, (Access1D)totals).supplyTo(output);
    }

    static void doTanh(PhysicalStore<Double> output) {
        output.modifyAll(PrimitiveMath.TANH);
    }

    ArtificialNeuralNetwork(NetworkBuilder builder) {
        this.myFactory = builder.getFactory();
        List<LayerTemplate> templates = builder.getLayers();
        this.myLayers = new CalculationLayer[templates.size()];
        for (int i = 0; i < this.myLayers.length; ++i) {
            LayerTemplate layerTemplate = templates.get(i);
            this.myLayers[i] = new CalculationLayer(this.myFactory, layerTemplate.inputs, layerTemplate.outputs, layerTemplate.activator);
        }
    }

    ArtificialNeuralNetwork(PhysicalStore.Factory<Double, ?> factory, int inputs, int[] layers) {
        this.myFactory = factory;
        this.myLayers = new CalculationLayer[layers.length];
        int tmpIn = inputs;
        int tmpOut = inputs;
        for (int i = 0; i < layers.length; ++i) {
            tmpIn = tmpOut;
            tmpOut = layers[i];
            this.myLayers[i] = new CalculationLayer(factory, tmpIn, tmpOut, Activator.SIGMOID);
        }
    }

    public int depth() {
        return this.myLayers.length;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || !(obj instanceof ArtificialNeuralNetwork)) {
            return false;
        }
        ArtificialNeuralNetwork other = (ArtificialNeuralNetwork)obj;
        return Arrays.equals(this.myLayers, other.myLayers);
    }

    public Activator getActivator(int layer) {
        return this.myLayers[layer].getActivator();
    }

    public double getBias(int layer, int output) {
        return this.myLayers[layer].getBias(output);
    }

    public double getWeight(int layer, int input, int output) {
        return this.myLayers[layer].getWeight(input, output);
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        return 31 * result + Arrays.hashCode(this.myLayers);
    }

    public NetworkInvoker newInvoker() {
        return this.newInvoker(1);
    }

    public NetworkInvoker newInvoker(int batchSize) {
        return new NetworkInvoker(this, batchSize);
    }

    public NetworkTrainer newTrainer() {
        return this.newTrainer(1);
    }

    public NetworkTrainer newTrainer(int batchSize) {
        NetworkTrainer trainer = new NetworkTrainer(this, batchSize);
        if (this.getOutputActivator() == Activator.SOFTMAX) {
            trainer.error(Error.CROSS_ENTROPY);
        } else {
            trainer.error(Error.HALF_SQUARED_DIFFERENCE);
        }
        return trainer;
    }

    public Structure2D[] structure() {
        Structure2D[] retVal = new Structure2D[this.myLayers.length];
        for (int l = 0; l < retVal.length; ++l) {
            retVal[l] = this.myLayers[l].getStructure();
        }
        return retVal;
    }

    public String toString() {
        StringBuilder tmpBuilder = new StringBuilder();
        tmpBuilder.append("ArtificialNeuralNetwork [Layers=");
        for (CalculationLayer calculationLayer : this.myLayers) {
            tmpBuilder.append("\n");
            tmpBuilder.append(calculationLayer);
        }
        tmpBuilder.append("\n");
        tmpBuilder.append("]");
        return tmpBuilder.toString();
    }

    public int width() {
        int retVal = this.myLayers[0].countInputNodes();
        for (CalculationLayer layer : this.myLayers) {
            retVal = Math.max(retVal, layer.countOutputNodes());
        }
        return retVal;
    }

    public void writeTo(DataOutput output) throws IOException {
        int version = this.myFactory == Primitive32Store.FACTORY ? 2 : 1;
        FileFormat.write(this, version, output);
    }

    public void writeTo(File file) {
        try (DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));){
            this.writeTo(output);
        }
        catch (IOException cause) {
            throw new RuntimeException(cause);
        }
    }

    public void writeTo(Path path, OpenOption ... options) {
        try (DataOutputStream output = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(path, options)));){
            this.writeTo(output);
        }
        catch (IOException cause) {
            throw new RuntimeException(cause);
        }
    }

    void adjust(int layer, PhysicalStore<Double> input, PhysicalStore<Double> output, PhysicalStore<Double> upstreamGradient, PhysicalStore<Double> downstreamGradient) {
        this.myLayers[layer].adjust(input, output, upstreamGradient, downstreamGradient, -this.myConfiguration.learningRate, this.myConfiguration.probabilityDidKeepInput(layer), this.myConfiguration.regularisation());
    }

    int countInputNodes() {
        return this.myLayers[0].countInputNodes();
    }

    int countInputNodes(int layer) {
        return this.myLayers[layer].countInputNodes();
    }

    int countOutputNodes() {
        return this.myLayers[this.myLayers.length - 1].countOutputNodes();
    }

    int countOutputNodes(int layer) {
        return this.myLayers[layer].countOutputNodes();
    }

    Activator getOutputActivator() {
        return this.myLayers[this.myLayers.length - 1].getActivator();
    }

    List<MatrixStore<Double>> getWeights() {
        ArrayList<MatrixStore<Double>> retVal = new ArrayList<MatrixStore<Double>>();
        for (int i = 0; i < this.myLayers.length; ++i) {
            retVal.add(this.myLayers[i].getLogicalWeights());
        }
        return retVal;
    }

    PhysicalStore<Double> invoke(int layer, PhysicalStore<Double> input, PhysicalStore<Double> output) {
        if (this.myConfiguration != null) {
            return this.myLayers[layer].invoke(input, output, this.myConfiguration.probabilityWillKeepOutput(layer, this.depth()));
        }
        return this.myLayers[layer].invoke(input, output);
    }

    DataBatch newBatch(int rows, int columns) {
        return DataBatch.from(this.myFactory, rows, columns);
    }

    PhysicalStore<Double> newStore(int rows, int columns) {
        return (PhysicalStore)this.myFactory.make(rows, columns);
    }

    void randomise() {
        for (int l = 0; l < this.myLayers.length; ++l) {
            this.myLayers[l].randomise();
        }
    }

    void scale(int layer, double factor) {
        this.myLayers[layer].scale(factor);
    }

    void setActivator(int layer, Activator activator) {
        this.myLayers[layer].setActivator(activator);
    }

    void setBias(int layer, int output, double bias) {
        this.myLayers[layer].setBias(output, bias);
    }

    void setConfiguration(TrainingConfiguration configuration) {
        if (this.myConfiguration != null && configuration == null) {
            int limit = this.depth();
            for (int l = 1; l < limit; ++l) {
                this.scale(l, this.myConfiguration.probabilityDidKeepInput(l));
            }
        }
        this.myConfiguration = configuration;
    }

    void setWeight(int layer, int input, int output, double weight) {
        this.myLayers[layer].setWeight(input, output, weight);
    }

    public static enum Error implements PrimitiveFunction.Binary
    {
        CROSS_ENTROPY((target, current) -> -target * Math.log(current), (target, current) -> current - target),
        HALF_SQUARED_DIFFERENCE((target, current) -> PrimitiveMath.HALF * (target - current) * (target - current), (target, current) -> current - target);

        private final PrimitiveFunction.Binary myDerivative;
        private final PrimitiveFunction.Binary myFunction;

        private Error(PrimitiveFunction.Binary function, PrimitiveFunction.Binary derivative) {
            this.myFunction = function;
            this.myDerivative = derivative;
        }

        @Override
        public double invoke(Access1D<?> target, Access1D<?> current) {
            int limit = MissingMath.toMinIntExact(target.count(), current.count());
            double retVal = PrimitiveMath.ZERO;
            for (int i = 0; i < limit; ++i) {
                retVal += this.myFunction.invoke(target.doubleValue(i), current.doubleValue(i));
            }
            return retVal;
        }

        @Override
        public double invoke(double target, double current) {
            return this.myFunction.invoke(target, current);
        }

        PrimitiveFunction.Binary getDerivative() {
            return this.myDerivative;
        }
    }

    public static enum Activator {
        IDENTITY(ArtificialNeuralNetwork::doIdentity, arg -> PrimitiveMath.ONE, true),
        RELU(ArtificialNeuralNetwork::doReLU, arg -> arg > PrimitiveMath.ZERO ? PrimitiveMath.ONE : PrimitiveMath.ZERO, true),
        SIGMOID(ArtificialNeuralNetwork::doSigmoid, arg -> arg * (PrimitiveMath.ONE - arg), true),
        SOFTMAX(ArtificialNeuralNetwork::doSoftMax, arg -> PrimitiveMath.ONE, false),
        TANH(ArtificialNeuralNetwork::doTanh, arg -> PrimitiveMath.ONE - arg * arg, true);

        private final PrimitiveFunction.Unary myDerivativeInTermsOfOutput;
        private final Consumer<PhysicalStore<Double>> myFunction;
        private final boolean mySingleFolded;

        private Activator(Consumer<PhysicalStore<Double>> function, PrimitiveFunction.Unary derivativeInTermsOfOutput, boolean singleFolded) {
            this.myFunction = function;
            this.myDerivativeInTermsOfOutput = derivativeInTermsOfOutput;
            this.mySingleFolded = singleFolded;
        }

        void activate(PhysicalStore<Double> output) {
            this.myFunction.accept(output);
        }

        void activate(PhysicalStore<Double> output, double probabilityToKeep) {
            if (PrimitiveMath.ZERO >= probabilityToKeep || probabilityToKeep > PrimitiveMath.ONE) {
                throw new IllegalArgumentException();
            }
            this.myFunction.accept(output);
            output.modifyAll(NodeDropper.of(probabilityToKeep));
        }

        PrimitiveFunction.Unary getDerivativeInTermsOfOutput() {
            return this.myDerivativeInTermsOfOutput;
        }

        boolean isSingleFolded() {
            return this.mySingleFolded;
        }
    }
}

