/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.checkpoint.StateHandleDummyUtil;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.TestingExecutionGraphBuilder;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;

public class StateAssignmentOperationTest
extends TestLogger {
    @Test
    public void testRepartitionSplitDistributeStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap1.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[30]));
        operatorState.putState(0, new OperatorSubtaskState((OperatorStateHandle)osh1, null, null, null, null, null));
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap2.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 15L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[40]));
        operatorState.putState(1, new OperatorSubtaskState((OperatorStateHandle)osh2, null, null, null, null, null));
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testRepartitionUnionState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{22L, 44L}, OperatorStateHandle.Mode.UNION));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[50]));
        operatorState.putState(0, new OperatorSubtaskState((OperatorStateHandle)osh1, null, null, null, null, null));
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap2.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[20]));
        operatorState.putState(1, new OperatorSubtaskState((OperatorStateHandle)osh2, null, null, null, null, null));
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testRepartitionBroadcastState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[60]));
        operatorState.putState(0, new OperatorSubtaskState((OperatorStateHandle)osh1, null, null, null, null, null));
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[60]));
        operatorState.putState(1, new OperatorSubtaskState((OperatorStateHandle)osh2, null, null, null, null, null));
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testReDistributeCombinedPartitionableStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(6);
        metaInfoMap1.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{22L, 44L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{52L, 63L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        metaInfoMap1.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{67L, 74L, 75L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{77L, 88L, 92L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{101L, 123L, 127L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[130]));
        operatorState.putState(0, new OperatorSubtaskState((OperatorStateHandle)osh1, null, null, null, null, null));
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(3);
        metaInfoMap2.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap2.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{20L, 27L, 28L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 44L, 48L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{57L, 79L, 83L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[86]));
        operatorState.putState(1, new OperatorSubtaskState((OperatorStateHandle)osh2, null, null, null, null, null));
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 3);
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 1);
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyAndCollectStateInfo(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism, Map<String, Integer> stateInfoCounts) {
        Map newManagedOperatorStates = StateAssignmentOperation.reDistributePartitionableStates(Collections.singletonList(operatorState), (int)newParallelism, Collections.singletonList(OperatorIDPair.generatedIDOnly((OperatorID)operatorID)), OperatorSubtaskState::getManagedOperatorState, (OperatorStateRepartitioner)RoundRobinOperatorStateRepartitioner.INSTANCE);
        for (List operatorStateHandles : newManagedOperatorStates.values()) {
            EnumMap stateModeOffsets = new EnumMap(OperatorStateHandle.Mode.class);
            for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
                stateModeOffsets.put(mode, new HashMap());
            }
            for (OperatorStateHandle operatorStateHandle : operatorStateHandles) {
                for (Map.Entry stateNameToMetaInfo : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
                    String stateName = (String)stateNameToMetaInfo.getKey();
                    stateInfoCounts.merge(stateName, 1, (count, inc) -> count + inc);
                    OperatorStateHandle.StateMetaInfo stateMetaInfo = (OperatorStateHandle.StateMetaInfo)stateNameToMetaInfo.getValue();
                    ((Map)stateModeOffsets.get(stateMetaInfo.getDistributionMode())).merge(stateName, stateMetaInfo.getOffsets().length, (count, inc) -> count + inc);
                }
            }
            for (Map.Entry entry : stateModeOffsets.entrySet()) {
                OperatorStateHandle.Mode mode = (OperatorStateHandle.Mode)entry.getKey();
                Map stateOffsets = (Map)entry.getValue();
                if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals((Object)mode)) {
                    if (oldParallelism < newParallelism) {
                        stateOffsets.values().forEach(length -> Assert.assertEquals((long)1L, (long)length.intValue()));
                        continue;
                    }
                    stateOffsets.values().forEach(length -> Assert.assertEquals((long)2L, (long)length.intValue()));
                    continue;
                }
                if (OperatorStateHandle.Mode.UNION.equals((Object)mode)) {
                    stateOffsets.values().forEach(length -> Assert.assertEquals((long)2L, (long)length.intValue()));
                    continue;
                }
                stateOffsets.values().forEach(length -> Assert.assertEquals((long)3L, (long)length.intValue()));
            }
        }
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID) {
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 3);
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 1);
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism) {
        HashMap<String, Integer> stateInfoCounts = new HashMap<String, Integer>();
        this.verifyAndCollectStateInfo(operatorState, operatorID, oldParallelism, newParallelism, stateInfoCounts);
        Assert.assertEquals((long)2L, (long)stateInfoCounts.size());
        if (stateInfoCounts.containsKey("t-1")) {
            if (oldParallelism < newParallelism) {
                Assert.assertEquals((long)2L, (long)((Integer)stateInfoCounts.get("t-1")).intValue());
                Assert.assertEquals((long)2L, (long)((Integer)stateInfoCounts.get("t-2")).intValue());
            } else {
                Assert.assertEquals((long)1L, (long)((Integer)stateInfoCounts.get("t-1")).intValue());
                Assert.assertEquals((long)1L, (long)((Integer)stateInfoCounts.get("t-2")).intValue());
            }
        }
        if (stateInfoCounts.containsKey("t-3")) {
            Assert.assertEquals((long)(2 * newParallelism), (long)((Integer)stateInfoCounts.get("t-3")).intValue());
            Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-4")).intValue());
        }
        if (stateInfoCounts.containsKey("t-5")) {
            Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-5")).intValue());
            Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-6")).intValue());
        }
    }

    private void verifyCombinedPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism) {
        HashMap<String, Integer> stateInfoCounts = new HashMap<String, Integer>();
        this.verifyAndCollectStateInfo(operatorState, operatorID, oldParallelism, newParallelism, stateInfoCounts);
        Assert.assertEquals((long)6L, (long)stateInfoCounts.size());
        Assert.assertEquals((long)(2 * newParallelism), (long)((Integer)stateInfoCounts.get("t-1")).intValue());
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-2")).intValue());
        if (oldParallelism < newParallelism) {
            Assert.assertEquals((long)2L, (long)((Integer)stateInfoCounts.get("t-3")).intValue());
        } else {
            Assert.assertEquals((long)1L, (long)((Integer)stateInfoCounts.get("t-3")).intValue());
        }
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-4")).intValue());
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-5")).intValue());
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-6")).intValue());
    }

    @Test
    public void testChannelStateAssignmentStability() throws JobException, JobExecutionException {
        int numOperators = 10;
        int numSubTasks = 100;
        Set<OperatorID> operatorIds = this.buildOperatorIds(numOperators);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, numSubTasks);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, numSubTasks);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            for (int subtaskIdx = 0; subtaskIdx < numSubTasks; ++subtaskIdx) {
                Assert.assertEquals((Object)states.get(operatorId).getState(subtaskIdx), (Object)this.getAssignedState(vertices.get(operatorId), operatorId, subtaskIdx));
            }
        }
    }

    @Test
    public void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell() throws JobException, JobExecutionException {
        int numSubTasks = 1;
        OperatorID operatorId = new OperatorID();
        OperatorID userDefinedOperatorId = new OperatorID();
        Set<OperatorID> operatorIds = Collections.singleton(userDefinedOperatorId);
        ExecutionJobVertex executionJobVertex = this.buildExecutionJobVertex(operatorId, userDefinedOperatorId, 1);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, numSubTasks);
        new StateAssignmentOperation(0L, Collections.singleton(executionJobVertex), states, false).assignStates();
        Assert.assertEquals((Object)states.get(userDefinedOperatorId).getState(0), (Object)this.getAssignedState(executionJobVertex, operatorId, 0));
    }

    private Set<OperatorID> buildOperatorIds(int operators) {
        HashSet<OperatorID> set = new HashSet<OperatorID>();
        for (int j = 0; j < operators; ++j) {
            set.add(new OperatorID());
        }
        return set;
    }

    private Map<OperatorID, OperatorState> buildOperatorStates(Set<OperatorID> operators, int numSubTasks) {
        Random random = new Random();
        return operators.stream().collect(Collectors.toMap(Function.identity(), operatorID -> {
            OperatorState state = new OperatorState(operatorID, numSubTasks, numSubTasks);
            for (int i = 0; i < numSubTasks; ++i) {
                state.putState(i, new OperatorSubtaskState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random))), new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random))), StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i))), StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i))), new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, random), StateHandleDummyUtil.createNewInputChannelStateHandle(10, random))), new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random), StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random)))));
            }
            return state;
        }));
    }

    private Map<OperatorID, ExecutionJobVertex> buildVertices(Set<OperatorID> operators, int parallelism) {
        return operators.stream().collect(Collectors.toMap(Function.identity(), operatorID -> {
            try {
                return this.buildExecutionJobVertex((OperatorID)operatorID, parallelism);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }));
    }

    private ExecutionJobVertex buildExecutionJobVertex(OperatorID operatorID, int parallelism) throws JobException, JobExecutionException {
        return this.buildExecutionJobVertex(operatorID, operatorID, parallelism);
    }

    private ExecutionJobVertex buildExecutionJobVertex(OperatorID operatorID, OperatorID userDefinedOperatorId, int parallelism) throws JobException, JobExecutionException {
        ExecutionGraph graph = TestingExecutionGraphBuilder.newBuilder().build();
        JobVertex jobVertex = new JobVertex(operatorID.toHexString(), new JobVertexID(), Collections.singletonList(OperatorIDPair.of((OperatorID)operatorID, (OperatorID)userDefinedOperatorId)));
        return new ExecutionJobVertex(graph, jobVertex, parallelism, 1, Time.seconds((long)1L), 1L, 1L);
    }

    private OperatorSubtaskState getAssignedState(ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
        return executionJobVertex.getTaskVertices()[subtaskIdx].getCurrentExecutionAttempt().getTaskRestore().getTaskStateSnapshot().getSubtaskStateByOperatorID(operatorId);
    }
}

