package seed.minerva;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import oneLiners.OneLiners;
import seed.digeom.FunctionND;
import seed.digeom.FunctionWithDryEval;
import seed.digeom.RectangularDomain;
import seed.digeom.operators.function.OpFuncLinearTransform;

/* loaded from: input_file:seed/minerva/LogPdfFunction.class */
public class LogPdfFunction extends FunctionND implements FunctionWithDryEval {
    public static final double hardLimitsNumericalPrecisionMarginMax = 1.0E-5d;
    public static final int hardLimitsNumericalPrecisionMarginSteps = 1000;
    protected GraphicalModel model;
    protected List<ProbabilityNode> free;
    protected List<ProbabilityNode> nodes;
    int numFree;
    private String[] pNames;
    private double[][] hardLimits;
    private double[][] typicalRanges;
    protected int numFunctionEvaluations = 0;
    public boolean lhdLimitedPrior = false;
    public double lhdLimit = Double.NaN;

    public LogPdfFunction(GraphicalModel graphicalModel) {
        this.model = graphicalModel;
        ProbabilityNodeVisitor probabilityNodeVisitor = new ProbabilityNodeVisitor();
        graphicalModel.topologicalSort(probabilityNodeVisitor);
        this.nodes = probabilityNodeVisitor.getVisitedNodes();
        this.free = new ArrayList();
        this.free.addAll(graphicalModel.getUnobservedNodes());
        this.numFree = 0;
        Iterator<ProbabilityNode> it = this.free.iterator();
        while (it.hasNext()) {
            this.numFree += it.next().dim();
        }
        collectParameterNames();
        collectHardLimits();
        collectTypicalRanges();
        setDomain(new RectangularDomain(this.hardLimits, this.pNames));
    }

    private void collectParameterNames() {
        this.pNames = new String[this.numFree];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.free) {
            String path = probabilityNode.getPath();
            if (probabilityNode.isUnivariate()) {
                int i2 = i;
                i++;
                this.pNames[i2] = path;
            } else {
                for (int i3 = 0; i3 < probabilityNode.dim(); i3++) {
                    int i4 = i;
                    i++;
                    this.pNames[i4] = String.valueOf(path) + "_" + i3;
                }
            }
        }
    }

    private void collectHardLimits() {
        this.hardLimits = new double[this.numFree][2];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.free) {
            if (probabilityNode.isUnivariate()) {
                if (probabilityNode instanceof TruncatedDistribution) {
                    double[][] hardLimits = ((TruncatedDistribution) probabilityNode).getHardLimits();
                    this.hardLimits[i][0] = hardLimits[0][0];
                    this.hardLimits[i][1] = hardLimits[0][1];
                    if (Double.isNaN(this.hardLimits[i][0]) || Double.isNaN(this.hardLimits[i][1]) || this.hardLimits[i][0] >= this.hardLimits[i][1]) {
                        throw new RuntimeException("Invalid (graph space) hard limits for parameter " + i + "(" + this.pNames[i] + "): " + this.hardLimits[i][0] + " < x_" + i + " < " + this.hardLimits[i][1]);
                    }
                } else {
                    this.hardLimits[i][0] = Double.NEGATIVE_INFINITY;
                    this.hardLimits[i][1] = Double.POSITIVE_INFINITY;
                }
                i++;
            } else {
                Multivariate multivariate = (Multivariate) probabilityNode;
                if (probabilityNode instanceof TruncatedDistribution) {
                    double[][] hardLimits2 = ((TruncatedDistribution) probabilityNode).getHardLimits();
                    for (int i2 = 0; i2 < hardLimits2.length; i2++) {
                        int i3 = i + i2;
                        this.hardLimits[i3][0] = hardLimits2[i2][0];
                        this.hardLimits[i3][1] = hardLimits2[i2][1];
                        if (Double.isNaN(this.hardLimits[i3][0]) || Double.isNaN(this.hardLimits[i3][1]) || this.hardLimits[i3][0] >= this.hardLimits[i3][1]) {
                            throw new RuntimeException("Invalid (graph space) hard limits for parameter " + i3 + "(" + this.pNames[i3] + "): " + this.hardLimits[i3][0] + " < x_" + i3 + " < " + this.hardLimits[i3][1]);
                        }
                    }
                } else {
                    for (int i4 = 0; i4 < multivariate.dim(); i4++) {
                        this.hardLimits[i + i4][0] = Double.NEGATIVE_INFINITY;
                        this.hardLimits[i + i4][1] = Double.POSITIVE_INFINITY;
                    }
                }
                i += multivariate.dim();
            }
        }
    }

    public void collectTypicalRanges() {
        this.typicalRanges = new double[this.numFree][2];
        int i = 0;
        for (ProbabilityNode probabilityNode : this.free) {
            double[][] typicalRange = probabilityNode.getTypicalRange();
            for (int i2 = 0; i2 < typicalRange.length; i2++) {
                if (Double.isNaN(typicalRange[i2][0]) || Double.isNaN(typicalRange[i2][1]) || typicalRange[i2][0] >= typicalRange[i2][1]) {
                    throw new RuntimeException("Invalid typical range for parameter " + i2 + "(" + probabilityNode.getPath() + "): " + typicalRange[i2][0] + " < x_" + i2 + " < " + typicalRange[i2][1]);
                }
                this.typicalRanges[i][0] = typicalRange[i2][0];
                this.typicalRanges[i][1] = typicalRange[i2][1];
                i++;
            }
        }
    }

    public OpFuncLinearTransform getTransformedFunction() {
        double[] dArr = new double[this.numFree];
        double[] dArr2 = new double[this.numFree];
        for (int i = 0; i < this.numFree; i++) {
            double d = (this.typicalRanges[i][1] - this.typicalRanges[i][0]) / 2.0d;
            double d2 = (this.typicalRanges[i][0] + this.typicalRanges[i][1]) / 2.0d;
            dArr[i] = 1.0d / d;
            dArr2[i] = (-d2) / d;
        }
        OpFuncLinearTransform opFuncLinearTransform = new OpFuncLinearTransform(this, dArr, dArr2);
        opFuncLinearTransform.setDomain(new RectangularDomain(getTransformedHardLimits(opFuncLinearTransform), this.pNames));
        return opFuncLinearTransform;
    }

    private double[][] getTransformedHardLimits(OpFuncLinearTransform opFuncLinearTransform) {
        double[][] transpose = OneLiners.transpose(this.hardLimits);
        double[][] dArr = new double[this.numFree][2];
        double[] transform = opFuncLinearTransform.transform(transpose[0]);
        for (int i = 0; i < this.numFree; i++) {
            double[] dArr2 = new double[this.numFree];
            dArr2[i] = transform[i];
            if (opFuncLinearTransform.transformBack(dArr2)[i] >= this.hardLimits[i][0]) {
                dArr[i][0] = transform[i];
            } else {
                dArr2[i] = transform[i] + 1.0E-5d;
                if (opFuncLinearTransform.transformBack(dArr2)[i] < this.hardLimits[i][0]) {
                    throw new ArithmeticException("Maximum margin (1.0E-5) exceeded trying to find a the nearest value to the transformed lower hardLimit which transform back correctly for param " + i + " (" + this.pNames[i] + ")");
                }
                double d = 0.0d;
                double d2 = 1.0E-5d;
                for (int i2 = 0; i2 < 1000; i2++) {
                    double d3 = (d + d2) / 2.0d;
                    if (d3 <= d || d3 >= d2) {
                        break;
                    }
                    dArr2[i] = transform[i] + d3;
                    if (opFuncLinearTransform.transformBack(dArr2)[i] < this.hardLimits[i][0]) {
                        d = d3;
                    } else {
                        d2 = d3;
                    }
                }
                dArr[i][0] = transform[i] + d2;
            }
        }
        double[] transform2 = opFuncLinearTransform.transform(transpose[1]);
        for (int i3 = 0; i3 < this.numFree; i3++) {
            double[] dArr3 = new double[this.numFree];
            dArr3[i3] = transform2[i3];
            if (opFuncLinearTransform.transformBack(dArr3)[i3] <= this.hardLimits[i3][1]) {
                dArr[i3][1] = transform2[i3];
            } else {
                dArr3[i3] = transform2[i3] - 1.0E-5d;
                if (opFuncLinearTransform.transformBack(dArr3)[i3] > this.hardLimits[i3][1]) {
                    throw new ArithmeticException("Maximum margin (1.0E-5) exceeded trying to find a the nearest value to the transformed upper hardLimit which transform back correctly for param " + i3 + " (" + this.pNames[i3] + ")");
                }
                double d4 = 0.0d;
                double d5 = 1.0E-5d;
                for (int i4 = 0; i4 < 1000; i4++) {
                    double d6 = (d4 + d5) / 2.0d;
                    if (d6 <= d4 || d6 >= d5) {
                        break;
                    }
                    dArr3[i3] = transform2[i3] - d6;
                    if (opFuncLinearTransform.transformBack(dArr3)[i3] > this.hardLimits[i3][1]) {
                        d4 = d6;
                    } else {
                        d5 = d6;
                    }
                }
                dArr[i3][1] = transform2[i3] - d5;
            }
        }
        return dArr;
    }

    public String[] getParameterNames() {
        return this.pNames;
    }

    public double[][] getTypicalRanges() {
        return this.typicalRanges;
    }

    public double[][] getHardLimits() {
        return this.hardLimits;
    }

    public double[][] getTypicalRanges(String[] strArr) {
        double[][] dArr = new double[strArr.length][2];
        for (int i = 0; i < strArr.length; i++) {
            int i2 = 0;
            while (true) {
                if (i2 < this.pNames.length) {
                    if (strArr[i].equals(this.pNames[i2])) {
                        dArr[i] = this.typicalRanges[i2];
                        break;
                    }
                    i2++;
                }
            }
        }
        return dArr;
    }

    public void dryEval(double[] dArr) {
        setFree(dArr);
    }

    public double eval(double[] dArr) {
        setFree(dArr);
        this.model.broadcastChanges();
        double d = 0.0d;
        double d2 = 0.0d;
        for (ProbabilityNode probabilityNode : this.nodes) {
            double logpdf = probabilityNode.logpdf();
            if (Double.isNaN(logpdf) || (Double.isInfinite(logpdf) && logpdf > 0.0d)) {
                System.err.println("Node '" + probabilityNode.getPath() + "' has invalid logPdf = " + logpdf);
            }
            if (probabilityNode.isObserved()) {
                d += logpdf;
                if (d <= -1.0E20d && !this.lhdLimitedPrior) {
                    break;
                }
            } else {
                d2 += logpdf;
                if (d2 <= -1.0E20d && !this.lhdLimitedPrior) {
                    break;
                }
            }
        }
        this.numFunctionEvaluations++;
        if (!this.lhdLimitedPrior) {
            return d2 + d;
        }
        if (d > this.lhdLimit) {
            return d2;
        }
        return Double.NEGATIVE_INFINITY;
    }

    public int numFunctionEvaluations() {
        return this.numFunctionEvaluations;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setFree(double[] dArr) {
        int i = 0;
        for (ProbabilityNode probabilityNode : this.free) {
            if (probabilityNode.isUnivariate()) {
                ((Univariate) probabilityNode).setDouble(dArr[i]);
                i++;
            } else {
                Multivariate multivariate = (Multivariate) probabilityNode;
                double[] dArr2 = new double[probabilityNode.dim()];
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    dArr2[i2] = dArr[i + i2];
                }
                multivariate.setDoubleArray(dArr2);
                i += dArr2.length;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] getFreeFromGraph() {
        return this.model.getFreeParameters();
    }

    public void dumpParameterInformation(OpFuncLinearTransform opFuncLinearTransform) {
        System.out.println("-----------------------------------\nMinerva graph free parameter space:\n" + getParameterInformation(opFuncLinearTransform) + "-----------------------------------\n");
    }

    public String getParameterInformation(OpFuncLinearTransform opFuncLinearTransform) {
        double[] freeParameters = this.model.getFreeParameters();
        double[][] dArr = new double[2][freeParameters.length];
        for (int i = 0; i < freeParameters.length; i++) {
            dArr[0][i] = -1.0d;
            dArr[1][i] = 1.0d;
        }
        if (opFuncLinearTransform != null) {
            dArr[0] = opFuncLinearTransform.transformBack(dArr[0]);
            dArr[1] = opFuncLinearTransform.transformBack(dArr[1]);
        }
        StringBuffer stringBuffer = new StringBuffer("Name\tValue\t(-1 xformed.)\t(+1 xformed.)\tLow hard\tHigh hard\n");
        for (int i2 = 0; i2 < freeParameters.length; i2++) {
            stringBuffer.append(String.valueOf(this.pNames[i2]) + "\t" + freeParameters[i2] + "\t" + dArr[0][i2] + "\t" + dArr[1][i2] + "\t" + this.hardLimits[i2][0] + "\t" + this.hardLimits[i2][1] + "\n");
        }
        stringBuffer.append("\n");
        return stringBuffer.toString();
    }

    public String freeParamPathList() {
        String str = "";
        int i = 0;
        for (ProbabilityNode probabilityNode : this.model.getUnobservedNodes()) {
            str = String.valueOf(str) + i + "\t" + probabilityNode.dim() + "\t" + probabilityNode.getPath() + "\n";
            i += probabilityNode.dim();
        }
        return str;
    }

    protected List<ProbabilityNode> getProbabilityNodes() {
        return this.nodes;
    }

    public List<ProbabilityNode> getFreeNodes() {
        return this.free;
    }

    public int domainDim() {
        return this.numFree;
    }

    public GraphicalModel getModel() {
        return this.model;
    }

    public boolean forceIntoLimits(boolean z, double[] dArr) {
        return forceIntoLimits(null, dArr, z);
    }

    public boolean forceIntoLimits(OpFuncLinearTransform opFuncLinearTransform, double[] dArr, boolean z) {
        return forceIntoLimits(opFuncLinearTransform, dArr, z, Double.NaN);
    }

    public boolean forceIntoLimits(OpFuncLinearTransform opFuncLinearTransform, double[] dArr, boolean z, double d) {
        double[][] rectangularBounds = opFuncLinearTransform == null ? this.hardLimits : opFuncLinearTransform.getDomain().getRectangularBounds();
        boolean z2 = false;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = rectangularBounds[i][0] + (Double.isNaN(d) ? 0.0d : d);
            double d3 = rectangularBounds[i][1] - (Double.isNaN(d) ? 0.0d : d);
            if (d2 >= d3) {
                System.err.println("WARNING: Trying to force into limits, but overshoots for parameter " + i + "('" + this.pNames[i] + "') overlap, forcing to central value.");
                dArr[i] = (d2 + d3) / 2.0d;
                z2 = true;
            } else if (dArr[i] < d2) {
                if (z) {
                    System.err.println("WARNING: Forcing parameter " + i + "('" + this.pNames[i] + "') above lower limit. !(" + rectangularBounds[i][0] + " < " + dArr[i] + " < " + rectangularBounds[i][1] + ")");
                }
                dArr[i] = d2;
                z2 = true;
            } else if (dArr[i] > d3) {
                if (z) {
                    System.err.println("WARNING: Forcing parameter " + i + "('" + this.pNames[i] + "') below upper limit. !(" + rectangularBounds[i][0] + " < " + dArr[i] + " < " + rectangularBounds[i][1] + ")");
                }
                dArr[i] = d3;
                z2 = true;
            }
        }
        return z2;
    }

    public void rescaleToConditionals() {
        double[] freeParameters = this.model.getFreeParameters();
        String[] parameterNames = getParameterNames();
        double logPdf = this.model.logPdf();
        for (int i = 0; i < freeParameters.length; i++) {
            System.out.print(String.valueOf(parameterNames[i]) + ": MAP=" + freeParameters[i] + ", oldRange=" + this.typicalRanges[i][0] + "-" + this.typicalRanges[i][1] + ", ");
            double d = (freeParameters[i] - this.typicalRanges[i][0]) / (this.typicalRanges[i][1] - this.typicalRanges[i][0]);
            double[] conditionalRangeInfo = getConditionalRangeInfo(freeParameters, logPdf, i, Math.min(-2.0d, 2.0d * d), Math.max(2.0d, 2.0d * d), 50, 0.1d);
            double d2 = (conditionalRangeInfo[3] + conditionalRangeInfo[0]) / 2.0d;
            double d3 = conditionalRangeInfo[3] - conditionalRangeInfo[0];
            this.typicalRanges[i][0] = d2 - (d3 / 2.0d);
            this.typicalRanges[i][1] = d2 + (d3 / 2.0d);
            System.out.println("left=" + conditionalRangeInfo[0] + ", right=" + conditionalRangeInfo[3] + ", newRange=" + this.typicalRanges[i][0] + "-" + this.typicalRanges[i][1]);
        }
        this.model.overrideTypicalRangesForFreeParameters(this.typicalRanges);
    }

    public double[] getConditionalRangeInfo(double[] dArr, double d, int i, double d2, double d3, int i2, double d4) {
        double[] dArr2 = (double[]) dArr.clone();
        double d5 = d - 4.0d;
        int i3 = 0;
        double d6 = this.typicalRanges[i][0] + (d2 * (this.typicalRanges[i][1] - this.typicalRanges[i][0]));
        double d7 = dArr[i];
        double d8 = Double.NaN;
        double d9 = Double.NaN;
        for (int i4 = 0; i4 < i2; i4++) {
            d8 = (d6 + d7) / 2.0d;
            dArr2[i] = d8;
            this.model.setFreeParameters(dArr2);
            d9 = this.model.logPdf();
            if (i4 >= 2 && Math.abs(d9 - d5) < d4) {
                break;
            }
            if (Double.isNaN(d9) || d9 <= d5) {
                d6 = d8;
            } else {
                d7 = d8;
            }
            i3++;
        }
        double d10 = d8;
        double d11 = (d10 - this.typicalRanges[i][0]) / (this.typicalRanges[i][1] - this.typicalRanges[i][0]);
        double d12 = d9;
        double d13 = this.typicalRanges[i][0] + (d3 * (this.typicalRanges[i][1] - this.typicalRanges[i][0]));
        double d14 = dArr[i];
        double d15 = Double.NaN;
        double d16 = Double.NaN;
        for (int i5 = 0; i5 < i2; i5++) {
            d15 = (d14 + d13) / 2.0d;
            dArr2[i] = d15;
            this.model.setFreeParameters(dArr2);
            d16 = this.model.logPdf();
            if (i5 >= 2 && Math.abs(d16 - d5) < d4) {
                break;
            }
            if (Double.isNaN(d16) || d16 < d5) {
                d13 = d15;
            } else {
                d14 = d15;
            }
            i3++;
        }
        double d17 = d15;
        return new double[]{d10, d11, d12, d17, (d17 - this.typicalRanges[i][0]) / (this.typicalRanges[i][1] - this.typicalRanges[i][0]), d16, i3};
    }
}
