package seed.minerva;

import java.lang.reflect.Field;
import java.util.Arrays;
import oneLiners.OneLiners;
import seed.matrix.CholeskyDecomposition;
import seed.matrix.DenseMatrix;
import seed.matrix.DiagonalMatrix;
import seed.matrix.LUDecomposition;
import seed.matrix.LowerSymmetricDenseMatrix;
import seed.matrix.Mat;
import seed.matrix.Matrix;
import seed.minerva.nodetypes.DoubleArray;
import seed.minerva.nodetypes.DoubleMatrix;
import seed.minerva.nodetypes.DoubleValue;

/* loaded from: input_file:seed/minerva/MultivariateNormalOld.class */
public class MultivariateNormalOld extends Multivariate {
    DoubleArray mean;
    DoubleMatrix cov;
    DoubleArray covdiag;
    DoubleValue var;
    DoubleMatrix invcov;
    public static final String MEAN = "mean";
    public static final String COV = "cov";
    public static final String INVCOV = "invcov";
    Matrix meanMatrix;
    Matrix covMatrix;
    Matrix covInvMatrix;
    Matrix choleskyLower;
    double[] covdiagArray;
    double[] covinvdiagArray;
    double logAbsCovDeterminant;
    double absCovDeterminant;
    boolean spd;
    boolean tryToNormalise;
    boolean normalised;

    public MultivariateNormalOld() {
        this(null, "", null, null, null, 2);
    }

    public MultivariateNormalOld(String str) {
        this(null, str, null, null, null, 2);
    }

    public MultivariateNormalOld(Graph graph, String str) {
        this(graph, str, null, null, null, 2);
    }

    public MultivariateNormalOld(Graph graph, String str, Node node, Node node2, double[] dArr, int i) {
        super(str);
        this.tryToNormalise = true;
        this.normalised = true;
        addConnectionPoint(new ConnectionPoint("mean", DoubleArray.class, false, getField("mean")));
        addConnectionPoint(new ConnectionPoint("cov", new Class[]{DoubleMatrix.class, DoubleArray.class, DoubleValue.class}, true, new Field[]{getField("cov"), getField("covdiag"), getField("var")}));
        addConnectionPoint(new ConnectionPoint("invcov", DoubleMatrix.class, true, getField("invcov")));
        if (graph != null) {
            graph.addNode(this);
        }
        if (node != null) {
            setConnection("mean", node);
        }
        if (node2 != null) {
            setConnection("cov", node2);
        }
        if (i == 1) {
            setObserved(true);
        }
        if (dArr != null) {
            setDoubleArray(dArr);
        }
    }

    @Override // seed.minerva.Multivariate
    public double logPdf(double[] dArr) {
        update();
        double exponent = exponent(dArr);
        return this.spd ? (((-0.5d) * this.logAbsCovDeterminant) - ((dArr.length / 2.0d) * Math.log(6.283185307179586d))) + exponent : exponent;
    }

    public boolean isDiagonal() {
        update();
        if (this.covdiag != null || this.var != null) {
            return true;
        }
        double[][] doubleMatrix = this.cov != null ? this.cov.getDoubleMatrix() : this.invcov.getDoubleMatrix();
        for (int i = 0; i < doubleMatrix.length; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                if (doubleMatrix[i][i2] < (-1.0E-24d) || doubleMatrix[i][i2] > 1.0E-24d) {
                    return false;
                }
            }
        }
        return true;
    }

    /* JADX WARN: Type inference failed for: r2v16, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r2v6, types: [double[], double[][]] */
    private double exponent(double[] dArr) {
        double d = 0.0d;
        if (this.var != null || this.covdiag != null) {
            double[] dArr2 = this.meanMatrix.toArray()[0];
            for (int i = 0; i < dArr.length; i++) {
                d += this.covinvdiagArray[i] * Math.pow(dArr[i] - dArr2[i], 2.0d);
            }
            d *= -0.5d;
        } else if (this.covInvMatrix != null) {
            Matrix transpose = new DenseMatrix(new double[]{dArr}).minus(this.meanMatrix).transpose();
            d = (-0.5d) * transpose.transpose().times(this.covInvMatrix).times(transpose).get(0, 0);
        } else if (this.covMatrix != null) {
            if (MinervaSettings.getDbgLevel() > 2) {
                System.out.print("MVN " + getPath() + ": Inverting covariance matrix..");
            }
            this.covInvMatrix = Mat.inv(this.covMatrix);
            if (MinervaSettings.getDbgLevel() > 2) {
                System.out.println("Done");
            }
            Matrix transpose2 = new DenseMatrix(new double[]{dArr}).minus(this.meanMatrix).transpose();
            d = (-0.5d) * transpose2.transpose().times(this.covInvMatrix).times(transpose2).get(0, 0);
        }
        return d;
    }

    public double chi2() {
        update();
        return (-2.0d) * exponent(this.value);
    }

    public double chi2Normalised() {
        update();
        if (!isDiagonal()) {
            throw new MinervaRuntimeException("Covariance matrix must be diagonal for chi2 to be calculated");
        }
        double d = 0.0d;
        double[] dArr = this.meanMatrix.toArray()[0];
        if (this.cov != null) {
            for (int i = 0; i < dArr.length; i++) {
                d += (1.0d / this.covMatrix.get(i, i)) * (this.value[i] - dArr[i]) * (this.value[i] - dArr[i]);
            }
        } else if (this.invcov != null) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d += this.covInvMatrix.get(i2, i2) * (this.value[i2] - dArr[i2]) * (this.value[i2] - dArr[i2]);
            }
        } else {
            for (int i3 = 0; i3 < this.covinvdiagArray.length; i3++) {
                d += this.covinvdiagArray[i3] * (this.value[i3] - dArr[i3]) * (this.value[i3] - dArr[i3]);
            }
        }
        return d / dArr.length;
    }

    public double[] diffNormalised() {
        update();
        if (!isDiagonal()) {
            throw new MinervaRuntimeException("Covariance matrix must be diagonal for chi2 to be calculated");
        }
        double[] dArr = this.meanMatrix.toArray()[0];
        double[] dArr2 = new double[dArr.length];
        if (this.cov != null) {
            for (int i = 0; i < dArr2.length; i++) {
                dArr2[i] = Math.sqrt(1.0d / this.covMatrix.get(i, i)) * (this.value[i] - dArr[i]);
            }
        } else if (this.invcov != null) {
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                dArr2[i2] = Math.sqrt(this.covInvMatrix.get(i2, i2)) * (this.value[i2] - dArr[i2]);
            }
        } else {
            for (int i3 = 0; i3 < this.covinvdiagArray.length; i3++) {
                dArr2[i3] = Math.sqrt(this.covinvdiagArray[i3]) * (this.value[i3] - dArr[i3]);
            }
        }
        return dArr2;
    }

    @Override // seed.minerva.Multivariate
    public double[] sample() {
        update();
        if (!this.spd) {
            throw new RuntimeException("Covariance matrix of multivariate normal not (symmetric and positive definite)");
        }
        int dim = dim();
        double[] dArr = new double[dim];
        double[] dArr2 = new double[dim];
        for (int i = 0; i < dim; i++) {
            dArr2[i] = RandomManager.instance().nextNormal(0.0d, 1.0d);
        }
        if (this.cov == null && this.invcov == null) {
            for (int i2 = 0; i2 < dim; i2++) {
                dArr[i2] = this.meanMatrix.get(0, i2);
                int i3 = i2;
                dArr[i3] = dArr[i3] + (Math.sqrt(this.covdiagArray[i2]) * dArr2[i2]);
            }
        } else {
            if (this.covMatrix == null) {
                if (MinervaSettings.getDbgLevel() > 2) {
                    System.out.print("MVN " + getPath() + ": Inverting inverse covariance matrix..");
                }
                this.covMatrix = Mat.inv(this.covInvMatrix);
                if (MinervaSettings.getDbgLevel() > 2) {
                    System.out.println("Done");
                }
            }
            if (this.choleskyLower == null) {
                this.choleskyLower = new CholeskyDecomposition(this.covMatrix).getL();
            }
            double[][] array = this.choleskyLower.toArray();
            for (int i4 = 0; i4 < dim; i4++) {
                dArr[i4] = this.meanMatrix.get(0, i4);
                for (int i5 = 0; i5 <= i4; i5++) {
                    int i6 = i4;
                    dArr[i6] = dArr[i6] + (array[i4][i5] * dArr2[i5]);
                }
            }
        }
        return dArr;
    }

    @Override // seed.minerva.ProbabilityNode
    public int dim() {
        update();
        if (this.value == null) {
            return 0;
        }
        return this.value.length;
    }

    @Override // seed.minerva.ProbabilityNode
    public double[] mean() {
        update();
        if (this.meanMatrix == null) {
            return null;
        }
        return this.meanMatrix.getFlatArray();
    }

    @Override // seed.minerva.ProbabilityNode
    public double[] sigma() {
        update();
        double[] dArr = new double[this.meanMatrix.getFlatArray().length];
        if (this.cov != null) {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = Math.sqrt(this.covMatrix.get(i, i));
            }
        } else if (this.invcov != null) {
            if (this.covMatrix == null) {
                if (MinervaSettings.getDbgLevel() > 2) {
                    System.out.print("MVN " + getPath() + ": Inverting inverse covariance matrix..");
                }
                this.covMatrix = Mat.inv(this.covInvMatrix);
                if (MinervaSettings.getDbgLevel() > 2) {
                    System.out.println("Done");
                }
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = Math.sqrt(this.covMatrix.get(i2, i2));
            }
        } else if (this.covdiag != null || this.var != null) {
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = Math.sqrt(this.covdiagArray[i3]);
            }
        }
        return dArr;
    }

    public double[] sigmaRaw() {
        update();
        return this.enable == null ? sigma() : unshuffleArray(sigma(), this.enableIndicies, this.enable.getBooleanArray().length, this.disableReplacementValue);
    }

    public double[][] cov() {
        update();
        if (this.covMatrix == null) {
            if (this.covdiag != null || this.var != null) {
                this.covMatrix = new DiagonalMatrix(this.covdiagArray);
            } else if (this.covMatrix == null) {
                if (MinervaSettings.getDbgLevel() > 0) {
                    System.out.print("MVN " + getPath() + ": Inverting inverse covariance matrix..");
                }
                this.covMatrix = Mat.inv(this.covInvMatrix);
                if (MinervaSettings.getDbgLevel() > 0) {
                    System.out.println("Done");
                }
            }
        }
        return this.covMatrix.toArray();
    }

    public double[][] covRaw() {
        update();
        cov();
        return this.enable == null ? this.covMatrix.toArray() : unshuffleSquareMatrix(this.covMatrix.toArray(), this.enableIndicies, this.enable.getBooleanArray().length, this.disableReplacementValue);
    }

    public double[] meanRaw() {
        update();
        return this.enable == null ? this.meanMatrix.getFlatArray() : unshuffleArray(this.meanMatrix.getFlatArray(), this.enableIndicies, this.enable.getBooleanArray().length, this.disableReplacementValue);
    }

    public double[][] covInv() {
        update();
        if (this.covInvMatrix == null) {
            if (this.covdiag != null || this.var != null) {
                this.covInvMatrix = new DiagonalMatrix(Mat.divide(1.0d, this.covdiagArray));
            } else if (this.covInvMatrix == null) {
                if (MinervaSettings.getDbgLevel() > 2) {
                    System.out.print("MVN " + getPath() + ": Inverting covariance matrix..");
                }
                this.covInvMatrix = Mat.inv(this.covMatrix);
                if (MinervaSettings.getDbgLevel() > 2) {
                    System.out.println("Done");
                }
            }
        }
        return this.covInvMatrix.toArray();
    }

    public double[] covInvDiag() {
        update();
        if (this.cov == null) {
            return this.covinvdiagArray;
        }
        double[] dArr = new double[dim()];
        for (int i = 0; i < dim(); i++) {
            dArr[i] = this.covInvMatrix.get(i, i);
        }
        return dArr;
    }

    public double[] covDiag() {
        update();
        if (this.cov == null) {
            return this.covdiagArray;
        }
        double[] dArr = new double[dim()];
        for (int i = 0; i < dim(); i++) {
            dArr[i] = this.covMatrix.get(i, i);
        }
        return dArr;
    }

    @Override // seed.minerva.ProbabilityNode
    public double logpdf() {
        update();
        return logPdf(this.value);
    }

    public void setTryToNormalise(boolean z) {
        this.tryToNormalise = z;
    }

    public boolean getTryToNormalise() {
        return this.tryToNormalise;
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    @Override // seed.minerva.Multivariate, seed.minerva.StateFull
    public void updateState() {
        super.updateState();
        boolean isAncestorChanged = isAncestorChanged(Multivariate.ENABLE);
        if (isAncestorChanged("mean") || isAncestorChanged) {
            if (this.mean == null) {
                this.meanMatrix = null;
            } else {
                double[] doubleArray = this.mean.getDoubleArray();
                if (this.enableIndicies != null && doubleArray.length != this.enableIndicies.length) {
                    doubleArray = shuffleArray(doubleArray, this.enableIndicies, null);
                }
                this.meanMatrix = new DenseMatrix(new double[]{doubleArray});
                if (this.value == null || this.value.length != doubleArray.length) {
                    this.value = new double[doubleArray.length];
                }
            }
        }
        if (isAncestorChanged("cov") || (isAncestorChanged && (this.cov != null || this.covdiag != null || this.var != null))) {
            if (this.cov != null) {
                double[][] doubleMatrix = this.cov.getDoubleMatrix();
                double[][] copyArray = (this.enableIndicies == null || doubleMatrix.length == this.enableIndicies.length) ? OneLiners.copyArray(doubleMatrix) : shuffleSquareMatrix(doubleMatrix, this.enableIndicies, null);
                if (!Arrays.deepEquals(copyArray, this.covMatrix == null ? null : this.covMatrix.toArray())) {
                    this.covMatrix = new LowerSymmetricDenseMatrix(copyArray);
                    this.covInvMatrix = null;
                    boolean z = MinervaSettings.getDbgLevel() > 2;
                    this.logAbsCovDeterminant = 1.0d;
                    if (this.tryToNormalise) {
                        if (z) {
                            System.out.print("MVN " + getPath() + ": Evaluating covariance matrix: LU decomposition... ");
                        }
                        this.logAbsCovDeterminant = Math.log(Math.abs(new LUDecomposition(this.covMatrix).det()));
                        if (Double.isInfinite(this.logAbsCovDeterminant)) {
                            this.logAbsCovDeterminant = 1.0d;
                        }
                        if (z) {
                            System.out.println("Done.");
                        }
                    }
                }
            } else if (this.covdiag != null) {
                this.covdiagArray = this.covdiag.getDoubleArray();
                if (this.enableIndicies != null && this.covdiagArray.length != this.enableIndicies.length) {
                    this.covdiagArray = shuffleArray(this.covdiagArray, this.enableIndicies, null);
                }
                this.covinvdiagArray = new double[this.covdiagArray.length];
                this.logAbsCovDeterminant = 0.0d;
                for (int i = 0; i < this.covdiagArray.length; i++) {
                    this.covinvdiagArray[i] = 1.0d / this.covdiagArray[i];
                    this.logAbsCovDeterminant += Math.log(this.covdiagArray[i]);
                }
                this.covMatrix = null;
                this.covInvMatrix = null;
            } else if (this.var != null) {
                if (this.mean == null) {
                    throw new RuntimeException("If single variance is used for matrix, mean must first be given.");
                }
                double d = this.var.getDouble();
                int length = this.meanMatrix.toArray()[0].length;
                this.covdiagArray = new double[length];
                this.covinvdiagArray = new double[length];
                for (int i2 = 0; i2 < length; i2++) {
                    this.covdiagArray[i2] = d;
                    this.covinvdiagArray[i2] = 1.0d / d;
                }
                this.logAbsCovDeterminant = length * Math.log(d);
                this.covMatrix = null;
                this.covInvMatrix = null;
            }
            if (this.cov != null) {
                CholeskyDecomposition choleskyDecomposition = new CholeskyDecomposition(this.covMatrix);
                this.spd = choleskyDecomposition.isSPD();
                this.choleskyLower = choleskyDecomposition.getL();
            } else {
                this.spd = true;
                this.choleskyLower = null;
            }
        }
        if (isAncestorChanged("invcov") || (isAncestorChanged && this.invcov != null)) {
            double[][] doubleMatrix2 = this.invcov.getDoubleMatrix();
            double[][] copyArray2 = (this.enableIndicies == null || doubleMatrix2.length == this.enableIndicies.length) ? OneLiners.copyArray(doubleMatrix2) : shuffleSquareMatrix(doubleMatrix2, this.enableIndicies, null);
            if (Arrays.deepEquals(copyArray2, this.covInvMatrix == null ? null : this.covInvMatrix.toArray())) {
                return;
            }
            this.covInvMatrix = new LowerSymmetricDenseMatrix(copyArray2);
            this.covMatrix = null;
            this.choleskyLower = null;
            this.spd = new CholeskyDecomposition(this.covInvMatrix).isSPD();
            if (!this.spd || !this.tryToNormalise) {
                this.logAbsCovDeterminant = 1.0d;
                return;
            }
            try {
                this.logAbsCovDeterminant = Math.log(Math.abs(1.0d / new LUDecomposition(this.covInvMatrix).det()));
                if (Double.isInfinite(this.logAbsCovDeterminant)) {
                    this.logAbsCovDeterminant = 1.0d;
                }
            } catch (Exception e) {
                this.logAbsCovDeterminant = 1.0d;
            }
        }
    }

    @Override // seed.minerva.ProbabilityNode
    public boolean isNormalised() {
        update();
        return this.spd;
    }

    public boolean isCovSPD() {
        update();
        return this.spd;
    }
}
