/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.trees.lmt.LogisticBase;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

public class SimpleLogistic
extends AbstractClassifier
implements OptionHandler,
AdditionalMeasureProducer,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 7397710626304705059L;
    protected LogisticBase m_boostedModel;
    protected NominalToBinary m_NominalToBinary = null;
    protected ReplaceMissingValues m_ReplaceMissingValues = null;
    protected int m_numBoostingIterations;
    protected int m_maxBoostingIterations = 500;
    protected int m_heuristicStop = 50;
    protected boolean m_useCrossValidation;
    protected boolean m_errorOnProbabilities;
    protected double m_weightTrimBeta = 0.0;
    private boolean m_useAIC = false;

    public SimpleLogistic() {
        this.m_numBoostingIterations = 0;
        this.m_useCrossValidation = true;
        this.m_errorOnProbabilities = false;
        this.m_weightTrimBeta = 0.0;
        this.m_useAIC = false;
    }

    public SimpleLogistic(int numBoostingIterations, boolean useCrossValidation, boolean errorOnProbabilities) {
        this.m_numBoostingIterations = numBoostingIterations;
        this.m_useCrossValidation = useCrossValidation;
        this.m_errorOnProbabilities = errorOnProbabilities;
        this.m_weightTrimBeta = 0.0;
        this.m_useAIC = false;
    }

    public static void main(String[] argv) {
        SimpleLogistic.runClassifier(new SimpleLogistic(), argv);
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        this.m_ReplaceMissingValues = new ReplaceMissingValues();
        this.m_ReplaceMissingValues.setInputFormat(data);
        data = Filter.useFilter(data, this.m_ReplaceMissingValues);
        this.m_NominalToBinary = new NominalToBinary();
        this.m_NominalToBinary.setInputFormat(data);
        data = Filter.useFilter(data, this.m_NominalToBinary);
        this.m_boostedModel = new LogisticBase(this.m_numBoostingIterations, this.m_useCrossValidation, this.m_errorOnProbabilities);
        this.m_boostedModel.setMaxIterations(this.m_maxBoostingIterations);
        this.m_boostedModel.setHeuristicStop(this.m_heuristicStop);
        this.m_boostedModel.setWeightTrimBeta(this.m_weightTrimBeta);
        this.m_boostedModel.setUseAIC(this.m_useAIC);
        this.m_boostedModel.setNumDecimalPlaces(this.m_numDecimalPlaces);
        this.m_boostedModel.buildClassifier(data);
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        this.m_ReplaceMissingValues.input(inst);
        inst = this.m_ReplaceMissingValues.output();
        this.m_NominalToBinary.input(inst);
        inst = this.m_NominalToBinary.output();
        return this.m_boostedModel.distributionForInstance(inst);
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tSet fixed number of iterations for LogitBoost", "I", 1, "-I <iterations>"));
        newVector.addElement(new Option("\tUse stopping criterion on training set (instead of\n\tcross-validation)", "S", 0, "-S"));
        newVector.addElement(new Option("\tUse error on probabilities (rmse) instead of\n\tmisclassification error for stopping criterion", "P", 0, "-P"));
        newVector.addElement(new Option("\tSet maximum number of boosting iterations", "M", 1, "-M <iterations>"));
        newVector.addElement(new Option("\tSet parameter for heuristic for early stopping of\n\tLogitBoost.\n\tIf enabled, the minimum is selected greedily, stopping\n\tif the current minimum has not changed for iter iterations.\n\tBy default, heuristic is enabled with value 50. Set to\n\tzero to disable heuristic.", "H", 1, "-H <iterations>"));
        newVector.addElement(new Option("\tSet beta for weight trimming for LogitBoost. Set to 0 for no weight trimming.\n", "W", 1, "-W <beta>"));
        newVector.addElement(new Option("\tThe AIC is used to choose the best iteration (instead of CV or training error).\n", "A", 0, "-A"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-I");
        options.add("" + this.getNumBoostingIterations());
        if (!this.getUseCrossValidation()) {
            options.add("-S");
        }
        if (this.getErrorOnProbabilities()) {
            options.add("-P");
        }
        options.add("-M");
        options.add("" + this.getMaxBoostingIterations());
        options.add("-H");
        options.add("" + this.getHeuristicStop());
        options.add("-W");
        options.add("" + this.getWeightTrimBeta());
        if (this.getUseAIC()) {
            options.add("-A");
        }
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String optionString = Utils.getOption('I', options);
        if (optionString.length() != 0) {
            this.setNumBoostingIterations(new Integer(optionString));
        }
        this.setUseCrossValidation(!Utils.getFlag('S', options));
        this.setErrorOnProbabilities(Utils.getFlag('P', options));
        optionString = Utils.getOption('M', options);
        if (optionString.length() != 0) {
            this.setMaxBoostingIterations(new Integer(optionString));
        }
        if ((optionString = Utils.getOption('H', options)).length() != 0) {
            this.setHeuristicStop(new Integer(optionString));
        }
        if ((optionString = Utils.getOption('W', options)).length() != 0) {
            this.setWeightTrimBeta(new Double(optionString));
        }
        this.setUseAIC(Utils.getFlag('A', options));
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    public int getNumBoostingIterations() {
        return this.m_numBoostingIterations;
    }

    public void setNumBoostingIterations(int n) {
        this.m_numBoostingIterations = n;
    }

    public boolean getUseCrossValidation() {
        return this.m_useCrossValidation;
    }

    public void setUseCrossValidation(boolean l) {
        this.m_useCrossValidation = l;
    }

    public boolean getErrorOnProbabilities() {
        return this.m_errorOnProbabilities;
    }

    public void setErrorOnProbabilities(boolean l) {
        this.m_errorOnProbabilities = l;
    }

    public int getMaxBoostingIterations() {
        return this.m_maxBoostingIterations;
    }

    public void setMaxBoostingIterations(int n) {
        this.m_maxBoostingIterations = n;
    }

    public int getHeuristicStop() {
        return this.m_heuristicStop;
    }

    public void setHeuristicStop(int n) {
        this.m_heuristicStop = n == 0 ? this.m_maxBoostingIterations : n;
    }

    public double getWeightTrimBeta() {
        return this.m_weightTrimBeta;
    }

    public void setWeightTrimBeta(double n) {
        this.m_weightTrimBeta = n;
    }

    public boolean getUseAIC() {
        return this.m_useAIC;
    }

    public void setUseAIC(boolean c) {
        this.m_useAIC = c;
    }

    public int getNumRegressions() {
        return this.m_boostedModel.getNumRegressions();
    }

    public String toString() {
        if (this.m_boostedModel == null) {
            return "No model built";
        }
        return "SimpleLogistic:\n" + this.m_boostedModel.toString();
    }

    public double measureAttributesUsed() {
        return this.m_boostedModel.percentAttributesUsed();
    }

    @Override
    public Enumeration<String> enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(3);
        newVector.addElement("measureAttributesUsed");
        newVector.addElement("measureNumIterations");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.compareToIgnoreCase("measureAttributesUsed") == 0) {
            return this.measureAttributesUsed();
        }
        if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) {
            return this.getNumRegressions();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (SimpleLogistic)");
    }

    public String globalInfo() {
        return "Classifier for building linear logistic regression models. LogitBoost with simple regression functions as base learners is used for fitting the logistic models. The optimal number of LogitBoost iterations to perform is cross-validated, which leads to automatic attribute selection. For more information see:\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Niels Landwehr and Mark Hall and Eibe Frank");
        result.setValue(TechnicalInformation.Field.TITLE, "Logistic Model Trees");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Machine Learning");
        result.setValue(TechnicalInformation.Field.YEAR, "2005");
        result.setValue(TechnicalInformation.Field.VOLUME, "95");
        result.setValue(TechnicalInformation.Field.PAGES, "161-205");
        result.setValue(TechnicalInformation.Field.NUMBER, "1-2");
        TechnicalInformation additional = result.add(TechnicalInformation.Type.INPROCEEDINGS);
        additional.setValue(TechnicalInformation.Field.AUTHOR, "Marc Sumner and Eibe Frank and Mark Hall");
        additional.setValue(TechnicalInformation.Field.TITLE, "Speeding up Logistic Model Tree Induction");
        additional.setValue(TechnicalInformation.Field.BOOKTITLE, "9th European Conference on Principles and Practice of Knowledge Discovery in Databases");
        additional.setValue(TechnicalInformation.Field.YEAR, "2005");
        additional.setValue(TechnicalInformation.Field.PAGES, "675-683");
        additional.setValue(TechnicalInformation.Field.PUBLISHER, "Springer");
        return result;
    }

    public String numBoostingIterationsTipText() {
        return "Set fixed number of iterations for LogitBoost. If >= 0, this sets the number of LogitBoost iterations to perform. If < 0, the number is cross-validated or a stopping criterion on the training set is used (depending on the value of useCrossValidation).";
    }

    public String useCrossValidationTipText() {
        return "Sets whether the number of LogitBoost iterations is to be cross-validated or the stopping criterion on the training set should be used. If not set (and no fixed number of iterations was given), the number of LogitBoost iterations is used that minimizes the error on the training set (misclassification error or error on probabilities depending on errorOnProbabilities).";
    }

    public String errorOnProbabilitiesTipText() {
        return "Use error on the probabilties as error measure when determining the best number of LogitBoost iterations. If set, the number of LogitBoost iterations is chosen that minimizes the root mean squared error (either on the training set or in the cross-validation, depending on useCrossValidation).";
    }

    public String maxBoostingIterationsTipText() {
        return "Sets the maximum number of iterations for LogitBoost. Default value is 500, for very small/large datasets a lower/higher value might be preferable.";
    }

    public String heuristicStopTipText() {
        return "If heuristicStop > 0, the heuristic for greedy stopping while cross-validating the number of LogitBoost iterations is enabled. This means LogitBoost is stopped if no new error minimum has been reached in the last heuristicStop iterations. It is recommended to use this heuristic, it gives a large speed-up especially on small datasets. The default value is 50.";
    }

    public String weightTrimBetaTipText() {
        return "Set the beta value used for weight trimming in LogitBoost. Only instances carrying (1 - beta)% of the weight from previous iteration are used in the next iteration. Set to 0 for no weight trimming. The default value is 0.";
    }

    public String useAICTipText() {
        return "The AIC is used to determine when to stop LogitBoost iterations (instead of cross-validation or training error).";
    }

    @Override
    public String numDecimalPlacesTipText() {
        return "The number of decimal places to be used for the output of coefficients.";
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11569 $");
    }
}

