/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.chemclipse.chromatogram.xxd.process.supplier.pca.model;

import java.util.ArrayList;
import org.eclipse.chemclipse.chromatogram.xxd.process.supplier.pca.exception.MathIllegalArgumentException;
import org.eclipse.chemclipse.chromatogram.xxd.process.supplier.pca.model.IMultivariateCalculator;
import org.eclipse.chemclipse.model.statistics.ISample;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;

public abstract class AbstractMultivariateCalculator
implements IMultivariateCalculator {
    private DMatrixRMaj loadings;
    private DMatrixRMaj scores;
    private double[] mean;
    private int numComps;
    private DMatrixRMaj sampleData;
    private ArrayList<ISample> sampleKeys = new ArrayList();
    private ArrayList<String> groupNames = new ArrayList();
    private int sampleIndex;
    private boolean computeSuccess;

    public AbstractMultivariateCalculator(int numSamples, int numVars, int numComponents) throws MathIllegalArgumentException {
        if (numComponents > numVars) {
            throw new MathIllegalArgumentException("Number of components must be smaller than number of variables.");
        }
        if (numVars <= 0) {
            throw new MathIllegalArgumentException("Number of variables must be larger than zero");
        }
        if (numSamples <= 0) {
            throw new MathIllegalArgumentException("Number of samples must be larger than zero.");
        }
        if (numComponents <= 0) {
            throw new MathIllegalArgumentException("Number of components must be larger than zero.");
        }
        this.sampleData = new DMatrixRMaj(numSamples, numVars);
        this.mean = new double[numVars];
        this.sampleIndex = 0;
        this.numComps = numComponents;
        this.computeSuccess = false;
    }

    @Override
    public void setComputeSuccess() {
        this.computeSuccess = true;
    }

    @Override
    public boolean getComputeStatus() {
        return this.computeSuccess;
    }

    @Override
    public void addObservation(double[] obsData, ISample sampleKey, String groupName) {
        int i = 0;
        while (i < obsData.length) {
            this.sampleData.set(this.sampleIndex, i, obsData[i]);
            ++i;
        }
        this.sampleKeys.add(sampleKey);
        this.groupNames.add(groupName);
        ++this.sampleIndex;
    }

    protected ArrayList<String> getGroupNames() {
        return this.groupNames;
    }

    public DMatrixRMaj getScores() {
        return this.scores;
    }

    private double[] applyLoadings(double[] obs) {
        DMatrixRMaj mean = DMatrixRMaj.wrap((int)this.sampleData.getNumCols(), (int)1, (double[])this.mean);
        DMatrixRMaj sample = new DMatrixRMaj(this.sampleData.getNumCols(), 1, true, obs);
        DMatrixRMaj rotated = new DMatrixRMaj(this.numComps, 1);
        CommonOps_DDRM.subtract((DMatrixD1)sample, (DMatrixD1)mean, (DMatrixD1)sample);
        CommonOps_DDRM.mult((DMatrix1Row)this.loadings, (DMatrix1Row)sample, (DMatrix1Row)rotated);
        return rotated.data;
    }

    @Override
    public double getErrorMetric(double[] obs) {
        if (!this.getComputeStatus()) {
            return 0.0;
        }
        double[] eig = this.applyLoadings(obs);
        double[] reproj = this.reproject(eig);
        double total = 0.0;
        int i = 0;
        while (i < reproj.length) {
            double d = obs[i] - reproj[i];
            total += d * d;
            ++i;
        }
        return Math.sqrt(total);
    }

    public DMatrixRMaj getLoadings() {
        return this.loadings;
    }

    @Override
    public double[] getLoadingVector(int var) {
        if (var < 0 || var >= this.numComps) {
            throw new IllegalArgumentException("Invalid component");
        }
        DMatrixRMaj loadingVector = new DMatrixRMaj(1, this.sampleData.numCols);
        CommonOps_DDRM.extract((DMatrix)this.loadings, (int)var, (int)(var + 1), (int)0, (int)this.sampleData.numCols, (DMatrix)loadingVector, (int)0, (int)0);
        return loadingVector.data;
    }

    @Override
    public double getSummedVariance() {
        DMatrixRMaj colMeans = new DMatrixRMaj(1, this.sampleData.numCols);
        CommonOps_DDRM.sumCols((DMatrixRMaj)this.sampleData, (DMatrixRMaj)colMeans);
        CommonOps_DDRM.divide((DMatrixD1)colMeans, (double)this.sampleData.numRows);
        DMatrixRMaj varTemp = this.sampleData.copy();
        DMatrixRMaj colTemp = new DMatrixRMaj(varTemp.numRows, 1);
        int i = 0;
        while (i < varTemp.numCols) {
            CommonOps_DDRM.extractColumn((DMatrixRMaj)varTemp, (int)i, (DMatrixRMaj)colTemp);
            CommonOps_DDRM.add((DMatrixD1)colTemp, (double)(colMeans.get(i) * -1.0));
            int j = 0;
            while (j < varTemp.numRows) {
                varTemp.set(j, i, Math.pow(colTemp.get(j), 2.0));
                ++j;
            }
            ++i;
        }
        DMatrixRMaj colSums = new DMatrixRMaj(1, this.sampleData.numCols);
        CommonOps_DDRM.sumCols((DMatrixRMaj)varTemp, (DMatrixRMaj)colSums);
        CommonOps_DDRM.divide((DMatrixD1)colSums, (double)(this.sampleData.numRows - 1));
        double summedVariance = CommonOps_DDRM.elementSum((DMatrixD1)colSums);
        return summedVariance;
    }

    @Override
    public double getExplainedVariance(int var) {
        DMatrixRMaj component = new DMatrixRMaj(this.sampleData.getNumRows(), 1);
        CommonOps_DDRM.extractColumn((DMatrixRMaj)this.getScores(), (int)var, (DMatrixRMaj)component);
        double colMean = CommonOps_DDRM.elementSum((DMatrixD1)component) / (double)this.sampleData.getNumRows();
        CommonOps_DDRM.add((DMatrixD1)component, (double)(colMean * -1.0));
        int i = 0;
        while (i < component.numRows) {
            component.set(i, 0, Math.pow(component.get(i), 2.0));
            ++i;
        }
        CommonOps_DDRM.divide((DMatrixD1)component, (double)(this.sampleData.numRows - 1));
        double explainedVariance = CommonOps_DDRM.elementSum((DMatrixD1)component) / 100.0;
        return explainedVariance;
    }

    protected double[] getMean() {
        return this.mean;
    }

    protected int getNumComps() {
        return this.numComps;
    }

    protected DMatrixRMaj getSampleData() {
        return this.sampleData;
    }

    @Override
    public double[] getScoreVector(ISample sampleId) {
        int obs = this.sampleKeys.indexOf(sampleId);
        DMatrixRMaj scoreVector = new DMatrixRMaj(1, this.numComps);
        CommonOps_DDRM.extract((DMatrix)this.scores, (int)obs, (int)(obs + 1), (int)0, (int)this.numComps, (DMatrix)scoreVector, (int)0, (int)0);
        return scoreVector.data;
    }

    protected double[] reproject(double[] scoreVector) {
        DMatrixRMaj sample = new DMatrixRMaj(this.sampleData.getNumCols(), 1);
        DMatrixRMaj rotated = DMatrixRMaj.wrap((int)this.numComps, (int)1, (double[])scoreVector);
        CommonOps_DDRM.multTransA((DMatrix1Row)this.loadings, (DMatrix1Row)rotated, (DMatrix1Row)sample);
        DMatrixRMaj mean = DMatrixRMaj.wrap((int)this.sampleData.getNumCols(), (int)1, (double[])this.mean);
        CommonOps_DDRM.add((DMatrixD1)sample, (DMatrixD1)mean, (DMatrixD1)sample);
        return sample.data;
    }

    protected void setLoadings(DMatrixRMaj loadings) {
        this.loadings = loadings;
    }

    protected void setScores(DMatrixRMaj scores) {
        this.scores = scores;
    }
}

