/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.speciation;

import dr.evolution.tree.TreeTrait;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.speciation.BirthDeathEpisodicSeriallySampledModel;
import dr.evomodel.speciation.CachedGradientDelegate;
import dr.evomodel.speciation.EfficientSpeciationLikelihood;
import dr.evomodel.speciation.NewBirthDeathSerialSamplingModel;
import dr.evomodel.speciation.SpeciationLikelihoodGradient;
import dr.evomodel.speciation.SpeciationModel;
import dr.evomodel.speciation.SpeciationModelGradientProvider;
import dr.evomodel.tree.TreeModel;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.AbstractModel;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.Reportable;

public class EfficientSpeciationLikelihoodGradient
extends AbstractModel
implements GradientWrtParameterProvider,
Reportable,
Loggable {
    static final String GRADIENT_KEY = "speciationGradient";
    private final EfficientSpeciationLikelihood likelihood;
    private final Parameter parameter;
    private final SpeciationLikelihoodGradient.WrtParameter wrtParameter;
    private final TreeModel tree;
    private final SpeciationModel speciationModel;
    private final BigFastTreeIntervals treeIntervals;
    private final SpeciationModelGradientProvider provider;
    private final TreeTrait gradientProvider;
    private boolean gradientKnown;
    private double[] gradient;
    private double[] storedGradient;

    public EfficientSpeciationLikelihoodGradient(EfficientSpeciationLikelihood efficientSpeciationLikelihood, SpeciationLikelihoodGradient.WrtParameter wrtParameter) {
        super("efficientSpeciationLikelihoodGradient");
        this.likelihood = efficientSpeciationLikelihood;
        this.wrtParameter = wrtParameter;
        this.tree = efficientSpeciationLikelihood.getTreeModel();
        this.speciationModel = efficientSpeciationLikelihood.getSpeciationModel();
        this.treeIntervals = efficientSpeciationLikelihood.getTreeIntervals();
        this.provider = efficientSpeciationLikelihood.getGradientProvider();
        this.parameter = wrtParameter.getParameter(this.provider, this.tree);
        efficientSpeciationLikelihood.addModel(this);
        if (wrtParameter == SpeciationLikelihoodGradient.WrtParameter.NODE_HEIGHT) {
            this.speciationModel.addModelListener(this);
            this.treeIntervals.addModelListener(this);
        }
        this.gradientKnown = false;
        this.gradientProvider = this.getGradientDelegateSingleton(efficientSpeciationLikelihood);
    }

    private TreeTrait getGradientDelegateSingleton(EfficientSpeciationLikelihood efficientSpeciationLikelihood) {
        TreeTrait treeTrait = efficientSpeciationLikelihood.getTreeTrait(GRADIENT_KEY);
        if (treeTrait == null) {
            CachedGradientDelegate cachedGradientDelegate = new CachedGradientDelegate(efficientSpeciationLikelihood);
            this.addModel(cachedGradientDelegate);
            treeTrait = cachedGradientDelegate;
            efficientSpeciationLikelihood.addTrait(treeTrait);
        }
        return treeTrait;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        if (this.wrtParameter == SpeciationLikelihoodGradient.WrtParameter.NODE_HEIGHT) {
            if (!this.gradientKnown) {
                this.gradient = this.wrtParameter.getGradientLogDensity(this.provider, this.tree);
                this.gradientKnown = true;
            }
            return this.gradient;
        }
        if (this.speciationModel instanceof BirthDeathEpisodicSeriallySampledModel) {
            return this.wrtParameter.filter((double[])this.gradientProvider.getTrait(null, null), ((BirthDeathEpisodicSeriallySampledModel)this.speciationModel).numIntervals);
        }
        return this.wrtParameter.filter((double[])this.gradientProvider.getTrait(null, null), ((NewBirthDeathSerialSamplingModel)this.speciationModel).numIntervals);
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "SpeciationLikelihoodGradient check");
    }

    @Override
    public String getReport() {
        String string = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 0.001);
        if (this.gradientProvider instanceof CachedGradientDelegate) {
            // empty if block
        }
        return string;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model != this.speciationModel && model != this.treeIntervals) {
            throw new IllegalArgumentException("Unknown model: " + model.getId());
        }
        this.gradientKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        throw new IllegalArgumentException("Unknown variable: " + variable.getId());
    }

    @Override
    protected void storeState() {
        if (this.wrtParameter == SpeciationLikelihoodGradient.WrtParameter.NODE_HEIGHT && this.gradient != null) {
            if (this.storedGradient == null) {
                this.storedGradient = new double[this.gradient.length];
            }
            System.arraycopy(this.gradient, 0, this.storedGradient, 0, this.gradient.length);
        }
    }

    @Override
    protected void restoreState() {
        if (this.wrtParameter == SpeciationLikelihoodGradient.WrtParameter.NODE_HEIGHT) {
            double[] dArray = this.gradient;
            this.gradient = this.storedGradient;
            this.storedGradient = dArray;
        }
    }

    @Override
    protected void acceptState() {
    }
}

