package org.joone.engine.extenders;

import org.joone.engine.RpropParameters;
import org.joone.log.ILogger;
import org.joone.log.LoggerFactory;

/* loaded from: input_file:org/joone/engine/extenders/RpropExtender.class */
public class RpropExtender extends DeltaRuleExtender {
    private static final ILogger log = LoggerFactory.getLogger(RpropExtender.class);
    protected double[][] theDeltas;
    protected double[][] thePreviousGradients;
    protected RpropParameters theRpropParameters;
    protected double[][] theSummedGradients;

    public void reinit() {
        if (getLearner().getMonitor().getLearningRate() != 1.0d) {
            log.warn("RPROP learning rate should be equal to 1.");
        }
        if (getLearner().getLayer() != null) {
            this.thePreviousGradients = new double[getLearner().getLayer().getRows()][1];
            this.theSummedGradients = new double[this.thePreviousGradients.length][1];
            this.theDeltas = new double[this.thePreviousGradients.length][1];
        } else if (getLearner().getSynapse() != null) {
            int inputDimension = getLearner().getSynapse().getInputDimension();
            int outputDimension = getLearner().getSynapse().getOutputDimension();
            this.thePreviousGradients = new double[inputDimension][outputDimension];
            this.theSummedGradients = new double[inputDimension][outputDimension];
            this.theDeltas = new double[inputDimension][outputDimension];
        }
        for (int i = 0; i < this.theDeltas.length; i++) {
            for (int i2 = 0; i2 < this.theDeltas[0].length; i2++) {
                this.theDeltas[i][i2] = getParameters().getInitialDelta(i, i2);
            }
        }
    }

    @Override // org.joone.engine.extenders.DeltaRuleExtender
    public double getDelta(double[] dArr, int i, double d) {
        double d2 = 0.0d;
        double[] dArr2 = this.theSummedGradients[i];
        dArr2[0] = dArr2[0] - d;
        if (getLearner().getUpdateWeightExtender().storeWeightsBiases()) {
            if (this.thePreviousGradients[i][0] * this.theSummedGradients[i][0] > 0.0d) {
                this.theDeltas[i][0] = Math.min(this.theDeltas[i][0] * getParameters().getEtaInc(), getParameters().getMaxDelta());
                d2 = (-1.0d) * sign(this.theSummedGradients[i][0]) * this.theDeltas[i][0];
                this.thePreviousGradients[i][0] = this.theSummedGradients[i][0];
            } else if (this.thePreviousGradients[i][0] * this.theSummedGradients[i][0] < 0.0d) {
                this.theDeltas[i][0] = Math.max(this.theDeltas[i][0] * getParameters().getEtaDec(), getParameters().getMinDelta());
                d2 = (-1.0d) * getLearner().getLayer().getBias().delta[i][0];
                this.thePreviousGradients[i][0] = 0.0d;
            } else {
                d2 = (-1.0d) * sign(this.theSummedGradients[i][0]) * this.theDeltas[i][0];
                this.thePreviousGradients[i][0] = this.theSummedGradients[i][0];
            }
            this.theSummedGradients[i][0] = 0.0d;
        }
        return d2;
    }

    @Override // org.joone.engine.extenders.DeltaRuleExtender
    public double getDelta(double[] dArr, int i, double[] dArr2, int i2, double d) {
        double d2 = 0.0d;
        double[] dArr3 = this.theSummedGradients[i];
        dArr3[i2] = dArr3[i2] - d;
        if (getLearner().getUpdateWeightExtender().storeWeightsBiases()) {
            if (this.thePreviousGradients[i][i2] * this.theSummedGradients[i][i2] > 0.0d) {
                this.theDeltas[i][i2] = Math.min(this.theDeltas[i][i2] * getParameters().getEtaInc(), getParameters().getMaxDelta());
                d2 = (-1.0d) * sign(this.theSummedGradients[i][i2]) * this.theDeltas[i][i2];
                this.thePreviousGradients[i][i2] = this.theSummedGradients[i][i2];
            } else if (this.thePreviousGradients[i][i2] * this.theSummedGradients[i][i2] < 0.0d) {
                this.theDeltas[i][i2] = Math.max(this.theDeltas[i][i2] * getParameters().getEtaDec(), getParameters().getMinDelta());
                d2 = (-1.0d) * getLearner().getSynapse().getWeights().delta[i][i2];
                this.thePreviousGradients[i][i2] = 0.0d;
            } else {
                d2 = (-1.0d) * sign(this.theSummedGradients[i][i2]) * this.theDeltas[i][i2];
                this.thePreviousGradients[i][i2] = this.theSummedGradients[i][i2];
            }
            this.theSummedGradients[i][i2] = 0.0d;
        }
        return d2;
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void postBiasUpdate(double[] dArr) {
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void postWeightUpdate(double[] dArr, double[] dArr2) {
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void preBiasUpdate(double[] dArr) {
        if (this.theDeltas == null || this.theDeltas.length != getLearner().getLayer().getRows()) {
            reinit();
        }
    }

    @Override // org.joone.engine.extenders.LearnerExtender
    public void preWeightUpdate(double[] dArr, double[] dArr2) {
        if (this.theDeltas != null && this.theDeltas.length == getLearner().getSynapse().getInputDimension() && this.theDeltas[0].length == getLearner().getSynapse().getOutputDimension()) {
            return;
        }
        reinit();
    }

    public RpropParameters getParameters() {
        if (this.theRpropParameters == null) {
            this.theRpropParameters = new RpropParameters();
        }
        return this.theRpropParameters;
    }

    public void setParameters(RpropParameters rpropParameters) {
        this.theRpropParameters = rpropParameters;
    }

    protected double sign(double d) {
        if (d > 0.0d) {
            return 1.0d;
        }
        return d < 0.0d ? -1.0d : 0.0d;
    }
}
