package seed.minerva;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import seed.matrix.BlockSparseMatrix;
import seed.matrix.CholeskyDecomposition;
import seed.matrix.DenseMatrix;
import seed.matrix.DiagonalMatrix;
import seed.matrix.LowerSymmetricDenseMatrix;
import seed.matrix.Mat;
import seed.matrix.Matrix;
import seed.matrix.RowwiseSparseMatrix;
import seed.minerva.support.EvecAlignedTGaussGibbs;

/* loaded from: input_file:seed/minerva/LinearGaussianInversion.class */
public class LinearGaussianInversion extends AbstractInversion {
    List<ProbabilityNode> observations;
    List<ProbabilityNode> parameters;
    int dim;
    int ndata;
    boolean dcovIsDiagonal;
    boolean dropNonGaussianParams;
    private String inversionName;
    private boolean enableTiming;
    Matrix C;
    Matrix M;
    Matrix D;
    Matrix MtcovDInv;
    Matrix covDInv;
    Matrix dataCov;
    Matrix covPriorInv;
    Matrix priorMean;
    Matrix posteriorMean;
    Matrix posteriorCov;
    Matrix choleskyLower;
    double[] covDInvDiag;
    double[] dataCovDiag;
    double parameterUnitForResponseFactors;
    private boolean isTruncated;
    private boolean sparseM;
    private boolean usingReducedParameterSet;
    Matrix M_orig;
    Matrix priorMean_orig;
    Matrix covPriorInv_orig;
    int[] keepIndicies;
    private EvecAlignedTGaussGibbs truncatedSampler;
    long _timingStart;

    public LinearGaussianInversion(GraphicalModel graphicalModel) {
        this(graphicalModel, false, null);
    }

    public LinearGaussianInversion(GraphicalModel graphicalModel, boolean z) {
        this(graphicalModel, z, null);
    }

    public LinearGaussianInversion(GraphicalModel graphicalModel, boolean z, String str) {
        super(graphicalModel);
        this.inversionName = "Linear Inversion";
        this.enableTiming = false;
        this.parameterUnitForResponseFactors = 1.0d;
        this.isTruncated = false;
        this.sparseM = false;
        this.usingReducedParameterSet = false;
        this.M_orig = null;
        this.priorMean_orig = null;
        this.covPriorInv_orig = null;
        this.keepIndicies = null;
        this.truncatedSampler = null;
        this.dropNonGaussianParams = z;
        if (str != null) {
            this.inversionName = str;
        }
        tic();
        graphicalModel.broadcastChanges();
        this.observations = graphicalModel.getObservedNodes();
        this.parameters = graphicalModel.getUnobservedNodes();
        checkModel();
        this.dim = dim();
        this.ndata = ndata();
        if (this.ndata <= 0) {
            throw new IllegalArgumentException("No data points to invert.");
        }
        tocOut("constructor");
    }

    public void setEnableTiming(boolean z) {
        this.enableTiming = z;
    }

    @Override // seed.minerva.Inversion
    public void refine() {
        if (this.usingReducedParameterSet) {
            restoreStateReducedParameterSet();
        }
        createBaseParameters();
        createPosteriorCov();
        tic();
        this.posteriorMean = this.priorMean.plus(this.posteriorCov.times(this.MtcovDInv.times(this.D.minus(this.C).minus(this.M.times(this.priorMean)))));
        tocOut("posteriorMean");
        setParameters(this.posteriorMean.getFlatArray());
        this.model.broadcastChanges();
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    private void createBaseParameters() {
        getModel().broadcastChanges();
        tic();
        this.observations = this.model.getObservedNodes();
        this.parameters = this.model.getUnobservedNodes();
        checkModel();
        this.dim = dim();
        this.ndata = ndata();
        tocOut("refine-init");
        createD();
        createM();
        createC();
        createPrior();
        try {
            createCovDInv();
        } catch (Exception e) {
            if (e instanceof MinervaRuntimeException) {
                throw ((MinervaRuntimeException) e);
            }
            e.printStackTrace();
            throw new MinervaRuntimeException("Could not inverse covariance matrix for data.");
        }
    }

    private void createPrior() {
        createPriorMean();
        try {
            createCovPriorInv();
        } catch (Exception e) {
            if (e instanceof MinervaRuntimeException) {
                throw ((MinervaRuntimeException) e);
            }
            e.printStackTrace();
            throw new MinervaRuntimeException("Could not inverse prior covariance matrix.");
        }
    }

    /* JADX WARN: Type inference failed for: r3v2, types: [double[], double[][]] */
    public void refineWithNewPrior(double[] dArr, double[][] dArr2) {
        if (this.usingReducedParameterSet) {
            restoreStateReducedParameterSet();
        }
        if (dArr != null) {
            this.priorMean = new DenseMatrix(new double[]{dArr}).transpose();
        }
        if (dArr2 != null) {
            this.covPriorInv = new LowerSymmetricDenseMatrix(dArr2);
        }
        createPosteriorCov();
        tic();
        this.posteriorMean = this.priorMean.plus(this.posteriorCov.times(this.MtcovDInv.times(this.D.minus(this.C).minus(this.M.times(this.priorMean)))));
        tocOut("posteriorMean");
        setParameters(this.posteriorMean.getFlatArray());
        this.model.broadcastChanges();
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    public void refineFromNewPrior() {
        if (this.usingReducedParameterSet) {
            restoreStateReducedParameterSet();
        }
        createPrior();
        createPosteriorCov();
        tic();
        this.posteriorMean = this.priorMean.plus(this.posteriorCov.times(this.MtcovDInv.times(this.D.minus(this.C).minus(this.M.times(this.priorMean)))));
        tocOut("posteriorMean");
        setParameters(this.posteriorMean.getFlatArray());
        this.model.broadcastChanges();
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v11, types: [double[], double[][]] */
    public void refineUsingReducedParameterSet(ProbabilityNode[] probabilityNodeArr, int[][] iArr) {
        if (this.usingReducedParameterSet) {
            restoreStateReducedParameterSet();
        } else {
            saveStateReducedParameterSet();
        }
        this.usingReducedParameterSet = true;
        DynamicIntArray dynamicIntArray = new DynamicIntArray();
        int i = 0;
        this.keepIndicies = null;
        for (ProbabilityNode probabilityNode : this.parameters) {
            this.keepIndicies = null;
            int i2 = 0;
            while (true) {
                if (i2 >= probabilityNodeArr.length) {
                    break;
                }
                if (probabilityNode == probabilityNodeArr[i2]) {
                    this.keepIndicies = iArr[i2];
                    break;
                }
                i2++;
            }
            if (this.keepIndicies != null) {
                for (int i3 = 0; i3 < probabilityNode.dim(); i3++) {
                    int i4 = 0;
                    while (true) {
                        if (i4 < this.keepIndicies.length) {
                            if (i3 == this.keepIndicies[i4]) {
                                dynamicIntArray.add(i);
                                break;
                            }
                            i4++;
                        }
                    }
                    i++;
                }
            } else {
                for (int i5 = 0; i5 < probabilityNode.dim(); i5++) {
                    dynamicIntArray.add(i);
                    i++;
                }
            }
        }
        this.keepIndicies = dynamicIntArray.getTrimmedArray();
        this.priorMean = new DenseMatrix(new double[]{Mat.extractElems(this.priorMean.getFlatArray(), this.keepIndicies)}).transpose();
        this.covPriorInv = Mat.inv(new DenseMatrix(Mat.extractElems(Mat.inv(this.covPriorInv).toArray(), this.keepIndicies, this.keepIndicies)));
        if (0 != 0) {
            createPrior();
            this.covPriorInv = Mat.inv(new DenseMatrix(Mat.extractElems(Mat.inv(this.covPriorInv).toArray(), this.keepIndicies, this.keepIndicies)));
            this.priorMean = new DenseMatrix(new double[]{Mat.extractElems(this.priorMean.getFlatArray(), this.keepIndicies)}).transpose();
        }
        this.M = new DenseMatrix(Mat.extractElems(this.M.toArray(), Mat.linSpaceInt(0, this.M.getNumRows() - 1, 1), this.keepIndicies));
        createPosteriorCov();
        tic();
        this.posteriorMean = this.priorMean.plus(this.posteriorCov.times(this.MtcovDInv.times(this.D.minus(this.C).minus(this.M.times(this.priorMean)))));
        tocOut("posteriorMean");
        setParameters(getFullParametersFromReducedSet(this.posteriorMean.getFlatArray()));
        this.model.broadcastChanges();
        this.usingReducedParameterSet = true;
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    private double[] getFullParametersFromReducedSet(double[] dArr) {
        double[] dArr2 = new double[this.priorMean_orig.getFlatArray().length];
        for (int i = 0; i < this.keepIndicies.length; i++) {
            dArr2[this.keepIndicies[i]] = dArr[i];
        }
        return dArr2;
    }

    private void saveStateReducedParameterSet() {
        this.M_orig = this.M;
        this.priorMean_orig = this.priorMean;
        this.covPriorInv_orig = this.covPriorInv;
    }

    private void restoreStateReducedParameterSet() {
        this.M = this.M_orig;
        this.priorMean = this.priorMean_orig;
        this.covPriorInv = this.covPriorInv_orig;
        this.usingReducedParameterSet = false;
    }

    public void refineFromNewData() {
        if (this.usingReducedParameterSet) {
            restoreStateReducedParameterSet();
        }
        createC();
        createD();
        createCovDInv();
        createPosteriorCov();
        tic();
        this.posteriorMean = this.priorMean.plus(this.posteriorCov.times(this.MtcovDInv.times(this.D.minus(this.C).minus(this.M.times(this.priorMean)))));
        tocOut("posteriorMean");
        setParameters(this.posteriorMean.getFlatArray());
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    public void refineFromNewDataUseStoredDataCovariance() {
        if (this.usingReducedParameterSet) {
            restoreStateReducedParameterSet();
        }
        createC();
        createD();
        tic();
        this.posteriorMean = this.priorMean.plus(this.posteriorCov.times(this.MtcovDInv.times(this.D.minus(this.C).minus(this.M.times(this.priorMean)))));
        tocOut("posteriorMean");
        this.truncatedSampler = null;
    }

    public double[] getPosteriorMean() {
        if (this.isTruncated) {
            System.err.println("WARNING: LinearGaussianInversion.getPosteriorMean() returning mean of non-trucated distribution for truncated inversion.");
        }
        return this.posteriorMean.getFlatArray();
    }

    public double[][] getPosteriorCov() {
        return this.posteriorCov.toArray();
    }

    public Matrix getPosteriorCov_Matrix() {
        return this.posteriorCov;
    }

    public void setPosteriorMean(double[] dArr) {
        this.posteriorMean = new DenseMatrix(dArr.length, 1, dArr);
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    public void setPosteriorCov(double[][] dArr) {
        this.posteriorCov = new DenseMatrix(dArr);
        this.choleskyLower = null;
        this.truncatedSampler = null;
    }

    public double[][] getM() {
        return this.M.toArray();
    }

    public Matrix getM_Matrix() {
        return this.M;
    }

    public double[] getC() {
        return this.C.transpose().toArray()[0];
    }

    public Matrix getC_Matrix() {
        return this.C;
    }

    public double[] getPriorMean() {
        return this.priorMean.transpose().toArray()[0];
    }

    public Matrix getPriorMean_Matrix() {
        return this.priorMean;
    }

    public double[][] getPriorCovInverse() {
        return this.covPriorInv.toArray();
    }

    public Matrix getPriorCovInverse_Matrix() {
        return this.covPriorInv;
    }

    public double[][] getDataCov() {
        return this.dataCov.toArray();
    }

    public Matrix getDataCov_Matrix() {
        return this.dataCov;
    }

    public double[] getD() {
        return this.D.transpose().toArray()[0];
    }

    public Matrix getD_Matrix() {
        return this.D;
    }

    public int dim() {
        int i = 0;
        Iterator<ProbabilityNode> it = this.parameters.iterator();
        while (it.hasNext()) {
            i += it.next().dim();
        }
        return i;
    }

    public int ndata() {
        int i = 0;
        Iterator<ProbabilityNode> it = this.observations.iterator();
        while (it.hasNext()) {
            i += it.next().dim();
        }
        return i;
    }

    public void checkModel() {
        ArrayList arrayList = new ArrayList();
        this.isTruncated = false;
        for (ProbabilityNode probabilityNode : this.parameters) {
            if ((probabilityNode instanceof MultivariateNormal) || (probabilityNode instanceof Normal)) {
                arrayList.add(probabilityNode);
            } else {
                String str = "Model " + getModel().getName() + " not applicable for LinearGaussianInversion: parameter '" + probabilityNode.getPath() + "' is non-normal: " + probabilityNode.getClass().getCanonicalName() + ", removing parameter";
                if (!this.dropNonGaussianParams) {
                    throw new MinervaRuntimeException(str);
                }
                System.err.println("WARNING: " + str);
            }
            if (((probabilityNode instanceof TruncatedMultivariateNormal) && !((TruncatedMultivariateNormal) probabilityNode).isNormalLimit()) || ((probabilityNode instanceof TruncatedNormal) && !((TruncatedNormal) probabilityNode).isNormalLimit())) {
                this.isTruncated = true;
            }
            if (getModel().getActiveProbabilityNodeAncestors(probabilityNode).size() > 0) {
                throw new MinervaRuntimeException("Model " + getModel().getName() + " not applicable for LinearGaussianInversion: hierarchical prior starting at node " + probabilityNode.getName());
            }
        }
        this.parameters = arrayList;
        for (ProbabilityNode probabilityNode2 : this.observations) {
            if (!(probabilityNode2 instanceof MultivariateNormal) && !(probabilityNode2 instanceof Normal)) {
                throw new MinervaRuntimeException("Model " + getModel().getName() + " not applicable for LinearGaussianInversion: contains non-normal observation:" + probabilityNode2.getClass().getCanonicalName());
            }
        }
    }

    private void createPosteriorCov() {
        long j = 0;
        if (this.enableTiming) {
            System.out.print(String.valueOf(this.inversionName) + ": Creating inverse posterior covariance ( M' . invCov_D . M  + invCov_P )... ");
            j = System.nanoTime();
        }
        if (this.dcovIsDiagonal) {
            this.MtcovDInv = Mat.mul(this.M, true, new DiagonalMatrix(this.covDInvDiag), false);
        } else {
            this.MtcovDInv = Mat.mul(this.M, true, this.covDInv, false);
        }
        Matrix plus = this.MtcovDInv.times(this.M).toDenseMatrix().plus(this.covPriorInv);
        if (this.enableTiming) {
            System.out.println("done [" + ((System.nanoTime() - j) / 1000) + " us].");
            System.out.print(String.valueOf(this.inversionName) + ": Inverting to posterior covariance... ");
            j = System.nanoTime();
        }
        if (this.enableTiming) {
            System.out.println(String.valueOf(this.inversionName) + ": Time createPosteriorCov.internalMatrixAlgebra: [ms]" + (System.currentTimeMillis() - j));
        }
        double[][] array = Mat.inv(plus).toArray();
        if (this.enableTiming) {
            System.out.println("done [" + ((System.nanoTime() - j) / 1000) + " us].");
            System.out.print(String.valueOf(this.inversionName) + ": Symmetrising... ");
            j = System.nanoTime();
        }
        for (int i = 0; i < array.length; i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                double d = (array[i][i2] + array[i2][i]) / 2.0d;
                array[i][i2] = d;
                array[i2][i] = d;
            }
        }
        if (this.enableTiming) {
            System.out.println("done [" + ((System.nanoTime() - j) / 1000) + " us].");
            System.out.print(String.valueOf(this.inversionName) + ": Creating matrix object... ");
            j = System.nanoTime();
        }
        this.posteriorCov = new LowerSymmetricDenseMatrix(array);
        if (this.enableTiming) {
            System.out.println("done [" + ((System.nanoTime() - j) / 1000) + " us].");
            System.nanoTime();
        }
    }

    private void createPosteriorCholeskyDecomposition() {
        CholeskyDecomposition choleskyDecomposition = new CholeskyDecomposition(this.posteriorCov);
        if (!choleskyDecomposition.isSPD()) {
            throw new RuntimeException("Posterior covariance is not SPD.");
        }
        this.choleskyLower = choleskyDecomposition.getL();
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    private void createD() {
        tic();
        double[] dArr = new double[this.ndata];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.observations) {
            if (probabilityNode.isUnivariate()) {
                dArr[i] = ((Univariate) probabilityNode).getDouble();
                i++;
            } else {
                for (double d : ((Multivariate) probabilityNode).getDoubleArray()) {
                    dArr[i] = d;
                    i++;
                }
            }
        }
        this.D = new DenseMatrix(new double[]{dArr}).transpose();
        tocOut("createD");
    }

    private void createCovDInv() {
        tic();
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.print(String.valueOf(this.inversionName) + ": Inverting data matrix (" + this.ndata + "x" + this.ndata + ") ... ");
        }
        this.dcovIsDiagonal = true;
        for (ProbabilityNode probabilityNode : this.observations) {
            this.dcovIsDiagonal &= probabilityNode.isUnivariate() || ((MultivariateNormal) probabilityNode).isDiagonal();
        }
        if (this.dcovIsDiagonal) {
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.print("(Diagonal) ");
            }
            double[] dArr = new double[this.ndata];
            int i = 0;
            for (ProbabilityNode probabilityNode2 : this.observations) {
                if (probabilityNode2.isUnivariate()) {
                    dArr[i] = Math.pow(((Normal) probabilityNode2).sigma1D(), 2.0d);
                    i++;
                } else {
                    double[] covDiag = ((MultivariateNormal) probabilityNode2).getCovDiag();
                    if (covDiag == null) {
                        throw new MinervaRuntimeException("No covariance matrix defined for node " + probabilityNode2.getName());
                    }
                    for (int i2 = 0; i2 < covDiag.length; i2++) {
                        dArr[i + i2] = covDiag[i2];
                    }
                    i += covDiag.length;
                }
            }
            this.covDInvDiag = new double[this.ndata];
            for (int i3 = 0; i3 < this.ndata; i3++) {
                this.covDInvDiag[i3] = 1.0d / dArr[i3];
            }
            this.covDInv = null;
            this.dataCov = new DiagonalMatrix(dArr);
        } else {
            double[][] dArr2 = new double[this.ndata][this.ndata];
            int i4 = 0;
            for (ProbabilityNode probabilityNode3 : this.observations) {
                if (probabilityNode3.isUnivariate()) {
                    dArr2[i4][i4] = Math.pow(((Normal) probabilityNode3).sigma1D(), 2.0d);
                    i4++;
                } else if (((MultivariateNormal) probabilityNode3).isDiagonal()) {
                    double[] covDiag2 = ((MultivariateNormal) probabilityNode3).getCovDiag();
                    if (covDiag2 == null) {
                        throw new MinervaRuntimeException("No covariance matrix defined for node " + probabilityNode3.getName());
                    }
                    for (int i5 = 0; i5 < covDiag2.length; i5++) {
                        dArr2[i4 + i5][i4 + i5] = covDiag2[i5];
                    }
                    i4 += covDiag2.length;
                } else {
                    double[][] cov = ((MultivariateNormal) probabilityNode3).getCov();
                    if (cov == null) {
                        throw new MinervaRuntimeException("No covariance matrix defined for node " + probabilityNode3.getName());
                    }
                    for (int i6 = 0; i6 < cov.length; i6++) {
                        for (int i7 = 0; i7 < cov.length; i7++) {
                            dArr2[i4 + i6][i4 + i7] = cov[i6][i7];
                        }
                    }
                    i4 += cov.length;
                }
            }
            this.dataCov = new LowerSymmetricDenseMatrix(dArr2);
            this.covDInv = Mat.inv(this.dataCov);
            this.dataCovDiag = null;
            this.covDInvDiag = null;
        }
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.println("Done.");
        }
        tocOut("createCovDInv");
    }

    public double[] predictData() {
        double[] dArr = new double[this.ndata];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.observations) {
            if (probabilityNode.isUnivariate()) {
                dArr[i] = ((Normal) probabilityNode).mean1D();
                i++;
            } else {
                for (double d : ((MultivariateNormal) probabilityNode).mean()) {
                    dArr[i] = d;
                    i++;
                }
            }
        }
        return dArr;
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r3v2, types: [double[], double[][]] */
    private void createC() {
        tic();
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.print(String.valueOf(this.inversionName) + ": Calculating constant responses... ");
        }
        this.C = new DenseMatrix(new double[]{predictData()}).transpose().minus(this.M.times(new DenseMatrix(new double[]{getParameters()}).transpose()));
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.println(" Done.");
        }
        tocOut("createC");
    }

    private void createCovPriorInv() {
        tic();
        boolean z = true;
        for (ProbabilityNode probabilityNode : this.parameters) {
            z &= probabilityNode.isUnivariate() || ((MultivariateNormal) probabilityNode).isDiagonal();
        }
        if (z) {
            double[] dArr = new double[this.dim];
            int i = 0;
            for (ProbabilityNode probabilityNode2 : this.parameters) {
                if (probabilityNode2.isUnivariate()) {
                    dArr[i] = Math.pow(((Normal) probabilityNode2).sigma1D(), -2.0d);
                    i++;
                } else {
                    double[] invCovDiag = ((MultivariateNormal) probabilityNode2).getInvCovDiag();
                    for (int i2 = 0; i2 < invCovDiag.length; i2++) {
                        dArr[i + i2] = invCovDiag[i2];
                    }
                    i += invCovDiag.length;
                }
            }
            this.covPriorInv = new DiagonalMatrix(dArr);
        } else {
            BlockSparseMatrix blockSparseMatrix = new BlockSparseMatrix(this.dim, this.dim);
            int i3 = 0;
            for (ProbabilityNode probabilityNode3 : this.parameters) {
                if (probabilityNode3.isUnivariate()) {
                    blockSparseMatrix.insertBlock(i3, i3, new DenseMatrix(1, 1, new double[]{Math.pow(((Normal) probabilityNode3).sigma1D(), -2.0d)}));
                    i3++;
                } else if (((MultivariateNormal) probabilityNode3).isDiagonal()) {
                    double[] invCovDiag2 = ((MultivariateNormal) probabilityNode3).getInvCovDiag();
                    blockSparseMatrix.insertBlock(i3, i3, new DiagonalMatrix(invCovDiag2));
                    i3 += invCovDiag2.length;
                } else {
                    double[][] invCov = ((MultivariateNormal) probabilityNode3).getInvCov();
                    blockSparseMatrix.insertBlock(i3, i3, new LowerSymmetricDenseMatrix(invCov));
                    i3 += invCov.length;
                }
            }
            this.covPriorInv = blockSparseMatrix;
        }
        tocOut("createCovPriorInv");
    }

    public Matrix createCovPrior() {
        tic();
        double[][] dArr = new double[this.dim][this.dim];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.parameters) {
            if (probabilityNode.isUnivariate()) {
                dArr[i][i] = Math.pow(((Normal) probabilityNode).sigma1D(), 2.0d);
                i++;
            } else {
                double[][] cov = ((MultivariateNormal) probabilityNode).getCov();
                for (int i2 = 0; i2 < cov.length; i2++) {
                    for (int i3 = 0; i3 < cov.length; i3++) {
                        dArr[i + i2][i + i3] = cov[i2][i3];
                    }
                }
                i += cov.length;
            }
        }
        LowerSymmetricDenseMatrix lowerSymmetricDenseMatrix = new LowerSymmetricDenseMatrix(dArr);
        tocOut("createCovPrior");
        return lowerSymmetricDenseMatrix;
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    private void createPriorMean() {
        tic();
        double[] dArr = new double[this.dim];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.parameters) {
            if (probabilityNode.isUnivariate()) {
                dArr[i] = ((Normal) probabilityNode).mean1D();
                i++;
            } else {
                for (double d : ((MultivariateNormal) probabilityNode).mean()) {
                    dArr[i] = d;
                    i++;
                }
            }
        }
        this.priorMean = new DenseMatrix(new double[]{dArr}).transpose();
        tocOut("createPriorMean");
    }

    public void createM() {
        if (this.sparseM) {
            createMSparse();
        } else {
            createMDense();
        }
    }

    public void createMDense() {
        tic();
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.print(String.valueOf(this.inversionName) + ": Calculating linear response matrix:");
        }
        double[][] dArr = new double[this.ndata][this.dim];
        getModel().broadcastChanges();
        double[] predictData = predictData();
        double[] parameters = getParameters();
        int dbgLevel = MinervaSettings.getDbgLevel();
        long currentTimeMillis = System.currentTimeMillis();
        int i = -1;
        for (int i2 = 0; i2 < this.dim; i2++) {
            double d = parameters[i2];
            int i3 = i2;
            parameters[i3] = parameters[i3] + this.parameterUnitForResponseFactors;
            setParameters(parameters);
            getModel().broadcastChanges();
            double[] predictData2 = predictData();
            for (int i4 = 0; i4 < this.ndata; i4++) {
                dArr[i4][i2] = (predictData2[i4] - predictData[i4]) / this.parameterUnitForResponseFactors;
            }
            parameters[i2] = d;
            if (dbgLevel > 0 && i2 % ((int) ((this.dim / 50.0d) + 1.0d)) == 0.0d) {
                System.out.print(".");
            }
            if (System.currentTimeMillis() - currentTimeMillis > 3000) {
                System.out.println("[" + i2 + "/" + this.dim + " (" + ((i2 * 100) / this.dim) + "%), " + ((System.currentTimeMillis() - currentTimeMillis) / (i2 - i)) + " ms/row]");
                i = i2;
                currentTimeMillis = System.currentTimeMillis();
            }
        }
        this.M = new DenseMatrix(dArr);
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.println("Done.");
        }
        tocOut("createM");
    }

    private void createMSparse() {
        tic();
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.print(String.valueOf(this.inversionName) + ": Calculating linear response matrix:");
        }
        RowwiseSparseMatrix rowwiseSparseMatrix = new RowwiseSparseMatrix(this.ndata, this.dim);
        getModel().broadcastChanges();
        double[] predictData = predictData();
        double[] parameters = getParameters();
        int dbgLevel = MinervaSettings.getDbgLevel();
        long currentTimeMillis = System.currentTimeMillis();
        int i = -1;
        for (int i2 = 0; i2 < this.dim; i2++) {
            double[] dArr = new double[this.ndata];
            double d = parameters[i2];
            int i3 = i2;
            parameters[i3] = parameters[i3] + this.parameterUnitForResponseFactors;
            setParameters(parameters);
            getModel().broadcastChanges();
            double[] predictData2 = predictData();
            for (int i4 = 0; i4 < this.ndata; i4++) {
                dArr[i4] = (predictData2[i4] - predictData[i4]) / this.parameterUnitForResponseFactors;
                if (dArr[i4] != 0.0d) {
                    rowwiseSparseMatrix.insertRowPart(i4, i2, new double[]{dArr[i4]});
                }
            }
            parameters[i2] = d;
            if (dbgLevel > 0 && i2 % ((int) ((this.dim / 50.0d) + 1.0d)) == 0.0d) {
                System.out.print(".");
            }
            if (System.currentTimeMillis() - currentTimeMillis > 5000) {
                System.out.println("[" + i2 + "/" + this.dim + " (" + ((i2 * 100) / this.dim) + "%), " + ((System.currentTimeMillis() - currentTimeMillis) / (i2 - i)) + " ms/row]");
                i = i2;
                currentTimeMillis = System.currentTimeMillis();
            }
        }
        this.M = rowwiseSparseMatrix;
        if (MinervaSettings.getDbgLevel() > 0) {
            System.out.println("Done.");
        }
        tocOut("createM");
    }

    private void zeroParameters() {
        for (ProbabilityNode probabilityNode : this.parameters) {
            if (probabilityNode.isUnivariate()) {
                ((Univariate) probabilityNode).setDouble(0.0d);
            } else {
                ((Multivariate) probabilityNode).setDoubleArray(new double[probabilityNode.dim()]);
            }
        }
    }

    private void setParameters(double[] dArr) {
        int i = 0;
        for (ProbabilityNode probabilityNode : this.parameters) {
            if (probabilityNode.isUnivariate()) {
                ((Univariate) probabilityNode).setDouble(dArr[i]);
                i++;
            } else {
                int dim = probabilityNode.dim();
                double[] dArr2 = new double[dim];
                for (int i2 = 0; i2 < dim; i2++) {
                    dArr2[i2] = dArr[i];
                    i++;
                }
                ((Multivariate) probabilityNode).setDoubleArray(dArr2);
            }
        }
    }

    public double[] getParameters() {
        double[] dArr = new double[this.dim];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.parameters) {
            if (probabilityNode.isUnivariate()) {
                dArr[i] = ((Univariate) probabilityNode).getDouble();
                i++;
            } else {
                int dim = probabilityNode.dim();
                double[] doubleArray = ((Multivariate) probabilityNode).getDoubleArray();
                for (int i2 = 0; i2 < dim; i2++) {
                    dArr[i] = doubleArray[i2];
                    i++;
                }
            }
        }
        return dArr;
    }

    public double[] sample() {
        if (this.choleskyLower == null) {
            createPosteriorCholeskyDecomposition();
        }
        int length = this.posteriorMean.getFlatArray().length;
        double[][] array = this.choleskyLower.toArray();
        double[] dArr = new double[length];
        double[] dArr2 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr2[i] = RandomManager.instance().nextNormal(0.0d, 1.0d);
        }
        for (int i2 = 0; i2 < length; i2++) {
            dArr[i2] = this.posteriorMean.get(i2, 0);
            for (int i3 = 0; i3 <= i2; i3++) {
                int i4 = i2;
                dArr[i4] = dArr[i4] + (array[i2][i3] * dArr2[i3]);
            }
        }
        return this.usingReducedParameterSet ? getFullParametersFromReducedSet(dArr) : dArr;
    }

    private double[][] getLimits() {
        double[][] dArr = new double[2][this.dim];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.parameters) {
            if (probabilityNode.isUnivariate()) {
                if (probabilityNode instanceof TruncatedDistribution) {
                    double[][] hardLimits = ((TruncatedDistribution) probabilityNode).getHardLimits();
                    dArr[0][i] = hardLimits[0][0];
                    dArr[1][i] = hardLimits[0][1];
                    if (Double.isNaN(dArr[0][i]) || Double.isNaN(dArr[1][i]) || dArr[0][i] >= dArr[1][i]) {
                        throw new RuntimeException("Invalid (graph space) hard limits for parameter " + i + "(" + probabilityNode.getPath() + "): " + dArr[0][i] + " < x_" + i + " < " + dArr[1][i]);
                    }
                } else {
                    dArr[0][i] = Double.NEGATIVE_INFINITY;
                    dArr[1][i] = Double.POSITIVE_INFINITY;
                }
                i++;
            } else {
                Multivariate multivariate = (Multivariate) probabilityNode;
                if (probabilityNode instanceof TruncatedDistribution) {
                    double[][] hardLimits2 = ((TruncatedDistribution) probabilityNode).getHardLimits();
                    for (int i2 = 0; i2 < hardLimits2.length; i2++) {
                        int i3 = i + i2;
                        dArr[0][i3] = hardLimits2[i2][0];
                        dArr[1][i3] = hardLimits2[i2][1];
                        if (Double.isNaN(dArr[0][i3]) || Double.isNaN(dArr[1][i3]) || dArr[0][i3] >= dArr[1][i3]) {
                            throw new RuntimeException("Invalid (graph space) hard limits for parameter " + i3 + "(" + probabilityNode.getPath() + "_" + i3 + "): " + dArr[0][i3] + " < x_" + i3 + " < " + dArr[1][i3]);
                        }
                    }
                } else {
                    for (int i4 = 0; i4 < multivariate.dim(); i4++) {
                        dArr[0][i + i4] = Double.NEGATIVE_INFINITY;
                        dArr[1][i + i4] = Double.POSITIVE_INFINITY;
                    }
                }
                i += multivariate.dim();
            }
        }
        return dArr;
    }

    public void sampleTruncatedInit(int i) {
        double[][] limits = getLimits();
        this.truncatedSampler = new EvecAlignedTGaussGibbs(this.posteriorCov, this.posteriorMean.getFlatArray(), limits[0], limits[1]);
        double[] dArr = (double[]) getPosteriorMean().clone();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] < limits[0][i2]) {
                dArr[i2] = limits[0][i2];
            }
            if (dArr[i2] > limits[1][i2]) {
                dArr[i2] = limits[1][i2];
            }
        }
        this.truncatedSampler.setPosition(dArr);
        System.out.println(String.valueOf(this.inversionName) + ": TruncatedLGI burn-in: ");
        for (int i3 = 0; i3 < i; i3++) {
            this.truncatedSampler.nextPass();
            if (i3 % (i / 10) == 0) {
                System.out.print(String.valueOf(Math.round((100 * i3) / i)) + "% ");
            }
        }
        System.out.println();
    }

    public void sampleTruncatedInit(int i, double[] dArr) {
        double[][] limits = getLimits();
        this.truncatedSampler = new EvecAlignedTGaussGibbs(this.posteriorCov, this.posteriorMean.getFlatArray(), limits[0], limits[1]);
        this.truncatedSampler.setPosition(dArr);
        System.out.println(String.valueOf(this.inversionName) + ": TruncatedLGI burn-in: ");
        for (int i2 = 0; i2 < i; i2++) {
            this.truncatedSampler.nextPass();
            if (i2 % (i / 10) == 0) {
                System.out.print(String.valueOf(Math.round((100 * i2) / i)) + "% ");
            }
        }
        System.out.println();
    }

    public double[] sampleTruncated() {
        if (this.truncatedSampler == null) {
            throw new RuntimeException("Must call LinearGaussianInversion.sampleTruncatedInit() before sampleTruncated()");
        }
        for (int i = 0; i < 10; i++) {
            this.truncatedSampler.nextPass();
        }
        return this.truncatedSampler.getCurrentPosition();
    }

    public double[] sampleAndSet() {
        double[] sample = sample();
        setParameters(sample);
        this.model.broadcastChanges();
        return sample;
    }

    public double[] sampleTruncatedAndSet() {
        double[] sampleTruncated = sampleTruncated();
        setParameters(sampleTruncated);
        return sampleTruncated;
    }

    public double logEvidence() {
        if (!this.usingReducedParameterSet) {
            createBaseParameters();
        }
        Matrix inv = Mat.inv(this.covPriorInv);
        if (!new CholeskyDecomposition(inv).isSPD()) {
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.print(String.valueOf(this.inversionName) + ": logEvidence: prior covariance is not spd, trying to find nearest spd. ");
            }
            inv = Mat.makeSPD(inv, 1.0E-12d);
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.println("Done.");
            }
        }
        tic();
        Matrix plus = this.M.times(this.priorMean).plus(this.C);
        Matrix plus2 = this.dataCov.plus(this.M.times(inv).times(this.M.transpose()));
        if (!new CholeskyDecomposition(plus2).isSPD()) {
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.print(String.valueOf(this.inversionName) + ": logEvidence: data covariance is not spd, trying to find nearest spd. ");
            }
            plus2 = Mat.makeSPD(plus2, 1.0E-12d);
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.println("Done.");
            }
        }
        double log = ((((-this.dataCov.getNumRows()) / 2.0d) * Mat.log(6.283185307179586d)) - (0.5d * Mat.logAbsDet(plus2))) + ((-0.5d) * this.D.minus(plus).transpose().times(Mat.inv(plus2)).times(this.D.minus(plus)).get(0, 0));
        tocOut("Calculated evidence");
        return log;
    }

    public double logEvidence(Matrix matrix, Matrix matrix2, Matrix matrix3) {
        Matrix matrix4 = this.M;
        this.M = matrix;
        createD();
        createCovDInv();
        createC();
        Matrix plus = matrix.times(matrix2).plus(this.C);
        Matrix plus2 = this.dataCov.plus(matrix.times(matrix3).times(matrix.transpose()));
        double log = ((((-this.dataCov.getNumRows()) / 2.0d) * Mat.log(6.283185307179586d)) - (0.5d * Mat.log(Mat.det(plus2)))) + ((-0.5d) * this.D.minus(plus).transpose().times(Mat.inv(plus2)).times(this.D.minus(plus)).get(0, 0));
        this.M = matrix4;
        return log;
    }

    public double klDivergence() {
        Matrix inv = Mat.inv(this.covPriorInv);
        if (!new CholeskyDecomposition(inv).isSPD()) {
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.print(String.valueOf(this.inversionName) + ": klDivergence: prior covariance is not spd, trying to find nearest spd. ");
            }
            inv = Mat.makeSPD(inv, 1.0E-12d);
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.println("Done.");
            }
        }
        tic();
        Matrix matrix = this.covPriorInv;
        if (!new CholeskyDecomposition(matrix).isSPD()) {
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.print(String.valueOf(this.inversionName) + ": klDivergence: prior covariance is not spd, trying to find nearest spd. ");
            }
            matrix = Mat.makeSPD(matrix, 1.0E-12d);
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.println(" Done.");
            }
        }
        tic();
        createPosteriorCov();
        Matrix matrix2 = this.posteriorCov;
        if (!new CholeskyDecomposition(matrix2).isSPD()) {
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.print(String.valueOf(this.inversionName) + ": klDivergence: posterior covariance is not spd, trying to find nearest spd. ");
            }
            Mat.makeSPD(matrix2, 1.0E-12d);
            if (MinervaSettings.getDbgLevel() > 0) {
                System.out.println(" Done.");
            }
        }
        Matrix minus = this.priorMean.minus(this.posteriorMean);
        double logAbsDet = (Mat.logAbsDet(inv) + (-Mat.logAbsDet(this.posteriorCov)) + Mat.trace(matrix.times(this.posteriorCov)) + minus.transpose().times(matrix).times(minus).getFlatArray()[0] + (-this.dim)) * (0.5d / Mat.log(2.0d));
        tocOut("Calculated KL divergence.");
        return logAbsDet;
    }

    public double logEvidence(Matrix matrix) {
        Matrix plus = this.M.times(this.priorMean).plus(this.C);
        Matrix plus2 = this.dataCov.plus(this.M.times(matrix).times(this.M.transpose()));
        return ((((-this.dataCov.getNumRows()) / 2.0d) * Mat.log(6.283185307179586d)) - (0.5d * Mat.log(Mat.det(plus2)))) + ((-0.5d) * this.D.minus(plus).transpose().times(Mat.inv(plus2)).times(this.D.minus(plus)).get(0, 0));
    }

    public double[] getPosteriorMeanFromSamples(int i, int i2, double[] dArr) {
        if (dArr == null) {
            sampleTruncatedInit(i);
        } else {
            sampleTruncatedInit(i, dArr);
        }
        double[] dArr2 = new double[dim()];
        for (int i3 = 0; i3 < i2; i3++) {
            Mat.addEquals(dArr2, sampleTruncated());
        }
        return Mat.divideEquals(dArr2, i2);
    }

    public double getParameterUnitForResponseFactors() {
        return this.parameterUnitForResponseFactors;
    }

    public void setParameterUnitForResponseFactors(double d) {
        this.parameterUnitForResponseFactors = d;
    }

    private final void tic() {
        this._timingStart = System.currentTimeMillis();
    }

    private final long toc() {
        return System.currentTimeMillis() - this._timingStart;
    }

    private final void tocOut(String str) {
        if (this.enableTiming) {
            System.out.println(String.valueOf(this.inversionName) + ": Time " + str + " [ms]: " + toc());
        }
    }

    public void dumpInternals(String str, String str2) {
        Mat.dump(this.C, String.valueOf(str) + "/C" + str2 + ".txt");
        Mat.dump(this.M, String.valueOf(str) + "/M" + str2 + ".txt");
        Mat.dump(this.D, String.valueOf(str) + "/D" + str2 + ".txt");
        Mat.dump(this.MtcovDInv, String.valueOf(str) + "/MtICovD" + str2 + ".txt");
        if (this.covDInv != null) {
            Mat.dump(this.covDInv, String.valueOf(str) + "/covDInv" + str2 + ".txt");
        }
        if (this.dataCov != null) {
            Mat.dump(this.dataCov, String.valueOf(str) + "/dataCov" + str2 + ".txt");
        }
        if (this.covDInvDiag != null) {
            Mat.dump(this.covDInvDiag, String.valueOf(str) + "/covDInvDiag" + str2 + ".txt");
        }
        if (this.dataCovDiag != null) {
            Mat.dump(this.dataCovDiag, String.valueOf(str) + "/dataCovDiag" + str2 + ".txt");
        }
        Mat.dump(this.covPriorInv, String.valueOf(str) + "/covPriorInv" + str2 + ".txt");
        Mat.dump(this.priorMean, String.valueOf(str) + "/priorMean" + str2 + ".txt");
        Mat.dump(this.posteriorMean, String.valueOf(str) + "/posteriorMean" + str2 + ".txt");
        Mat.dump(this.posteriorCov, String.valueOf(str) + "/posteriorCov" + str2 + ".txt");
        if (this.choleskyLower != null) {
            Mat.dump(this.choleskyLower, String.valueOf(str) + "/choleskyLower" + str2 + ".txt");
        }
    }

    public void setSparseM(boolean z) {
        this.sparseM = z;
    }

    public boolean isSparseM() {
        return this.sparseM;
    }
}
