package diva.sketch.classification;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:diva/sketch/classification/CrossValidation.class */
public class CrossValidation {
    public static final int K = 10;
    private int _k;
    private String _classifierType;
    private CVData _data;

    /* loaded from: input_file:diva/sketch/classification/CrossValidation$CVData.class */
    public static class CVData {
        HashMap _mapTypeToExamples = new HashMap();

        public void addClass(String str, ArrayList arrayList) {
            ArrayList arrayList2 = (ArrayList) this._mapTypeToExamples.get(str);
            if (arrayList2 == null) {
                arrayList2 = new ArrayList();
                this._mapTypeToExamples.put(str, arrayList2);
            }
            arrayList2.add(arrayList);
        }

        public void addExample(String str, FeatureSet featureSet) {
            ArrayList arrayList = (ArrayList) this._mapTypeToExamples.get(str);
            if (arrayList == null) {
                arrayList = new ArrayList();
                this._mapTypeToExamples.put(str, arrayList);
            }
            arrayList.add(featureSet);
        }

        public int getTypeCount() {
            return this._mapTypeToExamples.size();
        }

        public Iterator types() {
            return this._mapTypeToExamples.keySet().iterator();
        }

        public FeatureSet[] getExamples(String str) {
            Object[] array = ((ArrayList) this._mapTypeToExamples.get(str)).toArray();
            FeatureSet[] featureSetArr = new FeatureSet[array.length];
            for (int i = 0; i < array.length; i++) {
                featureSetArr[i] = (FeatureSet) array[i];
            }
            return featureSetArr;
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer();
            Iterator types = types();
            while (types.hasNext()) {
                String str = (String) types.next();
                int length = getExamples(str).length;
                stringBuffer.append(length + " " + str);
                stringBuffer.append(length < 2 ? "\n" : "s\n");
            }
            return stringBuffer.toString();
        }
    }

    /* loaded from: input_file:diva/sketch/classification/CrossValidation$CVResult.class */
    public static class CVResult {
        HashMap _mapTypeToStat = new HashMap();

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:diva/sketch/classification/CrossValidation$CVResult$Stat.class */
        public static class Stat {
            public int numCorrect;
            public int numIncorrect;

            private Stat() {
                this.numCorrect = 0;
                this.numIncorrect = 0;
            }

            public double getAccuracyRate() {
                return (100 * this.numCorrect) / (this.numCorrect + this.numIncorrect);
            }

            public double getErrorRate() {
                return (100 * this.numIncorrect) / (this.numCorrect + this.numIncorrect);
            }
        }

        public void incrCorrect(String str) {
            getStat(str).numCorrect++;
        }

        public void incrIncorrect(String str) {
            getStat(str).numIncorrect++;
        }

        public void combine(CVResult cVResult) {
            Iterator types = cVResult.types();
            while (types.hasNext()) {
                String str = (String) types.next();
                int correctCount = cVResult.getCorrectCount(str);
                int incorrectCount = cVResult.getIncorrectCount(str);
                Stat stat = getStat(str);
                stat.numCorrect += correctCount;
                stat.numIncorrect += incorrectCount;
            }
        }

        private Stat getStat(String str) {
            Stat stat = (Stat) this._mapTypeToStat.get(str);
            if (stat == null) {
                stat = new Stat();
                this._mapTypeToStat.put(str, stat);
            }
            return stat;
        }

        public int getCorrectCount(String str) {
            Stat stat = (Stat) this._mapTypeToStat.get(str);
            if (stat == null) {
                return 0;
            }
            return stat.numCorrect;
        }

        public int getIncorrectCount(String str) {
            Stat stat = (Stat) this._mapTypeToStat.get(str);
            if (stat == null) {
                return 0;
            }
            return stat.numIncorrect;
        }

        public double getAccuracyRate(String str) {
            Stat stat = (Stat) this._mapTypeToStat.get(str);
            if (stat == null) {
                return Double.NaN;
            }
            return stat.getAccuracyRate();
        }

        public double getErrorRate(String str) {
            Stat stat = (Stat) this._mapTypeToStat.get(str);
            if (stat == null) {
                return Double.NaN;
            }
            return stat.getErrorRate();
        }

        public Iterator types() {
            return this._mapTypeToStat.keySet().iterator();
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer();
            Iterator types = types();
            while (types.hasNext()) {
                String str = (String) types.next();
                stringBuffer.append(str + ": " + getCorrectCount(str) + " correct, " + getIncorrectCount(str) + " misses, accuracy rate: " + getAccuracyRate(str) + "%, error rate: " + getErrorRate(str) + "%\n");
            }
            return stringBuffer.toString();
        }
    }

    public CrossValidation(int i, String str, CVData cVData) {
        this._k = i;
        this._classifierType = str;
        this._data = cVData;
    }

    public CrossValidation(String str, CVData cVData) {
        this(10, str, cVData);
    }

    public static CVResult crossValidate(int i, TrainableClassifier trainableClassifier, CVData cVData) {
        CVResult cVResult = new CVResult();
        try {
            float f = 1.0f / i;
            String[] strArr = new String[cVData.getTypeCount()];
            int i2 = 0;
            Iterator types = cVData.types();
            while (types.hasNext()) {
                int i3 = i2;
                i2++;
                strArr[i3] = (String) types.next();
            }
            CVResult[] cVResultArr = new CVResult[i];
            for (int i4 = 0; i4 < i; i4++) {
                TrainingSet trainingSet = new TrainingSet();
                TrainingSet trainingSet2 = new TrainingSet();
                for (String str : strArr) {
                    FeatureSet[] examples = cVData.getExamples(str);
                    int length = examples.length;
                    if (length >= i) {
                        int i5 = (int) (length * f);
                        int i6 = i5 * i4;
                        int i7 = (i6 + i5) - 1;
                        for (int i8 = 0; i8 < length; i8++) {
                            if (i8 < i6 || i8 > i7) {
                                trainingSet.addPositiveExample(str, examples[i8]);
                            } else {
                                trainingSet2.addPositiveExample(str, examples[i8]);
                            }
                        }
                    }
                }
                trainableClassifier.train(trainingSet);
                CVResult cVResult2 = new CVResult();
                Iterator types2 = trainingSet2.types();
                while (types2.hasNext()) {
                    String str2 = (String) types2.next();
                    Iterator positiveExamples = trainingSet2.positiveExamples(str2);
                    while (positiveExamples.hasNext()) {
                        if (trainableClassifier.classify((FeatureSet) positiveExamples.next()).getHighestConfidenceType().equals(str2)) {
                            cVResult2.incrCorrect(str2);
                        } else {
                            cVResult2.incrIncorrect(str2);
                        }
                    }
                }
                cVResultArr[i4] = cVResult2;
            }
            for (int i9 = 0; i9 < i; i9++) {
                cVResult.combine(cVResultArr[i9]);
            }
        } catch (ClassifierException e) {
            e.printStackTrace();
        }
        return cVResult;
    }
}
