/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.processors.utils;

import java.util.Collections;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.flink.annotation.Internal;
import org.apache.flink.streaming.api.transformations.ShuffleMode;
import org.apache.flink.table.planner.plan.nodes.exec.AbstractExecNodeExactlyOnceVisitor;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecExchange;
import org.apache.flink.table.planner.plan.processors.utils.InputPriorityGraphGenerator;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;

@Internal
public class InputPriorityConflictResolver
extends InputPriorityGraphGenerator {
    private final ShuffleMode shuffleMode;

    public InputPriorityConflictResolver(List<ExecNode<?, ?>> roots, ExecEdge.DamBehavior safeDamBehavior, ShuffleMode shuffleMode) {
        super(roots, Collections.emptySet(), safeDamBehavior);
        this.shuffleMode = shuffleMode;
    }

    public void detectAndResolve() {
        this.createTopologyGraph();
    }

    @Override
    protected void resolveInputPriorityConflict(ExecNode<?, ?> node, int higherInput, int lowerInput) {
        ExecNode<?, ?> higherNode = node.getInputNodes().get(higherInput);
        ExecNode<?, ?> lowerNode = node.getInputNodes().get(lowerInput);
        if (lowerNode instanceof BatchExecExchange) {
            BatchExecExchange exchange = (BatchExecExchange)lowerNode;
            if (this.isConflictCausedByExchange(higherNode, exchange)) {
                BatchExecExchange newExchange = exchange.copy(exchange.getTraitSet(), exchange.getInput(), exchange.getDistribution());
                newExchange.setRequiredShuffleMode(this.shuffleMode);
                node.replaceInputNode(lowerInput, newExchange);
            } else {
                exchange.setRequiredShuffleMode(this.shuffleMode);
            }
        } else {
            node.replaceInputNode(lowerInput, this.createExchange(node, lowerInput));
        }
    }

    private boolean isConflictCausedByExchange(ExecNode<?, ?> higherNode, BatchExecExchange lowerNode) {
        ConflictCausedByExchangeChecker checker = new ConflictCausedByExchangeChecker(lowerNode);
        checker.visit(higherNode);
        return checker.found;
    }

    private BatchExecExchange createExchange(ExecNode<?, ?> node, int idx) {
        FlinkRelDistribution distribution;
        RelNode inputRel = (RelNode)((Object)node.getInputNodes().get(idx));
        ExecEdge.RequiredShuffle requiredShuffle = node.getInputEdges().get(idx).getRequiredShuffle();
        if (requiredShuffle.getType() == ExecEdge.ShuffleType.HASH) {
            distribution = FlinkRelDistribution.hash(requiredShuffle.getKeys(), true);
        } else {
            if (requiredShuffle.getType() == ExecEdge.ShuffleType.BROADCAST) {
                throw new IllegalStateException("Trying to resolve input priority conflict on broadcast side. This is not expected.");
            }
            distribution = requiredShuffle.getType() == ExecEdge.ShuffleType.SINGLETON ? FlinkRelDistribution.SINGLETON() : FlinkRelDistribution.ANY();
        }
        BatchExecExchange exchange = new BatchExecExchange(inputRel.getCluster(), inputRel.getTraitSet().replace(distribution), inputRel, distribution);
        exchange.setRequiredShuffleMode(this.shuffleMode);
        return exchange;
    }

    private static class ConflictCausedByExchangeChecker
    extends AbstractExecNodeExactlyOnceVisitor {
        private final BatchExecExchange exchange;
        private boolean found;

        private ConflictCausedByExchangeChecker(BatchExecExchange exchange) {
            this.exchange = exchange;
        }

        @Override
        protected void visitNode(ExecNode<?, ?> node) {
            if (node == this.exchange) {
                this.found = true;
            }
            for (ExecNode<?, ?> inputNode : node.getInputNodes()) {
                this.visit(inputNode);
                if (!this.found) continue;
                return;
            }
        }
    }
}

