package ca.nengo.model.nef.impl;

import ca.nengo.math.Function;
import ca.nengo.math.impl.AbstractFunction;
import ca.nengo.math.impl.ConstantFunction;
import ca.nengo.math.impl.GradientDescentApproximator;
import ca.nengo.math.impl.IdentityFunction;
import ca.nengo.math.impl.IndicatorPDF;
import ca.nengo.model.Node;
import ca.nengo.model.StructuralException;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.neuron.Neuron;
import ca.nengo.model.neuron.impl.LIFNeuronFactory;
import ca.nengo.model.neuron.impl.SpikingNeuron;
import ca.nengo.util.MU;
import ca.nengo.util.VectorGenerator;
import ca.nengo.util.impl.RandomHypersphereVG;
import ca.nengo.util.impl.Rectifier;

/* loaded from: input_file:ca/nengo/model/nef/impl/BiasOrigin.class */
public class BiasOrigin extends DecodedOrigin {
    private static final long serialVersionUID = 1;
    private NEFEnsemble myInterneurons;
    private float[][] myConstantOutputs;

    /* loaded from: input_file:ca/nengo/model/nef/impl/BiasOrigin$BiasEncodersMaintained.class */
    private static class BiasEncodersMaintained implements GradientDescentApproximator.Constraints {
        private static final long serialVersionUID = 1;
        private double[][] myBaseWeights;
        private double[] myBiasEncoders;
        private boolean myExcitatory;

        public BiasEncodersMaintained(float[][] fArr, float[] fArr2, boolean z) {
            this.myBaseWeights = MU.convert(fArr);
            this.myBiasEncoders = MU.convert(fArr2);
            this.myExcitatory = z;
        }

        @Override // ca.nengo.math.impl.GradientDescentApproximator.Constraints
        public boolean correct(float[] fArr) {
            boolean z = true;
            for (int i = 0; i < fArr.length; i++) {
                boolean z2 = false;
                if (this.myExcitatory && fArr[i] < 0.0f) {
                    fArr[i] = Float.MIN_VALUE;
                    z2 = true;
                } else if (!this.myExcitatory && fArr[i] > 0.0f) {
                    fArr[i] = -1.4E-45f;
                    z2 = true;
                }
                for (int i2 = 0; i2 < this.myBiasEncoders.length; i2++) {
                    if ((-this.myBaseWeights[i2][i]) / fArr[i] > this.myBiasEncoders[i2]) {
                        fArr[i] = -((float) (this.myBaseWeights[i2][i] / this.myBiasEncoders[i2]));
                        z2 = true;
                    }
                }
                if (!z2) {
                    z = false;
                }
            }
            return z;
        }

        @Override // ca.nengo.math.impl.GradientDescentApproximator.Constraints
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public GradientDescentApproximator.Constraints m88clone() throws CloneNotSupportedException {
            BiasEncodersMaintained biasEncodersMaintained = (BiasEncodersMaintained) super.clone();
            biasEncodersMaintained.myBaseWeights = MU.clone(this.myBaseWeights);
            biasEncodersMaintained.myBiasEncoders = (double[]) this.myBiasEncoders.clone();
            return biasEncodersMaintained;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ca/nengo/model/nef/impl/BiasOrigin$BiasedVG.class */
    public static class BiasedVG implements VectorGenerator {
        private VectorGenerator myVG;
        private int myDim;
        private float myBias;

        public BiasedVG(VectorGenerator vectorGenerator, int i, float f) {
            this.myVG = vectorGenerator;
            this.myDim = i;
            this.myBias = f;
        }

        @Override // ca.nengo.util.VectorGenerator
        public float[][] genVectors(int i, int i2) {
            float[][] genVectors = this.myVG.genVectors(i, i2);
            for (float[] fArr : genVectors) {
                int i3 = this.myDim;
                fArr[i3] = fArr[i3] + this.myBias;
            }
            return genVectors;
        }
    }

    public BiasOrigin(NEFEnsemble nEFEnsemble, String str, Node[] nodeArr, String str2, float[][] fArr, int i, boolean z) throws StructuralException {
        super(nEFEnsemble, str, nodeArr, str2, new Function[]{new ConstantFunction(nEFEnsemble.getDimension(), 0.0f)}, getUniformBiasDecoders(fArr, z));
        this.myInterneurons = createInterneurons(String.valueOf(str) + ":interneurons", i, z);
        this.myConstantOutputs = fArr;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r1v9, types: [float[], float[][]] */
    public void optimizeDecoders(float[][] fArr, float[] fArr2, boolean z) {
        GradientDescentApproximator gradientDescentApproximator = new GradientDescentApproximator(MU.transpose(new float[]{new float[this.myConstantOutputs[0].length]}), MU.clone(this.myConstantOutputs), new BiasEncodersMaintained(fArr, fArr2, z), true);
        gradientDescentApproximator.setStartingCoefficients(MU.transpose(getDecoders())[0]);
        super.setDecoders(MU.transpose(new float[]{gradientDescentApproximator.findCoefficients(new ConstantFunction(1, 0.0f))}));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v10, types: [float[], float[][]] */
    /* JADX WARN: Type inference failed for: r1v8, types: [float[], float[][]] */
    public void optimizeInterneuronDomain(DecodedTermination decodedTermination, DecodedTermination decodedTermination2) {
        float[] range = getRange();
        range[0] = range[0] - (0.4f * (range[1] - range[0]));
        decodedTermination.setStaticBias(new float[]{-range[0]});
        decodedTermination2.setStaticBias(MU.sum(decodedTermination2.getStaticBias(), new float[]{range[0] / (range[1] - range[0])}));
        try {
            decodedTermination.setTransform(new float[]{new float[]{1.0f / (range[1] - range[0])}});
            decodedTermination2.setTransform(new float[]{new float[]{-(range[1] - range[0])}});
        } catch (StructuralException e) {
            throw new RuntimeException("Problem parameterizing termination", e);
        }
    }

    public float[] getRange() {
        float[] prod = MU.prod(MU.transpose(this.myConstantOutputs), MU.transpose(getDecoders())[0]);
        return new float[]{MU.min(prod), MU.max(prod)};
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [float[], float[][]] */
    private static float[][] getUniformBiasDecoders(float[][] fArr, boolean z) {
        ?? r0 = new float[fArr.length];
        float biasDecoder = getBiasDecoder(fArr, z);
        for (int i = 0; i < r0.length; i++) {
            float[] fArr2 = new float[1];
            fArr2[0] = biasDecoder;
            r0[i] = fArr2;
        }
        return r0;
    }

    private static float getBiasDecoder(float[][] fArr, boolean z) {
        float f = 0.0f;
        for (int i = 0; i < fArr[0].length; i++) {
            float f2 = 0.0f;
            for (float[] fArr2 : fArr) {
                f2 += fArr2[i];
            }
            if (f2 > f) {
                f = f2;
            }
        }
        return z ? 1.0f / f : (-1.0f) / f;
    }

    private NEFEnsemble createInterneurons(String str, int i, boolean z) throws StructuralException {
        Cloneable identityFunction = z ? new IdentityFunction(1, 0) : new AbstractFunction(1) { // from class: ca.nengo.model.nef.impl.BiasOrigin.1
            private static final long serialVersionUID = 1;

            @Override // ca.nengo.math.impl.AbstractFunction, ca.nengo.math.Function
            public float map(float[] fArr) {
                return 1.0f + fArr[0];
            }
        };
        NEFEnsembleFactoryImpl nEFEnsembleFactoryImpl = new NEFEnsembleFactoryImpl() { // from class: ca.nengo.model.nef.impl.BiasOrigin.2
            @Override // ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl
            protected void addDefaultOrigins(NEFEnsemble nEFEnsemble) {
            }
        };
        nEFEnsembleFactoryImpl.setEncoderFactory(new Rectifier(nEFEnsembleFactoryImpl.getEncoderFactory(), true));
        nEFEnsembleFactoryImpl.setEvalPointFactory(new BiasedVG(new RandomHypersphereVG(false, 0.5f, 0.0f), 0, z ? 0.5f : -0.5f));
        nEFEnsembleFactoryImpl.setNodeFactory(new LIFNeuronFactory(0.02f, 1.0E-4f, z ? new IndicatorPDF(200.0f, 500.0f) : new IndicatorPDF(400.0f, 800.0f), z ? new IndicatorPDF(-0.15f, 0.9f) : new IndicatorPDF(-1.2f, 0.1f)));
        nEFEnsembleFactoryImpl.setApproximatorFactory(new GradientDescentApproximator.Factory(new GradientDescentApproximator.CoefficientsSameSign(true), false));
        NEFEnsemble make = nEFEnsembleFactoryImpl.make(str, i, 1);
        for (int i2 = 0; i2 < 10; i2++) {
            SpikingNeuron spikingNeuron = (SpikingNeuron) make.getNodes()[i2];
            spikingNeuron.setBias(1.0f - spikingNeuron.getScale());
        }
        DecodedOrigin decodedOrigin = (DecodedOrigin) make.addDecodedOrigin(NEFEnsemble.X, new Function[]{identityFunction}, Neuron.AXON);
        float[][] decoders = decodedOrigin.getDecoders();
        for (int i3 = 0; i3 < 10; i3++) {
            float[] fArr = new float[1];
            fArr[0] = (1.0f / 10) / 300.0f;
            decoders[i3] = fArr;
        }
        decodedOrigin.setDecoders(decoders);
        return make;
    }

    public NEFEnsemble getInterneurons() {
        return this.myInterneurons;
    }
}
