package seed.minerva;

import oneLiners.OneLiners;
import otherSupport.TruncatedUnivarGauss;
import seed.minerva.nodetypes.DoubleArray;

/* loaded from: input_file:seed/minerva/TruncatedMultivariateNormal.class */
public class TruncatedMultivariateNormal extends MultivariateNormal implements TruncatedDistribution {
    static final double SQRT_2PI = Math.sqrt(6.283185307179586d);
    static final double SQRT_2 = Math.sqrt(2.0d);
    static final double SQRT12 = Math.sqrt(12.0d);
    public static final String NONDIAG_ERROR_MESSAGE = "Not implemented: Cannot calculate mean/sigma for TruncatedMultivariateNormal with non-diagonal covariance. Calling forceMeanAndSigmasAsTruncatedMarginals() will cause mean()/sigma() to return the mean of the individually truncated marginals (of the non-truncated Gaussian) which is correct in certain circumstances e.g with a 0 centered Gaussian with prior smoothing-styleoff diagonal elements and a P>0 limit.";
    protected DoubleArray low;
    protected DoubleArray high;
    private boolean meansAndSigmasAsTruncatedMarginals;
    public double maxSampleAttempts;
    private double[] m;
    private double[] s;
    private double[] a;
    private double[] b;
    private boolean isDiagonal;
    private double[] fA;
    private double[] fB;
    private double[] cumfA;
    private double[] cumfB;

    public TruncatedMultivariateNormal() {
        this(null, "", null, null, null, 2);
    }

    public TruncatedMultivariateNormal(String str) {
        this(null, str, null, null, null, 2);
    }

    public TruncatedMultivariateNormal(Graph graph, String str) {
        this(graph, str, null, null, null, 2);
    }

    public TruncatedMultivariateNormal(Graph graph, String str, Node node, Node node2, double[] dArr, int i) {
        this(graph, str, node, node2, null, null, dArr, 2);
    }

    public TruncatedMultivariateNormal(Graph graph, String str, Node node, Node node2, Node node3, Node node4, double[] dArr, int i) {
        super(graph, str, node, node2, dArr, i);
        this.meansAndSigmasAsTruncatedMarginals = false;
        this.maxSampleAttempts = 100000.0d;
        addConnectionPoint(new ConnectionPoint("low", DoubleArray.class, true, getField("low")));
        addConnectionPoint(new ConnectionPoint("high", DoubleArray.class, true, getField("high")));
        if (node3 != null) {
            setConnection("low", node3);
        }
        if (node4 != null) {
            setConnection("high", node4);
        }
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.Multivariate
    public double logPdf(double[] dArr) {
        double d;
        double pow;
        update();
        for (int i = 0; i < dArr.length; i++) {
            if (this.a != null && dArr[i] < this.a[i]) {
                return Double.NEGATIVE_INFINITY;
            }
            if (this.b != null && dArr[i] > this.b[i]) {
                return Double.NEGATIVE_INFINITY;
            }
        }
        if (!this.isDiagonal) {
            return super.logPdf(dArr);
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (Double.isInfinite(this.s[i2])) {
                d = d2;
                pow = -Math.log(this.b[i2] - this.a[i2]);
            } else {
                double log = (-0.5d) * Math.log(6.283185307179586d);
                double d3 = -Math.log(this.s[i2]);
                d = d2;
                pow = log + d3 + ((-0.5d) * Math.pow((dArr[i2] - this.m[i2]) / this.s[i2], 2.0d)) + (-Math.log(this.cumfB[i2] - this.cumfA[i2]));
            }
            d2 = d + pow;
        }
        return d2;
    }

    @Override // seed.minerva.TruncatedDistribution
    public double[][] getHardLimits() {
        update();
        double[][] dArr = new double[this.m.length][2];
        for (int i = 0; i < this.m.length; i++) {
            dArr[i][0] = this.a == null ? Double.NEGATIVE_INFINITY : this.a[i];
            dArr[i][1] = this.b == null ? Double.POSITIVE_INFINITY : this.b[i];
        }
        return dArr;
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.Multivariate
    public double[] sample() {
        update();
        if (!internalGetIsSpd()) {
            throw new RuntimeException("Covariance matrix of multivariate normal not (symmetric and positive definite)");
        }
        int dim = dim();
        double[] dArr = new double[dim];
        double[] dArr2 = new double[dim];
        if (isDiagonal()) {
            TruncatedUnivarGauss truncatedUnivarGauss = new TruncatedUnivarGauss();
            for (int i = 0; i < dim; i++) {
                dArr[i] = truncatedUnivarGauss.nextTruncatedGaussian(this.m[i], this.s[i], this.a == null ? Double.NEGATIVE_INFINITY : this.a[i], this.b == null ? Double.POSITIVE_INFINITY : this.b[i]);
            }
            return dArr;
        }
        for (int i2 = 0; i2 < this.maxSampleAttempts; i2++) {
            for (int i3 = 0; i3 < dim; i3++) {
                dArr2[i3] = RandomManager.instance().nextNormal(0.0d, 1.0d);
            }
            boolean z = true;
            if (!internalIsFromDiagonalCovs()) {
                double[][] array = internalGetCholeskyDecomposition().getL().toArray();
                for (int i4 = 0; i4 < dim; i4++) {
                    dArr[i4] = this.meanStore[i4];
                    for (int i5 = 0; i5 <= i4; i5++) {
                        int i6 = i4;
                        dArr[i6] = dArr[i6] + (array[i4][i5] * dArr2[i5]);
                    }
                    if ((this.a != null && dArr[i4] < this.a[i4]) || (this.b != null && dArr[i4] > this.b[i4])) {
                        z = false;
                        break;
                    }
                }
            } else {
                double[] internalGetCovDiag = internalGetCovDiag();
                for (int i7 = 0; i7 < dim; i7++) {
                    dArr[i7] = this.meanStore[i7];
                    int i8 = i7;
                    dArr[i8] = dArr[i8] + (Math.sqrt(internalGetCovDiag[i7]) * dArr2[i7]);
                    if ((this.a != null && dArr[i7] < this.a[i7]) || (this.b != null && dArr[i7] > this.b[i7])) {
                        z = false;
                        break;
                    }
                }
            }
            if (z) {
                return dArr;
            }
        }
        throw new ArithmeticException("Too many attempts at sampling TruncatedMultivariateNormal (>" + this.maxSampleAttempts + ")!");
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.ProbabilityNode
    public double[] mean() {
        update();
        if (this.meanStore == null) {
            return null;
        }
        if (!this.isDiagonal && !this.meansAndSigmasAsTruncatedMarginals) {
            throw new RuntimeException(String.valueOf(getPath()) + ": Not implemented: Cannot calculate mean/sigma for TruncatedMultivariateNormal with non-diagonal covariance. Calling forceMeanAndSigmasAsTruncatedMarginals() will cause mean()/sigma() to return the mean of the individually truncated marginals (of the non-truncated Gaussian) which is correct in certain circumstances e.g with a 0 centered Gaussian with prior smoothing-styleoff diagonal elements and a P>0 limit.");
        }
        double[] dArr = new double[this.m.length];
        for (int i = 0; i < this.m.length; i++) {
            if (Double.isInfinite(this.s[i])) {
                dArr[i] = (this.a[i] + this.b[i]) / 2.0d;
            } else {
                dArr[i] = this.m[i] + (((this.fA[i] - this.fB[i]) / (this.cumfB[i] - this.cumfA[i])) * this.s[i]);
            }
        }
        return dArr;
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.ProbabilityNode
    public double[] sigma() {
        update();
        if (!this.isDiagonal && !this.meansAndSigmasAsTruncatedMarginals) {
            throw new RuntimeException("Not implemented: Cannot calculate mean/sigma for TruncatedMultivariateNormal with non-diagonal covariance. Calling forceMeanAndSigmasAsTruncatedMarginals() will cause mean()/sigma() to return the mean of the individually truncated marginals (of the non-truncated Gaussian) which is correct in certain circumstances e.g with a 0 centered Gaussian with prior smoothing-styleoff diagonal elements and a P>0 limit.");
        }
        double[] dArr = new double[this.m.length];
        for (int i = 0; i < this.m.length; i++) {
            if (Double.isInfinite(this.s[i])) {
                dArr[i] = (this.b[i] - this.a[i]) / SQRT12;
            } else {
                double d = 1.0d + ((((this.a == null || this.fA[i] == 0.0d) ? 0.0d : ((this.a[i] - this.m[i]) / this.s[i]) * this.fA[i]) - ((this.b == null || this.fB[i] == 0.0d) ? 0.0d : ((this.b[i] - this.m[i]) / this.s[i]) * this.fB[i])) / (this.cumfB[i] - this.cumfA[i])) + (-Math.pow((this.fA[i] - this.fB[i]) / (this.cumfB[i] - this.cumfA[i]), 2.0d));
                if (d < 0.0d) {
                    System.err.println("Numerical problems calculating sigma for TruncatedMultivariateNormal, assuming effectivly uniform distrib");
                    dArr[i] = (this.b[i] - this.a[i]) / SQRT12;
                } else {
                    dArr[i] = this.s[i] * Math.sqrt(d);
                }
            }
        }
        return dArr;
    }

    @Override // seed.minerva.MultivariateNormal
    public double[] meanRaw() {
        update();
        return this.enable == null ? mean() : unshuffleArray(mean(), this.enableIndicies, this.enable.getBooleanArray().length, this.disableReplacementValue);
    }

    @Override // seed.minerva.MultivariateNormal
    public double[] sigmaRaw() {
        update();
        return this.enable == null ? sigma() : unshuffleArray(sigma(), this.enableIndicies, this.enable.getBooleanArray().length, this.disableReplacementValue);
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.ProbabilityNode
    public double logpdf() {
        update();
        return logPdf(this.value);
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.Multivariate, seed.minerva.StateFull
    public void updateState() {
        super.updateState();
        this.isDiagonal = super.internalIsDiagonal();
        this.m = super.internalGetMean();
        this.a = this.low == null ? null : this.low.getDoubleArray();
        this.b = this.high == null ? null : this.high.getDoubleArray();
        int length = this.m.length;
        if (!this.isDiagonal && !this.meansAndSigmasAsTruncatedMarginals) {
            this.s = null;
            return;
        }
        this.s = super.internalGetSigma();
        if (this.fA == null || this.fA.length != length) {
            this.fA = new double[length];
        }
        if (this.fB == null || this.fB.length != length) {
            this.fB = new double[length];
        }
        if (this.cumfA == null || this.cumfA.length != length) {
            this.cumfA = new double[length];
        }
        if (this.cumfB == null || this.cumfB.length != length) {
            this.cumfB = new double[length];
        }
        for (int i = 0; i < length; i++) {
            if (this.a != null && this.b != null && this.b[i] < this.a[i]) {
                throw new ArithmeticException("Truncated Gaussian '" + getPath() + "' has invalid limits (low[" + i + "] > high[" + i + "]) ");
            }
            this.fA[i] = this.a == null ? 0.0d : Math.exp((-0.5d) * Math.pow((this.a[i] - this.m[i]) / this.s[i], 2.0d)) / SQRT_2PI;
            this.fB[i] = this.b == null ? 0.0d : Math.exp((-0.5d) * Math.pow((this.b[i] - this.m[i]) / this.s[i], 2.0d)) / SQRT_2PI;
            this.cumfA[i] = (this.a == null || this.a[i] == Double.NEGATIVE_INFINITY) ? 0.0d : 0.5d + (0.5d * OneLiners.erf((this.a[i] - this.m[i]) / (this.s[i] * SQRT_2)));
            this.cumfB[i] = (this.b == null || this.b[i] == Double.POSITIVE_INFINITY) ? 1.0d : 0.5d + (0.5d * OneLiners.erf((this.b[i] - this.m[i]) / (this.s[i] * SQRT_2)));
        }
    }

    @Override // seed.minerva.MultivariateNormal
    public boolean isDiagonal() {
        return this.isDiagonal;
    }

    @Override // seed.minerva.MultivariateNormal, seed.minerva.ProbabilityNode
    public boolean isNormalised() {
        return this.isDiagonal;
    }

    public void forceMeansAndSigmasAsTruncatedMarginals(boolean z) {
        this.meansAndSigmasAsTruncatedMarginals = z;
        setChanged("meansAndSigmasAsTruncatedMarginals");
    }

    public boolean isForcingMeansAndSigmasAsTruncatedMarginals() {
        return this.meansAndSigmasAsTruncatedMarginals;
    }

    public boolean isNormalLimit() {
        update();
        return this.a == null && this.b == null;
    }

    public boolean isUniformLimit() {
        update();
        for (int i = 0; i < this.s.length; i++) {
            if (!Double.isInfinite(this.s[i])) {
                return false;
            }
        }
        return true;
    }
}
