/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc.deprecated;

import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.deprecated.AbstractHamiltonianMCOperator;
import dr.math.MathUtils;

@Deprecated
public class LatentFactorHamiltonianMC
extends AbstractHamiltonianMCOperator {
    private LatentFactorModel lfm;
    private FullyConjugateMultivariateTraitLikelihood tree;
    private MatrixParameterInterface factors;
    private MatrixParameterInterface loadings;
    private MatrixParameterInterface Precision;
    private int nfac;
    private int ntaxa;
    private int ntraits;
    private double stepSize;
    private int nSteps;
    private boolean diffusionSN = true;
    private Parameter missingIndicator;

    public LatentFactorHamiltonianMC(LatentFactorModel latentFactorModel, FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, double d, AdaptationMode adaptationMode, double d2, int n, double d3) {
        super(adaptationMode, d3);
        this.setWeight(d);
        this.lfm = latentFactorModel;
        this.tree = fullyConjugateMultivariateTraitLikelihood;
        this.factors = latentFactorModel.getFactors();
        this.loadings = latentFactorModel.getLoadings();
        this.Precision = latentFactorModel.getColumnPrecision();
        this.nfac = latentFactorModel.getFactorDimension();
        this.ntaxa = latentFactorModel.getFactors().getColumnDimension();
        this.ntraits = this.Precision.getRowDimension();
        this.stepSize = d2;
        this.nSteps = n;
        this.missingIndicator = latentFactorModel.getMissingIndicator();
    }

    @Override
    protected double getAdaptableParameterValue() {
        return Math.log(this.stepSize);
    }

    @Override
    protected void setAdaptableParameterValue(double d) {
        this.stepSize = Math.exp(d);
    }

    @Override
    public double getRawParameter() {
        return this.stepSize;
    }

    @Override
    public String getOperatorName() {
        return "Latent Factor Hamiltonian Monte Carlo";
    }

    private double[][] getMatrix(double[] dArray) {
        double[][] dArray2 = new double[this.nfac][this.ntaxa];
        for (int i = 0; i < this.nfac; ++i) {
            for (int j = 0; j < this.ntraits; ++j) {
                for (int k = 0; k < this.ntaxa; ++k) {
                    if (this.missingIndicator != null && this.missingIndicator.getParameterValue(k * this.ntraits + j) == 1.0) continue;
                    double[] dArray3 = dArray2[i];
                    int n = k;
                    dArray3[n] = dArray3[n] - this.loadings.getParameterValue(j, i) * this.Precision.getParameterValue(j, j) * dArray[k * this.ntraits + j];
                }
            }
        }
        return dArray2;
    }

    private double[][] getGradient(double[][] dArray, double[] dArray2) {
        double[] dArray3 = this.lfm.getResidual();
        double[][] dArray4 = this.getMatrix(dArray3);
        for (int i = 0; i < this.nfac; ++i) {
            for (int j = 0; j < this.ntaxa; ++j) {
                double[] dArray5 = dArray4[i];
                int n = j;
                dArray5[n] = dArray5[n] + (this.factors.getParameterValue(i, j) - dArray[j][i]) * dArray2[j];
            }
        }
        return dArray4;
    }

    @Override
    public double doOperation() {
        int n;
        int n2;
        int n3;
        double[][] dArray = this.tree.getConditionalMeans();
        double d = MathUtils.nextDouble();
        double d2 = this.stepSize;
        double[] dArray2 = this.tree.getPrecisionFactors();
        double[][] dArray3 = this.getGradient(dArray, dArray2);
        this.drawMomentum(this.lfm.getFactorDimension() * this.ntaxa);
        double d3 = 0.0;
        for (n3 = 0; n3 < this.momentum.length; ++n3) {
            d3 += this.momentum[n3] * this.momentum[n3] / (2.0 * this.getMomentumSd() * this.getMomentumSd());
        }
        for (n3 = 0; n3 < this.lfm.getFactorDimension(); ++n3) {
            for (n2 = 0; n2 < this.ntaxa; ++n2) {
                this.momentum[n3 * this.ntaxa + n2] = this.momentum[n3 * this.ntaxa + n2] - d2 / 2.0 * dArray3[n3][n2];
            }
        }
        for (n3 = 0; n3 < this.nSteps; ++n3) {
            for (n2 = 0; n2 < this.lfm.getFactorDimension(); ++n2) {
                for (n = 0; n < this.ntaxa; ++n) {
                    this.factors.setParameterValueQuietly(n2, n, this.factors.getParameterValue(n2, n) + d2 * this.momentum[n2 * this.ntaxa + n] / (this.getMomentumSd() * this.getMomentumSd()));
                }
            }
            this.factors.fireParameterChangedEvent();
            if (n3 == this.nSteps) continue;
            dArray3 = this.getGradient(dArray, dArray2);
            for (n2 = 0; n2 < this.lfm.getFactorDimension(); ++n2) {
                for (n = 0; n < this.ntaxa; ++n) {
                    this.momentum[n2 * this.ntaxa + n] = this.momentum[n2 * this.ntaxa + n] - d2 * dArray3[n2][n];
                }
            }
        }
        dArray3 = this.getGradient(dArray, dArray2);
        for (n3 = 0; n3 < this.lfm.getFactorDimension(); ++n3) {
            for (n2 = 0; n2 < this.ntaxa; ++n2) {
                this.momentum[n3 * this.ntaxa + n2] = this.momentum[n3 * this.ntaxa + n2] - d2 / 2.0 * dArray3[n3][n2];
            }
        }
        double d4 = 0.0;
        for (n = 0; n < this.momentum.length; ++n) {
            d4 += this.momentum[n] * this.momentum[n] / (2.0 * this.getMomentumSd() * this.getMomentumSd());
        }
        return d3 - d4;
    }
}

