/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.Checkpoints;
import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateHandleDummyUtil;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.metadata.CheckpointMetadata;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

public class CheckpointMetadataLoadingTest {
    private final ClassLoader cl = this.getClass().getClassLoader();

    @Test
    public void testAllStateRestored() throws Exception {
        JobID jobId = new JobID();
        OperatorID operatorId = new OperatorID();
        long checkpointId = 2147606770L;
        int parallelism = 128128;
        CompletedCheckpointStorageLocation testSavepoint = CheckpointMetadataLoadingTest.createSavepointWithOperatorSubtaskState(2147606770L, operatorId, 128128);
        Map<JobVertexID, ExecutionJobVertex> tasks = CheckpointMetadataLoadingTest.createTasks(operatorId, 128128, 128128);
        CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint((JobID)jobId, tasks, (CompletedCheckpointStorageLocation)testSavepoint, (ClassLoader)this.cl, (boolean)false);
        Assert.assertEquals((Object)jobId, (Object)loaded.getJobId());
        Assert.assertEquals((long)2147606770L, (long)loaded.getCheckpointID());
    }

    @Test
    public void testMaxParallelismMismatch() throws Exception {
        OperatorID operatorId = new OperatorID();
        int parallelism = 128128;
        CompletedCheckpointStorageLocation testSavepoint = CheckpointMetadataLoadingTest.createSavepointWithOperatorSubtaskState(242L, operatorId, 128128);
        Map<JobVertexID, ExecutionJobVertex> tasks = CheckpointMetadataLoadingTest.createTasks(operatorId, 128128, 128129);
        try {
            Checkpoints.loadAndValidateCheckpoint((JobID)new JobID(), tasks, (CompletedCheckpointStorageLocation)testSavepoint, (ClassLoader)this.cl, (boolean)false);
            Assert.fail((String)"Did not throw expected Exception");
        }
        catch (IllegalStateException expected) {
            Assert.assertTrue((boolean)expected.getMessage().contains("Max parallelism mismatch"));
        }
    }

    @Test
    public void testNonRestoredStateWhenDisallowed() throws Exception {
        OperatorID operatorId = new OperatorID();
        int parallelism = 9;
        CompletedCheckpointStorageLocation testSavepoint = CheckpointMetadataLoadingTest.createSavepointWithOperatorSubtaskState(242L, operatorId, 9);
        Map tasks = Collections.emptyMap();
        try {
            Checkpoints.loadAndValidateCheckpoint((JobID)new JobID(), tasks, (CompletedCheckpointStorageLocation)testSavepoint, (ClassLoader)this.cl, (boolean)false);
            Assert.fail((String)"Did not throw expected Exception");
        }
        catch (IllegalStateException expected) {
            Assert.assertTrue((boolean)expected.getMessage().contains("allowNonRestoredState"));
        }
    }

    @Test
    public void testNonRestoredStateWhenAllowed() throws Exception {
        OperatorID operatorId = new OperatorID();
        int parallelism = 9;
        CompletedCheckpointStorageLocation testSavepoint = CheckpointMetadataLoadingTest.createSavepointWithOperatorSubtaskState(242L, operatorId, 9);
        Map tasks = Collections.emptyMap();
        CompletedCheckpoint loaded = Checkpoints.loadAndValidateCheckpoint((JobID)new JobID(), tasks, (CompletedCheckpointStorageLocation)testSavepoint, (ClassLoader)this.cl, (boolean)true);
        Assert.assertTrue((boolean)loaded.getOperatorStates().isEmpty());
    }

    @Test
    public void testUnmatchedCoordinatorOnlyStateFails() throws Exception {
        OperatorID operatorID = new OperatorID();
        int maxParallelism = 1234;
        OperatorState state = new OperatorState(operatorID, 617, 1234);
        state.setCoordinatorState(new ByteStreamStateHandle("coordinatorState", new byte[0]));
        CompletedCheckpointStorageLocation testSavepoint = CheckpointMetadataLoadingTest.createSavepointWithOperatorState(42L, state);
        Map tasks = Collections.emptyMap();
        try {
            Checkpoints.loadAndValidateCheckpoint((JobID)new JobID(), tasks, (CompletedCheckpointStorageLocation)testSavepoint, (ClassLoader)this.cl, (boolean)false);
            Assert.fail((String)"Did not throw expected Exception");
        }
        catch (IllegalStateException expected) {
            Assert.assertTrue((boolean)expected.getMessage().contains("allowNonRestoredState"));
        }
    }

    private static CompletedCheckpointStorageLocation createSavepointWithOperatorState(long checkpointId, OperatorState state) throws IOException {
        ByteStreamStateHandle serializedMetadata;
        CheckpointMetadata savepoint = new CheckpointMetadata(checkpointId, Collections.singletonList(state), Collections.emptyList());
        try (ByteArrayOutputStream os = new ByteArrayOutputStream();){
            Checkpoints.storeCheckpointMetadata((CheckpointMetadata)savepoint, (OutputStream)os);
            serializedMetadata = new ByteStreamStateHandle("checkpoint", os.toByteArray());
        }
        return new TestCompletedCheckpointStorageLocation((StreamStateHandle)serializedMetadata, "dummy/pointer");
    }

    private static CompletedCheckpointStorageLocation createSavepointWithOperatorSubtaskState(long checkpointId, OperatorID operatorId, int parallelism) throws IOException {
        Random rnd = new Random();
        OperatorSubtaskState subtaskState = OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)new OperatorStreamStateHandle(Collections.emptyMap(), (StreamStateHandle)new ByteStreamStateHandle("testHandler", new byte[0]))).setInputChannelState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewInputChannelStateHandle(10, rnd))).setResultSubpartitionState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, rnd))).build();
        OperatorState state = new OperatorState(operatorId, parallelism, parallelism);
        state.putState(0, subtaskState);
        return CheckpointMetadataLoadingTest.createSavepointWithOperatorState(checkpointId, state);
    }

    private static Map<JobVertexID, ExecutionJobVertex> createTasks(OperatorID operatorId, int parallelism, int maxParallelism) {
        JobVertexID vertexId = new JobVertexID(operatorId.getLowerPart(), operatorId.getUpperPart());
        ExecutionJobVertex vertex = (ExecutionJobVertex)Mockito.mock(ExecutionJobVertex.class);
        Mockito.when((Object)vertex.getParallelism()).thenReturn((Object)parallelism);
        Mockito.when((Object)vertex.getMaxParallelism()).thenReturn((Object)maxParallelism);
        Mockito.when((Object)vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorIDPair.generatedIDOnly((OperatorID)operatorId)));
        if (parallelism != maxParallelism) {
            Mockito.when((Object)vertex.canRescaleMaxParallelism(ArgumentMatchers.anyInt())).thenReturn((Object)false);
        }
        HashMap<JobVertexID, ExecutionJobVertex> tasks = new HashMap<JobVertexID, ExecutionJobVertex>();
        tasks.put(vertexId, vertex);
        return tasks;
    }
}

