/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.RemoveUseless;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class Logistic
extends Classifier
implements OptionHandler,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 3932117032546553727L;
    protected double[][] m_Par;
    protected double[][] m_Data;
    protected int m_NumPredictors;
    protected int m_ClassIndex;
    protected int m_NumClasses;
    protected double m_Ridge = 1.0E-8;
    private RemoveUseless m_AttFilter;
    private NominalToBinary m_NominalToBinary;
    private ReplaceMissingValues m_ReplaceMissingValues;
    protected boolean m_Debug;
    protected double m_LL;
    private int m_MaxIts = -1;

    public String globalInfo() {
        return "Class for building and using a multinomial logistic regression model with a ridge estimator.\n\nThere are some modifications, however, compared to the paper of leCessie and van Houwelingen(1992): \n\nIf there are k classes for n instances with m attributes, the parameter matrix B to be calculated will be an m*(k-1) matrix.\n\nThe probability for class j with the exception of the last class is\n\nPj(Xi) = exp(XiBj)/((sum[j=1..(k-1)]exp(Xi*Bj))+1) \n\nThe last class has probability\n\n1-(sum[j=1..(k-1)]Pj(Xi)) \n\t= 1/((sum[j=1..(k-1)]exp(Xi*Bj))+1)\n\nThe (negative) multinomial log-likelihood is thus: \n\nL = -sum[i=1..n]{\n\tsum[j=1..(k-1)](Yij * ln(Pj(Xi)))\n\t+(1 - (sum[j=1..(k-1)]Yij)) \n\t* ln(1 - sum[j=1..(k-1)]Pj(Xi))\n\t} + ridge * (B^2)\n\nIn order to find the matrix B for which L is minimised, a Quasi-Newton Method is used to search for the optimized values of the m*(k-1) variables.  Note that before we use the optimization procedure, we 'squeeze' the matrix B into a m*(k-1) vector.  For details of the optimization procedure, please check weka.core.Optimization class.\n\nAlthough original Logistic Regression does not deal with instance weights, we modify the algorithm a little bit to handle the instance weights.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString() + "\n\n" + "Note: Missing values are replaced using a ReplaceMissingValuesFilter, and " + "nominal attributes are transformed into numeric attributes using a " + "NominalToBinaryFilter.";
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "le Cessie, S. and van Houwelingen, J.C.");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1992");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Ridge Estimators in Logistic Regression");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Applied Statistics");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "41");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "1");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "191-201");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(3);
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tSet the ridge in the log-likelihood.", "R", 1, "-R <ridge>"));
        vector.addElement(new Option("\tSet the maximum number of iterations (default -1, until convergence).", "M", 1, "-M <number>"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setDebug(Utils.getFlag('D', stringArray));
        String string = Utils.getOption('R', stringArray);
        this.m_Ridge = string.length() != 0 ? Double.parseDouble(string) : 1.0E-8;
        String string2 = Utils.getOption('M', stringArray);
        this.m_MaxIts = string2.length() != 0 ? Integer.parseInt(string2) : -1;
    }

    public String[] getOptions() {
        String[] stringArray = new String[5];
        int n = 0;
        if (this.getDebug()) {
            stringArray[n++] = "-D";
        }
        stringArray[n++] = "-R";
        stringArray[n++] = "" + this.m_Ridge;
        stringArray[n++] = "-M";
        stringArray[n++] = "" + this.m_MaxIts;
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    public String debugTipText() {
        return "Output debug information to the console.";
    }

    public void setDebug(boolean bl) {
        this.m_Debug = bl;
    }

    public boolean getDebug() {
        return this.m_Debug;
    }

    public String ridgeTipText() {
        return "Set the Ridge value in the log-likelihood.";
    }

    public void setRidge(double d) {
        this.m_Ridge = d;
    }

    public double getRidge() {
        return this.m_Ridge;
    }

    public String maxItsTipText() {
        return "Maximum number of iterations to perform.";
    }

    public int getMaxIts() {
        return this.m_MaxIts;
    }

    public void setMaxIts(int n) {
        this.m_MaxIts = n;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        int n2;
        int n3;
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_ReplaceMissingValues = new ReplaceMissingValues();
        this.m_ReplaceMissingValues.setInputFormat(instances);
        instances = Filter.useFilter(instances, this.m_ReplaceMissingValues);
        this.m_AttFilter = new RemoveUseless();
        this.m_AttFilter.setInputFormat(instances);
        instances = Filter.useFilter(instances, this.m_AttFilter);
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(instances);
        instances = Filter.useFilter(instances, this.m_NominalToBinary);
        this.m_ClassIndex = instances.classIndex();
        this.m_NumClasses = instances.numClasses();
        int n4 = this.m_NumClasses - 1;
        int n5 = this.m_NumPredictors = instances.numAttributes() - 1;
        int n6 = instances.numInstances();
        this.m_Data = new double[n6][n5 + 1];
        int[] nArray = new int[n6];
        double[] dArray = new double[n5 + 1];
        double[] dArray2 = new double[n5 + 1];
        double[] dArray3 = new double[n4 + 1];
        double[] dArray4 = new double[n6];
        double d = 0.0;
        this.m_Par = new double[n5 + 1][n4];
        if (this.m_Debug) {
            System.out.println("Extracting data...");
        }
        for (n3 = 0; n3 < n6; ++n3) {
            Instance instance = instances.instance(n3);
            nArray[n3] = (int)instance.classValue();
            dArray4[n3] = instance.weight();
            d += dArray4[n3];
            this.m_Data[n3][0] = 1.0;
            n2 = 1;
            for (n = 0; n <= n5; ++n) {
                double d2;
                if (n == this.m_ClassIndex) continue;
                this.m_Data[n3][n2] = d2 = instance.value(n);
                int n7 = n2;
                dArray[n7] = dArray[n7] + dArray4[n3] * d2;
                int n8 = n2++;
                dArray2[n8] = dArray2[n8] + dArray4[n3] * d2 * d2;
            }
            int n9 = nArray[n3];
            dArray3[n9] = dArray3[n9] + 1.0;
        }
        if (d <= 1.0 && n6 > 1) {
            throw new Exception("Sum of weights of instances less than 1, please reweight!");
        }
        dArray[0] = 0.0;
        dArray2[0] = 1.0;
        for (n3 = 1; n3 <= n5; ++n3) {
            dArray[n3] = dArray[n3] / d;
            dArray2[n3] = d > 1.0 ? Math.sqrt(Math.abs(dArray2[n3] - d * dArray[n3] * dArray[n3]) / (d - 1.0)) : 0.0;
        }
        if (this.m_Debug) {
            System.out.println("Descriptives...");
            for (n3 = 0; n3 <= n4; ++n3) {
                System.out.println(dArray3[n3] + " cases have class " + n3);
            }
            System.out.println("\n Variable     Avg       SD    ");
            for (n3 = 1; n3 <= n5; ++n3) {
                System.out.println(Utils.doubleToString(n3, 8, 4) + Utils.doubleToString(dArray[n3], 10, 4) + Utils.doubleToString(dArray2[n3], 10, 4));
            }
        }
        for (n3 = 0; n3 < n6; ++n3) {
            for (int i = 0; i <= n5; ++i) {
                if (dArray2[i] == 0.0) continue;
                this.m_Data[n3][i] = (this.m_Data[n3][i] - dArray[i]) / dArray2[i];
            }
        }
        if (this.m_Debug) {
            System.out.println("\nIteration History...");
        }
        double[] dArray5 = new double[(n5 + 1) * n4];
        double[][] dArray6 = new double[2][dArray5.length];
        for (n2 = 0; n2 < n4; ++n2) {
            n = n2 * (n5 + 1);
            dArray5[n] = Math.log(dArray3[n2] + 1.0) - Math.log(dArray3[n4] + 1.0);
            dArray6[0][n] = Double.NaN;
            dArray6[1][n] = Double.NaN;
            for (int i = 1; i <= n5; ++i) {
                dArray5[n + i] = 0.0;
                dArray6[0][n + i] = Double.NaN;
                dArray6[1][n + i] = Double.NaN;
            }
        }
        OptEng optEng = new OptEng();
        optEng.setDebug(this.m_Debug);
        optEng.setWeights(dArray4);
        optEng.setClassLabels(nArray);
        if (this.m_MaxIts == -1) {
            dArray5 = optEng.findArgmin(dArray5, dArray6);
            while (dArray5 == null) {
                dArray5 = optEng.getVarbValues();
                if (this.m_Debug) {
                    System.out.println("200 iterations finished, not enough!");
                }
                dArray5 = optEng.findArgmin(dArray5, dArray6);
            }
            if (this.m_Debug) {
                System.out.println(" -------------<Converged>--------------");
            }
        } else {
            optEng.setMaxIteration(this.m_MaxIts);
            dArray5 = optEng.findArgmin(dArray5, dArray6);
            if (dArray5 == null) {
                dArray5 = optEng.getVarbValues();
            }
        }
        this.m_LL = -optEng.getMinFunction();
        this.m_Data = null;
        for (n = 0; n < n4; ++n) {
            this.m_Par[0][n] = dArray5[n * (n5 + 1)];
            for (int i = 1; i <= n5; ++i) {
                this.m_Par[i][n] = dArray5[n * (n5 + 1) + i];
                if (dArray2[i] == 0.0) continue;
                double[] dArray7 = this.m_Par[i];
                int n10 = n;
                dArray7[n10] = dArray7[n10] / dArray2[i];
                double[] dArray8 = this.m_Par[0];
                int n11 = n;
                dArray8[n11] = dArray8[n11] - this.m_Par[i][n] * dArray[i];
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_ReplaceMissingValues.input(instance);
        instance = this.m_ReplaceMissingValues.output();
        this.m_AttFilter.input(instance);
        instance = this.m_AttFilter.output();
        this.m_NominalToBinary.input(instance);
        instance = this.m_NominalToBinary.output();
        double[] dArray = new double[this.m_NumPredictors + 1];
        int n = 1;
        dArray[0] = 1.0;
        for (int i = 0; i <= this.m_NumPredictors; ++i) {
            if (i == this.m_ClassIndex) continue;
            dArray[n++] = instance.value(i);
        }
        double[] dArray2 = this.evaluateProbability(dArray);
        return dArray2;
    }

    private double[] evaluateProbability(double[] dArray) {
        int n;
        double[] dArray2 = new double[this.m_NumClasses];
        double[] dArray3 = new double[this.m_NumClasses];
        for (n = 0; n < this.m_NumClasses - 1; ++n) {
            for (int i = 0; i <= this.m_NumPredictors; ++i) {
                int n2 = n;
                dArray3[n2] = dArray3[n2] + this.m_Par[i][n] * dArray[i];
            }
        }
        dArray3[this.m_NumClasses - 1] = 0.0;
        for (n = 0; n < this.m_NumClasses; ++n) {
            double d = 0.0;
            for (int i = 0; i < this.m_NumClasses - 1; ++i) {
                d += Math.exp(dArray3[i] - dArray3[n]);
            }
            dArray2[n] = 1.0 / (d + Math.exp(-dArray3[n]));
        }
        return dArray2;
    }

    public String toString() {
        int n;
        int n2;
        String string = "Logistic Regression with ridge parameter of " + this.m_Ridge;
        if (this.m_Par == null) {
            return string + ": No model built yet.";
        }
        string = string + "\nCoefficients...\nVariable      Coeff.\n";
        for (n2 = 1; n2 <= this.m_NumPredictors; ++n2) {
            string = string + Utils.doubleToString(n2, 8, 0);
            for (n = 0; n < this.m_NumClasses - 1; ++n) {
                string = string + " " + Utils.doubleToString(this.m_Par[n2][n], 12, 4);
            }
            string = string + "\n";
        }
        string = string + "Intercept ";
        for (n2 = 0; n2 < this.m_NumClasses - 1; ++n2) {
            string = string + " " + Utils.doubleToString(this.m_Par[0][n2], 10, 4);
        }
        string = string + "\n";
        string = string + "\nOdds Ratios...\nVariable         O.R.\n";
        for (n2 = 1; n2 <= this.m_NumPredictors; ++n2) {
            string = string + Utils.doubleToString(n2, 8, 0);
            for (n = 0; n < this.m_NumClasses - 1; ++n) {
                double d = Math.exp(this.m_Par[n2][n]);
                string = string + " " + (d > 1.0E10 ? "" + d : Utils.doubleToString(d, 12, 4));
            }
            string = string + "\n";
        }
        return string;
    }

    public static void main(String[] stringArray) {
        Logistic.runClassifier(new Logistic(), stringArray);
    }

    private class OptEng
    extends Optimization {
        private double[] weights;
        private int[] cls;

        private OptEng() {
        }

        public void setWeights(double[] dArray) {
            this.weights = dArray;
        }

        public void setClassLabels(int[] nArray) {
            this.cls = nArray;
        }

        protected double objectiveFunction(double[] dArray) {
            int n;
            double d = 0.0;
            int n2 = Logistic.this.m_NumPredictors + 1;
            for (n = 0; n < this.cls.length; ++n) {
                double[] dArray2 = new double[Logistic.this.m_NumClasses - 1];
                for (int i = 0; i < Logistic.this.m_NumClasses - 1; ++i) {
                    int n3 = i * n2;
                    for (int j = 0; j < n2; ++j) {
                        int n4 = i;
                        dArray2[n4] = dArray2[n4] + Logistic.this.m_Data[n][j] * dArray[n3 + j];
                    }
                }
                double d2 = dArray2[Utils.maxIndex(dArray2)];
                double d3 = Math.exp(-d2);
                double d4 = this.cls[n] == Logistic.this.m_NumClasses - 1 ? -d2 : dArray2[this.cls[n]] - d2;
                for (int i = 0; i < Logistic.this.m_NumClasses - 1; ++i) {
                    d3 += Math.exp(dArray2[i] - d2);
                }
                d -= this.weights[n] * (d4 - Math.log(d3));
            }
            for (n = 0; n < Logistic.this.m_NumClasses - 1; ++n) {
                for (int i = 1; i < n2; ++i) {
                    d += Logistic.this.m_Ridge * dArray[n * n2 + i] * dArray[n * n2 + i];
                }
            }
            return d;
        }

        protected double[] evaluateGradient(double[] dArray) {
            int n;
            double[] dArray2 = new double[dArray.length];
            int n2 = Logistic.this.m_NumPredictors + 1;
            for (n = 0; n < this.cls.length; ++n) {
                int n3;
                int n4;
                double[] dArray3 = new double[Logistic.this.m_NumClasses - 1];
                for (int i = 0; i < Logistic.this.m_NumClasses - 1; ++i) {
                    double d = 0.0;
                    n4 = i * n2;
                    for (int j = 0; j < n2; ++j) {
                        d += Logistic.this.m_Data[n][j] * dArray[n4 + j];
                    }
                    dArray3[i] = d;
                }
                double d = dArray3[Utils.maxIndex(dArray3)];
                double d2 = Math.exp(-d);
                for (int i = 0; i < Logistic.this.m_NumClasses - 1; ++i) {
                    dArray3[i] = Math.exp(dArray3[i] - d);
                    d2 += dArray3[i];
                }
                Utils.normalize(dArray3, d2);
                for (n3 = 0; n3 < Logistic.this.m_NumClasses - 1; ++n3) {
                    n4 = n3 * n2;
                    double d3 = this.weights[n] * dArray3[n3];
                    for (int i = 0; i < n2; ++i) {
                        int n5 = n4 + i;
                        dArray2[n5] = dArray2[n5] + d3 * Logistic.this.m_Data[n][i];
                    }
                }
                if (this.cls[n] == Logistic.this.m_NumClasses - 1) continue;
                for (n3 = 0; n3 < n2; ++n3) {
                    int n6 = this.cls[n] * n2 + n3;
                    dArray2[n6] = dArray2[n6] - this.weights[n] * Logistic.this.m_Data[n][n3];
                }
            }
            for (n = 0; n < Logistic.this.m_NumClasses - 1; ++n) {
                for (int i = 1; i < n2; ++i) {
                    int n7 = n * n2 + i;
                    dArray2[n7] = dArray2[n7] + 2.0 * Logistic.this.m_Ridge * dArray[n * n2 + i];
                }
            }
            return dArray2;
        }
    }
}

