/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.fedplanner;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class FederatedMemoTable {
    private final Map<Pair<Long, FEDInstruction.FederatedOutput>, FedPlanVariants> hopMemoTable = new HashMap<Pair<Long, FEDInstruction.FederatedOutput>, FedPlanVariants>();

    public FedPlan addFedPlan(Hop hop, FEDInstruction.FederatedOutput fedOutType, List<Pair<Long, FEDInstruction.FederatedOutput>> planChilds) {
        FedPlanVariants fedPlanVariantList;
        long hopID = hop.getHopID();
        if (this.contains(hopID, fedOutType)) {
            fedPlanVariantList = this.hopMemoTable.get(new ImmutablePair((Object)hopID, (Object)fedOutType));
        } else {
            fedPlanVariantList = new FedPlanVariants(hop, fedOutType);
            this.hopMemoTable.put((Pair<Long, FEDInstruction.FederatedOutput>)new ImmutablePair((Object)hopID, (Object)fedOutType), fedPlanVariantList);
        }
        FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList);
        fedPlanVariantList.addFedPlan(newPlan);
        return newPlan;
    }

    public FedPlan getMinCostFedPlan(Pair<Long, FEDInstruction.FederatedOutput> fedPlanPair) {
        FedPlanVariants fedPlanVariantList = this.hopMemoTable.get(fedPlanPair);
        return fedPlanVariantList._fedPlanVariants.stream().min(Comparator.comparingDouble(FedPlan::getTotalCost)).orElse(null);
    }

    public FedPlanVariants getFedPlanVariants(long hopID, FEDInstruction.FederatedOutput fedOutType) {
        return this.hopMemoTable.get(new ImmutablePair((Object)hopID, (Object)fedOutType));
    }

    public FedPlanVariants getFedPlanVariants(Pair<Long, FEDInstruction.FederatedOutput> fedPlanPair) {
        return this.hopMemoTable.get(fedPlanPair);
    }

    public FedPlan getFedPlanAfterPrune(long hopID, FEDInstruction.FederatedOutput fedOutType) {
        FedPlanVariants fedPlanVariantList = this.hopMemoTable.get(new ImmutablePair((Object)hopID, (Object)fedOutType));
        return fedPlanVariantList._fedPlanVariants.get(0);
    }

    public FedPlan getFedPlanAfterPrune(Pair<Long, FEDInstruction.FederatedOutput> fedPlanPair) {
        FedPlanVariants fedPlanVariantList = this.hopMemoTable.get(fedPlanPair);
        return fedPlanVariantList._fedPlanVariants.get(0);
    }

    public boolean contains(long hopID, FEDInstruction.FederatedOutput fedOutType) {
        return this.hopMemoTable.containsKey(new ImmutablePair((Object)hopID, (Object)fedOutType));
    }

    public void pruneFedPlan(long hopID, FEDInstruction.FederatedOutput federatedOutput) {
        this.hopMemoTable.get(new ImmutablePair((Object)hopID, (Object)federatedOutput)).prune();
    }

    public static class FedPlan {
        private double totalCost = 0.0;
        private final FedPlanVariants fedPlanVariants;
        private final List<Pair<Long, FEDInstruction.FederatedOutput>> childFedPlans;

        public FedPlan(List<Pair<Long, FEDInstruction.FederatedOutput>> childFedPlans, FedPlanVariants fedPlanVariants) {
            this.childFedPlans = childFedPlans;
            this.fedPlanVariants = fedPlanVariants;
        }

        public void setTotalCost(double totalCost) {
            this.totalCost = totalCost;
        }

        public void setSelfCost(double selfCost) {
            this.fedPlanVariants.hopCommon.selfCost = selfCost;
        }

        public void setNetTransferCost(double netTransferCost) {
            this.fedPlanVariants.hopCommon.netTransferCost = netTransferCost;
        }

        public Hop getHopRef() {
            return this.fedPlanVariants.hopCommon.hopRef;
        }

        public long getHopID() {
            return this.fedPlanVariants.hopCommon.hopRef.getHopID();
        }

        public FEDInstruction.FederatedOutput getFedOutType() {
            return this.fedPlanVariants.fedOutType;
        }

        public double getTotalCost() {
            return this.totalCost;
        }

        public double getSelfCost() {
            return this.fedPlanVariants.hopCommon.selfCost;
        }

        public double getNetTransferCost() {
            return this.fedPlanVariants.hopCommon.netTransferCost;
        }

        public List<Pair<Long, FEDInstruction.FederatedOutput>> getChildFedPlans() {
            return this.childFedPlans;
        }

        public double getCondNetTransferCost(FEDInstruction.FederatedOutput parentFedOutType) {
            if (parentFedOutType == this.getFedOutType()) {
                return 0.0;
            }
            return this.fedPlanVariants.hopCommon.netTransferCost;
        }
    }

    public static class FedPlanVariants {
        protected HopCommon hopCommon;
        private final FEDInstruction.FederatedOutput fedOutType;
        protected List<FedPlan> _fedPlanVariants;

        public FedPlanVariants(Hop hopRef, FEDInstruction.FederatedOutput fedOutType) {
            this.hopCommon = new HopCommon(hopRef);
            this.fedOutType = fedOutType;
            this._fedPlanVariants = new ArrayList<FedPlan>();
        }

        public void addFedPlan(FedPlan fedPlan) {
            this._fedPlanVariants.add(fedPlan);
        }

        public List<FedPlan> getFedPlanVariants() {
            return this._fedPlanVariants;
        }

        public boolean isEmpty() {
            return this._fedPlanVariants.isEmpty();
        }

        public void prune() {
            if (this._fedPlanVariants.size() > 1) {
                FedPlan minCostPlan = this._fedPlanVariants.stream().min(Comparator.comparingDouble(FedPlan::getTotalCost)).orElse(null);
                this._fedPlanVariants.clear();
                this._fedPlanVariants.add(minCostPlan);
            }
        }
    }

    public static class HopCommon {
        protected final Hop hopRef;
        protected double selfCost;
        protected double netTransferCost;

        protected HopCommon(Hop hopRef) {
            this.hopRef = hopRef;
            this.selfCost = 0.0;
            this.netTransferCost = 0.0;
        }
    }
}

