package seed.minerva.handlers.profiles;

import oneLiners.BinaryMatrixFile;
import oneLiners.BinaryMatrixWriter;
import oneLiners.OneLiners;
import seed.minerva.GraphicalModel;
import seed.minerva.LogPdfFunction;
import seed.minerva.Node;
import seed.minerva.Normal;
import seed.minerva.ProbabilityNodeImpl;
import seed.minerva.TruncatedNormal;
import seed.minerva.nodetypes.DoubleValue;
import seed.minerva.nodetypes.ScalarFunction1D;
import seed.minerva.nodetypes.ScalarFunction1DSwitch;
import seed.minerva.toBeGeneral.LinearEdge;
import seed.minerva.toBeGeneral.ModifiedTanhProfile;
import seed.minerva.toBeGeneral.ModifiedTanhProfileConfig;

/* loaded from: input_file:seed/minerva/handlers/profiles/ModifiedTanhAndInterpProfileHandler.class */
public class ModifiedTanhAndInterpProfileHandler extends InterpolationProfileHandler {
    public TruncatedNormal tanhPsi0Node;
    public TruncatedNormal tanhWidthNode;
    public TruncatedNormal tanhHeightNode;
    public TruncatedNormal tanhOffsetNode;
    public ModifiedTanhProfileConfig mtanhCfg;
    Node mtanh;
    public ScalarFunction1DSwitch switchNode;
    public Normal mtanhFirstDiffConstraint;
    public Normal mtanhSecondDiffConstraint;
    private double[] sampleOutputPsi;
    private BinaryMatrixWriter sampleWriter;
    private boolean linearEdge = false;

    public void build(GraphicalModel graphicalModel, String str, boolean z, double d, boolean z2) {
        this.linearEdge = z2;
        build(graphicalModel, str, z, d);
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler
    public void build(GraphicalModel graphicalModel, String str, boolean z, double d) {
        super.build(graphicalModel, str, z, d);
        this.mtanhCfg = new ModifiedTanhProfileConfig();
        if (this.linearEdge) {
            this.mtanh = new LinearEdge(String.valueOf(str) + "MTanh");
        } else {
            this.mtanh = new ModifiedTanhProfile(String.valueOf(str) + "MTanh");
        }
        this.g.add(this.mtanh);
        this.tanhPsi0Node = new TruncatedNormal(this.g, "tanhPsi0");
        this.tanhWidthNode = new TruncatedNormal(this.g, "tanhWidth");
        this.tanhHeightNode = new TruncatedNormal(this.g, "tanhHeight");
        this.tanhOffsetNode = new TruncatedNormal(this.g, "tanhOffset");
        this.tanhPsi0Node.setConnection("mean", this.mtanhCfg, "getPsi0Mean");
        this.tanhPsi0Node.setConnection(Normal.SIGMA, this.mtanhCfg, "getPsi0Sigma");
        this.tanhPsi0Node.setConnection("low", this.mtanhCfg, "getPsi0Min");
        this.tanhPsi0Node.setConnection("high", this.mtanhCfg, "getPsi0Max");
        this.tanhPsi0Node.setConnection(ProbabilityNodeImpl.TYPMIN, this.mtanhCfg, "getPsi0TypMin");
        this.tanhPsi0Node.setConnection(ProbabilityNodeImpl.TYPMAX, this.mtanhCfg, "getPsi0TypMax");
        this.tanhWidthNode.setConnection("mean", this.mtanhCfg, "getWidthMean");
        this.tanhWidthNode.setConnection(Normal.SIGMA, this.mtanhCfg, "getWidthSigma");
        this.tanhWidthNode.setConnection("low", this.mtanhCfg, "getWidthMin");
        this.tanhWidthNode.setConnection("high", this.mtanhCfg, "getWidthMax");
        this.tanhWidthNode.setConnection(ProbabilityNodeImpl.TYPMIN, this.mtanhCfg, "getWidthTypMin");
        this.tanhWidthNode.setConnection(ProbabilityNodeImpl.TYPMAX, this.mtanhCfg, "getWidthTypMax");
        this.tanhHeightNode.setConnection("mean", this.mtanhCfg, "getHeightMean");
        this.tanhHeightNode.setConnection(Normal.SIGMA, this.mtanhCfg, "getHeightSigma");
        this.tanhHeightNode.setConnection("low", this.mtanhCfg, "getHeightMin");
        this.tanhHeightNode.setConnection("high", this.mtanhCfg, "getHeightMax");
        this.tanhHeightNode.setConnection(ProbabilityNodeImpl.TYPMIN, this.mtanhCfg, "getHeightTypMin");
        this.tanhHeightNode.setConnection(ProbabilityNodeImpl.TYPMAX, this.mtanhCfg, "getHeightTypMax");
        this.tanhOffsetNode.setConnection("mean", this.mtanhCfg, "getOffsetMean");
        this.tanhOffsetNode.setConnection(Normal.SIGMA, this.mtanhCfg, "getOffsetSigma");
        this.tanhOffsetNode.setConnection("low", this.mtanhCfg, "getOffsetMin");
        this.tanhOffsetNode.setConnection("high", this.mtanhCfg, "getOffsetMax");
        this.tanhOffsetNode.setConnection(ProbabilityNodeImpl.TYPMIN, this.mtanhCfg, "getOffsetTypMin");
        this.tanhOffsetNode.setConnection(ProbabilityNodeImpl.TYPMAX, this.mtanhCfg, "getOffsetTypMax");
        this.mtanh.setConnection("centreX", this.tanhPsi0Node);
        this.mtanh.setConnection("width", this.tanhWidthNode);
        this.mtanh.setConnection("height", this.tanhHeightNode);
        this.mtanh.setConnection("offset", this.tanhOffsetNode);
        this.mtanh.setConnection("knotX", this.cfg, "getKnotX");
        this.mtanh.setConnection("knotY", this.knotVals);
        this.switchNode = new ScalarFunction1DSwitch(this.g, "sum-TanhPlusInterp", this.profileNode, (ScalarFunction1D) this.mtanh, (DoubleValue) null);
        this.switchNode.setConnection("switchX", this.cfg, "getKnotX");
        this.mtanhFirstDiffConstraint = new Normal(this.g, "mtanhFirstDiffConstraint");
        this.mtanhFirstDiffConstraint.setConnection("mean", this.mtanh, "getMatchingGradient");
        this.mtanhFirstDiffConstraint.setConnection("value", this.mtanhCfg, "getMatchingGradMean");
        this.mtanhFirstDiffConstraint.setConnection(Normal.SIGMA, this.mtanhCfg, "getMatchingGradSigma");
        this.mtanhFirstDiffConstraint.setObserved(true);
        this.mtanhSecondDiffConstraint = new Normal(this.g, "mtanhSecondDiffConstraint");
        this.mtanhSecondDiffConstraint.setConnection("mean", this.mtanh, "getMatchingSecondDiff");
        this.mtanhSecondDiffConstraint.setConnection("value", this.mtanhCfg, "getMatchingSecondDiffMean");
        this.mtanhSecondDiffConstraint.setConnection(Normal.SIGMA, this.mtanhCfg, "getMatchingSecondDiffSigma");
        this.mtanhSecondDiffConstraint.setObserved(true);
    }

    public void setLinearAndTanh(double d, double d2, double d3, double d4) {
        setTanhParams(d, d2, d3, d3 / 100.0d);
        super.setPedestalAndLinear(d3, d4);
    }

    public void setTanhParams(double d, double d2, double d3, double d4) {
        this.tanhPsi0Node.setDouble(d);
        this.tanhWidthNode.setDouble(d2);
        this.tanhHeightNode.setDouble(d3);
        this.tanhOffsetNode.setDouble(d4);
    }

    public void freeTanh(boolean z, boolean z2, boolean z3, boolean z4) {
        this.tanhPsi0Node.setActive(z);
        this.tanhWidthNode.setActive(z2);
        this.tanhHeightNode.setActive(z3);
        this.tanhOffsetNode.setActive(z4);
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler, seed.minerva.handlers.Handler
    public void saveState(String str) {
        saveState(str, this.unit);
    }

    /* JADX WARN: Type inference failed for: r1v23, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v27, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler
    public void saveState(String str, double d) {
        BinaryMatrixFile.mustWrite(String.valueOf(str) + "-tanh.bin", new double[]{new double[]{this.tanhPsi0Node.getDouble(), this.tanhWidthNode.getDouble(), this.tanhHeightNode.getDouble() * d, this.tanhOffsetNode.getDouble() * d}}, true);
        double[] knotX = this.cfg.getKnotX();
        double[] arrayMultiply = OneLiners.arrayMultiply(this.knotVals.getDoubleArray(), d);
        double[] linSpace = OneLiners.linSpace(Math.min(knotX[0], this.tanhPsi0Node.getDouble() - (2.0d * this.tanhWidthNode.getDouble())), Math.max(knotX[knotX.length - 1], this.tanhPsi0Node.getDouble() + (2.0d * this.tanhWidthNode.getDouble())), LogPdfFunction.hardLimitsNumericalPrecisionMarginSteps);
        BinaryMatrixFile.mustWrite(String.valueOf(str) + ".bin", new double[]{linSpace, OneLiners.arrayMultiply(this.switchNode.eval(linSpace), d)}, true);
        BinaryMatrixFile.mustWrite(String.valueOf(str) + "-knots.bin", new double[]{knotX, arrayMultiply}, true);
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler, seed.minerva.handlers.Handler
    public void loadState(String str) {
        loadState(str, this.unit);
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler
    public void loadState(String str, double d) {
        double[][] mustLoad = BinaryMatrixFile.mustLoad(String.valueOf(str) + "-tanh.bin", true);
        this.tanhPsi0Node.setDouble(mustLoad[0][0]);
        this.tanhWidthNode.setDouble(mustLoad[0][1]);
        this.tanhHeightNode.setDouble(mustLoad[0][2] / d);
        this.tanhOffsetNode.setDouble(mustLoad[0][3] / d);
        super.loadState(str, d);
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler
    public void initSampling(String str, int i) {
        double[] knotX = this.cfg.getKnotX();
        this.sampleOutputPsi = OneLiners.linSpace(Math.min(knotX[0], this.tanhPsi0Node.getDouble() - (2.0d * this.tanhWidthNode.getDouble())), Math.max(knotX[knotX.length - 1], this.tanhPsi0Node.getDouble() + (2.0d * this.tanhWidthNode.getDouble())), LogPdfFunction.hardLimitsNumericalPrecisionMarginSteps);
        this.sampleWriter = new BinaryMatrixWriter(String.valueOf(str) + "/" + this.basename, i);
        this.sampleWriter.writeRow(this.sampleOutputPsi);
    }

    public void writeSample() {
        this.sampleWriter.writeRow(this.switchNode.eval(this.sampleOutputPsi));
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler, seed.minerva.handlers.Handler
    public void endSampling() {
        this.sampleWriter.close();
    }

    @Override // seed.minerva.handlers.profiles.InterpolationProfileHandler, seed.minerva.handlers.profiles.OneDimensionalProfileHandler
    public ScalarFunction1D getProfileNode() {
        return this.switchNode;
    }
}
