package seed.minerva;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:seed/minerva/GraphicalModel.class */
public class GraphicalModel extends GraphImpl {
    public static final double EFFECTIVE_LOGZERO = -1.0E20d;

    public GraphicalModel() {
    }

    public GraphicalModel(String str) {
        super(str);
    }

    public int numObservations() {
        int i = 0;
        for (Node node : getAllNodes()) {
            if (node instanceof ProbabilityNode) {
                ProbabilityNode probabilityNode = (ProbabilityNode) node;
                if (probabilityNode.isActive() && probabilityNode.isObserved()) {
                    i += probabilityNode.dim();
                }
            }
        }
        return i;
    }

    public int numFree() {
        int i = 0;
        for (Node node : getAllNodes()) {
            if (node instanceof ProbabilityNode) {
                ProbabilityNode probabilityNode = (ProbabilityNode) node;
                if (probabilityNode.isActive() && !probabilityNode.isObserved()) {
                    i += probabilityNode.dim();
                }
            }
        }
        return i;
    }

    public List<ProbabilityNode> getObservedNodes() {
        ObservedNodeVisitor observedNodeVisitor = new ObservedNodeVisitor();
        fastTraversal(observedNodeVisitor);
        return observedNodeVisitor.getVisitedNodes();
    }

    public List<ProbabilityNode> getAllObservedNodes() {
        ObservedNodeVisitor observedNodeVisitor = new ObservedNodeVisitor();
        observedNodeVisitor.setVisitActiveNodesOnly(false);
        fastTraversal(observedNodeVisitor);
        return observedNodeVisitor.getVisitedNodes();
    }

    public List<ProbabilityNode> getUnobservedNodes() {
        UnobservedNodeVisitor unobservedNodeVisitor = new UnobservedNodeVisitor();
        fastTraversal(unobservedNodeVisitor);
        return unobservedNodeVisitor.getVisitedNodes();
    }

    public List<ProbabilityNode> getAllUnobservedNodes() {
        UnobservedNodeVisitor unobservedNodeVisitor = new UnobservedNodeVisitor();
        unobservedNodeVisitor.setVisitActiveNodesOnly(false);
        fastTraversal(unobservedNodeVisitor);
        return unobservedNodeVisitor.getVisitedNodes();
    }

    public List<ProbabilityNode> getActiveProbabilityNodes() {
        ArrayList arrayList = new ArrayList();
        for (Node node : getAllNodes()) {
            if (node instanceof ProbabilityNode) {
                ProbabilityNode probabilityNode = (ProbabilityNode) node;
                if (probabilityNode.isActive()) {
                    arrayList.add(probabilityNode);
                }
            }
        }
        return arrayList;
    }

    List<ProbabilityNode> getActiveProbabilityNodesAncestralOrdering() {
        ProbabilityNodeVisitor probabilityNodeVisitor = new ProbabilityNodeVisitor();
        topologicalSort(probabilityNodeVisitor);
        return probabilityNodeVisitor.getVisitedNodes();
    }

    public List<ProbabilityNode> getActiveProbabilityNodeAncestors(Node node) {
        ProbabilityNodeVisitor probabilityNodeVisitor = new ProbabilityNodeVisitor();
        depthFirstAncestorTraversal(probabilityNodeVisitor, node);
        return probabilityNodeVisitor.getVisitedNodes();
    }

    public void sampleAndSetPriorPredictive() {
        Graph rootGraph = getRootGraph();
        ProbabilityNodeVisitor probabilityNodeVisitor = new ProbabilityNodeVisitor();
        rootGraph.topologicalSort(probabilityNodeVisitor);
        List<ProbabilityNode> visitedNodes = probabilityNodeVisitor.getVisitedNodes();
        Iterator<ProbabilityNode> it = visitedNodes.iterator();
        while (it.hasNext()) {
            it.next().setChanged();
        }
        rootGraph.broadcastChanges();
        Iterator<ProbabilityNode> it2 = visitedNodes.iterator();
        while (it2.hasNext()) {
            it2.next().sampleAndSet();
        }
    }

    public void sampleAndSetObservations() {
        Graph rootGraph = getRootGraph();
        ProbabilityNodeVisitor probabilityNodeVisitor = new ProbabilityNodeVisitor();
        rootGraph.topologicalSort(probabilityNodeVisitor);
        List<ProbabilityNode> visitedNodes = probabilityNodeVisitor.getVisitedNodes();
        for (ProbabilityNode probabilityNode : visitedNodes) {
            if (probabilityNode.isObserved()) {
                probabilityNode.setChanged();
            }
        }
        rootGraph.broadcastChanges();
        for (ProbabilityNode probabilityNode2 : visitedNodes) {
            if (probabilityNode2.isObserved()) {
                probabilityNode2.sampleAndSet();
            }
        }
    }

    public void setFreeParameters(double[] dArr) {
        int i = 0;
        for (ProbabilityNode probabilityNode : getUnobservedNodes()) {
            if (probabilityNode.isUnivariate()) {
                ((Univariate) probabilityNode).setDouble(dArr[i]);
                i++;
            } else {
                int dim = probabilityNode.dim();
                double[] dArr2 = new double[dim];
                for (int i2 = 0; i2 < dim; i2++) {
                    dArr2[i2] = dArr[i];
                    i++;
                }
                ((Multivariate) probabilityNode).setDoubleArray(dArr2);
            }
        }
    }

    public double[] getFreeParameters() {
        List<ProbabilityNode> unobservedNodes = getUnobservedNodes();
        int i = 0;
        Iterator<ProbabilityNode> it = unobservedNodes.iterator();
        while (it.hasNext()) {
            i += it.next().dim();
        }
        double[] dArr = new double[i];
        int i2 = 0;
        for (ProbabilityNode probabilityNode : unobservedNodes) {
            if (probabilityNode.isUnivariate()) {
                dArr[i2] = ((Univariate) probabilityNode).getDouble();
                i2++;
            } else {
                int dim = probabilityNode.dim();
                double[] doubleArray = ((Multivariate) probabilityNode).getDoubleArray();
                for (int i3 = 0; i3 < dim; i3++) {
                    dArr[i2] = doubleArray[i3];
                    i2++;
                }
            }
        }
        return dArr;
    }

    public void setAllUnobservedParameters(double[] dArr) {
        int i = 0;
        for (ProbabilityNode probabilityNode : getAllUnobservedNodes()) {
            if (probabilityNode.isUnivariate()) {
                ((Univariate) probabilityNode).setDouble(dArr[i]);
                i++;
            } else {
                int dim = probabilityNode.dim();
                double[] dArr2 = new double[dim];
                for (int i2 = 0; i2 < dim; i2++) {
                    dArr2[i2] = dArr[i];
                    i++;
                }
                ((Multivariate) probabilityNode).setDoubleArray(dArr2);
            }
        }
    }

    public double[] getAllUnobservedParameters() {
        List<ProbabilityNode> allUnobservedNodes = getAllUnobservedNodes();
        int i = 0;
        Iterator<ProbabilityNode> it = allUnobservedNodes.iterator();
        while (it.hasNext()) {
            i += it.next().dim();
        }
        double[] dArr = new double[i];
        int i2 = 0;
        for (ProbabilityNode probabilityNode : allUnobservedNodes) {
            if (probabilityNode.isUnivariate()) {
                dArr[i2] = ((Univariate) probabilityNode).getDouble();
                i2++;
            } else {
                int dim = probabilityNode.dim();
                double[] doubleArray = ((Multivariate) probabilityNode).getDoubleArray();
                for (int i3 = 0; i3 < dim; i3++) {
                    dArr[i2] = doubleArray[i3];
                    i2++;
                }
            }
        }
        return dArr;
    }

    public double[][] getTypicalRangeForFreeParameters() {
        List<ProbabilityNode> unobservedNodes = getUnobservedNodes();
        int i = 0;
        Iterator<ProbabilityNode> it = unobservedNodes.iterator();
        while (it.hasNext()) {
            i += it.next().dim();
        }
        double[][] dArr = new double[i][2];
        int i2 = 0;
        Iterator<ProbabilityNode> it2 = unobservedNodes.iterator();
        while (it2.hasNext()) {
            double[][] typicalRange = it2.next().getTypicalRange();
            for (int i3 = 0; i3 < typicalRange.length; i3++) {
                dArr[i2][0] = typicalRange[i3][0];
                dArr[i2][1] = typicalRange[i3][1];
                i2++;
            }
        }
        return dArr;
    }

    public void setTypicalRangeForFreeParameters(double[][] dArr) {
        overrideTypicalRangesForFreeParameters(dArr);
    }

    public void overrideTypicalRangesForFreeParameters(double[][] dArr) {
        List<ProbabilityNode> unobservedNodes = getUnobservedNodes();
        int i = 0;
        Iterator<ProbabilityNode> it = unobservedNodes.iterator();
        while (it.hasNext()) {
            i += it.next().dim();
        }
        if (dArr.length != i) {
            throw new MinervaRuntimeException("Internal error: dimensions differ: " + dArr.length + " != " + i);
        }
        int i2 = 0;
        for (ProbabilityNode probabilityNode : unobservedNodes) {
            double[][] dArr2 = new double[probabilityNode.dim()][2];
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                dArr2[i3][0] = dArr[i2][0];
                dArr2[i3][1] = dArr[i2][1];
                i2++;
            }
            probabilityNode.setTypicalRangeOverride(dArr2);
        }
    }

    public LogPdfFunction getLogPdfFunction() {
        return new LogPdfFunction(this);
    }

    public double logPdf() {
        return logPdf(true, true, true);
    }

    public double logPdf(boolean z, boolean z2, boolean z3) {
        Graph rootGraph = getRootGraph();
        rootGraph.broadcastChanges();
        ProbabilityNodeVisitor probabilityNodeVisitor = new ProbabilityNodeVisitor();
        rootGraph.topologicalSort(probabilityNodeVisitor);
        double d = 0.0d;
        for (ProbabilityNode probabilityNode : probabilityNodeVisitor.getVisitedNodes()) {
            double logpdf = probabilityNode.logpdf();
            if (Double.isNaN(logpdf) || (Double.isInfinite(logpdf) && logpdf > 0.0d)) {
                System.err.println("Node '" + probabilityNode.getName() + "' has invalid logPdf = " + logpdf);
            }
            if ((probabilityNode.isObserved() && z2) || (!probabilityNode.isObserved() && z)) {
                d += logpdf;
            }
            if (z3 && d < -1.0E20d) {
                break;
            }
        }
        return d;
    }

    public double chi2Normalised() {
        int i = 0;
        List<ProbabilityNode> observedNodes = getObservedNodes();
        for (ProbabilityNode probabilityNode : observedNodes) {
            if (!(probabilityNode instanceof MultivariateNormal) && !(probabilityNode instanceof Normal)) {
                throw new MinervaRuntimeException("Model " + getName() + " not applicable for normalised chi2 measure: contains non-normal observation:" + probabilityNode.getClass().getCanonicalName());
            }
            if ((probabilityNode instanceof MultivariateNormal) && !((MultivariateNormal) probabilityNode).isDiagonal()) {
                throw new MinervaRuntimeException("Can't do chi2 calculation, covariance matrix is not diagonal for observation: " + probabilityNode.getName());
            }
            i += probabilityNode.dim();
        }
        double d = 0.0d;
        for (ProbabilityNode probabilityNode2 : observedNodes) {
            if (probabilityNode2 instanceof Normal) {
                Normal normal = (Normal) probabilityNode2;
                d += Math.pow((normal.getDouble() - normal.mean1D()) / normal.sigma1D(), 2.0d);
            } else {
                MultivariateNormal multivariateNormal = (MultivariateNormal) probabilityNode2;
                double[][] covInv = multivariateNormal.covInv();
                double[] doubleArray = multivariateNormal.getDoubleArray();
                double[] mean = multivariateNormal.mean();
                for (int i2 = 0; i2 < covInv.length; i2++) {
                    d += covInv[i2][i2] * (doubleArray[i2] - mean[i2]) * (doubleArray[i2] - mean[i2]);
                }
            }
        }
        return d / i;
    }

    public void sampleAndSet() {
        Iterator<ProbabilityNode> it = getObservedNodes().iterator();
        while (it.hasNext()) {
            it.next().sampleAndSet();
        }
    }

    public void setObservationsToMeans() {
        for (ProbabilityNode probabilityNode : getObservedNodes()) {
            if (probabilityNode.isUnivariate()) {
                ((Univariate) probabilityNode).setDouble(probabilityNode.mean()[0]);
            } else {
                ((Multivariate) probabilityNode).setDoubleArray(probabilityNode.mean());
            }
        }
    }
}
