/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.calcite.plan.PlanTooComplexError;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.ConflictDetectionHelper;
import org.apache.calcite.rel.rules.HyperEdge;
import org.apache.calcite.rel.rules.HyperGraph;
import org.apache.calcite.rel.rules.LongBitmap;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.trace.CalciteTrace;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;

public class DpHyp {
    private static final Logger LOGGER = CalciteTrace.getDpHypJoinReorderTracer();
    protected final HyperGraph hyperGraph;
    private final Map<Long, RelNode> dpTable;
    private final Map<Long, ImmutableList<HyperGraph.NodeState>> resultInputOrder;
    protected final RelBuilder builder;
    private final RelMetadataQuery mq;
    private final int bloat;

    public DpHyp(HyperGraph hyperGraph, RelBuilder builder, RelMetadataQuery relMetadataQuery, int bloat) {
        this.hyperGraph = hyperGraph.copy(hyperGraph.getTraitSet(), (List)hyperGraph.getInputs());
        this.dpTable = new HashMap<Long, RelNode>();
        this.resultInputOrder = new HashMap<Long, ImmutableList<HyperGraph.NodeState>>();
        this.builder = builder;
        this.mq = relMetadataQuery;
        this.bloat = bloat;
    }

    public void startEnumerateJoin() {
        int i;
        int size = this.hyperGraph.getInputs().size();
        for (i = 0; i < size; ++i) {
            long singleNode = LongBitmap.newBitmap(i);
            LOGGER.debug("Initialize the dp table. Node {{}} is:\n {}", (Object)i, (Object)RelOptUtil.toString(this.hyperGraph.getInput(i)));
            this.dpTable.put(singleNode, this.hyperGraph.getInput(i));
            this.resultInputOrder.put(singleNode, (ImmutableList<HyperGraph.NodeState>)ImmutableList.of((Object)new HyperGraph.NodeState(i, true)));
            this.hyperGraph.initEdgeBitMap(singleNode);
        }
        try {
            for (i = size - 2; i >= 0; --i) {
                long csg = LongBitmap.newBitmap(i);
                long forbidden = csg - 1L;
                this.emitCsg(csg);
                this.enumerateCsgRec(csg, forbidden);
            }
        }
        catch (PlanTooComplexError e) {
            LOGGER.error("The dp table is too large, and the enumeration ends automatically.");
        }
    }

    private void emitCsg(long csg) {
        long forbidden = csg | LongBitmap.getBvBitmap(csg);
        long neighbors = this.hyperGraph.getNeighborBitmap(csg, forbidden);
        LongBitmap.ReverseIterator reverseIterator = new LongBitmap.ReverseIterator(neighbors);
        for (long cmp : reverseIterator) {
            List<HyperEdge> edges = this.hyperGraph.connectCsgCmp(csg, cmp);
            if (!edges.isEmpty()) {
                this.emitCsgCmp(csg, cmp, edges);
            }
            long newForbidden = (cmp | LongBitmap.getBvBitmap(cmp)) & neighbors;
            this.enumerateCmpRec(csg, cmp, newForbidden |= forbidden);
        }
    }

    private void enumerateCsgRec(long csg, long forbidden) {
        long neighbors = this.hyperGraph.getNeighborBitmap(csg, forbidden);
        LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors);
        for (long subNeighbor : subsetIterator) {
            this.hyperGraph.updateEdgesForUnion(csg, subNeighbor);
            long newCsg = csg | subNeighbor;
            if (!this.dpTable.containsKey(newCsg)) continue;
            this.emitCsg(newCsg);
        }
        long newForbidden = forbidden | neighbors;
        subsetIterator.reset();
        for (long subNeighbor : subsetIterator) {
            long newCsg = csg | subNeighbor;
            this.enumerateCsgRec(newCsg, newForbidden);
        }
    }

    private void enumerateCmpRec(long csg, long cmp, long forbidden) {
        long neighbors = this.hyperGraph.getNeighborBitmap(cmp, forbidden);
        LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors);
        for (long subNeighbor : subsetIterator) {
            List<HyperEdge> edges;
            long newCmp = cmp | subNeighbor;
            this.hyperGraph.updateEdgesForUnion(cmp, subNeighbor);
            if (!this.dpTable.containsKey(newCmp) || (edges = this.hyperGraph.connectCsgCmp(csg, newCmp)).isEmpty()) continue;
            this.emitCsgCmp(csg, newCmp, edges);
        }
        long newForbidden = forbidden | neighbors;
        subsetIterator.reset();
        for (long subNeighbor : subsetIterator) {
            long newCmp = cmp | subNeighbor;
            this.enumerateCmpRec(csg, newCmp, newForbidden);
        }
    }

    private void emitCsgCmp(long csg, long cmp, List<HyperEdge> edges) {
        RelNode newPlan1;
        RelNode child1 = this.dpTable.get(csg);
        RelNode child2 = this.dpTable.get(cmp);
        ImmutableList<HyperGraph.NodeState> csgOrder = this.resultInputOrder.get(csg);
        ImmutableList<HyperGraph.NodeState> cmpOrder = this.resultInputOrder.get(cmp);
        assert (child1 != null && child2 != null && csgOrder != null && cmpOrder != null);
        assert (Long.bitCount(csg) == csgOrder.size() && Long.bitCount(cmp) == cmpOrder.size());
        JoinRelType joinType = this.hyperGraph.extractJoinType(edges);
        if (joinType == null) {
            return;
        }
        if (!this.hyperGraph.applicable(csg | cmp, edges)) {
            return;
        }
        ArrayList<HyperGraph.NodeState> unionOrder = new ArrayList<HyperGraph.NodeState>((Collection<HyperGraph.NodeState>)csgOrder);
        unionOrder.addAll((Collection<HyperGraph.NodeState>)cmpOrder);
        RexNode joinCond1 = this.hyperGraph.extractJoinCond(unionOrder, csgOrder.size(), edges, joinType);
        RelNode winPlan = newPlan1 = this.builder.push(child1).push(child2).join(joinType, joinCond1).build();
        ImmutableList<HyperGraph.NodeState> winOrder = ImmutableList.copyOf(unionOrder);
        assert (this.verifyDpResultRowType(newPlan1, unionOrder));
        if (ConflictDetectionHelper.isCommutative(joinType)) {
            unionOrder = new ArrayList<HyperGraph.NodeState>((Collection<HyperGraph.NodeState>)cmpOrder);
            unionOrder.addAll((Collection<HyperGraph.NodeState>)csgOrder);
            RexNode joinCond2 = this.hyperGraph.extractJoinCond(unionOrder, cmpOrder.size(), edges, joinType);
            RelNode newPlan2 = this.builder.push(child2).push(child1).join(joinType, joinCond2).build();
            winPlan = this.chooseBetterPlan(winPlan, newPlan2);
            assert (this.verifyDpResultRowType(newPlan2, unionOrder));
            if (winPlan.equals(newPlan2)) {
                winOrder = ImmutableList.copyOf(unionOrder);
            }
        }
        LOGGER.debug("Found set {} and {}, connected by condition {}. [cost={}, rows={}]", new Object[]{LongBitmap.printBitmap(csg), LongBitmap.printBitmap(cmp), RexUtil.composeConjunction(this.builder.getRexBuilder(), edges.stream().map(edge -> edge.getCondition()).collect(Collectors.toList())), this.mq.getCumulativeCost(winPlan), this.mq.getRowCount(winPlan)});
        RelNode oriPlan = this.dpTable.get(csg | cmp);
        boolean dpTableUpdated = true;
        if (oriPlan != null) {
            if ((winPlan = this.chooseBetterPlan(winPlan, oriPlan)).equals(oriPlan)) {
                winOrder = this.resultInputOrder.get(csg | cmp);
                dpTableUpdated = false;
            }
        } else if (this.dpTable.size() > this.bloat) {
            throw new PlanTooComplexError();
        }
        assert (winOrder != null);
        if (dpTableUpdated) {
            LOGGER.debug("Dp table is updated. The better plan for subgraph {} now is:\n {}", (Object)LongBitmap.printBitmap(csg | cmp), (Object)RelOptUtil.toString(winPlan));
        }
        this.dpTable.put(csg | cmp, winPlan);
        this.resultInputOrder.put(csg | cmp, winOrder);
    }

    public @Nullable RelNode getBestPlan() {
        int size = this.hyperGraph.getInputs().size();
        long wholeGraph = LongBitmap.newBitmapBetween(0, size);
        RelNode orderedJoin = this.dpTable.get(wholeGraph);
        if (orderedJoin == null) {
            LOGGER.error("The optimal plan was not generated because the enumeration ended prematurely");
            return null;
        }
        LOGGER.debug("Enumeration completed. The best plan is:\n {}", (Object)RelOptUtil.toString(orderedJoin));
        ImmutableList<HyperGraph.NodeState> resultOrder = this.resultInputOrder.get(wholeGraph);
        assert (resultOrder != null && resultOrder.size() == size);
        List<RexNode> projects = this.hyperGraph.restoreProjectionOrder((List<HyperGraph.NodeState>)resultOrder, orderedJoin.getRowType().getFieldList());
        return this.builder.push(orderedJoin).project(projects).build();
    }

    private RelNode chooseBetterPlan(RelNode plan1, RelNode plan2) {
        RelOptCost cost1 = this.mq.getCumulativeCost(plan1);
        RelOptCost cost2 = this.mq.getCumulativeCost(plan2);
        if (cost1 != null && cost2 != null) {
            return cost1.isLt(cost2) ? plan1 : plan2;
        }
        if (cost1 != null) {
            return plan1;
        }
        return plan2;
    }

    protected boolean verifyDpResultRowType(RelNode plan, List<HyperGraph.NodeState> resultOrder) {
        if (resultOrder.size() != this.hyperGraph.getInputs().size()) {
            return true;
        }
        List<RexNode> projects = this.hyperGraph.restoreProjectionOrder(resultOrder, plan.getRowType().getFieldList());
        RelNode resultNode = this.builder.push(plan).project(projects).build();
        return RelOptUtil.areRowTypesEqual(resultNode.getRowType(), this.hyperGraph.getRowType(), false);
    }
}

