/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropyR;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoidR;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class QuaternaryOp
extends MultiThreadedHop {
    public static boolean FORCE_REPLICATION = false;
    private Types.OpOp4 _op = null;
    private boolean _postWeights = false;
    private boolean _logout = false;
    private boolean _minusin = false;
    private int _baseType = -1;
    private boolean _mult = false;
    private boolean _minus = false;
    private boolean _umult = false;
    private Types.OpOp1 _uop = null;
    private Types.OpOp2 _sop = null;

    private QuaternaryOp() {
    }

    public QuaternaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp4 o, Hop inX, Hop inU, Hop inV, Hop inW, boolean post) {
        this(l, dt, vt, o, inX, inU, inV);
        this.getInput().add(3, inW);
        inW.getParent().add(this);
        this._postWeights = post;
    }

    public QuaternaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp4 o, Hop inX, Hop inU, Hop inV, boolean flag1, boolean flag2) {
        this(l, dt, vt, o, inX, inU, inV);
        this._logout = flag1;
        this._minusin = flag2;
    }

    public QuaternaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp4 o, Hop inX, Hop inU, Hop inV, Hop inW, int baseType, boolean flag1, boolean flag2) {
        this(l, dt, vt, o, inX, inU, inV);
        if (inW != null) {
            this.getInput().add(3, inW);
            inW.getParent().add(this);
        }
        this._baseType = baseType;
        this._mult = flag1;
        this._minus = flag2;
    }

    public QuaternaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp4 o, Hop inW, Hop inU, Hop inV, boolean umult, Types.OpOp1 uop, Types.OpOp2 sop) {
        this(l, dt, vt, o, inW, inU, inV);
        this._umult = umult;
        this._uop = uop;
        this._sop = sop;
    }

    public QuaternaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp4 o, Hop inX, Hop inU, Hop inV) {
        super(l, dt, vt);
        this._op = o;
        this.getInput().add(0, inX);
        this.getInput().add(1, inU);
        this.getInput().add(2, inV);
        inX.getParent().add(this);
        inU.getParent().add(this);
        inV.getParent().add(this);
    }

    @Override
    public void checkArity() {
        HopsException.check(this._input.size() == 3 || this._input.size() == 4, this, "should have arity 3 or 4 but has arity %d", this._input.size());
    }

    public Types.OpOp4 getOp() {
        return this._op;
    }

    @Override
    public boolean isGPUEnabled() {
        return false;
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return true;
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        try {
            Types.ExecType et = this.optFindExecType();
            switch (this._op) {
                case WSLOSS: {
                    WeightedSquaredLoss.WeightsType wtype = this.checkWeightsType();
                    if (et == Types.ExecType.CP) {
                        this.constructCPLopsWeightedSquaredLoss(wtype);
                        break;
                    }
                    if (et == Types.ExecType.SPARK) {
                        this.constructSparkLopsWeightedSquaredLoss(wtype);
                        break;
                    }
                    throw new HopsException("Unsupported quaternaryop-wsloss exec type: " + (Object)((Object)et));
                }
                case WSIGMOID: {
                    WeightedSigmoid.WSigmoidType wtype = this.checkWSigmoidType();
                    if (et == Types.ExecType.CP) {
                        this.constructCPLopsWeightedSigmoid(wtype);
                        break;
                    }
                    if (et == Types.ExecType.SPARK) {
                        this.constructSparkLopsWeightedSigmoid(wtype);
                        break;
                    }
                    throw new HopsException("Unsupported quaternaryop-wsigmoid exec type: " + (Object)((Object)et));
                }
                case WDIVMM: {
                    WeightedDivMM.WDivMMType wtype = this.checkWDivMMType();
                    if (et == Types.ExecType.CP) {
                        this.constructCPLopsWeightedDivMM(wtype);
                        break;
                    }
                    if (et == Types.ExecType.SPARK) {
                        this.constructSparkLopsWeightedDivMM(wtype);
                        break;
                    }
                    throw new HopsException("Unsupported quaternaryop-wdivmm exec type: " + (Object)((Object)et));
                }
                case WCEMM: {
                    WeightedCrossEntropy.WCeMMType wtype = this.checkWCeMMType();
                    if (et == Types.ExecType.CP) {
                        this.constructCPLopsWeightedCeMM(wtype);
                        break;
                    }
                    if (et == Types.ExecType.SPARK) {
                        this.constructSparkLopsWeightedCeMM(wtype);
                        break;
                    }
                    throw new HopsException("Unsupported quaternaryop-wcemm exec type: " + (Object)((Object)et));
                }
                case WUMM: {
                    WeightedUnaryMM.WUMMType wtype;
                    WeightedUnaryMM.WUMMType wUMMType = wtype = this._umult ? WeightedUnaryMM.WUMMType.MULT : WeightedUnaryMM.WUMMType.DIV;
                    if (et == Types.ExecType.CP) {
                        this.constructCPLopsWeightedUMM(wtype);
                        break;
                    }
                    if (et == Types.ExecType.SPARK) {
                        this.constructSparkLopsWeightedUMM(wtype);
                        break;
                    }
                    throw new HopsException("Unsupported quaternaryop-wumm exec type: " + (Object)((Object)et));
                }
                default: {
                    throw new HopsException(this.printErrorLocation() + "Unknown QuaternaryOp (" + (Object)((Object)this._op) + ") while constructing Lops");
                }
            }
        }
        catch (LopsException e) {
            throw new HopsException(this.printErrorLocation() + "error constructing lops for QuaternaryOp.", e);
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    @Override
    public String getOpString() {
        return "q(" + this._op.toString() + ")";
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    private void constructCPLopsWeightedSquaredLoss(WeightedSquaredLoss.WeightsType wtype) {
        WeightedSquaredLoss wsloss = new WeightedSquaredLoss(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getInput().get(3).constructLops(), this.getDataType(), this.getValueType(), wtype, Types.ExecType.CP);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        wsloss.setNumThreads(k);
        this.setOutputDimensions(wsloss);
        this.setLineNumbers(wsloss);
        this.setLops(wsloss);
    }

    private void constructSparkLopsWeightedSquaredLoss(WeightedSquaredLoss.WeightsType wtype) {
        boolean isMapWsloss;
        double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
        double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
        Hop X = this.getInput().get(0);
        Hop U = this.getInput().get(1);
        Hop V = this.getInput().get(2);
        Hop W = this.getInput().get(3);
        double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2());
        double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2());
        boolean bl = isMapWsloss = !wtype.hasFourInputs() && m1Size + m2Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal && 2.0 * m2Size < memBudgetLocal;
        if (!FORCE_REPLICATION && isMapWsloss) {
            WeightedSquaredLoss wsloss = new WeightedSquaredLoss(X.constructLops(), U.constructLops(), V.constructLops(), W.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, wtype, Types.ExecType.SPARK);
            this.setOutputDimensions(wsloss);
            this.setLineNumbers(wsloss);
            this.setLops(wsloss);
        } else {
            boolean cacheU = !FORCE_REPLICATION && m1Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal;
            boolean cacheV = !FORCE_REPLICATION && (!cacheU && m2Size < memBudgetExec || cacheU && m1Size + m2Size < memBudgetExec) && 2.0 * m2Size < memBudgetLocal;
            WeightedSquaredLossR wsloss = new WeightedSquaredLossR(X.constructLops(), U.constructLops(), V.constructLops(), W.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, wtype, cacheU, cacheV, Types.ExecType.SPARK);
            this.setOutputDimensions(wsloss);
            this.setLineNumbers(wsloss);
            this.setLops(wsloss);
        }
    }

    private void constructCPLopsWeightedSigmoid(WeightedSigmoid.WSigmoidType wtype) {
        WeightedSigmoid wsig = new WeightedSigmoid(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getDataType(), this.getValueType(), wtype, Types.ExecType.CP);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        wsig.setNumThreads(k);
        this.setOutputDimensions(wsig);
        this.setLineNumbers(wsig);
        this.setLops(wsig);
    }

    private void constructSparkLopsWeightedSigmoid(WeightedSigmoid.WSigmoidType wtype) {
        double m2Size;
        boolean isMapWsig;
        double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
        double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
        Hop X = this.getInput().get(0);
        Hop U = this.getInput().get(1);
        Hop V = this.getInput().get(2);
        double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2());
        boolean bl = isMapWsig = m1Size + (m2Size = (double)OptimizerUtils.estimateSize(V.getDim1(), V.getDim2())) < memBudgetExec && 2.0 * m1Size < memBudgetLocal && 2.0 * m2Size < memBudgetLocal;
        if (!FORCE_REPLICATION && isMapWsig) {
            WeightedSigmoid wsigmoid = new WeightedSigmoid(X.constructLops(), U.constructLops(), V.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wtype, Types.ExecType.SPARK);
            this.setOutputDimensions(wsigmoid);
            this.setLineNumbers(wsigmoid);
            this.setLops(wsigmoid);
        } else {
            boolean cacheU = !FORCE_REPLICATION && m1Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal;
            boolean cacheV = !FORCE_REPLICATION && (!cacheU && m2Size < memBudgetExec || cacheU && m1Size + m2Size < memBudgetExec) && 2.0 * m2Size < memBudgetLocal;
            WeightedSigmoidR wsigmoid = new WeightedSigmoidR(X.constructLops(), U.constructLops(), V.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wtype, cacheU, cacheV, Types.ExecType.SPARK);
            this.setOutputDimensions(wsigmoid);
            this.setLineNumbers(wsigmoid);
            this.setLops(wsigmoid);
        }
    }

    private void constructCPLopsWeightedDivMM(WeightedDivMM.WDivMMType wtype) {
        WeightedDivMM wdiv = new WeightedDivMM(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getInput().get(3).constructLops(), this.getDataType(), this.getValueType(), wtype, Types.ExecType.CP);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        wdiv.setNumThreads(k);
        this.setOutputDimensions(wdiv);
        this.setLineNumbers(wdiv);
        this.setLops(wdiv);
    }

    private void constructSparkLopsWeightedDivMM(WeightedDivMM.WDivMMType wtype) {
        boolean isMapWdivmm;
        double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
        double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
        Hop W = this.getInput().get(0);
        Hop U = this.getInput().get(1);
        Hop V = this.getInput().get(2);
        Hop X = this.getInput().get(3);
        double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2());
        double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2());
        boolean bl = isMapWdivmm = (!wtype.hasFourInputs() || wtype.hasScalar()) && m1Size + m2Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal && 2.0 * m2Size < memBudgetLocal;
        if (!FORCE_REPLICATION && isMapWdivmm) {
            WeightedDivMM wdivmm = new WeightedDivMM(W.constructLops(), U.constructLops(), V.constructLops(), X.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wtype, Types.ExecType.SPARK);
            this.setOutputDimensions(wdivmm);
            this.setLineNumbers(wdivmm);
            this.setLops(wdivmm);
        } else {
            boolean cacheU = !FORCE_REPLICATION && m1Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal;
            boolean cacheV = !FORCE_REPLICATION && (!cacheU && m2Size < memBudgetExec || cacheU && m1Size + m2Size < memBudgetExec) && 2.0 * m2Size < memBudgetLocal;
            WeightedDivMMR wdivmm = new WeightedDivMMR(W.constructLops(), U.constructLops(), V.constructLops(), X.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wtype, cacheU, cacheV, Types.ExecType.SPARK);
            this.setOutputDimensions(wdivmm);
            this.setLineNumbers(wdivmm);
            this.setLops(wdivmm);
        }
    }

    private void constructCPLopsWeightedCeMM(WeightedCrossEntropy.WCeMMType wtype) {
        WeightedCrossEntropy wcemm = new WeightedCrossEntropy(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getInput().get(3).constructLops(), this.getDataType(), this.getValueType(), wtype, Types.ExecType.CP);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        wcemm.setNumThreads(k);
        this.setOutputDimensions(wcemm);
        this.setLineNumbers(wcemm);
        this.setLops(wcemm);
    }

    private void constructSparkLopsWeightedCeMM(WeightedCrossEntropy.WCeMMType wtype) {
        double m2Size;
        boolean isMapWcemm;
        double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
        double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
        Hop X = this.getInput().get(0);
        Hop U = this.getInput().get(1);
        Hop V = this.getInput().get(2);
        Hop eps = this.getInput().get(3);
        double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2());
        boolean bl = isMapWcemm = m1Size + (m2Size = (double)OptimizerUtils.estimateSize(V.getDim1(), V.getDim2())) < memBudgetExec && 2.0 * m1Size < memBudgetLocal && 2.0 * m2Size < memBudgetLocal;
        if (!FORCE_REPLICATION && isMapWcemm) {
            WeightedCrossEntropy wcemm = new WeightedCrossEntropy(X.constructLops(), U.constructLops(), V.constructLops(), eps.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, wtype, Types.ExecType.SPARK);
            this.setOutputDimensions(wcemm);
            this.setLineNumbers(wcemm);
            this.setLops(wcemm);
        } else {
            boolean cacheU = !FORCE_REPLICATION && m1Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal;
            boolean cacheV = !FORCE_REPLICATION && (!cacheU && m2Size < memBudgetExec || cacheU && m1Size + m2Size < memBudgetExec) && 2.0 * m2Size < memBudgetLocal;
            WeightedCrossEntropyR wcemm = new WeightedCrossEntropyR(X.constructLops(), U.constructLops(), V.constructLops(), eps.constructLops(), Types.DataType.SCALAR, Types.ValueType.FP64, wtype, cacheU, cacheV, Types.ExecType.SPARK);
            this.setOutputDimensions(wcemm);
            this.setLineNumbers(wcemm);
            this.setLops(wcemm);
        }
    }

    private void constructCPLopsWeightedUMM(WeightedUnaryMM.WUMMType wtype) {
        Types.OpOp1 uop = this._uop != null ? this._uop : (this._sop == Types.OpOp2.POW ? Types.OpOp1.POW2 : Types.OpOp1.MULT2);
        WeightedUnaryMM wumm = new WeightedUnaryMM(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getInput().get(2).constructLops(), this.getDataType(), this.getValueType(), wtype, uop, Types.ExecType.CP);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        wumm.setNumThreads(k);
        this.setOutputDimensions(wumm);
        this.setLineNumbers(wumm);
        this.setLops(wumm);
    }

    private void constructSparkLopsWeightedUMM(WeightedUnaryMM.WUMMType wtype) {
        boolean isMapWsloss;
        Types.OpOp1 uop = this._uop != null ? this._uop : (this._sop == Types.OpOp2.POW ? Types.OpOp1.POW2 : Types.OpOp1.MULT2);
        double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
        double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
        Hop X = this.getInput().get(0);
        Hop U = this.getInput().get(1);
        Hop V = this.getInput().get(2);
        double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2());
        double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2());
        boolean bl = isMapWsloss = m1Size + m2Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal && 2.0 * m2Size < memBudgetLocal;
        if (!FORCE_REPLICATION && isMapWsloss) {
            WeightedUnaryMM wumm = new WeightedUnaryMM(X.constructLops(), U.constructLops(), V.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wtype, uop, Types.ExecType.SPARK);
            this.setOutputDimensions(wumm);
            this.setLineNumbers(wumm);
            this.setLops(wumm);
        } else {
            boolean cacheU = !FORCE_REPLICATION && m1Size < memBudgetExec && 2.0 * m1Size < memBudgetLocal;
            boolean cacheV = !FORCE_REPLICATION && (!cacheU && m2Size < memBudgetExec || cacheU && m1Size + m2Size < memBudgetExec) && 2.0 * m2Size < memBudgetLocal;
            WeightedUnaryMMR wumm = new WeightedUnaryMMR(X.constructLops(), U.constructLops(), V.constructLops(), Types.DataType.MATRIX, Types.ValueType.FP64, wtype, uop, cacheU, cacheV, Types.ExecType.SPARK);
            this.setOutputDimensions(wumm);
            this.setLineNumbers(wumm);
            this.setLops(wumm);
        }
    }

    private WeightedSquaredLoss.WeightsType checkWeightsType() {
        WeightedSquaredLoss.WeightsType ret = WeightedSquaredLoss.WeightsType.NONE;
        if (!(this.getInput().get(3) instanceof LiteralOp)) {
            ret = this._postWeights ? WeightedSquaredLoss.WeightsType.POST : WeightedSquaredLoss.WeightsType.PRE;
        } else if (this._postWeights) {
            ret = WeightedSquaredLoss.WeightsType.POST_NZ;
        }
        return ret;
    }

    private WeightedSigmoid.WSigmoidType checkWSigmoidType() {
        if (this._logout && this._minusin) {
            return WeightedSigmoid.WSigmoidType.LOG_MINUS;
        }
        if (this._logout) {
            return WeightedSigmoid.WSigmoidType.LOG;
        }
        if (this._minusin) {
            return WeightedSigmoid.WSigmoidType.MINUS;
        }
        return WeightedSigmoid.WSigmoidType.BASIC;
    }

    private WeightedDivMM.WDivMMType checkWDivMMType() {
        switch (this._baseType) {
            case 0: {
                return WeightedDivMM.WDivMMType.MULT_BASIC;
            }
            case 1: {
                if (this.getInput().get(3).getDataType() == Types.DataType.MATRIX) {
                    return WeightedDivMM.WDivMMType.MULT_MINUS_4_LEFT;
                }
                if (this._minus) {
                    return WeightedDivMM.WDivMMType.MULT_MINUS_LEFT;
                }
                return this._mult ? WeightedDivMM.WDivMMType.MULT_LEFT : WeightedDivMM.WDivMMType.DIV_LEFT;
            }
            case 2: {
                if (this.getInput().get(3).getDataType() == Types.DataType.MATRIX) {
                    return WeightedDivMM.WDivMMType.MULT_MINUS_4_RIGHT;
                }
                if (this._minus) {
                    return WeightedDivMM.WDivMMType.MULT_MINUS_RIGHT;
                }
                return this._mult ? WeightedDivMM.WDivMMType.MULT_RIGHT : WeightedDivMM.WDivMMType.DIV_RIGHT;
            }
            case 3: {
                return WeightedDivMM.WDivMMType.DIV_LEFT_EPS;
            }
            case 4: {
                return WeightedDivMM.WDivMMType.DIV_RIGHT_EPS;
            }
        }
        return null;
    }

    private WeightedCrossEntropy.WCeMMType checkWCeMMType() {
        return this._baseType == 1 ? WeightedCrossEntropy.WCeMMType.BASIC_EPS : WeightedCrossEntropy.WCeMMType.BASIC;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        switch (this._op) {
            case WSLOSS: 
            case WCEMM: {
                return 8.0;
            }
            case WSIGMOID: 
            case WDIVMM: 
            case WUMM: {
                double sp = OptimizerUtils.getSparsity(dim1, dim2, nnz);
                return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sp);
            }
        }
        return 0.0;
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        return 0.0;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        DataCharacteristics ret = null;
        switch (this._op) {
            case WSLOSS: {
                break;
            }
            case WSIGMOID: 
            case WUMM: {
                DataCharacteristics mcW = memo.getAllInputStats(this.getInput().get(0));
                ret = new MatrixCharacteristics(mcW.getRows(), mcW.getCols(), -1, mcW.getNonZeros());
                break;
            }
            case WDIVMM: {
                if (this._baseType == 0) {
                    ret = memo.getAllInputStats(this.getInput().get(0));
                    break;
                }
                if (this._baseType == 1 || this._baseType == 3) {
                    DataCharacteristics mcV = memo.getAllInputStats(this.getInput().get(2));
                    ret = mcV.setNonZeros(-1L);
                    break;
                }
                DataCharacteristics mcU = memo.getAllInputStats(this.getInput().get(1));
                ret = mcU.setNonZeros(-1L);
                break;
            }
            default: {
                throw new RuntimeException("Memory for operation (" + (Object)((Object)this._op) + ") can not be estimated.");
            }
        }
        return ret;
    }

    @Override
    protected Types.ExecType optFindExecType() {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() && this.getInput().get(1).areDimsBelowThreshold() && this.getInput().get(2).areDimsBelowThreshold() && this.getInput().get(3).areDimsBelowThreshold() ? Types.ExecType.CP : Types.ExecType.SPARK);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    @Override
    public void refreshSizeInformation() {
        switch (this._op) {
            case WSLOSS: {
                break;
            }
            case WSIGMOID: 
            case WUMM: {
                Hop inW = this.getInput().get(0);
                this.setDim1(inW.getDim1());
                this.setDim2(inW.getDim2());
                this.setNnz(inW.getNnz());
                break;
            }
            case WDIVMM: {
                if (this._baseType == 0) {
                    Hop inW = this.getInput().get(0);
                    this.setDim1(inW.getDim1());
                    this.setDim2(inW.getDim2());
                    this.setNnz(inW.getNnz());
                    break;
                }
                if (this._baseType == 1 || this._baseType == 3) {
                    Hop inV = this.getInput().get(2);
                    this.setDim1(inV.getDim1());
                    this.setDim2(inV.getDim2());
                    this.setNnz(-1L);
                    break;
                }
                Hop inU = this.getInput().get(1);
                this.setDim1(inU.getDim1());
                this.setDim2(inU.getDim2());
                this.setNnz(-1L);
                break;
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        QuaternaryOp ret = new QuaternaryOp();
        ret.clone(this, false);
        ret._op = this._op;
        ret._postWeights = this._postWeights;
        ret._logout = this._logout;
        ret._minusin = this._minusin;
        ret._baseType = this._baseType;
        ret._mult = this._mult;
        ret._minus = this._minus;
        ret._umult = this._umult;
        ret._uop = this._uop;
        ret._sop = this._sop;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof QuaternaryOp)) {
            return false;
        }
        QuaternaryOp that2 = (QuaternaryOp)that;
        boolean bl = ret = this._op == that2._op && this.getInput().size() == that2.getInput().size() && this.getInput().get(0) == that2.getInput().get(0) && this.getInput().get(1) == that2.getInput().get(1) && this.getInput().get(2) == that2.getInput().get(2);
        if (ret && this.getInput().size() == 4) {
            ret &= this.getInput().get(3) == that2.getInput().get(3);
        }
        ret &= this._postWeights == that2._postWeights;
        ret &= this._logout == that2._logout;
        ret &= this._minusin == that2._minusin;
        ret &= this._baseType == that2._baseType;
        ret &= this._mult == that2._mult;
        ret &= this._minus == that2._minus;
        ret &= this._umult == that2._umult;
        ret &= this._uop == that2._uop;
        ret &= this._sop == that2._sop;
        return ret &= this._maxNumThreads == that2._maxNumThreads;
    }
}

