package org.signalml.domain.roc;

import java.beans.IntrospectionException;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.signalml.app.model.components.LabelledPropertyDescriptor;
import org.signalml.app.model.components.PropertyProvider;
import org.signalml.app.model.components.WriterExportableTable;
import org.signalml.app.util.i18n.SvarogI18n;
import org.signalml.method.iterator.IterableParameter;
import org.signalml.method.iterator.MethodIteratorData;
import org.signalml.method.iterator.ParameterIterationSettings;

/* loaded from: input_file:org/signalml/domain/roc/RocData.class */
public class RocData implements WriterExportableTable, PropertyProvider {
    private IterableParameter[] parameters;
    private ArrayList<RocDataPoint> rocDataPoints;
    private boolean dirtyStatistics;
    private double areaUnderCurve;
    private int maxAccuracyIteration;
    private double maxAccuracy;
    private int pointBelowOSSIteration;
    private double pointBelowOSSDistance;
    private int pointAboveOSSIteration;
    private double pointAboveOSSDistance;
    private double ossIntersectionTP;

    public RocData(IterableParameter[] iterableParameterArr) {
        this.dirtyStatistics = true;
        if (iterableParameterArr == null) {
            throw new NullPointerException("No parameters");
        }
        if (iterableParameterArr.length == 0) {
            throw new IllegalArgumentException("More than one parameter needed");
        }
        this.parameters = iterableParameterArr;
        this.rocDataPoints = new ArrayList<>();
    }

    public RocData(IterableParameter[] iterableParameterArr, RocDataPoint[] rocDataPointArr) {
        this(iterableParameterArr);
        this.rocDataPoints = new ArrayList<>(this.rocDataPoints);
    }

    public static RocData createForParameterIterationSettings(ParameterIterationSettings[] parameterIterationSettingsArr) {
        LinkedList linkedList = new LinkedList();
        for (ParameterIterationSettings parameterIterationSettings : parameterIterationSettingsArr) {
            if (parameterIterationSettings.isIterated()) {
                linkedList.add(parameterIterationSettings.getParameter());
            }
        }
        IterableParameter[] iterableParameterArr = new IterableParameter[linkedList.size()];
        linkedList.toArray(iterableParameterArr);
        return new RocData(iterableParameterArr);
    }

    public static RocData createForMethodIteratorData(MethodIteratorData methodIteratorData) {
        return createForParameterIterationSettings(methodIteratorData.getParameters());
    }

    public void add(RocDataPoint rocDataPoint) {
        this.rocDataPoints.add(rocDataPoint);
        this.dirtyStatistics = true;
    }

    public int getParameterCount() {
        return this.parameters.length;
    }

    public IterableParameter getParameterAt(int i) {
        return this.parameters[i];
    }

    public int getSampleCount() {
        return this.rocDataPoints.size();
    }

    public RocDataPoint getRocDataPointAt(int i) {
        return this.rocDataPoints.get(i);
    }

    public Object getParameterValueAt(int i, int i2) {
        return this.rocDataPoints.get(i2).getParameterValues()[i];
    }

    public int getTruePositiveCount(int i) {
        return this.rocDataPoints.get(i).getTruePositiveCount();
    }

    public int getTrueNegativeCount(int i) {
        return this.rocDataPoints.get(i).getTrueNegativeCount();
    }

    public int getFalsePositiveCount(int i) {
        return this.rocDataPoints.get(i).getFalsePositiveCount();
    }

    public int getFalseNegativeCount(int i) {
        return this.rocDataPoints.get(i).getFalseNegativeCount();
    }

    public double getTrueRateAt(int i) {
        return this.rocDataPoints.get(i).getTrueRate();
    }

    public double getFalseRateAt(int i) {
        return this.rocDataPoints.get(i).getFalseRate();
    }

    public double[] getTrueRates() {
        int size = this.rocDataPoints.size();
        double[] dArr = new double[size];
        for (int i = 0; i < size; i++) {
            dArr[i] = this.rocDataPoints.get(i).getTrueRate();
        }
        return dArr;
    }

    public double[] getFalseRates() {
        int size = this.rocDataPoints.size();
        double[] dArr = new double[size];
        for (int i = 0; i < size; i++) {
            dArr[i] = this.rocDataPoints.get(i).getFalseRate();
        }
        return dArr;
    }

    private boolean isAboveOSS(RocDataPoint rocDataPoint) {
        return rocDataPoint.getTrueRate() >= 1.0d - rocDataPoint.getFalseRate();
    }

    private double getOSSDistance(RocDataPoint rocDataPoint) {
        return Math.abs(1.0d - (rocDataPoint.getFalseRate() + rocDataPoint.getTrueRate())) / Math.sqrt(2.0d);
    }

    private void calculateStatistics() {
        int size = this.rocDataPoints.size();
        if (size == 0) {
            this.areaUnderCurve = 0.0d;
            this.maxAccuracy = 0.0d;
            this.maxAccuracyIteration = -1;
            this.pointAboveOSSDistance = -1.0d;
            this.pointAboveOSSIteration = -1;
            this.pointBelowOSSDistance = -1.0d;
            this.pointBelowOSSIteration = -1;
        } else {
            this.areaUnderCurve = 0.0d;
            RocDataPoint rocDataPoint = this.rocDataPoints.get(0);
            this.areaUnderCurve += (rocDataPoint.getFalseRate() * rocDataPoint.getTrueRate()) / 2.0d;
            this.maxAccuracy = rocDataPoint.getAccuracy();
            this.maxAccuracyIteration = 0;
            this.pointAboveOSSIteration = -1;
            this.pointBelowOSSIteration = -1;
            this.pointAboveOSSDistance = Double.MAX_VALUE;
            this.pointBelowOSSDistance = Double.MAX_VALUE;
            if (isAboveOSS(rocDataPoint)) {
                this.pointAboveOSSDistance = getOSSDistance(rocDataPoint);
                this.pointAboveOSSIteration = 0;
            } else {
                this.pointBelowOSSDistance = getOSSDistance(rocDataPoint);
                this.pointBelowOSSIteration = 0;
            }
            for (int i = 1; i < size; i++) {
                RocDataPoint rocDataPoint2 = rocDataPoint;
                rocDataPoint = this.rocDataPoints.get(i);
                this.areaUnderCurve += 0.5d * (rocDataPoint2.getTrueRate() + rocDataPoint.getTrueRate()) * (rocDataPoint.getFalseRate() - rocDataPoint2.getFalseRate());
                double accuracy = rocDataPoint.getAccuracy();
                if (accuracy > this.maxAccuracy) {
                    this.maxAccuracy = accuracy;
                    this.maxAccuracyIteration = i;
                }
                double oSSDistance = getOSSDistance(rocDataPoint);
                if (isAboveOSS(rocDataPoint)) {
                    if (oSSDistance < this.pointAboveOSSDistance) {
                        this.pointAboveOSSDistance = oSSDistance;
                        this.pointAboveOSSIteration = i;
                    }
                } else if (oSSDistance < this.pointBelowOSSDistance) {
                    this.pointBelowOSSDistance = oSSDistance;
                    this.pointBelowOSSIteration = i;
                }
            }
            this.areaUnderCurve += 0.5d * (rocDataPoint.getTrueRate() + 1.0d) * (1.0d - rocDataPoint.getFalseRate());
            if (this.pointAboveOSSIteration < 0 || this.pointBelowOSSIteration < 0) {
                this.ossIntersectionTP = -1.0d;
            } else {
                RocDataPoint rocDataPoint3 = this.rocDataPoints.get(this.pointAboveOSSIteration);
                double falseRate = rocDataPoint3.getFalseRate();
                double trueRate = rocDataPoint3.getTrueRate();
                RocDataPoint rocDataPoint4 = this.rocDataPoints.get(this.pointBelowOSSIteration);
                double falseRate2 = rocDataPoint4.getFalseRate();
                double trueRate2 = rocDataPoint4.getTrueRate();
                this.ossIntersectionTP = ((trueRate2 * (1.0d - falseRate)) - (trueRate * (1.0d - falseRate2))) / ((falseRate2 + trueRate2) - (falseRate + trueRate));
            }
            if (this.pointAboveOSSIteration < 0) {
                this.pointAboveOSSDistance = -1.0d;
            } else {
                this.pointAboveOSSIteration++;
            }
            if (this.pointBelowOSSIteration < 0) {
                this.pointBelowOSSDistance = -1.0d;
            } else {
                this.pointBelowOSSIteration++;
            }
            if (this.maxAccuracyIteration >= 0) {
                this.maxAccuracyIteration++;
            }
        }
        this.dirtyStatistics = false;
    }

    public double getAreaUnderCurve() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.areaUnderCurve;
    }

    public double getMaxAccuracy() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.maxAccuracy;
    }

    public int getMaxAccuracyIteration() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.maxAccuracyIteration;
    }

    public int getPointBelowOSSIteration() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.pointBelowOSSIteration;
    }

    public double getPointBelowOSSDistance() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.pointBelowOSSDistance;
    }

    public int getPointAboveOSSIteration() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.pointAboveOSSIteration;
    }

    public double getPointAboveOSSDistance() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.pointAboveOSSDistance;
    }

    public double getOssIntersectionTP() {
        if (this.dirtyStatistics) {
            calculateStatistics();
        }
        return this.ossIntersectionTP;
    }

    @Override // org.signalml.app.model.components.WriterExportableTable
    public void export(Writer writer, String str, String str2, Object obj) throws IOException {
        for (int i = 0; i < this.parameters.length; i++) {
            writer.append((CharSequence) this.parameters[i].getName());
            writer.append((CharSequence) str);
        }
        writer.append("TP");
        writer.append((CharSequence) str);
        writer.append("FP");
        writer.append((CharSequence) str);
        writer.append("TN");
        writer.append((CharSequence) str);
        writer.append("FN");
        writer.append((CharSequence) str);
        writer.append("FP rate");
        writer.append((CharSequence) str);
        writer.append("TP rate");
        writer.append((CharSequence) str);
        writer.append("sensitivity");
        writer.append((CharSequence) str);
        writer.append("specifity");
        writer.append((CharSequence) str);
        writer.append("accuracy");
        writer.append((CharSequence) str);
        writer.append("positive_pred_value");
        writer.append((CharSequence) str);
        writer.append("negative_pred_value");
        writer.append((CharSequence) str);
        writer.append("false_discovery_rate");
        writer.append((CharSequence) str2);
        int size = this.rocDataPoints.size();
        for (int i2 = 0; i2 < size; i2++) {
            RocDataPoint rocDataPoint = this.rocDataPoints.get(i2);
            for (int i3 = 0; i3 < this.parameters.length; i3++) {
                writer.append((CharSequence) rocDataPoint.getParameterValues()[i3].toString());
                writer.append((CharSequence) str);
            }
            writer.append((CharSequence) Integer.toString(rocDataPoint.getTruePositiveCount()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Integer.toString(rocDataPoint.getFalsePositiveCount()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Integer.toString(rocDataPoint.getTrueNegativeCount()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Integer.toString(rocDataPoint.getFalseNegativeCount()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getFalseRate()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getTrueRate()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getSensitivity()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getSpecifity()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getAccuracy()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getPositivePredictiveValue()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getNegativePredictiveValue()));
            writer.append((CharSequence) str);
            writer.append((CharSequence) Double.toString(rocDataPoint.getFalseDiscoveryRate()));
            writer.append((CharSequence) str2);
        }
    }

    @Override // org.signalml.app.model.components.PropertyProvider
    public List<LabelledPropertyDescriptor> getPropertyList() throws IntrospectionException {
        LinkedList linkedList = new LinkedList();
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("area under curve"), "areaUnderCurve", RocData.class, "getAreaUnderCurve", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("maximal accuracy iteration"), "maxAccuracyIteration", RocData.class, "getMaxAccuracyIteration", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("maximal accuracy"), "maxAccuracy", RocData.class, "getMaxAccuracy", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("point abose OSS iteration"), "pointAboveOSSIteration", RocData.class, "getPointAboveOSSIteration", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("point above OSS distance"), "pointAboveOSSDistance", RocData.class, "getPointAboveOSSDistance", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("point below OSS iteration"), "pointBelowOSSIteration", RocData.class, "getPointBelowOSSIteration", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("point below OSS distance"), "pointBelowOSSDistance", RocData.class, "getPointBelowOSSDistance", null));
        linkedList.add(new LabelledPropertyDescriptor(SvarogI18n._("OSS intersection TP"), "ossIntersectionTP", RocData.class, "getOssIntersectionTP", null));
        return linkedList;
    }
}
