package org.apache.tez.dag.api.client.rpc;

import com.google.protobuf.RpcController;
import com.google.protobuf.ServiceException;
import java.io.IOException;
import java.util.Set;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.tez.client.FrameworkClient;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.client.DAGClient;
import org.apache.tez.dag.api.client.DAGClientImpl;
import org.apache.tez.dag.api.client.DAGStatus;
import org.apache.tez.dag.api.client.StatusGetOpts;
import org.apache.tez.dag.api.client.VertexStatus;
import org.apache.tez.dag.api.client.rpc.DAGClientAMProtocolRPC;
import org.apache.tez.dag.api.records.DAGProtos;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatcher;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.internal.util.collections.Sets;

/* loaded from: input_file:org/apache/tez/dag/api/client/rpc/TestDAGClient.class */
public class TestDAGClient {
    private DAGClient dagClient;
    private ApplicationId mockAppId;
    private ApplicationReport mockAppReport;
    private String dagIdStr;
    private DAGClientAMProtocolBlockingPB mockProxy;
    private DAGProtos.VertexStatusProto vertexStatusProtoWithoutCounters;
    private DAGProtos.VertexStatusProto vertexStatusProtoWithCounters;
    private DAGProtos.DAGStatusProto dagStatusProtoWithoutCounters;
    private DAGProtos.DAGStatusProto dagStatusProtoWithCounters;

    /* loaded from: input_file:org/apache/tez/dag/api/client/rpc/TestDAGClient$DAGCounterRequestMatcher.class */
    private static class DAGCounterRequestMatcher extends ArgumentMatcher<DAGClientAMProtocolRPC.GetDAGStatusRequestProto> {
        private DAGCounterRequestMatcher() {
        }

        public boolean matches(Object obj) {
            if (!(obj instanceof DAGClientAMProtocolRPC.GetDAGStatusRequestProto)) {
                return false;
            }
            DAGClientAMProtocolRPC.GetDAGStatusRequestProto getDAGStatusRequestProto = (DAGClientAMProtocolRPC.GetDAGStatusRequestProto) obj;
            return getDAGStatusRequestProto.getStatusOptionsCount() != 0 && getDAGStatusRequestProto.getStatusOptionsList().get(0) == DAGProtos.StatusGetOptsProto.GET_COUNTERS;
        }
    }

    /* loaded from: input_file:org/apache/tez/dag/api/client/rpc/TestDAGClient$VertexCounterRequestMatcher.class */
    private static class VertexCounterRequestMatcher extends ArgumentMatcher<DAGClientAMProtocolRPC.GetVertexStatusRequestProto> {
        private VertexCounterRequestMatcher() {
        }

        public boolean matches(Object obj) {
            if (!(obj instanceof DAGClientAMProtocolRPC.GetVertexStatusRequestProto)) {
                return false;
            }
            DAGClientAMProtocolRPC.GetVertexStatusRequestProto getVertexStatusRequestProto = (DAGClientAMProtocolRPC.GetVertexStatusRequestProto) obj;
            return getVertexStatusRequestProto.getStatusOptionsCount() != 0 && getVertexStatusRequestProto.getStatusOptionsList().get(0) == DAGProtos.StatusGetOptsProto.GET_COUNTERS;
        }
    }

    private void setUpData() {
        DAGProtos.ProgressProto build = DAGProtos.ProgressProto.newBuilder().setFailedTaskCount(1).setKilledTaskCount(1).setRunningTaskCount(2).setSucceededTaskCount(2).setTotalTaskCount(6).build();
        DAGProtos.TezCountersProto build2 = DAGProtos.TezCountersProto.newBuilder().addCounterGroups(DAGProtos.TezCounterGroupProto.newBuilder().setName("DAGGroup").addCounters(DAGProtos.TezCounterProto.newBuilder().setDisplayName("dag_counter_1").setValue(99L))).build();
        this.dagStatusProtoWithoutCounters = DAGProtos.DAGStatusProto.newBuilder().addDiagnostics("Diagnostics_0").setState(DAGProtos.DAGStatusStateProto.DAG_RUNNING).setDAGProgress(build).addVertexProgress(DAGProtos.StringProgressPairProto.newBuilder().setKey("v1").setProgress(DAGProtos.ProgressProto.newBuilder().setFailedTaskCount(0).setSucceededTaskCount(0).setKilledTaskCount(0))).addVertexProgress(DAGProtos.StringProgressPairProto.newBuilder().setKey("v2").setProgress(DAGProtos.ProgressProto.newBuilder().setFailedTaskCount(1).setSucceededTaskCount(1).setKilledTaskCount(1))).build();
        this.dagStatusProtoWithCounters = DAGProtos.DAGStatusProto.newBuilder(this.dagStatusProtoWithoutCounters).setDagCounters(build2).build();
        DAGProtos.ProgressProto build3 = DAGProtos.ProgressProto.newBuilder().setFailedTaskCount(1).setKilledTaskCount(0).setRunningTaskCount(0).setSucceededTaskCount(1).build();
        DAGProtos.TezCountersProto build4 = DAGProtos.TezCountersProto.newBuilder().addCounterGroups(DAGProtos.TezCounterGroupProto.newBuilder().addCounters(DAGProtos.TezCounterProto.newBuilder().setDisplayName("vertex_counter_1").setValue(99L))).build();
        this.vertexStatusProtoWithoutCounters = DAGProtos.VertexStatusProto.newBuilder().addDiagnostics("V_Diagnostics_0").setProgress(build3).setState(DAGProtos.VertexStatusStateProto.VERTEX_SUCCEEDED).build();
        this.vertexStatusProtoWithCounters = DAGProtos.VertexStatusProto.newBuilder(this.vertexStatusProtoWithoutCounters).setVertexCounters(build4).build();
    }

    @Before
    public void setUp() throws YarnException, IOException, TezException, ServiceException {
        setUpData();
        this.mockAppId = (ApplicationId) Mockito.mock(ApplicationId.class);
        this.mockAppReport = (ApplicationReport) Mockito.mock(ApplicationReport.class);
        this.dagIdStr = "dag_9999_0001_1";
        this.mockProxy = (DAGClientAMProtocolBlockingPB) Mockito.mock(DAGClientAMProtocolBlockingPB.class);
        Mockito.when(this.mockProxy.getDAGStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetDAGStatusRequestProto) Matchers.any(DAGClientAMProtocolRPC.GetDAGStatusRequestProto.class))).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithoutCounters).build());
        Mockito.when(this.mockProxy.getDAGStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetDAGStatusRequestProto) Matchers.argThat(new DAGCounterRequestMatcher()))).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithCounters).build());
        Mockito.when(this.mockProxy.getVertexStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetVertexStatusRequestProto) Matchers.any(DAGClientAMProtocolRPC.GetVertexStatusRequestProto.class))).thenReturn(DAGClientAMProtocolRPC.GetVertexStatusResponseProto.newBuilder().setVertexStatus(this.vertexStatusProtoWithoutCounters).build());
        Mockito.when(this.mockProxy.getVertexStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetVertexStatusRequestProto) Matchers.argThat(new VertexCounterRequestMatcher()))).thenReturn(DAGClientAMProtocolRPC.GetVertexStatusResponseProto.newBuilder().setVertexStatus(this.vertexStatusProtoWithCounters).build());
        this.dagClient = new DAGClientImpl(this.mockAppId, this.dagIdStr, new TezConfiguration(), (FrameworkClient) null);
        DAGClientRPCImpl realClient = this.dagClient.getRealClient();
        realClient.appReport = this.mockAppReport;
        realClient.proxy = this.mockProxy;
    }

    @Test
    public void testApp() throws IOException, TezException, ServiceException {
        Assert.assertTrue(this.dagClient.getExecutionContext().contains(this.mockAppId.toString()));
        Assert.assertEquals(this.mockAppReport, this.dagClient.getRealClient().getApplicationReportInternal());
    }

    @Test
    public void testDAGStatus() throws Exception {
        DAGStatus dAGStatus = this.dagClient.getDAGStatus((Set) null);
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy, Mockito.times(1))).getDAGStatus((RpcController) null, DAGClientAMProtocolRPC.GetDAGStatusRequestProto.newBuilder().setDagId(this.dagIdStr).build());
        Assert.assertEquals(new DAGStatus(this.dagStatusProtoWithoutCounters), dAGStatus);
        System.out.println("DAGStatusWithoutCounter:" + dAGStatus);
        DAGStatus dAGStatus2 = this.dagClient.getDAGStatus(Sets.newSet(new StatusGetOpts[]{StatusGetOpts.GET_COUNTERS}));
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy, Mockito.times(1))).getDAGStatus((RpcController) null, DAGClientAMProtocolRPC.GetDAGStatusRequestProto.newBuilder().setDagId(this.dagIdStr).addStatusOptions(DAGProtos.StatusGetOptsProto.GET_COUNTERS).build());
        Assert.assertEquals(new DAGStatus(this.dagStatusProtoWithCounters), dAGStatus2);
        System.out.println("DAGStatusWithCounter:" + dAGStatus2);
    }

    @Test
    public void testVertexStatus() throws Exception {
        VertexStatus vertexStatus = this.dagClient.getVertexStatus("v1", (Set) null);
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy)).getVertexStatus((RpcController) null, DAGClientAMProtocolRPC.GetVertexStatusRequestProto.newBuilder().setDagId(this.dagIdStr).setVertexName("v1").build());
        Assert.assertEquals(new VertexStatus(this.vertexStatusProtoWithoutCounters), vertexStatus);
        System.out.println("VertexWithoutCounter:" + vertexStatus);
        VertexStatus vertexStatus2 = this.dagClient.getVertexStatus("v1", Sets.newSet(new StatusGetOpts[]{StatusGetOpts.GET_COUNTERS}));
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy)).getVertexStatus((RpcController) null, DAGClientAMProtocolRPC.GetVertexStatusRequestProto.newBuilder().setDagId(this.dagIdStr).setVertexName("v1").addStatusOptions(DAGProtos.StatusGetOptsProto.GET_COUNTERS).build());
        Assert.assertEquals(new VertexStatus(this.vertexStatusProtoWithCounters), vertexStatus2);
        System.out.println("VertexWithCounter:" + vertexStatus2);
    }

    @Test
    public void testTryKillDAG() throws Exception {
        this.dagClient.tryKillDAG();
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy, Mockito.times(1))).tryKillDAG((RpcController) null, DAGClientAMProtocolRPC.TryKillDAGRequestProto.newBuilder().setDagId(this.dagIdStr).build());
    }

    @Test
    public void testWaitForCompletion() throws Exception {
        Mockito.when(this.mockProxy.getDAGStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetDAGStatusRequestProto) Matchers.any(DAGClientAMProtocolRPC.GetDAGStatusRequestProto.class))).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithoutCounters).build()).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(DAGProtos.DAGStatusProto.newBuilder(this.dagStatusProtoWithoutCounters).setState(DAGProtos.DAGStatusStateProto.DAG_SUCCEEDED).build()).build());
        this.dagClient.waitForCompletion();
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy, Mockito.times(2))).getDAGStatus((RpcController) null, DAGClientAMProtocolRPC.GetDAGStatusRequestProto.newBuilder().setDagId(this.dagIdStr).build());
    }

    @Test
    public void testWaitForCompletionWithStatusUpdates() throws Exception {
        Mockito.when(this.mockProxy.getDAGStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetDAGStatusRequestProto) Matchers.any(DAGClientAMProtocolRPC.GetDAGStatusRequestProto.class))).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithoutCounters).build()).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithoutCounters).build()).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(DAGProtos.DAGStatusProto.newBuilder(this.dagStatusProtoWithoutCounters).setState(DAGProtos.DAGStatusStateProto.DAG_SUCCEEDED).build()).build());
        this.dagClient.waitForCompletionWithStatusUpdates((Set) null);
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy, Mockito.times(3))).getDAGStatus((RpcController) null, DAGClientAMProtocolRPC.GetDAGStatusRequestProto.newBuilder().setDagId(this.dagIdStr).build());
        Mockito.when(this.mockProxy.getDAGStatus((RpcController) Matchers.isNull(RpcController.class), (DAGClientAMProtocolRPC.GetDAGStatusRequestProto) Matchers.any(DAGClientAMProtocolRPC.GetDAGStatusRequestProto.class))).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithCounters).build()).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(this.dagStatusProtoWithCounters).build()).thenReturn(DAGClientAMProtocolRPC.GetDAGStatusResponseProto.newBuilder().setDagStatus(DAGProtos.DAGStatusProto.newBuilder(this.dagStatusProtoWithCounters).setState(DAGProtos.DAGStatusStateProto.DAG_SUCCEEDED).build()).build());
        this.dagClient.waitForCompletionWithStatusUpdates(Sets.newSet(new StatusGetOpts[]{StatusGetOpts.GET_COUNTERS}));
        ((DAGClientAMProtocolBlockingPB) Mockito.verify(this.mockProxy, Mockito.times(3))).getDAGStatus((RpcController) null, DAGClientAMProtocolRPC.GetDAGStatusRequestProto.newBuilder().setDagId(this.dagIdStr).addStatusOptions(DAGProtos.StatusGetOptsProto.GET_COUNTERS).build());
    }
}
