/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.coordinator.group.streams.assignor;

import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.coordinator.group.streams.assignor.AssignmentMemberSpec;
import org.apache.kafka.coordinator.group.streams.assignor.GroupAssignment;
import org.apache.kafka.coordinator.group.streams.assignor.GroupSpec;
import org.apache.kafka.coordinator.group.streams.assignor.MemberAssignment;
import org.apache.kafka.coordinator.group.streams.assignor.ProcessState;
import org.apache.kafka.coordinator.group.streams.assignor.TaskAssignor;
import org.apache.kafka.coordinator.group.streams.assignor.TaskAssignorException;
import org.apache.kafka.coordinator.group.streams.assignor.TaskId;
import org.apache.kafka.coordinator.group.streams.assignor.TopologyDescriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StickyTaskAssignor
implements TaskAssignor {
    private static final String STICKY_ASSIGNOR_NAME = "sticky";
    private static final Logger log = LoggerFactory.getLogger(StickyTaskAssignor.class);
    private LocalState localState;

    @Override
    public String name() {
        return STICKY_ASSIGNOR_NAME;
    }

    @Override
    public GroupAssignment assign(GroupSpec groupSpec, TopologyDescriber topologyDescriber) throws TaskAssignorException {
        this.initialize(groupSpec, topologyDescriber);
        GroupAssignment assignments = this.doAssign(groupSpec, topologyDescriber);
        this.localState = null;
        return assignments;
    }

    private GroupAssignment doAssign(GroupSpec groupSpec, TopologyDescriber topologyDescriber) {
        int numStandbyReplicas;
        Set<TaskId> activeTasks = this.taskIds(topologyDescriber, true);
        this.assignActive(activeTasks);
        int n = numStandbyReplicas = groupSpec.assignmentConfigs().isEmpty() ? 0 : Integer.parseInt(groupSpec.assignmentConfigs().get("num.standby.replicas"));
        if (numStandbyReplicas > 0) {
            Set<TaskId> statefulTasks = this.taskIds(topologyDescriber, false);
            this.assignStandby(statefulTasks, numStandbyReplicas);
        }
        return this.buildGroupAssignment(groupSpec.members().keySet());
    }

    private Set<TaskId> taskIds(TopologyDescriber topologyDescriber, boolean isActive) {
        HashSet<TaskId> ret = new HashSet<TaskId>();
        for (String subtopology : topologyDescriber.subtopologies()) {
            if (!isActive && !topologyDescriber.isStateful(subtopology)) continue;
            int numberOfPartitions = topologyDescriber.maxNumInputPartitions(subtopology);
            for (int i = 0; i < numberOfPartitions; ++i) {
                ret.add(new TaskId(subtopology, i));
            }
        }
        return ret;
    }

    private void initialize(GroupSpec groupSpec, TopologyDescriber topologyDescriber) {
        this.localState = new LocalState();
        this.localState.allTasks = 0;
        for (String string : topologyDescriber.subtopologies()) {
            int numberOfPartitions = topologyDescriber.maxNumInputPartitions(string);
            this.localState.allTasks += numberOfPartitions;
        }
        this.localState.totalCapacity = groupSpec.members().size();
        this.localState.tasksPerMember = StickyTaskAssignor.computeTasksPerMember(this.localState.allTasks, this.localState.totalCapacity);
        this.localState.taskPairs = new TaskPairs(this.localState.allTasks * (this.localState.allTasks - 1) / 2);
        this.localState.processIdToState = new HashMap<String, ProcessState>();
        this.localState.activeTaskToPrevMember = new HashMap<TaskId, Member>();
        this.localState.standbyTaskToPrevMember = new HashMap<TaskId, Set<Member>>();
        for (Map.Entry entry : groupSpec.members().entrySet()) {
            Set<Integer> partitionNoSet;
            String memberId = (String)entry.getKey();
            String processId = ((AssignmentMemberSpec)entry.getValue()).processId();
            Member member = new Member(processId, memberId);
            AssignmentMemberSpec memberSpec = (AssignmentMemberSpec)entry.getValue();
            this.localState.processIdToState.putIfAbsent(processId, new ProcessState(processId));
            this.localState.processIdToState.get(processId).addMember(memberId);
            for (Map.Entry<String, Set<Integer>> entry2 : memberSpec.activeTasks().entrySet()) {
                partitionNoSet = entry2.getValue();
                for (int partitionNo : partitionNoSet) {
                    this.localState.activeTaskToPrevMember.put(new TaskId(entry2.getKey(), partitionNo), member);
                }
            }
            for (Map.Entry<String, Set<Integer>> entry2 : memberSpec.standbyTasks().entrySet()) {
                partitionNoSet = entry2.getValue();
                for (int partitionNo : partitionNoSet) {
                    TaskId taskId = new TaskId(entry2.getKey(), partitionNo);
                    this.localState.standbyTaskToPrevMember.putIfAbsent(taskId, new HashSet());
                    this.localState.standbyTaskToPrevMember.get(taskId).add(member);
                }
            }
        }
    }

    private GroupAssignment buildGroupAssignment(Set<String> members) {
        HashMap<String, MemberAssignment> memberAssignments = new HashMap<String, MemberAssignment>();
        Map<String, Set> activeTasksAssignments = this.localState.processIdToState.entrySet().stream().flatMap(entry -> ((ProcessState)entry.getValue()).assignedActiveTasksByMember().entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (set1, set2) -> {
            set1.addAll(set2);
            return set1;
        }));
        Map<String, Set> standbyTasksAssignments = this.localState.processIdToState.entrySet().stream().flatMap(entry -> ((ProcessState)entry.getValue()).assignedStandbyTasksByMember().entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (set1, set2) -> {
            set1.addAll(set2);
            return set1;
        }));
        for (String memberId : members) {
            HashMap<String, Set<Integer>> activeTasks = new HashMap();
            if (activeTasksAssignments.containsKey(memberId)) {
                activeTasks = this.toCompactedTaskIds(activeTasksAssignments.get(memberId));
            }
            Map<String, Set<Integer>> standByTasks = new HashMap<String, Set<Integer>>();
            if (standbyTasksAssignments.containsKey(memberId)) {
                standByTasks = this.toCompactedTaskIds(standbyTasksAssignments.get(memberId));
            }
            memberAssignments.put(memberId, new MemberAssignment(activeTasks, standByTasks, new HashMap<String, Set<Integer>>()));
        }
        return new GroupAssignment(memberAssignments);
    }

    private Map<String, Set<Integer>> toCompactedTaskIds(Set<TaskId> taskIds) {
        HashMap<String, Set<Integer>> ret = new HashMap<String, Set<Integer>>();
        for (TaskId taskId : taskIds) {
            ret.putIfAbsent(taskId.subtopologyId(), new HashSet());
            ((Set)ret.get(taskId.subtopologyId())).add(taskId.partition());
        }
        return ret;
    }

    private void assignActive(Set<TaskId> activeTasks) {
        TaskId task;
        Iterator<TaskId> it = activeTasks.iterator();
        while (it.hasNext()) {
            task = it.next();
            Member prevMember = this.localState.activeTaskToPrevMember.get(task);
            if (prevMember == null || !this.hasUnfulfilledQuota(prevMember)) continue;
            this.localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true);
            this.updateHelpers(prevMember, task, true);
            it.remove();
        }
        it = activeTasks.iterator();
        while (it.hasNext()) {
            task = it.next();
            Set<Member> prevMembers = this.localState.standbyTaskToPrevMember.get(task);
            Member prevMember = this.findMemberWithLeastLoad(prevMembers, task, true);
            if (prevMember == null || !this.hasUnfulfilledQuota(prevMember)) continue;
            this.localState.processIdToState.get(prevMember.processId).addTask(prevMember.memberId, task, true);
            this.updateHelpers(prevMember, task, true);
            it.remove();
        }
        it = activeTasks.iterator();
        while (it.hasNext()) {
            task = it.next();
            Set<Member> allMembers = this.localState.processIdToState.entrySet().stream().flatMap(entry -> ((ProcessState)entry.getValue()).memberToTaskCounts().keySet().stream().map(memberId -> new Member((String)entry.getKey(), (String)memberId))).collect(Collectors.toSet());
            Member member = this.findMemberWithLeastLoad(allMembers, task, false);
            if (member == null) {
                log.error("Unable to assign active task {} to any member.", (Object)task);
                throw new TaskAssignorException("No member available to assign active task {}." + String.valueOf(task));
            }
            this.localState.processIdToState.get(member.processId).addTask(member.memberId, task, true);
            it.remove();
            this.updateHelpers(member, task, true);
        }
    }

    private void maybeUpdateTasksPerMember(int activeTasksNo) {
        if (activeTasksNo == this.localState.tasksPerMember) {
            --this.localState.totalCapacity;
            this.localState.allTasks -= activeTasksNo;
            this.localState.tasksPerMember = StickyTaskAssignor.computeTasksPerMember(this.localState.allTasks, this.localState.totalCapacity);
        }
    }

    private Member findMemberWithLeastLoad(Set<Member> members, TaskId taskId, boolean returnSameMember) {
        Optional<ProcessState> processWithLeastLoad;
        if (members == null || members.isEmpty()) {
            return null;
        }
        Set<Object> rightPairs = members.stream().filter(member -> this.localState.taskPairs.hasNewPair(taskId, this.localState.processIdToState.get(member.processId).assignedTasks())).collect(Collectors.toSet());
        if (rightPairs.isEmpty()) {
            rightPairs = members;
        }
        if ((processWithLeastLoad = rightPairs.stream().map(member -> this.localState.processIdToState.get(member.processId)).min(Comparator.comparingDouble(ProcessState::load))).isEmpty()) {
            return null;
        }
        if (returnSameMember) {
            return this.localState.standbyTaskToPrevMember.get(taskId).stream().filter(standby -> standby.processId.equals(((ProcessState)processWithLeastLoad.get()).processId())).findFirst().orElseGet(() -> this.memberWithLeastLoad((ProcessState)processWithLeastLoad.get()));
        }
        return this.memberWithLeastLoad(processWithLeastLoad.get());
    }

    private Member memberWithLeastLoad(ProcessState processWithLeastLoad) {
        Optional<String> memberWithLeastLoad = processWithLeastLoad.memberToTaskCounts().entrySet().stream().min(Map.Entry.comparingByValue()).map(Map.Entry::getKey);
        return memberWithLeastLoad.map(memberId -> new Member(processWithLeastLoad.processId(), (String)memberId)).orElse(null);
    }

    private boolean hasUnfulfilledQuota(Member member) {
        return this.localState.processIdToState.get(member.processId).memberToTaskCounts().get(member.memberId) < this.localState.tasksPerMember;
    }

    private void assignStandby(Set<TaskId> standbyTasks, int numStandbyReplicas) {
        block0: for (TaskId task : standbyTasks) {
            for (int i = 0; i < numStandbyReplicas; ++i) {
                Set<Member> availableMembers;
                Set<Member> prevMembers;
                Set availableProcesses = this.localState.processIdToState.values().stream().filter(process -> !process.hasTask(task)).map(ProcessState::processId).collect(Collectors.toSet());
                if (availableProcesses.isEmpty()) {
                    log.warn("{} There is not enough available capacity. You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.", (Object)this.errorMessage(numStandbyReplicas, i, task));
                    continue block0;
                }
                Member standby = null;
                Member prevMember = this.localState.activeTaskToPrevMember.get(task);
                if (prevMember != null && availableProcesses.contains(prevMember.processId) && this.isLoadBalanced(prevMember.processId) && this.localState.taskPairs.hasNewPair(task, this.localState.processIdToState.get(prevMember.processId).assignedTasks())) {
                    standby = prevMember;
                }
                if (standby == null && (prevMembers = this.localState.standbyTaskToPrevMember.get(task)) != null && !prevMembers.isEmpty()) {
                    prevMembers.removeIf(member -> !availableProcesses.contains(member.processId));
                    prevMember = this.findMemberWithLeastLoad(prevMembers, task, true);
                    if (prevMember != null && this.isLoadBalanced(prevMember.processId)) {
                        standby = prevMember;
                    }
                }
                if (standby == null && (standby = this.findMemberWithLeastLoad(availableMembers = availableProcesses.stream().flatMap(pId -> this.localState.processIdToState.get(pId).memberToTaskCounts().keySet().stream().map(mId -> new Member((String)pId, (String)mId))).collect(Collectors.toSet()), task, false)) == null) {
                    log.warn("{} Error in standby task assignment!", (Object)this.errorMessage(numStandbyReplicas, i, task));
                    continue block0;
                }
                this.localState.processIdToState.get(standby.processId).addTask(standby.memberId, task, false);
                this.updateHelpers(standby, task, false);
            }
        }
    }

    private String errorMessage(int numStandbyReplicas, int i, TaskId task) {
        return "Unable to assign " + (numStandbyReplicas - i) + " of " + numStandbyReplicas + " standby tasks for task [" + String.valueOf(task) + "].";
    }

    private boolean isLoadBalanced(String processId) {
        ProcessState process = this.localState.processIdToState.get(processId);
        double load = process.load();
        boolean isLeastLoadedProcess = this.localState.processIdToState.values().stream().allMatch(p -> p.load() >= load);
        return process.hasCapacity() || isLeastLoadedProcess;
    }

    private void updateHelpers(Member member, TaskId taskId, boolean isActive) {
        this.localState.taskPairs.addPairs(taskId, this.localState.processIdToState.get(member.processId).assignedTasks());
        if (isActive) {
            this.maybeUpdateTasksPerMember(this.localState.processIdToState.get(member.processId).assignedActiveTasks().size());
        }
    }

    private static int computeTasksPerMember(int numberOfTasks, int numberOfMembers) {
        if (numberOfMembers == 0) {
            return 0;
        }
        int tasksPerMember = numberOfTasks / numberOfMembers;
        if (numberOfTasks % numberOfMembers > 0) {
            ++tasksPerMember;
        }
        return tasksPerMember;
    }

    private static class LocalState {
        private TaskPairs taskPairs;
        Map<TaskId, Member> activeTaskToPrevMember;
        Map<TaskId, Set<Member>> standbyTaskToPrevMember;
        Map<String, ProcessState> processIdToState;
        int allTasks;
        int totalCapacity;
        int tasksPerMember;

        private LocalState() {
        }
    }

    private static class TaskPairs {
        private final Set<Pair> pairs;
        private final int maxPairs;

        TaskPairs(int maxPairs) {
            this.maxPairs = maxPairs;
            this.pairs = new HashSet<Pair>(maxPairs);
        }

        boolean hasNewPair(TaskId task1, Set<TaskId> taskIds) {
            if (this.pairs.size() == this.maxPairs) {
                return false;
            }
            if (taskIds.size() == 0) {
                return true;
            }
            for (TaskId taskId : taskIds) {
                if (this.pairs.contains(this.pair(task1, taskId))) continue;
                return true;
            }
            return false;
        }

        void addPairs(TaskId taskId, Set<TaskId> assigned) {
            for (TaskId id : assigned) {
                if (id.equals(taskId)) continue;
                this.pairs.add(this.pair(id, taskId));
            }
        }

        Pair pair(TaskId task1, TaskId task2) {
            if (task1.compareTo(task2) < 0) {
                return new Pair(task1, task2);
            }
            return new Pair(task2, task1);
        }

        private static class Pair {
            private final TaskId task1;
            private final TaskId task2;

            Pair(TaskId task1, TaskId task2) {
                this.task1 = task1;
                this.task2 = task2;
            }

            public boolean equals(Object o) {
                if (this == o) {
                    return true;
                }
                if (o == null || this.getClass() != o.getClass()) {
                    return false;
                }
                Pair pair = (Pair)o;
                return Objects.equals(this.task1, pair.task1) && Objects.equals(this.task2, pair.task2);
            }

            public int hashCode() {
                return Objects.hash(this.task1, this.task2);
            }
        }
    }

    static class Member {
        private final String processId;
        private final String memberId;

        public Member(String processId, String memberId) {
            this.processId = processId;
            this.memberId = memberId;
        }
    }
}

