/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import org.apache.flink.core.memory.HeapMemorySegment;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateReaderImpl;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.memory.NonPersistentMetadataCheckpointStorageLocation;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.BiFunctionWithException;
import org.junit.Assert;
import org.junit.Test;

public class ChannelPersistenceITCase {
    private static final Random RANDOM = new Random(System.currentTimeMillis());

    @Test
    public void testReadWritten() throws Exception {
        long checkpointId = 1L;
        InputChannelInfo inputChannelInfo = new InputChannelInfo(2, 3);
        byte[] inputChannelInfoData = this.randomBytes(1024);
        ResultSubpartitionInfo resultSubpartitionInfo = new ResultSubpartitionInfo(4, 5);
        byte[] resultSubpartitionInfoData = this.randomBytes(1024);
        ChannelStateWriter.ChannelStateWriteResult handles = this.write(checkpointId, Collections.singletonMap(inputChannelInfo, inputChannelInfoData), Collections.singletonMap(resultSubpartitionInfo, resultSubpartitionInfoData));
        Assert.assertArrayEquals((byte[])inputChannelInfoData, (byte[])this.read(this.toTaskStateSnapshot(handles), inputChannelInfoData.length, (BiFunctionWithException<ChannelStateReader, MemorySegment, ChannelStateReader.ReadResult, Exception>)((BiFunctionWithException)(reader, mem) -> reader.readInputData(inputChannelInfo, (Buffer)new NetworkBuffer(mem, FreeingBufferRecycler.INSTANCE)))));
        Assert.assertArrayEquals((byte[])resultSubpartitionInfoData, (byte[])this.read(this.toTaskStateSnapshot(handles), resultSubpartitionInfoData.length, (BiFunctionWithException<ChannelStateReader, MemorySegment, ChannelStateReader.ReadResult, Exception>)((BiFunctionWithException)(reader, mem) -> reader.readOutputData(resultSubpartitionInfo, new BufferBuilder(mem, FreeingBufferRecycler.INSTANCE)))));
    }

    private byte[] randomBytes(int size) {
        byte[] bytes = new byte[size];
        RANDOM.nextBytes(bytes);
        return bytes;
    }

    private ChannelStateWriter.ChannelStateWriteResult write(long checkpointId, Map<InputChannelInfo, byte[]> icMap, Map<ResultSubpartitionInfo, byte[]> rsMap) throws Exception {
        int maxStateSize = ChannelPersistenceITCase.sizeOfBytes(icMap) + ChannelPersistenceITCase.sizeOfBytes(rsMap) + 16;
        Map<InputChannelInfo, Buffer> icBuffers = this.wrapWithBuffers(icMap);
        Map<ResultSubpartitionInfo, Buffer> rsBuffers = this.wrapWithBuffers(rsMap);
        try (ChannelStateWriterImpl writer = new ChannelStateWriterImpl("test", ChannelPersistenceITCase.getStreamFactoryFactory(maxStateSize));){
            writer.open();
            writer.start(checkpointId, new CheckpointOptions(CheckpointType.CHECKPOINT, new CheckpointStorageLocationReference("poly".getBytes())));
            for (Map.Entry<InputChannelInfo, Buffer> channelStateWriteResult : icBuffers.entrySet()) {
                writer.addInputData(checkpointId, channelStateWriteResult.getKey(), -2, CloseableIterator.ofElements(Buffer::recycleBuffer, (Object[])new Buffer[]{channelStateWriteResult.getValue()}));
            }
            writer.finishInput(checkpointId);
            for (Map.Entry<InputChannelInfo, Buffer> entry : rsBuffers.entrySet()) {
                writer.addOutputData(checkpointId, (ResultSubpartitionInfo)entry.getKey(), -2, new Buffer[]{entry.getValue()});
            }
            writer.finishOutput(checkpointId);
            ChannelStateWriter.ChannelStateWriteResult result = writer.getAndRemoveWriteResult(checkpointId);
            result.getResultSubpartitionStateHandles().join();
            ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult = result;
            return channelStateWriteResult;
        }
    }

    public static CheckpointStorageWorkerView getStreamFactoryFactory() {
        return ChannelPersistenceITCase.getStreamFactoryFactory(42);
    }

    public static CheckpointStorageWorkerView getStreamFactoryFactory(final int maxStateSize) {
        return new CheckpointStorageWorkerView(){

            public CheckpointStreamFactory resolveCheckpointStorageLocation(long checkpointId, CheckpointStorageLocationReference reference) {
                return new NonPersistentMetadataCheckpointStorageLocation(maxStateSize);
            }

            public CheckpointStreamFactory.CheckpointStateOutputStream createTaskOwnedStateStream() {
                throw new UnsupportedOperationException();
            }
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private byte[] read(TaskStateSnapshot taskStateSnapshot, int size, BiFunctionWithException<ChannelStateReader, MemorySegment, ChannelStateReader.ReadResult, Exception> readFn) throws Exception {
        byte[] dst = new byte[size];
        HeapMemorySegment mem = HeapMemorySegment.FACTORY.wrap(dst);
        try {
            Preconditions.checkState((ChannelStateReader.ReadResult.NO_MORE_DATA == readFn.apply((Object)new ChannelStateReaderImpl(taskStateSnapshot), (Object)mem) ? 1 : 0) != 0);
        }
        finally {
            mem.free();
        }
        return dst;
    }

    private TaskStateSnapshot toTaskStateSnapshot(ChannelStateWriter.ChannelStateWriteResult t) throws Exception {
        return new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), new OperatorSubtaskState(StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), new StateObjectCollection((Collection)t.getInputChannelStateHandles().get()), new StateObjectCollection((Collection)t.getResultSubpartitionStateHandles().get()))));
    }

    private <C> List<C> collect(Collection<StateObject> handles, Class<C> clazz) {
        return handles.stream().filter(clazz::isInstance).map(h -> h).collect(Collectors.toList());
    }

    private static int sizeOfBytes(Map<?, byte[]> map) {
        return map.values().stream().mapToInt(d -> ((byte[])d).length).sum();
    }

    private <K> Map<K, Buffer> wrapWithBuffers(Map<K, byte[]> icMap) {
        return icMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ChannelPersistenceITCase.wrapWithBuffer((byte[])e.getValue())));
    }

    private static Buffer wrapWithBuffer(byte[] data) {
        NetworkBuffer buffer = new NetworkBuffer((MemorySegment)HeapMemorySegment.FACTORY.allocateUnpooledSegment(data.length, null), FreeingBufferRecycler.INSTANCE);
        buffer.writeBytes(data);
        return buffer;
    }
}

