/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.SubtaskState;
import org.apache.flink.runtime.checkpoint.TaskState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskStateHandles;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;

public class StateAssignmentOperation {
    private final Logger logger;
    private final Map<JobVertexID, ExecutionJobVertex> tasks;
    private final Map<JobVertexID, TaskState> taskStates;
    private final boolean allowNonRestoredState;

    public StateAssignmentOperation(Logger logger, Map<JobVertexID, ExecutionJobVertex> tasks, Map<JobVertexID, TaskState> taskStates, boolean allowNonRestoredState) {
        this.logger = (Logger)Preconditions.checkNotNull((Object)logger);
        this.tasks = (Map)Preconditions.checkNotNull(tasks);
        this.taskStates = (Map)Preconditions.checkNotNull(taskStates);
        this.allowNonRestoredState = allowNonRestoredState;
    }

    public boolean assignStates() throws Exception {
        boolean expandedToLegacyIds = false;
        Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
        for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : this.taskStates.entrySet()) {
            TaskState taskState = taskGroupStateEntry.getValue();
            ExecutionJobVertex executionJobVertex = localTasks.get((Object)taskGroupStateEntry.getKey());
            if (executionJobVertex == null && !expandedToLegacyIds) {
                localTasks = ExecutionJobVertex.includeLegacyJobVertexIDs(localTasks);
                executionJobVertex = localTasks.get((Object)taskGroupStateEntry.getKey());
                expandedToLegacyIds = true;
                this.logger.info("Could not find ExecutionJobVertex. Including legacy JobVertexIDs in search.");
            }
            if (executionJobVertex == null) {
                if (this.allowNonRestoredState) {
                    this.logger.info("Skipped checkpoint state for operator {}.", (Object)taskState.getJobVertexID());
                    continue;
                }
                throw new IllegalStateException("There is no execution job vertex for the job vertex ID " + (Object)((Object)taskGroupStateEntry.getKey()));
            }
            this.checkParallelismPreconditions(taskState, executionJobVertex);
            StateAssignmentOperation.assignTaskStatesToOperatorInstances(taskState, executionJobVertex);
        }
        return true;
    }

    private void checkParallelismPreconditions(TaskState taskState, ExecutionJobVertex executionJobVertex) {
        if (taskState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (!executionJobVertex.isMaxParallelismConfigured()) {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("Overriding maximum parallelism for JobVertex " + (Object)((Object)executionJobVertex.getJobVertexId()) + " from " + executionJobVertex.getMaxParallelism() + " to " + taskState.getMaxParallelism());
                }
                executionJobVertex.setMaxParallelism(taskState.getMaxParallelism());
            } else {
                throw new IllegalStateException("The maximum parallelism (" + taskState.getMaxParallelism() + ") with which the latest " + "checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This " + "is currently not supported.");
            }
        }
        int oldParallelism = taskState.getParallelism();
        int newParallelism = executionJobVertex.getParallelism();
        if (taskState.hasNonPartitionedState() && oldParallelism != newParallelism) {
            throw new IllegalStateException("Cannot restore the latest checkpoint because the operator " + (Object)((Object)executionJobVertex.getJobVertexId()) + " has non-partitioned " + "state and its parallelism changed. The operator " + (Object)((Object)executionJobVertex.getJobVertexId()) + " has parallelism " + newParallelism + " whereas the corresponding " + "state object has a parallelism of " + oldParallelism);
        }
    }

    private static void assignTaskStatesToOperatorInstances(TaskState taskState, ExecutionJobVertex executionJobVertex) {
        int oldParallelism = taskState.getParallelism();
        int newParallelism = executionJobVertex.getParallelism();
        List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(executionJobVertex.getMaxParallelism(), newParallelism);
        int chainLength = taskState.getChainLength();
        List[] parallelOpStatesBackend = new List[chainLength];
        List[] parallelOpStatesStream = new List[chainLength];
        ArrayList<KeyGroupsStateHandle> parallelKeyedStatesBackend = new ArrayList<KeyGroupsStateHandle>(oldParallelism);
        ArrayList<KeyGroupsStateHandle> parallelKeyedStateStream = new ArrayList<KeyGroupsStateHandle>(oldParallelism);
        for (int p = 0; p < oldParallelism; ++p) {
            KeyGroupsStateHandle keyedStateStream;
            SubtaskState subtaskState = taskState.getState(p);
            if (null == subtaskState) continue;
            StateAssignmentOperation.collectParallelStatesByChainOperator(parallelOpStatesBackend, subtaskState.getManagedOperatorState());
            StateAssignmentOperation.collectParallelStatesByChainOperator(parallelOpStatesStream, subtaskState.getRawOperatorState());
            KeyGroupsStateHandle keyedStateBackend = subtaskState.getManagedKeyedState();
            if (null != keyedStateBackend) {
                parallelKeyedStatesBackend.add(keyedStateBackend);
            }
            if (null == (keyedStateStream = subtaskState.getRawKeyedState())) continue;
            parallelKeyedStateStream.add(keyedStateStream);
        }
        List[] partitionedParallelStatesBackend = new List[chainLength];
        List[] partitionedParallelStatesStream = new List[chainLength];
        OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
        for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
            List chainOpParallelStatesBackend = parallelOpStatesBackend[chainIdx];
            List chainOpParallelStatesStream = parallelOpStatesStream[chainIdx];
            partitionedParallelStatesBackend[chainIdx] = StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, chainOpParallelStatesBackend, oldParallelism, newParallelism);
            partitionedParallelStatesStream[chainIdx] = StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, chainOpParallelStatesStream, oldParallelism, newParallelism);
        }
        for (int subTaskIdx = 0; subTaskIdx < newParallelism; ++subTaskIdx) {
            List<KeyGroupsStateHandle> newKeyedStateStream;
            List<KeyGroupsStateHandle> newKeyedStatesBackend;
            ChainedStateHandle<StreamStateHandle> nonPartitionableState = null;
            if (oldParallelism == newParallelism && taskState.getState(subTaskIdx) != null) {
                nonPartitionableState = taskState.getState(subTaskIdx).getLegacyOperatorState();
            }
            Collection[] iab = new Collection[chainLength];
            Collection[] ias = new Collection[chainLength];
            List<Collection<OperatorStateHandle>> operatorStateFromBackend = Arrays.asList(iab);
            List<Collection<OperatorStateHandle>> operatorStateFromStream = Arrays.asList(ias);
            for (int chainIdx = 0; chainIdx < partitionedParallelStatesBackend.length; ++chainIdx) {
                List redistributedOpStateBackend = partitionedParallelStatesBackend[chainIdx];
                List redistributedOpStateStream = partitionedParallelStatesStream[chainIdx];
                if (redistributedOpStateBackend != null) {
                    operatorStateFromBackend.set(chainIdx, (Collection<OperatorStateHandle>)redistributedOpStateBackend.get(subTaskIdx));
                }
                if (redistributedOpStateStream == null) continue;
                operatorStateFromStream.set(chainIdx, (Collection<OperatorStateHandle>)redistributedOpStateStream.get(subTaskIdx));
            }
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIdx].getCurrentExecutionAttempt();
            if (oldParallelism == newParallelism) {
                SubtaskState subtaskState = taskState.getState(subTaskIdx);
                if (subtaskState != null) {
                    KeyGroupsStateHandle oldKeyedStatesBackend = subtaskState.getManagedKeyedState();
                    KeyGroupsStateHandle oldKeyedStatesStream = subtaskState.getRawKeyedState();
                    newKeyedStatesBackend = oldKeyedStatesBackend != null ? Collections.singletonList(oldKeyedStatesBackend) : null;
                    newKeyedStateStream = oldKeyedStatesStream != null ? Collections.singletonList(oldKeyedStatesStream) : null;
                } else {
                    newKeyedStatesBackend = null;
                    newKeyedStateStream = null;
                }
            } else {
                KeyGroupRange subtaskKeyGroupIds = keyGroupPartitions.get(subTaskIdx);
                newKeyedStatesBackend = StateAssignmentOperation.getKeyGroupsStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
                newKeyedStateStream = StateAssignmentOperation.getKeyGroupsStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
            }
            TaskStateHandles taskStateHandles = new TaskStateHandles(nonPartitionableState, operatorStateFromBackend, operatorStateFromStream, newKeyedStatesBackend, newKeyedStateStream);
            currentExecutionAttempt.setInitialState(taskStateHandles);
        }
    }

    public static List<KeyGroupsStateHandle> getKeyGroupsStateHandles(Collection<KeyGroupsStateHandle> allKeyGroupsHandles, KeyGroupRange subtaskKeyGroupIds) {
        ArrayList<KeyGroupsStateHandle> subtaskKeyGroupStates = new ArrayList<KeyGroupsStateHandle>();
        for (KeyGroupsStateHandle storedKeyGroup : allKeyGroupsHandles) {
            KeyGroupsStateHandle intersection = storedKeyGroup.getKeyGroupIntersection(subtaskKeyGroupIds);
            if (intersection.getNumberOfKeyGroups() <= 0) continue;
            subtaskKeyGroupStates.add(intersection);
        }
        return subtaskKeyGroupStates;
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
        Preconditions.checkArgument((numberKeyGroups >= parallelism ? 1 : 0) != 0);
        ArrayList<KeyGroupRange> result = new ArrayList<KeyGroupRange>(parallelism);
        for (int i = 0; i < parallelism; ++i) {
            result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
        }
        return result;
    }

    private static void collectParallelStatesByChainOperator(List<OperatorStateHandle>[] chainParallelOpStates, ChainedStateHandle<OperatorStateHandle> chainOpState) {
        if (null != chainOpState) {
            int chainLength = chainOpState.getLength();
            Preconditions.checkState((chainLength >= chainParallelOpStates.length ? 1 : 0) != 0, (Object)("Found more states than operators in the chain. Chain length: " + chainLength + ", States: " + chainParallelOpStates.length));
            for (int chainIdx = 0; chainIdx < chainParallelOpStates.length; ++chainIdx) {
                OperatorStateHandle operatorState = chainOpState.get(chainIdx);
                if (null == operatorState) continue;
                List<OperatorStateHandle> opParallelStatesForOneChainOp = chainParallelOpStates[chainIdx];
                if (null == opParallelStatesForOneChainOp) {
                    chainParallelOpStates[chainIdx] = opParallelStatesForOneChainOp = new ArrayList<OperatorStateHandle>();
                }
                opParallelStatesForOneChainOp.add(operatorState);
            }
        }
    }

    private static List<Collection<OperatorStateHandle>> applyRepartitioner(OperatorStateRepartitioner opStateRepartitioner, List<OperatorStateHandle> chainOpParallelStates, int oldParallelism, int newParallelism) {
        if (chainOpParallelStates == null) {
            return null;
        }
        if (newParallelism != oldParallelism) {
            return opStateRepartitioner.repartitionState(chainOpParallelStates, newParallelism);
        }
        ArrayList<Collection<OperatorStateHandle>> repackStream = new ArrayList<Collection<OperatorStateHandle>>(newParallelism);
        for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
            Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets = operatorStateHandle.getStateNameToPartitionOffsets();
            for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
                if (!OperatorStateHandle.Mode.BROADCAST.equals((Object)metaInfo.getDistributionMode())) continue;
                return opStateRepartitioner.repartitionState(chainOpParallelStates, newParallelism);
            }
            repackStream.add(Collections.singletonList(operatorStateHandle));
        }
        return repackStream;
    }
}

