/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.maxent.quasinewton.Function;
import opennlp.tools.ml.maxent.quasinewton.LineSearch;

public class QNMinimizer {
    public static final double CONVERGE_TOLERANCE = 1.0E-4;
    public static final double REL_GRAD_NORM_TOL = 1.0E-4;
    public static final double INITIAL_STEP_SIZE = 1.0;
    public static final double MIN_STEP_SIZE = 1.0E-10;
    public static final double L1COST_DEFAULT = 0.0;
    public static final double L2COST_DEFAULT = 0.0;
    public static final int NUM_ITERATIONS_DEFAULT = 100;
    public static final int M_DEFAULT = 15;
    public static final int MAX_FCT_EVAL_DEFAULT = 30000;
    private double l1Cost;
    private double l2Cost;
    private int iterations;
    private int m;
    private int maxFctEval;
    private boolean verbose;
    private int dimension;
    private UpdateInfo updateInfo;
    private Evaluator evaluator;

    public QNMinimizer() {
        this(0.0, 0.0);
    }

    public QNMinimizer(double l1Cost, double l2Cost) {
        this(l1Cost, l2Cost, 100);
    }

    public QNMinimizer(double l1Cost, double l2Cost, int iterations) {
        this(l1Cost, l2Cost, iterations, 15, 30000);
    }

    public QNMinimizer(double l1Cost, double l2Cost, int iterations, int m, int maxFctEval) {
        this(l1Cost, l2Cost, iterations, m, maxFctEval, true);
    }

    public QNMinimizer(double l1Cost, double l2Cost, int iterations, int m, int maxFctEval, boolean verbose) {
        if (l1Cost < 0.0 || l2Cost < 0.0) {
            throw new IllegalArgumentException("L1-cost and L2-cost must not be less than zero");
        }
        if (iterations <= 0) {
            throw new IllegalArgumentException("Number of iterations must be larger than zero");
        }
        if (m <= 0) {
            throw new IllegalArgumentException("Number of Hessian updates must be larger than zero");
        }
        if (maxFctEval <= 0) {
            throw new IllegalArgumentException("Maximum number of function evaluations must be larger than zero");
        }
        this.l1Cost = l1Cost;
        this.l2Cost = l2Cost;
        this.iterations = iterations;
        this.m = m;
        this.maxFctEval = maxFctEval;
        this.verbose = verbose;
    }

    public Evaluator getEvaluator() {
        return this.evaluator;
    }

    public void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }

    public double[] minimize(Function function) {
        int i;
        L2RegFunction l2RegFunction = new L2RegFunction(function, this.l2Cost);
        this.dimension = l2RegFunction.getDimension();
        this.updateInfo = new UpdateInfo(this.m, this.dimension);
        double[] currPoint = new double[this.dimension];
        double currValue = l2RegFunction.valueAt(currPoint);
        double[] currGrad = new double[this.dimension];
        System.arraycopy(l2RegFunction.gradientAt(currPoint), 0, currGrad, 0, this.dimension);
        double[] pseudoGrad = null;
        if (this.l1Cost > 0.0) {
            currValue += this.l1Cost * ArrayMath.l1norm(currPoint);
            pseudoGrad = new double[this.dimension];
            this.computePseudoGrad(currPoint, currGrad, pseudoGrad);
        }
        LineSearch.LineSearchResult lsr = this.l1Cost > 0.0 ? LineSearch.LineSearchResult.getInitialObjectForL1(currValue, currGrad, pseudoGrad, currPoint) : LineSearch.LineSearchResult.getInitialObject(currValue, currGrad, currPoint);
        if (this.verbose) {
            this.display("\nSolving convex optimization problem.");
            this.display("\nObjective function has " + this.dimension + " variable(s).");
            this.display("\n\nPerforming " + this.iterations + " iterations with L1Cost=" + this.l1Cost + " and L2Cost=" + this.l2Cost + "\n");
        }
        double[] direction = new double[this.dimension];
        long startTime = System.currentTimeMillis();
        double initialStepSize = this.l1Cost > 0.0 ? ArrayMath.invL2norm(lsr.getPseudoGradAtNext()) : ArrayMath.invL2norm(lsr.getGradAtNext());
        for (int iter = 1; iter <= this.iterations; ++iter) {
            if (this.l1Cost > 0.0) {
                System.arraycopy(lsr.getPseudoGradAtNext(), 0, direction, 0, direction.length);
            } else {
                System.arraycopy(lsr.getGradAtNext(), 0, direction, 0, direction.length);
            }
            this.computeDirection(direction);
            if (this.l1Cost > 0.0) {
                pseudoGrad = lsr.getPseudoGradAtNext();
                for (i = 0; i < this.dimension; ++i) {
                    if (!(direction[i] * pseudoGrad[i] >= 0.0)) continue;
                    direction[i] = 0.0;
                }
                LineSearch.doConstrainedLineSearch(l2RegFunction, direction, lsr, this.l1Cost, initialStepSize);
                this.computePseudoGrad(lsr.getNextPoint(), lsr.getGradAtNext(), pseudoGrad);
                lsr.setPseudoGradAtNext(pseudoGrad);
            } else {
                LineSearch.doLineSearch(l2RegFunction, direction, lsr, initialStepSize);
            }
            this.updateInfo.update(lsr);
            if (this.verbose) {
                if (iter < 10) {
                    this.display("  " + iter + ":  ");
                } else if (iter < 100) {
                    this.display(" " + iter + ":  ");
                } else {
                    this.display(iter + ":  ");
                }
                if (this.evaluator != null) {
                    this.display("\t" + lsr.getValueAtNext() + "\t" + lsr.getFuncChangeRate() + "\t" + this.evaluator.evaluate(lsr.getNextPoint()) + "\n");
                } else {
                    this.display("\t " + lsr.getValueAtNext() + "\t" + lsr.getFuncChangeRate() + "\n");
                }
            }
            if (this.isConverged(lsr)) break;
            initialStepSize = 1.0;
        }
        if (this.l1Cost > 0.0 && this.l2Cost > 0.0) {
            double[] x = lsr.getNextPoint();
            for (i = 0; i < this.dimension; ++i) {
                x[i] = StrictMath.sqrt(1.0 + this.l2Cost) * x[i];
            }
        }
        long endTime = System.currentTimeMillis();
        long duration = endTime - startTime;
        this.display("Running time: " + (double)duration / 1000.0 + "s\n");
        this.updateInfo = null;
        System.gc();
        double[] parameters = new double[this.dimension];
        System.arraycopy(lsr.getNextPoint(), 0, parameters, 0, this.dimension);
        return parameters;
    }

    private void computePseudoGrad(double[] x, double[] g, double[] pg) {
        for (int i = 0; i < this.dimension; ++i) {
            pg[i] = x[i] < 0.0 ? g[i] - this.l1Cost : (x[i] > 0.0 ? g[i] + this.l1Cost : (g[i] < -this.l1Cost ? g[i] + this.l1Cost : (g[i] > this.l1Cost ? g[i] - this.l1Cost : 0.0)));
        }
    }

    private void computeDirection(double[] direction) {
        int i;
        int k = this.updateInfo.kCounter;
        double[] rho = this.updateInfo.rho;
        double[] alpha = this.updateInfo.alpha;
        double[][] S = this.updateInfo.S;
        double[][] Y = this.updateInfo.Y;
        for (i = k - 1; i >= 0; --i) {
            alpha[i] = rho[i] * ArrayMath.innerProduct(S[i], direction);
            for (int j = 0; j < this.dimension; ++j) {
                direction[j] = direction[j] - alpha[i] * Y[i][j];
            }
        }
        for (i = 0; i < k; ++i) {
            double beta = rho[i] * ArrayMath.innerProduct(Y[i], direction);
            for (int j = 0; j < this.dimension; ++j) {
                direction[j] = direction[j] + S[i][j] * (alpha[i] - beta);
            }
        }
        for (i = 0; i < this.dimension; ++i) {
            direction[i] = -direction[i];
        }
    }

    private boolean isConverged(LineSearch.LineSearchResult lsr) {
        double gradNorm;
        if (lsr.getFuncChangeRate() < 1.0E-4) {
            if (this.verbose) {
                this.display("Function change rate is smaller than the threshold 1.0E-4.\nTraining will stop.\n\n");
            }
            return true;
        }
        double xNorm = StrictMath.max(1.0, ArrayMath.l2norm(lsr.getNextPoint()));
        double d = gradNorm = this.l1Cost > 0.0 ? ArrayMath.l2norm(lsr.getPseudoGradAtNext()) : ArrayMath.l2norm(lsr.getGradAtNext());
        if (gradNorm / xNorm < 1.0E-4) {
            if (this.verbose) {
                this.display("Relative L2-norm of the gradient is smaller than the threshold 1.0E-4.\nTraining will stop.\n\n");
            }
            return true;
        }
        if (lsr.getStepSize() < 1.0E-10) {
            if (this.verbose) {
                this.display("Step size is smaller than the minimum step size 1.0E-10.\nTraining will stop.\n\n");
            }
            return true;
        }
        if (lsr.getFctEvalCount() > this.maxFctEval) {
            if (this.verbose) {
                this.display("Maximum number of function evaluations has exceeded the threshold " + this.maxFctEval + ".\nTraining will stop.\n\n");
            }
            return true;
        }
        return false;
    }

    private void display(String s) {
        System.out.print(s);
    }

    public static interface Evaluator {
        public double evaluate(double[] var1);
    }

    public static class L2RegFunction
    implements Function {
        private Function f;
        private double l2Cost;

        public L2RegFunction(Function f, double l2Cost) {
            this.f = f;
            this.l2Cost = l2Cost;
        }

        @Override
        public int getDimension() {
            return this.f.getDimension();
        }

        @Override
        public double valueAt(double[] x) {
            this.checkDimension(x);
            double value = this.f.valueAt(x);
            if (this.l2Cost > 0.0) {
                value += this.l2Cost * ArrayMath.innerProduct(x, x);
            }
            return value;
        }

        @Override
        public double[] gradientAt(double[] x) {
            this.checkDimension(x);
            double[] gradient = this.f.gradientAt(x);
            if (this.l2Cost > 0.0) {
                for (int i = 0; i < x.length; ++i) {
                    int n = i;
                    gradient[n] = gradient[n] + 2.0 * this.l2Cost * x[i];
                }
            }
            return gradient;
        }

        private void checkDimension(double[] x) {
            if (x.length != this.getDimension()) {
                throw new IllegalArgumentException("x's dimension is not the same as function's dimension");
            }
        }
    }

    private class UpdateInfo {
        private double[][] S;
        private double[][] Y;
        private double[] rho;
        private double[] alpha;
        private int m;
        private int kCounter;

        UpdateInfo(int numCorrection, int dimension) {
            this.m = numCorrection;
            this.kCounter = 0;
            this.S = new double[this.m][dimension];
            this.Y = new double[this.m][dimension];
            this.rho = new double[this.m];
            this.alpha = new double[this.m];
        }

        public void update(LineSearch.LineSearchResult lsr) {
            double[] currPoint = lsr.getCurrPoint();
            double[] gradAtCurr = lsr.getGradAtCurr();
            double[] nextPoint = lsr.getNextPoint();
            double[] gradAtNext = lsr.getGradAtNext();
            double SYk = 0.0;
            if (this.kCounter < this.m) {
                for (int j = 0; j < QNMinimizer.this.dimension; ++j) {
                    this.S[this.kCounter][j] = nextPoint[j] - currPoint[j];
                    this.Y[this.kCounter][j] = gradAtNext[j] - gradAtCurr[j];
                    SYk += this.S[this.kCounter][j] * this.Y[this.kCounter][j];
                }
                this.rho[this.kCounter] = 1.0 / SYk;
            } else {
                for (int i = 0; i < this.m - 1; ++i) {
                    this.S[i] = this.S[i + 1];
                    this.Y[i] = this.Y[i + 1];
                    this.rho[i] = this.rho[i + 1];
                }
                for (int j = 0; j < QNMinimizer.this.dimension; ++j) {
                    this.S[this.m - 1][j] = nextPoint[j] - currPoint[j];
                    this.Y[this.m - 1][j] = gradAtNext[j] - gradAtCurr[j];
                    SYk += this.S[this.m - 1][j] * this.Y[this.m - 1][j];
                }
                this.rho[this.m - 1] = 1.0 / SYk;
            }
            if (this.kCounter < this.m) {
                ++this.kCounter;
            }
        }
    }
}

