/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.stream.Collectors;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;

public class EdgeManagerBuildUtil {
    static void connectVertexToResult(ExecutionJobVertex vertex, IntermediateResult intermediateResult, DistributionPattern distributionPattern) {
        switch (distributionPattern) {
            case POINTWISE: {
                EdgeManagerBuildUtil.connectPointwise(vertex.getTaskVertices(), intermediateResult);
                break;
            }
            case ALL_TO_ALL: {
                EdgeManagerBuildUtil.connectAllToAll(vertex.getTaskVertices(), intermediateResult);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unrecognized distribution pattern.");
            }
        }
    }

    private static void connectAllToAll(ExecutionVertex[] taskVertices, IntermediateResult intermediateResult) {
        ConsumedPartitionGroup consumedPartitions = ConsumedPartitionGroup.fromMultiplePartitions(Arrays.stream(intermediateResult.getPartitions()).map(IntermediateResultPartition::getPartitionId).collect(Collectors.toList()));
        for (ExecutionVertex ev : taskVertices) {
            ev.addConsumedPartitionGroup(consumedPartitions);
        }
        ConsumerVertexGroup vertices = ConsumerVertexGroup.fromMultipleVertices(Arrays.stream(taskVertices).map(ExecutionVertex::getID).collect(Collectors.toList()));
        for (IntermediateResultPartition partition : intermediateResult.getPartitions()) {
            partition.addConsumers(vertices);
        }
    }

    private static void connectPointwise(ExecutionVertex[] taskVertices, IntermediateResult intermediateResult) {
        int targetCount;
        int sourceCount = intermediateResult.getPartitions().length;
        if (sourceCount == (targetCount = taskVertices.length)) {
            for (int i = 0; i < sourceCount; ++i) {
                ExecutionVertex executionVertex = taskVertices[i];
                IntermediateResultPartition partition = intermediateResult.getPartitions()[i];
                ConsumerVertexGroup consumerVertexGroup = ConsumerVertexGroup.fromSingleVertex(executionVertex.getID());
                partition.addConsumers(consumerVertexGroup);
                ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromSinglePartition(partition.getPartitionId());
                executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
            }
        } else if (sourceCount > targetCount) {
            for (int index = 0; index < targetCount; ++index) {
                ExecutionVertex executionVertex = taskVertices[index];
                ConsumerVertexGroup consumerVertexGroup = ConsumerVertexGroup.fromSingleVertex(executionVertex.getID());
                int start2 = index * sourceCount / targetCount;
                int end = (index + 1) * sourceCount / targetCount;
                ArrayList<IntermediateResultPartitionID> consumedPartitions = new ArrayList<IntermediateResultPartitionID>(end - start2);
                for (int i = start2; i < end; ++i) {
                    IntermediateResultPartition partition = intermediateResult.getPartitions()[i];
                    partition.addConsumers(consumerVertexGroup);
                    consumedPartitions.add(partition.getPartitionId());
                }
                ConsumedPartitionGroup consumedPartitionGroup = ConsumedPartitionGroup.fromMultiplePartitions(consumedPartitions);
                executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
            }
        } else {
            for (int partitionNum = 0; partitionNum < sourceCount; ++partitionNum) {
                IntermediateResultPartition partition = intermediateResult.getPartitions()[partitionNum];
                ConsumedPartitionGroup consumerPartitionGroup = ConsumedPartitionGroup.fromSinglePartition(partition.getPartitionId());
                int start3 = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
                ArrayList<ExecutionVertexID> consumers = new ArrayList<ExecutionVertexID>(end - start3);
                for (int i = start3; i < end; ++i) {
                    ExecutionVertex executionVertex = taskVertices[i];
                    executionVertex.addConsumedPartitionGroup(consumerPartitionGroup);
                    consumers.add(executionVertex.getID());
                }
                ConsumerVertexGroup consumerVertexGroup = ConsumerVertexGroup.fromMultipleVertices(consumers);
                partition.addConsumers(consumerVertexGroup);
            }
        }
    }
}

