package seed.mcmc;

import seed.digeom.Function;
import seed.digeom.FunctionWithDryEval;

/* loaded from: input_file:seed/mcmc/MetropolisHastingsSampler.class */
public class MetropolisHastingsSampler {
    protected Function logPdf;
    protected ProposalDistribution proposalDistribution;
    protected ProposalAdapter proposalAdapter;
    protected double[] pos;
    protected double logValue;
    protected double[] lastProposedPos;
    protected double lastProposedLogValue;
    protected double lastProposalAlpha;
    protected boolean accepted;
    protected boolean syncronizeFunctionState;
    protected double[][] hardLimits;
    protected double heat;

    public MetropolisHastingsSampler(Function function, ProposalDistribution proposalDistribution, double[] dArr, ProposalAdapter proposalAdapter) {
        this(function, proposalDistribution, dArr, proposalAdapter, false);
    }

    public MetropolisHastingsSampler(Function function, ProposalDistribution proposalDistribution, double[] dArr, ProposalAdapter proposalAdapter, boolean z) {
        this.syncronizeFunctionState = false;
        this.heat = 1.0d;
        this.logPdf = function;
        if (z) {
            this.hardLimits = function.getDomain().getRectangularBounds();
            for (int i = 0; i < dArr.length; i++) {
                if (dArr[i] <= this.hardLimits[i][0]) {
                    System.err.println("Pushing param " + i + "(" + function.getDomain().getType(i).getName() + ") above lower limit: " + this.hardLimits[i][0] + " <= " + dArr[i] + "<=" + this.hardLimits[i][1] + ".");
                    dArr[i] = this.hardLimits[i][0] + (1.0E-5d * (this.hardLimits[i][1] - this.hardLimits[i][0]));
                } else if (dArr[i] >= this.hardLimits[i][1]) {
                    System.err.println("Pulling param " + i + "(" + function.getDomain().getType(i).getName() + ") below lower limit: " + this.hardLimits[i][0] + " <= " + dArr[i] + "<=" + this.hardLimits[i][1] + ".");
                    dArr[i] = this.hardLimits[i][1] - (1.0E-5d * (this.hardLimits[i][1] - this.hardLimits[i][0]));
                }
            }
        } else {
            this.hardLimits = null;
        }
        setPos(dArr);
        setProposalDistribution(proposalDistribution);
        setProposalAdapter(proposalAdapter);
    }

    public void burnin(int i, ProposalAdapter proposalAdapter) {
        if (proposalAdapter == null) {
            iterate(i);
            return;
        }
        ProposalAdapter proposalAdapter2 = this.proposalAdapter;
        this.proposalAdapter = proposalAdapter;
        proposalAdapter.setSampler(this);
        proposalAdapter.reset();
        for (int i2 = 0; i2 < i; i2++) {
            iterate();
        }
        this.proposalAdapter = proposalAdapter2;
    }

    public void setProposalAdapter(ProposalAdapter proposalAdapter) {
        this.proposalAdapter = proposalAdapter;
        if (proposalAdapter != null) {
            proposalAdapter.setSampler(this);
            proposalAdapter.reset();
        }
    }

    public ProposalAdapter getProposalAdapter() {
        return this.proposalAdapter;
    }

    public void setProposalDistribution(ProposalDistribution proposalDistribution) {
        this.proposalDistribution = proposalDistribution;
        proposalDistribution.setSampler(this);
    }

    public ProposalDistribution getProposalDistribution() {
        return this.proposalDistribution;
    }

    public void iterate() {
        double[] sample = this.proposalDistribution.sample();
        double d = Double.NaN;
        double d2 = 0.0d;
        if (this.hardLimits != null) {
            for (int i = 0; i < this.pos.length; i++) {
                if (sample[i] < this.hardLimits[i][0] || sample[i] > this.hardLimits[i][1]) {
                    d = Double.NEGATIVE_INFINITY;
                    d2 = Double.POSITIVE_INFINITY;
                    break;
                }
            }
        }
        if (this.hardLimits == null || d2 <= 0.0d) {
            d = this.logPdf.eval(sample);
            d2 = RandomManager.instance().nextUniform(0.0d, 1.0d);
        }
        proposePosition(sample, d, d2);
        if (this.proposalAdapter != null) {
            this.proposalAdapter.update();
        }
        if (!this.syncronizeFunctionState || this.accepted) {
            return;
        }
        if (this.logPdf instanceof FunctionWithDryEval) {
            ((FunctionWithDryEval) this.logPdf).dryEval(this.pos);
        } else {
            this.logPdf.eval(this.pos);
        }
    }

    public boolean proposePosition(double[] dArr, double d, double d2) {
        System.arraycopy(dArr, 0, this.lastProposedPos, 0, this.lastProposedPos.length);
        this.lastProposedLogValue = d;
        this.lastProposalAlpha = d2;
        if (this.proposalDistribution.isSymmetric()) {
            if (d2 < Math.exp(this.heat * (d - this.logValue))) {
                this.accepted = true;
                System.arraycopy(dArr, 0, this.pos, 0, this.pos.length);
                this.logValue = d;
            } else {
                this.accepted = false;
            }
        } else if (d2 < Math.exp(((this.heat * (d - this.logValue)) + this.proposalDistribution.logPdf(this.pos, dArr)) - this.proposalDistribution.logPdf(dArr, this.pos))) {
            this.accepted = true;
            System.arraycopy(dArr, 0, this.pos, 0, this.pos.length);
            this.logValue = d;
        } else {
            this.accepted = false;
        }
        return this.accepted;
    }

    public void iterate(int i) {
        double[] dArr = (double[]) null;
        boolean syncronizeFunctionState = getSyncronizeFunctionState();
        if (syncronizeFunctionState) {
            dArr = (double[]) this.pos.clone();
            setSyncronizeFunctionState(false);
        }
        for (int i2 = 0; i2 < i; i2++) {
            iterate();
        }
        if (syncronizeFunctionState) {
            this.logPdf.eval(dArr);
            setSyncronizeFunctionState(true);
        }
    }

    public double[] getPos() {
        return this.pos;
    }

    public double getLogValue() {
        return this.logValue;
    }

    public void setPos(double[] dArr) {
        this.pos = dArr;
        if (dArr != null) {
            this.logValue = this.logPdf.eval(dArr);
            this.lastProposedPos = new double[dArr.length];
        }
    }

    public double[] getLastProposedPos() {
        return this.lastProposedPos;
    }

    public double getLastProposedLogValue() {
        return this.lastProposedLogValue;
    }

    public double getAlpha() {
        return this.lastProposalAlpha;
    }

    public Function getLogPdf() {
        return this.logPdf;
    }

    public boolean isAccepted() {
        return this.accepted;
    }

    public void setSyncronizeFunctionState(boolean z) {
        this.syncronizeFunctionState = z;
    }

    public boolean getSyncronizeFunctionState() {
        return this.syncronizeFunctionState;
    }

    public double getHeat() {
        return this.heat;
    }

    public void setHeat(double d) {
        this.heat = d;
    }
}
