/*
 * Decompiled with CFR 0.152.
 */
package org.apache.helix.controller.rebalancer.strategy;

import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.helix.HelixException;
import org.apache.helix.ZNRecord;
import org.apache.helix.controller.LogUtil;
import org.apache.helix.controller.dataproviders.ResourceControllerDataProvider;
import org.apache.helix.controller.rebalancer.strategy.RebalanceStrategy;
import org.apache.helix.controller.rebalancer.strategy.crushMapping.CRUSHPlacementAlgorithm;
import org.apache.helix.controller.rebalancer.topology.InstanceNode;
import org.apache.helix.controller.rebalancer.topology.Node;
import org.apache.helix.controller.rebalancer.topology.Topology;
import org.apache.helix.model.InstanceConfig;
import org.apache.helix.util.JenkinsHash;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiRoundCrushRebalanceStrategy
implements RebalanceStrategy<ResourceControllerDataProvider> {
    private static final Logger Log = LoggerFactory.getLogger((String)MultiRoundCrushRebalanceStrategy.class.getName());
    private String _resourceName;
    private List<String> _partitions;
    private Topology _clusterTopo;
    private int _replicas;
    private LinkedHashMap<String, Integer> _stateCountMap;
    private final int MAX_ITERATION = 3;
    private static final int MAX_RETRY = 100;
    private final JenkinsHash hashFun = new JenkinsHash();
    private CRUSHPlacementAlgorithm placementAlgorithm = new CRUSHPlacementAlgorithm();

    @Override
    public void init(String resourceName, List<String> partitions, LinkedHashMap<String, Integer> states, int maximumPerNode) {
        this._resourceName = resourceName;
        this._partitions = partitions;
        this._replicas = this.countStateReplicas(states);
        this._stateCountMap = states;
    }

    @Override
    public ZNRecord computePartitionAssignment(List<String> allNodes, List<String> liveNodes, Map<String, Map<String, String>> currentMapping, ResourceControllerDataProvider clusterData) throws HelixException {
        Map<String, InstanceConfig> instanceConfigMap = clusterData.getInstanceConfigMap();
        this._clusterTopo = new Topology(allNodes, liveNodes, instanceConfigMap, clusterData.getClusterConfig());
        Node root = this._clusterTopo.getRootNode();
        HashMap<String, List<Node>> zoneMapping = new HashMap<String, List<Node>>();
        for (int i = 0; i < this._partitions.size(); ++i) {
            String partitionName = this._partitions.get(i);
            long pData = partitionName.hashCode();
            List<Node> zones = this.select(root, this._clusterTopo.getFaultZoneType(), pData, this._replicas);
            zoneMapping.put(partitionName, zones);
        }
        HashMap<Integer, String> idxStateMap = new HashMap<Integer, String>();
        int i = 0;
        for (Map.Entry<String, Integer> e : this._stateCountMap.entrySet()) {
            String state = e.getKey();
            int count = e.getValue();
            for (int j = 0; j < count; ++j) {
                idxStateMap.put(i + j, state);
            }
            i += count;
        }
        HashMap<String, Map<String, List<Node>>> partitionStateMapping = new HashMap<String, Map<String, List<Node>>>();
        for (Node zone : this._clusterTopo.getFaultZones()) {
            LinkedHashMap statePartitionMap = new LinkedHashMap();
            for (Map.Entry e : zoneMapping.entrySet()) {
                String partition = (String)e.getKey();
                List zones = (List)e.getValue();
                for (int k = 0; k < zones.size(); ++k) {
                    if (!((Node)zones.get(k)).equals(zone)) continue;
                    String state = (String)idxStateMap.get(k);
                    if (!statePartitionMap.containsKey(state)) {
                        statePartitionMap.put(state, new ArrayList());
                    }
                    ((List)statePartitionMap.get(state)).add(partition);
                }
            }
            for (String state : this._stateCountMap.keySet()) {
                List partitions = (List)statePartitionMap.get(state);
                if (partitions == null || partitions.isEmpty()) continue;
                Map<String, Node> assignments = this.singleZoneMapping(zone, partitions);
                for (String partition : assignments.keySet()) {
                    Map stateMapping;
                    Node node = assignments.get(partition);
                    if (!partitionStateMapping.containsKey(partition)) {
                        partitionStateMapping.put(partition, new HashMap());
                    }
                    if (!(stateMapping = (Map)partitionStateMapping.get(partition)).containsKey(state)) {
                        stateMapping.put(state, new ArrayList());
                    }
                    ((List)stateMapping.get(state)).add(node);
                }
            }
        }
        return this.generateZNRecord(this._resourceName, this._partitions, partitionStateMapping, clusterData.getClusterEventId());
    }

    private ZNRecord generateZNRecord(String resource, List<String> partitions, Map<String, Map<String, List<Node>>> partitionStateMapping, String eventId) {
        HashMap<String, List<String>> newPreferences = new HashMap<String, List<String>>();
        for (int i = 0; i < partitions.size(); ++i) {
            String partitionName = partitions.get(i);
            Map<String, List<Node>> stateNodeMap = partitionStateMapping.get(partitionName);
            for (String state : this._stateCountMap.keySet()) {
                List<Node> nodes = stateNodeMap.get(state);
                ArrayList<String> nodeList = new ArrayList<String>();
                for (int j = 0; j < nodes.size(); ++j) {
                    Node selectedNode = nodes.get(j);
                    if (selectedNode instanceof InstanceNode) {
                        nodeList.add(((InstanceNode)selectedNode).getInstanceName());
                        continue;
                    }
                    LogUtil.logError(Log, eventId, "Selected node is not associated with an instance: " + selectedNode.toString());
                }
                if (!newPreferences.containsKey(partitionName)) {
                    newPreferences.put(partitionName, new ArrayList());
                }
                ((List)newPreferences.get(partitionName)).addAll(nodeList);
            }
        }
        ZNRecord result = new ZNRecord(resource);
        result.setListFields(newPreferences);
        return result;
    }

    private Map<String, Node> singleZoneMapping(Node zone, List<String> partitions) {
        if (zone.isFailed() || zone.getWeight() == 0L || partitions.isEmpty()) {
            return Collections.emptyMap();
        }
        long totalWeight = zone.getWeight();
        int totalPartition = partitions.size();
        HashMap<Node, List<String>> nodePartitionsMap = new HashMap<Node, List<String>>();
        HashMap<Node, List<String>> prevNodePartitionsMap = new HashMap<Node, List<String>>();
        ArrayList<String> partitionsToAssign = new ArrayList<String>(partitions);
        HashMap<Node, List<String>> toRemovedMap = new HashMap<Node, List<String>>();
        int iteration = 0;
        Node root = zone;
        boolean noAssignmentFound = false;
        while (iteration++ < 3 && !noAssignmentFound) {
            this.copyAssignment(nodePartitionsMap, prevNodePartitionsMap);
            for (Map.Entry e : toRemovedMap.entrySet()) {
                List curAssignedPartitions = (List)nodePartitionsMap.get(e.getKey());
                List toRemoved = (List)e.getValue();
                curAssignedPartitions.removeAll(toRemoved);
                partitionsToAssign.addAll(toRemoved);
            }
            for (String p : partitionsToAssign) {
                List<Node> nodes;
                long pData = p.hashCode();
                try {
                    nodes = this.select(root, this._clusterTopo.getEndNodeType(), pData, 1);
                }
                catch (IllegalStateException e) {
                    nodePartitionsMap = prevNodePartitionsMap;
                    noAssignmentFound = true;
                    break;
                }
                for (Node n : nodes) {
                    if (!nodePartitionsMap.containsKey(n)) {
                        nodePartitionsMap.put(n, new ArrayList());
                    }
                    ((List)nodePartitionsMap.get(n)).add(p);
                }
                root = this.recalculateWeight(zone, totalWeight, totalPartition, nodePartitionsMap, partitions, toRemovedMap);
            }
            partitionsToAssign.clear();
        }
        HashMap<String, Node> partitionMap = new HashMap<String, Node>();
        for (Map.Entry e : nodePartitionsMap.entrySet()) {
            Node n = (Node)e.getKey();
            List assigned = (List)e.getValue();
            for (String p : assigned) {
                partitionMap.put(p, n);
            }
        }
        return partitionMap;
    }

    private void copyAssignment(Map<Node, List<String>> nodePartitionsMap, Map<Node, List<String>> prevNodePartitionMap) {
        if (nodePartitionsMap.size() > 0) {
            for (Node node : nodePartitionsMap.keySet()) {
                prevNodePartitionMap.put(node, new ArrayList(nodePartitionsMap.get(node)));
            }
        }
    }

    private Node recalculateWeight(Node zone, long totalWeight, int totalPartition, Map<Node, List<String>> nodePartitionsMap, List<String> partitions, Map<Node, List<String>> toRemovedMap) {
        HashMap<Node, Integer> newNodeWeight = new HashMap<Node, Integer>();
        HashSet<Node> completedNodes = new HashSet<Node>();
        for (Node node : Topology.getAllLeafNodes(zone)) {
            int missing;
            if (node.isFailed()) {
                completedNodes.add(node);
                continue;
            }
            long weight = node.getWeight();
            double ratio = (double)weight / (double)totalWeight;
            int target = (int)Math.floor(ratio * (double)totalPartition);
            List<String> assignedPatitions = nodePartitionsMap.get(node);
            int numPartitions = 0;
            if (assignedPatitions != null) {
                numPartitions = assignedPatitions.size();
            }
            if (numPartitions > target + 1) {
                int remove = numPartitions - target - 1;
                Collections.sort(partitions);
                ArrayList<String> toRemoved = new ArrayList<String>(assignedPatitions.subList(0, remove));
                toRemovedMap.put(node, toRemoved);
            }
            if ((missing = target - numPartitions) > 0) {
                newNodeWeight.put(node, missing * 10);
                continue;
            }
            completedNodes.add(node);
        }
        if (!newNodeWeight.isEmpty()) {
            return Topology.clone(zone, newNodeWeight, completedNodes);
        }
        return zone;
    }

    private List<Node> select(Node topNode, String nodeType, long data, int rf) throws HelixException {
        ArrayList<Node> nodes = new ArrayList<Node>();
        long input = data;
        int count = rf;
        int tries = 0;
        while (nodes.size() < rf) {
            List<Node> selected = this.placementAlgorithm.select(topNode, input, rf, nodeType, this.nodeAlreadySelected(new HashSet<Node>(nodes)));
            nodes.addAll(selected);
            count = rf - nodes.size();
            if (count <= 0) continue;
            input = this.hashFun.hash(input);
            if (++tries < 100) continue;
            throw new HelixException(String.format("could not find all mappings after %d tries", tries));
        }
        return nodes;
    }

    private Predicate<Node> nodeAlreadySelected(Set<Node> selectedNodes) {
        return Predicates.not((Predicate)Predicates.in(selectedNodes));
    }

    private int countStateReplicas(Map<String, Integer> stateCountMap) {
        int total = 0;
        for (Integer count : stateCountMap.values()) {
            total += count.intValue();
        }
        return total;
    }
}

