/*
 * Decompiled with CFR 0.152.
 */
package dr.oldevomodel.ibd;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.oldevomodel.substmodel.AbstractSubstitutionModel;
import dr.oldevomodel.substmodel.HKY;
import dr.oldevomodel.treelikelihood.NodePosteriorTreeLikelihood;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class AvgPosteriorIBDReporter
extends AbstractModel
implements TreeTraitProvider {
    protected double[] ibdweights;
    protected double[][] ibdForward;
    protected double[][] ibdBackward;
    protected double[] diag;
    protected boolean weightsKnown;
    protected HKY substitutionModel;
    protected TreeModel treeModel;
    protected BranchRateModel branchRateModel;
    protected Parameter mutationParameter;
    protected NodePosteriorTreeLikelihood likelihoodReporter;
    protected double[] probabilities;
    TreeTrait avgPosteriorIBDWeight = new TreeTrait.D(){

        @Override
        public String getTraitName() {
            return "AvgPosteriorIBDWeight";
        }

        @Override
        public TreeTrait.Intent getIntent() {
            return TreeTrait.Intent.NODE;
        }

        @Override
        public Double getTrait(Tree tree, NodeRef nodeRef) {
            if (!AvgPosteriorIBDReporter.this.weightsKnown) {
                AvgPosteriorIBDReporter.this.expectedIBD();
                AvgPosteriorIBDReporter.this.weightsKnown = true;
            }
            if (tree.isExternal(nodeRef)) {
                int n = nodeRef.getNumber();
                return AvgPosteriorIBDReporter.this.ibdweights[n] + 1.0;
            }
            return null;
        }
    };
    public static final String IBD_REPORTER_LIKELIHOOD = "avgPosteriorIBDReporter";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(TreeModel.class), new ElementRule(BranchRateModel.class, true), new ElementRule(AbstractSubstitutionModel.class), new ElementRule(Parameter.class), new ElementRule(NodePosteriorTreeLikelihood.class)};

        @Override
        public String getParserName() {
            return AvgPosteriorIBDReporter.IBD_REPORTER_LIKELIHOOD;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeModel treeModel = (TreeModel)xMLObject.getChild(TreeModel.class);
            Parameter parameter = (Parameter)xMLObject.getChild(Parameter.class);
            AbstractSubstitutionModel abstractSubstitutionModel = (AbstractSubstitutionModel)xMLObject.getChild(AbstractSubstitutionModel.class);
            BranchRateModel branchRateModel = (BranchRateModel)xMLObject.getChild(BranchRateModel.class);
            if (branchRateModel == null) {
                branchRateModel = new DefaultBranchRateModel();
            }
            NodePosteriorTreeLikelihood nodePosteriorTreeLikelihood = (NodePosteriorTreeLikelihood)xMLObject.getChild(NodePosteriorTreeLikelihood.class);
            return new AvgPosteriorIBDReporter(nodePosteriorTreeLikelihood, parameter, treeModel, branchRateModel, abstractSubstitutionModel);
        }

        @Override
        public String getParserDescription() {
            return "This element represents a reporter for average expected number of tips ibd conditional on observed patterns.";
        }

        @Override
        public Class getReturnType() {
            return Likelihood.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    AvgPosteriorIBDReporter(NodePosteriorTreeLikelihood nodePosteriorTreeLikelihood, Parameter parameter, TreeModel treeModel, BranchRateModel branchRateModel, AbstractSubstitutionModel abstractSubstitutionModel) {
        super("AvgPosteriorIBDReporter");
        this.substitutionModel = (HKY)abstractSubstitutionModel;
        this.addModel(this.substitutionModel);
        this.treeModel = treeModel;
        this.addModel(this.treeModel);
        this.branchRateModel = branchRateModel;
        this.addModel(this.branchRateModel);
        this.mutationParameter = parameter;
        this.addVariable(this.mutationParameter);
        this.likelihoodReporter = nodePosteriorTreeLikelihood;
        this.probabilities = new double[abstractSubstitutionModel.getStateCount() * abstractSubstitutionModel.getStateCount()];
    }

    public void forwardIBD() {
        int n = this.treeModel.getNodeCount();
        int n2 = this.substitutionModel.getStateCount();
        this.getDiagonalRates(this.diag);
        int n3 = this.likelihoodReporter.getPatternCount();
        for (int i = 0; i < n; ++i) {
            int n4;
            double d;
            NodeRef nodeRef = this.treeModel.getNode(i);
            NodeRef nodeRef2 = this.treeModel.getParent(nodeRef);
            this.likelihoodReporter.getNodeMatrix(i, this.probabilities);
            double[] dArray = this.likelihoodReporter.getPosteriors(i);
            if (nodeRef2 == null) continue;
            if (this.treeModel.isExternal(nodeRef)) {
                d = this.branchRateModel.getBranchRate(this.treeModel, nodeRef) * (this.treeModel.getNodeHeight(nodeRef2) - this.treeModel.getNodeHeight(nodeRef));
                for (n4 = 0; n4 < n2; ++n4) {
                    double d2 = Math.exp(-this.diag[n4] * d) / this.probabilities[n4 + n4 * n2];
                    for (int j = 0; j < n3; ++j) {
                        this.ibdForward[i][j * n2 + n4] = dArray[j * n2 + n4] * d2;
                    }
                }
                continue;
            }
            d = this.branchRateModel.getBranchRate(this.treeModel, nodeRef) * (this.treeModel.getNodeHeight(nodeRef2) - this.treeModel.getNodeHeight(nodeRef));
            n4 = this.treeModel.getChildCount(nodeRef);
            for (int j = 0; j < n2; ++j) {
                double d3 = Math.exp(-this.diag[j] * d) / this.probabilities[j + j * n2];
                for (int k = 0; k < n3; ++k) {
                    this.ibdForward[i][k * n2 + j] = 0.0;
                    for (int i2 = 0; i2 < n4; ++i2) {
                        int n5 = this.treeModel.getChild(nodeRef, i2).getNumber();
                        double[] dArray2 = this.ibdForward[i];
                        int n6 = k * n2 + j;
                        dArray2[n6] = dArray2[n6] + this.ibdForward[n5][k * n2 + j];
                    }
                    double[] dArray3 = this.ibdForward[i];
                    int n7 = k * n2 + j;
                    dArray3[n7] = dArray3[n7] * (dArray[k * n2 + j] * d3);
                }
            }
        }
    }

    public void backwardIBD(NodeRef nodeRef) {
        NodeRef nodeRef2;
        int n;
        int n2;
        int n3;
        int n4 = this.substitutionModel.getStateCount();
        int n5 = this.likelihoodReporter.getPatternCount();
        if (nodeRef == null) {
            nodeRef = this.treeModel.getRoot();
            n3 = nodeRef.getNumber();
            for (n2 = 0; n2 < n5 * n4; ++n2) {
                this.ibdBackward[n3][n2] = 0.0;
            }
        }
        this.getDiagonalRates(this.diag);
        n3 = this.treeModel.getChildCount(nodeRef);
        n2 = nodeRef.getNumber();
        double[] dArray = this.likelihoodReporter.getPosteriors(n2);
        for (n = 0; n < n3; ++n) {
            nodeRef2 = this.treeModel.getChild(nodeRef, n);
            int n6 = nodeRef2.getNumber();
            this.likelihoodReporter.getNodeMatrix(n6, this.probabilities);
            double d = this.branchRateModel.getBranchRate(this.treeModel, nodeRef2) * (this.treeModel.getNodeHeight(nodeRef) - this.treeModel.getNodeHeight(nodeRef2));
            for (int i = 0; i < n5; ++i) {
                for (int j = 0; j < n4; ++j) {
                    this.ibdBackward[n6][i * n4 + j] = this.ibdBackward[n2][i * n4 + j];
                    for (int k = 0; k < n3; ++k) {
                        if (k == n) continue;
                        int n7 = this.treeModel.getChild(nodeRef, k).getNumber();
                        double[] dArray2 = this.ibdBackward[n6];
                        int n8 = i * n4 + j;
                        dArray2[n8] = dArray2[n8] + this.ibdForward[n7][i * n4 + j];
                    }
                    double[] dArray3 = this.ibdBackward[n6];
                    int n9 = i * n4 + j;
                    dArray3[n9] = dArray3[n9] * (dArray[i * n4 + j] * Math.exp(-this.diag[j] * d) / this.probabilities[j + j * n4]);
                }
            }
        }
        for (n = 0; n < n3; ++n) {
            nodeRef2 = this.treeModel.getChild(nodeRef, n);
            this.backwardIBD(nodeRef2);
        }
    }

    public void expectedIBD() {
        int n;
        int n2 = this.substitutionModel.getStateCount();
        int n3 = this.treeModel.getNodeCount();
        int n4 = this.likelihoodReporter.getPatternCount();
        if (this.ibdweights == null) {
            this.ibdweights = new double[this.treeModel.getExternalNodeCount()];
            this.ibdForward = new double[n3][n2 * n4];
            this.ibdBackward = new double[n3][n2 * n4];
            this.diag = new double[n2];
        }
        this.forwardIBD();
        this.backwardIBD(null);
        int n5 = this.treeModel.getExternalNodeCount();
        double[] dArray = this.likelihoodReporter.getPatternWeights();
        double d = 0.0;
        for (n = 0; n < n4; ++n) {
            d += dArray[n];
        }
        for (n = 0; n < n5; ++n) {
            double[] dArray2 = this.likelihoodReporter.getPosteriors(n);
            this.ibdweights[n] = 0.0;
            for (int i = 0; i < n4; ++i) {
                for (int j = 0; j < n2; ++j) {
                    int n6 = n;
                    this.ibdweights[n6] = this.ibdweights[n6] + this.ibdBackward[n][i * n2 + j] * dArray2[i * n2 + j] * dArray[i] / d;
                }
            }
        }
    }

    protected void getDiagonalRates(double[] dArray) {
        double d = this.substitutionModel.getKappa();
        double[] dArray2 = this.substitutionModel.getFrequencyModel().getFrequencies();
        double d2 = this.mutationParameter.getParameterValue(0);
        double d3 = 0.5 / ((dArray2[0] + dArray2[2]) * (dArray2[1] + dArray2[3]) + d * (dArray2[0] * dArray2[2] + dArray2[1] * dArray2[3]));
        dArray[0] = (dArray2[1] + dArray2[3] + dArray2[2] * d) * d2 * d3;
        dArray[1] = (dArray2[0] + dArray2[2] + dArray2[3] * d) * d2 * d3;
        dArray[2] = (dArray2[1] + dArray2[3] + dArray2[0] * d) * d2 * d3;
        dArray[3] = (dArray2[0] + dArray2[2] + dArray2[1] * d) * d2 * d3;
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return new TreeTrait[]{this.avgPosteriorIBDWeight};
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        return this.avgPosteriorIBDWeight;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.branchRateModel || model == this.treeModel || model == this.substitutionModel || model == this.likelihoodReporter) {
            this.weightsKnown = false;
        } else {
            System.err.println("Weird call back to IBDReporter from " + model.getModelName());
        }
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.mutationParameter) {
            this.weightsKnown = false;
        } else {
            System.err.println("Weird call back to IBDReporter from " + variable.getVariableName());
        }
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
    }

    @Override
    protected void acceptState() {
    }
}

