package seed.minerva.handlers.profiles;

import algorithmrepository.LinearInterpolation1D;
import java.util.Arrays;
import java.util.Comparator;
import oneLiners.BinaryMatrixFile;
import oneLiners.BinaryMatrixWriter;
import oneLiners.OneLiners;
import seed.minerva.GraphicalModel;
import seed.minerva.ProbabilityNodeImpl;
import seed.minerva.TruncatedMultivariateNormal;
import seed.minerva.nodetypes.Interpolation1DNode;
import seed.minerva.nodetypes.ScalarFunction1D;
import seed.minerva.toBeGeneral.FixedKnot1DProfileConfig;

/* loaded from: input_file:seed/minerva/handlers/profiles/InterpolationProfileHandler.class */
public class InterpolationProfileHandler implements OneDimensionalProfileHandler {
    public GraphicalModel g;
    public TruncatedMultivariateNormal knotVals;
    public FixedKnot1DProfileConfig cfg;
    public Interpolation1DNode profileNode;
    protected String basename;
    private double[] sampleOutputX;
    private BinaryMatrixWriter sampleWriter;
    protected double unit;

    public void build(GraphicalModel graphicalModel, String str, boolean z) {
        build(graphicalModel, str, z, 1.0d);
    }

    public void build(GraphicalModel graphicalModel, String str, boolean z, double d) {
        this.unit = d;
        partBuild(graphicalModel, str);
        connectPriorNode(z);
    }

    protected void partBuild(GraphicalModel graphicalModel, String str) {
        this.basename = str;
        this.g = new GraphicalModel(str);
        graphicalModel.add(this.g);
        this.cfg = new FixedKnot1DProfileConfig(String.valueOf(str) + "_cfg");
        this.profileNode = new Interpolation1DNode(String.valueOf(str) + "_prof", 1, 2, 0.0d);
        this.g.add(this.cfg, this.profileNode);
        this.profileNode.setConnection("knotXs", this.cfg, "getKnotX");
    }

    private void connectPriorNode(boolean z) {
        double[] dArr = (double[]) null;
        if (this.knotVals != null) {
            dArr = (double[]) this.knotVals.getDoubleArray().clone();
        }
        this.knotVals = new TruncatedMultivariateNormal(this.g, String.valueOf(this.basename) + "_par");
        this.knotVals.setConnection("mean", this.cfg, "getKnotYMean");
        this.knotVals.setConnection("cov", this.cfg, "getKnotYCovDiag");
        this.knotVals.setConnection("invcov", this.cfg, "getKnotYInvCov");
        if (this.cfg.isInitialised()) {
            if (dArr == null) {
                dArr = OneLiners.fillArray(0.0d, this.cfg.getNKnots());
            }
            if (z && this.cfg.getKnotYCovDiag() == null) {
                this.cfg.setPriorNormal(0.0d, 1.0d);
            } else if (!z && this.cfg.getKnotYInvCov() == null) {
                int nKnots = this.cfg.getNKnots();
                double[][] dArr2 = new double[nKnots][nKnots];
                for (int i = 0; i < nKnots; i++) {
                    dArr2[i][i] = 1.0d;
                }
                this.cfg.setPriorNormal(OneLiners.fillArray(0.0d, nKnots), dArr2);
            }
            this.knotVals.setDoubleArray(dArr);
        }
        this.knotVals.setConnection("low", this.cfg, "getKnotYMin");
        this.knotVals.setConnection("high", this.cfg, "getKnotYMax");
        this.knotVals.setConnection(ProbabilityNodeImpl.TYPMIN, this.cfg, "getKnotYTypMin");
        this.knotVals.setConnection(ProbabilityNodeImpl.TYPMAX, this.cfg, "getKnotYTypMax");
        this.profileNode.setConnection("knotYs", this.knotVals);
    }

    public void setIsDiagonal(boolean z) {
        connectPriorNode(z);
    }

    public void makeFlat(double d) {
        double[] doubleArray = this.knotVals.getDoubleArray();
        for (int i = 0; i < doubleArray.length; i++) {
            doubleArray[i] = d;
        }
        this.knotVals.setChanged();
    }

    public void makeFlat(double d, double d2) {
        double[] knotX = this.cfg.getKnotX();
        double[] doubleArray = this.knotVals.getDoubleArray();
        for (int i = 0; i < doubleArray.length; i++) {
            doubleArray[i] = knotX[i] >= 1.0d ? d2 : d;
        }
        this.knotVals.setChanged();
    }

    public void setPedestalAndLinear(double d, double d2) {
        double[] doubleArray = this.knotVals.getDoubleArray();
        double[] knotX = this.cfg.getKnotX();
        for (int i = 0; i < doubleArray.length; i++) {
            doubleArray[i] = knotX[i] > 1.0d ? 0.0d : d + ((1.0d - knotX[i]) * (d2 - d));
        }
        this.knotVals.setChanged();
    }

    @Override // seed.minerva.handlers.Handler
    public void saveState(String str) {
        saveState(str, this.unit);
    }

    public void saveState(String str, double d) {
        this.profileNode.dualDump(String.valueOf(str) + "/" + this.basename, true, d);
    }

    @Override // seed.minerva.handlers.Handler
    public void loadState(String str) {
        loadState(str, this.unit);
    }

    /* JADX WARN: Type inference failed for: r0v15, types: [double[], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r1v13, types: [double[], java.lang.Object[]] */
    public void loadState(String str, double d) {
        double[][] mustLoad = BinaryMatrixFile.mustLoad(String.valueOf(str) + "/" + this.basename + "-knots.bin", true);
        int nKnots = this.cfg.getNKnots();
        mustLoad[1] = OneLiners.arrayMultiply(mustLoad[1], 1.0d / d);
        if (mustLoad[0].length == nKnots && Arrays.deepEquals(new double[]{mustLoad[0]}, new double[]{this.cfg.getKnotX()})) {
            this.knotVals.setDoubleArray(mustLoad[1]);
        } else {
            setFromData(mustLoad[0], mustLoad[1]);
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Override // seed.minerva.handlers.profiles.OneDimensionalProfileHandler
    public void setFromData(double[] dArr, double[] dArr2) {
        double[][] transpose = OneLiners.transpose((double[][]) new double[]{dArr, dArr2});
        Arrays.sort(transpose, new Comparator<double[]>() { // from class: seed.minerva.handlers.profiles.InterpolationProfileHandler.1
            @Override // java.util.Comparator
            public int compare(double[] dArr3, double[] dArr4) {
                return Double.compare(dArr3[0], dArr4[0]);
            }
        });
        double[][] transpose2 = OneLiners.transpose(transpose);
        this.knotVals.setDoubleArray(new LinearInterpolation1D(transpose2[0], transpose2[1], 2, 0.0d).eval(this.cfg.getKnotX()));
    }

    public void initSampling(String str, int i) {
        setupProfileSampling(i);
        initSampling(str);
    }

    @Override // seed.minerva.handlers.Handler
    public void initSampling(String str) {
        if (this.sampleOutputX == null) {
            this.sampleWriter = null;
        } else {
            this.sampleWriter = new BinaryMatrixWriter(String.valueOf(str) + "/" + this.basename + "-samples.bin", this.sampleOutputX.length + 1);
            this.sampleWriter.writeRow(-1, this.sampleOutputX);
        }
    }

    public void setupProfileSampling(int i) {
        double[] knotX = this.cfg.getKnotX();
        this.sampleOutputX = OneLiners.linSpace(knotX[0], knotX[knotX.length - 1], i);
    }

    public void setupProfileSampling(double[] dArr) {
        this.sampleOutputX = dArr;
    }

    @Override // seed.minerva.handlers.Handler
    public void writeSample(int i) {
        if (this.sampleWriter != null) {
            this.sampleWriter.writeRow(Integer.valueOf(i), this.profileNode.eval(this.sampleOutputX));
        }
    }

    @Override // seed.minerva.handlers.Handler
    public void endSampling() {
        if (this.sampleWriter != null) {
            this.sampleWriter.close();
        }
    }

    @Override // seed.minerva.handlers.profiles.OneDimensionalProfileHandler
    public ScalarFunction1D getProfileNode() {
        return this.profileNode;
    }
}
