/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;

public class RewriteCompressedReblock
extends StatementBlockRewriteRule {
    private static final Log LOG = LogFactory.getLog((String)RewriteCompressedReblock.class.getName());
    private static final String TMP_PREFIX = "__cmtx";

    @Override
    public boolean createsSplitDag() {
        return false;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate) {
        if (!HopRewriteUtils.isLastLevelStatementBlock(sb) || sb.getHops() == null) {
            return Arrays.asList(sb);
        }
        Compression.CompressConfig compress = ConfigurationManager.getCompressConfig();
        if (compress.isEnabled()) {
            Hop.resetVisitStatus(sb.getHops());
            for (Hop h : sb.getHops()) {
                RewriteCompressedReblock.injectCompressionDirective(h, compress, sb.getDMLProg());
            }
            Hop.resetVisitStatus(sb.getHops());
        }
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
        return sbs;
    }

    private static void injectCompressionDirective(Hop hop, Compression.CompressConfig compress, DMLProgram prog) {
        if (hop.isVisited() || hop.requiresCompression()) {
            return;
        }
        for (Hop hi : hop.getInput()) {
            RewriteCompressedReblock.injectCompressionDirective(hi, compress, prog);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.DataType.SCALAR) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD, Types.OpOpData.TRANSIENTWRITE) || hop instanceof LiteralOp) {
            return;
        }
        switch (compress) {
            case TRUE: {
                if (!RewriteCompressedReblock.satisfiesCompressionCondition(hop)) break;
                hop.setRequiresCompression();
                break;
            }
            case AUTO: {
                if (!OptimizerUtils.isSparkExecutionMode() || !RewriteCompressedReblock.satisfiesAutoCompressionCondition(hop, prog)) break;
                hop.setRequiresCompression();
                break;
            }
            case COST: {
                if (!RewriteCompressedReblock.satisfiesCostCompressionCondition(hop, prog)) break;
                hop.setRequiresCompression();
                break;
            }
        }
        if (RewriteCompressedReblock.satisfiesDeCompressionCondition(hop)) {
            hop.setRequiresDeCompression();
        }
        hop.setVisited();
    }

    public static boolean satisfiesSizeConstraintsForCompression(Hop hop) {
        if (hop.getDim2() >= 1L) {
            long x = hop.getDim1();
            long y = hop.getDim2();
            boolean ret = y << 10 <= x * x || hop.getSparsity() < 1.0E-4 && y > 100L;
            return ret;
        }
        if (hop.getDim1() >= 1L) {
            boolean ret = hop.getDim1() > 10000L;
            return ret;
        }
        return true;
    }

    public static boolean satisfiesCompressionCondition(Hop hop) {
        boolean satisfies = false;
        if (RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)) {
            satisfies |= HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD) && !hop.isScalar();
            satisfies |= HopRewriteUtils.isTransformEncode(hop);
        }
        return satisfies;
    }

    public static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
        boolean satisfies = false;
        if (RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop)) {
            satisfies |= HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE) && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix();
            satisfies |= HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD) && !hop.isScalar();
            satisfies |= HopRewriteUtils.isUnary(hop, Types.OpOp1.ROUND, Types.OpOp1.FLOOR, Types.OpOp1.NOT, Types.OpOp1.CEIL);
            satisfies |= HopRewriteUtils.isBinary(hop, Types.OpOp2.EQUAL, Types.OpOp2.NOTEQUAL, Types.OpOp2.LESS, Types.OpOp2.LESSEQUAL, Types.OpOp2.GREATER, Types.OpOp2.GREATEREQUAL, Types.OpOp2.AND, Types.OpOp2.OR, Types.OpOp2.MODULUS);
            satisfies |= HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE);
        }
        if (LOG.isDebugEnabled() && satisfies) {
            LOG.debug((Object)("Operation Satisfies: " + hop));
        }
        return satisfies;
    }

    private static boolean satisfiesDeCompressionCondition(Hop hop) {
        return false;
    }

    private static boolean outOfCore(Hop hop) {
        double cacheSize;
        double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(hop);
        return matrixPSize > (cacheSize = SparkExecutionContext.getDataMemoryBudget(true, true));
    }

    private static boolean ultraSparse(Hop hop) {
        double sparsity = OptimizerUtils.getSparsity(hop);
        return sparsity < 4.0E-5;
    }

    private static boolean satisfiesAutoCompressionCondition(Hop hop, DMLProgram prog) {
        if (!RewriteCompressedReblock.satisfiesCompressionCondition(hop) || !(hop.getMemEstimate() >= OptimizerUtils.getLocalMemBudget())) {
            return false;
        }
        if (hop.dimsKnown(true) && RewriteCompressedReblock.outOfCore(hop) && !RewriteCompressedReblock.ultraSparse(hop)) {
            return RewriteCompressedReblock.analyseProgram(hop, prog).isValidAutoCompression();
        }
        return false;
    }

    private static boolean satisfiesCostCompressionCondition(Hop hop, DMLProgram prog) {
        boolean satisfies = true;
        satisfies &= RewriteCompressedReblock.satisfiesAggressiveCompressionCondition(hop);
        satisfies &= hop.dimsKnown(false);
        return satisfies &= RewriteCompressedReblock.analyseProgram(hop, prog).isValidAggressiveCompression();
    }

    private static ProbeStatus analyseProgram(Hop hop, DMLProgram prog) {
        ProbeStatus status = new ProbeStatus(hop.getHopID(), prog);
        for (StatementBlock sb : prog.getStatementBlocks()) {
            status.rAnalyzeProgram(sb);
        }
        return status;
    }

    private static class ProbeStatus {
        private final long startHopID;
        private final DMLProgram prog;
        private int numberCompressedOpsExecuted = 0;
        private int numberDecompressedOpsExecuted = 0;
        private int inefficientSupportedOpsExecuted = 0;
        private boolean foundStart = false;
        private boolean usedInLoop = false;
        private boolean condUpdate = false;
        private boolean nonApplicable = false;
        private HashSet<String> procFn = new HashSet();
        private HashSet<String> compMtx = new HashSet();

        private ProbeStatus(long hopID, DMLProgram p) {
            this.startHopID = hopID;
            this.prog = p;
        }

        private ProbeStatus(ProbeStatus status) {
            this.startHopID = status.startHopID;
            this.prog = status.prog;
            this.foundStart = status.foundStart;
            this.usedInLoop = status.usedInLoop;
            this.condUpdate = status.condUpdate;
            this.nonApplicable = status.nonApplicable;
            this.procFn.addAll(status.procFn);
        }

        private void rAnalyzeProgram(StatementBlock sb) {
            if (sb instanceof FunctionStatementBlock) {
                FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
                FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
                for (StatementBlock csb : fstmt.getBody()) {
                    this.rAnalyzeProgram(csb);
                }
            } else if (sb instanceof WhileStatementBlock) {
                WhileStatementBlock wsb = (WhileStatementBlock)sb;
                WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
                for (StatementBlock csb : wstmt.getBody()) {
                    this.rAnalyzeProgram(csb);
                }
                if (wsb.variablesRead().containsAnyName(this.compMtx)) {
                    this.usedInLoop = true;
                }
            } else if (sb instanceof IfStatementBlock) {
                IfStatementBlock isb = (IfStatementBlock)sb;
                IfStatement istmt = (IfStatement)isb.getStatement(0);
                for (StatementBlock csb : istmt.getIfBody()) {
                    this.rAnalyzeProgram(csb);
                }
                for (StatementBlock csb : istmt.getElseBody()) {
                    this.rAnalyzeProgram(csb);
                }
                if (isb.variablesUpdated().containsAnyName(this.compMtx)) {
                    this.condUpdate = true;
                }
            } else if (sb instanceof ForStatementBlock) {
                ForStatementBlock fsb = (ForStatementBlock)sb;
                ForStatement fstmt = (ForStatement)fsb.getStatement(0);
                for (StatementBlock csb : fstmt.getBody()) {
                    this.rAnalyzeProgram(csb);
                }
                if (fsb.variablesRead().containsAnyName(this.compMtx)) {
                    this.usedInLoop = true;
                }
            } else if (sb.getHops() != null) {
                ArrayList<Hop> roots = sb.getHops();
                Hop.resetVisitStatus(roots);
                for (Hop root : roots) {
                    this.rAnalyzeHopDag(root);
                }
                this.compMtx.removeIf(n -> n.startsWith(RewriteCompressedReblock.TMP_PREFIX));
                Hop.resetVisitStatus(roots);
            }
        }

        private void rAnalyzeHopDag(Hop current) {
            if (current.isVisited()) {
                return;
            }
            for (Hop input : current.getInput()) {
                this.rAnalyzeHopDag(input);
            }
            if (current.getHopID() == this.startHopID) {
                this.compMtx.add(ProbeStatus.getTmpName(current));
                this.foundStart = true;
            }
            if (HopRewriteUtils.isData(current, Types.OpOpData.TRANSIENTWRITE) && this.compMtx.contains(ProbeStatus.getTmpName(current.getInput().get(0)))) {
                this.compMtx.add(current.getName());
            } else if (HopRewriteUtils.isData(current, Types.OpOpData.TRANSIENTREAD) && this.compMtx.contains(current.getName())) {
                this.compMtx.add(ProbeStatus.getTmpName(current));
            } else if (this.hasCompressedInput(current)) {
                if (current instanceof FunctionOp) {
                    this.handleFunctionOps(current);
                } else {
                    this.handleApplicableOps(current);
                }
            }
            current.setVisited();
        }

        private boolean hasCompressedInput(Hop hop) {
            if (this.compMtx.isEmpty()) {
                return false;
            }
            for (Hop input : hop.getInput()) {
                if (!this.compMtx.contains(ProbeStatus.getTmpName(input))) continue;
                return true;
            }
            return false;
        }

        private static String getTmpName(Hop hop) {
            return RewriteCompressedReblock.TMP_PREFIX + hop.getHopID();
        }

        private boolean isCompressed(Hop hop) {
            return this.compMtx.contains(ProbeStatus.getTmpName(hop));
        }

        private void handleFunctionOps(Hop current) {
            FunctionOp fop = (FunctionOp)current;
            String fkey = fop.getFunctionKey();
            if (!this.procFn.contains(fkey)) {
                this.procFn.add(fkey);
                FunctionStatementBlock fsb = this.prog.getFunctionStatementBlock(fkey);
                FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
                ProbeStatus status2 = new ProbeStatus(this);
                for (int i = 0; i < fop.getInput().size(); ++i) {
                    if (!this.compMtx.contains(ProbeStatus.getTmpName(fop.getInput().get(i)))) continue;
                    status2.compMtx.add(fstmt.getInputParams().get(i).getName());
                }
                status2.rAnalyzeProgram(fsb);
                this.foundStart |= status2.foundStart;
                this.usedInLoop |= status2.usedInLoop;
                this.condUpdate |= status2.condUpdate;
                this.nonApplicable |= status2.nonApplicable;
                this.numberCompressedOpsExecuted += status2.numberCompressedOpsExecuted;
                this.numberDecompressedOpsExecuted += status2.numberDecompressedOpsExecuted;
                String[] outputs = fop.getOutputVariableNames();
                for (int i = 0; i < outputs.length; ++i) {
                    if (!status2.compMtx.contains(fstmt.getOutputParams().get(i).getName())) continue;
                    this.compMtx.add(outputs[i]);
                }
            }
        }

        private void handleApplicableOps(Hop current) {
            boolean applicable;
            boolean compUCOut = false;
            compUCOut |= current instanceof AggBinaryOp;
            compUCOut |= HopRewriteUtils.isBinaryMatrixColVectorOperation(current);
            boolean isAggregate = HopRewriteUtils.isAggUnaryOp(current, Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX, Types.AggOp.MEAN);
            if (isAggregate && current.getDim2() < 2L && current.getDim1() >= 1000L) {
                ++this.inefficientSupportedOpsExecuted;
            }
            boolean compCOut = false;
            compCOut |= HopRewriteUtils.isBinaryMatrixScalarOperation(current);
            compCOut |= HopRewriteUtils.isBinaryMatrixRowVectorOperation(current);
            compUCOut = (compCOut |= current instanceof AggBinaryOp && this.isCompressed(current.getInput().get(0))) ? false : (compUCOut |= isAggregate);
            compCOut |= HopRewriteUtils.isBinary(current, Types.OpOp2.CBIND);
            boolean metaOp = HopRewriteUtils.isUnary(current, Types.OpOp1.NROW, Types.OpOp1.NCOL);
            boolean ctableOp = HopRewriteUtils.isTernary(current, Types.OpOp3.CTABLE);
            if (ctableOp) {
                this.numberCompressedOpsExecuted += 4;
                compCOut = true;
            }
            boolean bl = applicable = compUCOut || compCOut || metaOp;
            if (applicable) {
                ++this.numberCompressedOpsExecuted;
            } else {
                LOG.warn((Object)("Decompession op: " + current));
                ++this.numberDecompressedOpsExecuted;
            }
            this.nonApplicable |= !applicable;
            if (compCOut) {
                this.compMtx.add(ProbeStatus.getTmpName(current));
            }
        }

        private boolean isValidAutoCompression() {
            return this.foundStart && this.usedInLoop && !this.condUpdate && !this.nonApplicable;
        }

        private boolean isValidAggressiveCompression() {
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)this.toString());
            }
            return this.inefficientSupportedOpsExecuted < this.numberCompressedOpsExecuted && (this.usedInLoop || this.numberCompressedOpsExecuted > 3) && this.numberDecompressedOpsExecuted < 1;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Compressed ProbeStatus : hopID =" + this.startHopID);
            sb.append("\n CLA Ops         : " + this.numberCompressedOpsExecuted);
            sb.append("\n Decompress Ops  : " + this.numberDecompressedOpsExecuted);
            sb.append("\n Inefficient Ops : " + this.inefficientSupportedOpsExecuted);
            sb.append("\n foundStart " + this.foundStart + " , inLoop :" + this.usedInLoop + " , condUpdate : " + this.condUpdate + " , nonApplicable : " + this.nonApplicable);
            sb.append("\n compressed Matrix: " + this.compMtx);
            sb.append("\n Prog Fn " + this.procFn);
            return sb.toString();
        }
    }
}

