/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.cartesianproduct;

import com.google.common.math.LongMath;
import com.google.common.primitives.Ints;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductCombination;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductEdgeManager;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerReal;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.apache.tez.runtime.library.utils.Grouper;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class FairCartesianProductVertexManager
extends CartesianProductVertexManagerReal {
    private static final Logger LOG = LoggerFactory.getLogger(FairCartesianProductVertexManager.class);
    private CartesianProductUserPayload.CartesianProductConfigProto config;
    private List<String> sourceList;
    private Map<String, Source> sourcesByName = new HashMap<String, Source>();
    private Map<String, SrcVertex> srcVerticesByName = new HashMap<String, SrcVertex>();
    private boolean enableGrouping;
    private int maxParallelism;
    private int numPartitions;
    private long minOpsPerWorker;
    private long minNumRecordForEstimation;
    private boolean vertexReconfigured = false;
    private boolean vertexStarted = false;
    private boolean vertexStartSchedule = false;
    private int numCPSrcNotInConfigureState = 0;
    private int numBroadcastSrcNotInRunningState = 0;
    private Queue<TaskAttemptIdentifier> completedSrcTaskToProcess = new LinkedList<TaskAttemptIdentifier>();
    private RoaringBitmap scheduledTasks = new RoaringBitmap();
    private int parallelism;
    private int[] numChunksPerSrc;
    private Grouper grouper = new Grouper();

    public FairCartesianProductVertexManager(VertexManagerPluginContext context) {
        super(context);
    }

    @Override
    public void initialize(CartesianProductUserPayload.CartesianProductConfigProto config) throws Exception {
        this.config = config;
        this.maxParallelism = config.hasMaxParallelism() ? config.getMaxParallelism() : 1000;
        this.enableGrouping = config.hasEnableGrouping() ? config.getEnableGrouping() : true;
        this.minOpsPerWorker = config.hasMinOpsPerWorker() ? config.getMinOpsPerWorker() : 1000000L;
        this.sourceList = config.getSourcesList();
        this.numPartitions = config.hasNumPartitionsForFairCase() ? config.getNumPartitionsForFairCase() : (int)Math.pow(this.maxParallelism, 1.0 / (double)this.sourceList.size());
        for (Map.Entry e : this.getContext().getInputVertexEdgeProperties().entrySet()) {
            if (((EdgeProperty)e.getValue()).getDataMovementType() == EdgeProperty.DataMovementType.CUSTOM && ((EdgeProperty)e.getValue()).getEdgeManagerDescriptor().getClassName().equals(CartesianProductEdgeManager.class.getName())) {
                this.srcVerticesByName.put((String)e.getKey(), new SrcVertex());
                this.srcVerticesByName.get(e.getKey()).name = (String)e.getKey();
                this.getContext().registerForVertexStateUpdates((String)e.getKey(), EnumSet.of(VertexState.CONFIGURED));
                ++this.numCPSrcNotInConfigureState;
                continue;
            }
            this.getContext().registerForVertexStateUpdates((String)e.getKey(), EnumSet.of(VertexState.RUNNING));
            ++this.numBroadcastSrcNotInRunningState;
        }
        Map srcGroups = this.getContext().getInputVertexGroups();
        for (int i = 0; i < this.sourceList.size(); ++i) {
            String srcName = this.sourceList.get(i);
            Source source = new Source();
            source.position = i;
            if (srcGroups.containsKey(srcName)) {
                source.name = srcName;
                for (String srcVName : (List)srcGroups.get(srcName)) {
                    source.srcVertices.add(this.srcVerticesByName.get(srcVName));
                    this.srcVerticesByName.get((Object)srcVName).source = source;
                }
            } else {
                source.name = srcName;
                source.srcVertices.add(this.srcVerticesByName.get(srcName));
                this.srcVerticesByName.get((Object)srcName).source = source;
            }
            this.sourcesByName.put(srcName, source);
        }
        this.minNumRecordForEstimation = (long)Math.pow(this.minOpsPerWorker * (long)this.maxParallelism, 1.0 / (double)this.sourceList.size());
        this.numChunksPerSrc = new int[this.sourcesByName.size()];
        this.getContext().vertexReconfigurationPlanned();
    }

    @Override
    public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions) throws Exception {
        this.vertexStarted = true;
        if (completions != null) {
            LOG.info("OnVertexStarted with " + completions.size() + " completed source task");
            for (TaskAttemptIdentifier attempt : completions) {
                this.addCompletedSrcTaskToProcess(attempt);
            }
        }
        this.tryScheduleTasks();
    }

    @Override
    public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws IOException {
        String vertex = stateUpdate.getVertexName();
        VertexState state = stateUpdate.getVertexState();
        if (state == VertexState.CONFIGURED) {
            this.srcVerticesByName.get((Object)vertex).numTask = this.getContext().getVertexNumTasks(vertex);
            --this.numCPSrcNotInConfigureState;
        } else if (state == VertexState.RUNNING) {
            --this.numBroadcastSrcNotInRunningState;
        }
        this.tryScheduleTasks();
    }

    @Override
    public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) throws Exception {
        this.addCompletedSrcTaskToProcess(attempt);
        this.tryScheduleTasks();
    }

    private void addCompletedSrcTaskToProcess(TaskAttemptIdentifier attempt) {
        int taskId = attempt.getTaskIdentifier().getIdentifier();
        String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
        SrcVertex srcV = this.srcVerticesByName.get(vertex);
        if (srcV != null && !srcV.taskCompleted.contains(taskId)) {
            srcV.taskCompleted.add(taskId);
            this.completedSrcTaskToProcess.add(attempt);
        }
    }

    private boolean tryStartSchedule() {
        this.vertexStartSchedule = this.vertexReconfigured && this.vertexStarted && this.numBroadcastSrcNotInRunningState == 0;
        return this.vertexStartSchedule;
    }

    @Override
    public synchronized void onVertexManagerEventReceived(VertexManagerEvent vmEvent) throws IOException {
        if (this.vertexReconfigured) {
            return;
        }
        if (vmEvent.getUserPayload() != null) {
            String srcVertex = vmEvent.getProducerAttemptIdentifier().getTaskIdentifier().getVertexIdentifier().getName();
            SrcVertex srcV = this.srcVerticesByName.get(srcVertex);
            if (srcV == null) {
                return;
            }
            ShuffleUserPayloads.VertexManagerEventPayloadProto proto = ShuffleUserPayloads.VertexManagerEventPayloadProto.parseFrom(ByteString.copyFrom((ByteBuffer)vmEvent.getUserPayload()));
            srcV.numRecord += proto.getNumRecord();
            srcV.taskWithVMEvent.add(vmEvent.getProducerAttemptIdentifier().getTaskIdentifier().getIdentifier());
        }
        this.tryScheduleTasks();
    }

    private void reconfigureWithZeroTask() {
        this.getContext().reconfigureVertex(0, null, null);
        this.vertexReconfigured = true;
        this.getContext().doneReconfiguringVertex();
    }

    private boolean tryReconfigure() throws IOException {
        if (this.numCPSrcNotInConfigureState > 0) {
            return false;
        }
        for (Source src : this.sourcesByName.values()) {
            if (src.getNumTask() != 0) continue;
            this.parallelism = 0;
            this.reconfigureWithZeroTask();
            return true;
        }
        if (this.config.hasGroupingFraction() && this.config.getGroupingFraction() > 0.0f) {
            for (SrcVertex srcV : this.srcVerticesByName.values()) {
                if (srcV.taskCompleted.getCardinality() >= srcV.numTask || !((float)srcV.numTask * this.config.getGroupingFraction() > (float)srcV.taskCompleted.getCardinality()) && srcV.numRecord != 0L) continue;
                return false;
            }
        } else {
            for (SrcVertex srcV : this.srcVerticesByName.values()) {
                if (srcV.numRecord >= this.minNumRecordForEstimation || srcV.taskWithVMEvent.getCardinality() >= srcV.numTask) continue;
                return false;
            }
        }
        LOG.info("Start reconfiguring vertex " + this.getContext().getVertexName() + ", max parallelism: " + this.maxParallelism + ", min-ops-per-worker: " + this.minOpsPerWorker + ", num partition: " + this.numPartitions);
        for (Source src : this.sourcesByName.values()) {
            LOG.info(src.toString());
        }
        long totalOps = 1L;
        for (Source src : this.sourcesByName.values()) {
            src.numRecord = src.estimateNumRecord();
            if (src.numRecord == 0L) {
                LOG.info("Set parallelism to 0 because source " + src.name + " has 0 output recorc");
                this.reconfigureWithZeroTask();
                return true;
            }
            try {
                totalOps = LongMath.checkedMultiply((long)totalOps, (long)src.numRecord);
            }
            catch (ArithmeticException e) {
                LOG.info("totalOps exceeds 9223372036854775807, capping to 9223372036854775807");
                totalOps = Long.MAX_VALUE;
            }
        }
        this.parallelism = totalOps / this.minOpsPerWorker >= (long)this.maxParallelism ? this.maxParallelism : (int)((totalOps + this.minOpsPerWorker - 1L) / this.minOpsPerWorker);
        LOG.info("Total ops " + totalOps + ", initial parallelism " + this.parallelism);
        if (this.enableGrouping) {
            this.determineNumChunks(this.sourcesByName, this.parallelism);
        } else {
            for (Source src : this.sourcesByName.values()) {
                src.numChunk = src.getSrcVertexWithMostOutput().numTask;
            }
        }
        this.parallelism = 1;
        for (Source src : this.sourcesByName.values()) {
            this.parallelism *= src.numChunk;
        }
        LOG.info("After reconfigure, final parallelism " + this.parallelism);
        for (Source src : this.sourcesByName.values()) {
            LOG.info(src.toString());
        }
        for (int i = 0; i < this.numChunksPerSrc.length; ++i) {
            this.numChunksPerSrc[i] = this.sourcesByName.get((Object)this.sourceList.get((int)i)).numChunk;
        }
        CartesianProductUserPayload.CartesianProductConfigProto.Builder builder = CartesianProductUserPayload.CartesianProductConfigProto.newBuilder(this.config);
        builder.addAllNumChunks(Ints.asList((int[])this.numChunksPerSrc));
        Map edgeProperties = this.getContext().getInputVertexEdgeProperties();
        Iterator iter = edgeProperties.entrySet().iterator();
        while (iter.hasNext()) {
            Map.Entry e = iter.next();
            if (((EdgeProperty)e.getValue()).getDataMovementType() == EdgeProperty.DataMovementType.CUSTOM) continue;
            iter.remove();
        }
        for (Source src : this.sourcesByName.values()) {
            builder.clearNumTaskPerVertexInGroup();
            for (int i = 0; i < src.srcVertices.size(); ++i) {
                SrcVertex srcV = src.srcVertices.get(i);
                builder.setPositionInGroup(i);
                ((EdgeProperty)edgeProperties.get(srcV.name)).getEdgeManagerDescriptor().setUserPayload(UserPayload.create((ByteBuffer)ByteBuffer.wrap(builder.build().toByteArray())));
                builder.addNumTaskPerVertexInGroup(srcV.numTask);
            }
        }
        this.getContext().reconfigureVertex(this.parallelism, null, edgeProperties);
        this.vertexReconfigured = true;
        this.getContext().doneReconfiguringVertex();
        return true;
    }

    private void determineNumChunks(Map<String, Source> sourcesByName, int parallelism) {
        double k = Math.log10(parallelism);
        for (Source src : sourcesByName.values()) {
            k -= Math.log10(src.numRecord);
        }
        k = Math.pow(10.0, k / (double)sourcesByName.size());
        for (Source src : sourcesByName.values()) {
            if (!((double)src.numRecord * k < 2.0)) continue;
            src.numChunk = 1;
        }
        k = Math.log10(parallelism);
        int numLargeSrc = 0;
        for (Source src : sourcesByName.values()) {
            if (src.numChunk == 1) continue;
            k -= Math.log10(src.numRecord);
            ++numLargeSrc;
        }
        k = Math.pow(10.0, k / (double)numLargeSrc);
        for (Source src : sourcesByName.values()) {
            if (src.numChunk == 1) continue;
            src.numChunk = Math.min(this.maxParallelism, Math.min(src.getSrcVertexWithMostOutput().numTask * this.numPartitions, Math.max(1, (int)((double)src.numRecord * k))));
        }
    }

    private void tryScheduleTasks() throws IOException {
        if (!this.vertexReconfigured && !this.tryReconfigure()) {
            return;
        }
        if (!this.vertexStartSchedule && !this.tryStartSchedule()) {
            return;
        }
        while (!this.completedSrcTaskToProcess.isEmpty()) {
            this.scheduleTasksDependOnCompletion(this.completedSrcTaskToProcess.poll());
        }
    }

    private void scheduleTasksDependOnCompletion(TaskAttemptIdentifier attempt) {
        if (this.parallelism == 0) {
            return;
        }
        int taskId = attempt.getTaskIdentifier().getIdentifier();
        String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
        SrcVertex srcV = this.srcVerticesByName.get(vertex);
        Source src = srcV.source;
        ArrayList<VertexManagerPluginContext.ScheduleTaskRequest> requests = new ArrayList<VertexManagerPluginContext.ScheduleTaskRequest>();
        CartesianProductCombination combination = new CartesianProductCombination(this.numChunksPerSrc, src.position);
        this.grouper.init(srcV.numTask * this.numPartitions, src.numChunk);
        int firstRelevantChunk = this.grouper.getGroupId(taskId * this.numPartitions);
        int lastRelevantChunk = this.grouper.getGroupId(taskId * this.numPartitions + this.numPartitions - 1);
        for (int chunkId = firstRelevantChunk; chunkId <= lastRelevantChunk; ++chunkId) {
            combination.firstTaskWithFixedChunk(chunkId);
            do {
                List<Integer> list = combination.getCombination();
                if (this.scheduledTasks.contains(combination.getTaskId())) continue;
                boolean readyToSchedule = src.isChunkCompleted(list.get(src.position));
                for (int srcId = 0; readyToSchedule && srcId < list.size(); ++srcId) {
                    if (srcId == src.position) continue;
                    readyToSchedule = this.sourcesByName.get(this.sourceList.get(srcId)).isChunkCompleted(list.get(srcId));
                }
                if (!readyToSchedule) continue;
                requests.add(VertexManagerPluginContext.ScheduleTaskRequest.create((int)combination.getTaskId(), null));
                this.scheduledTasks.add(combination.getTaskId());
            } while (combination.nextTaskWithFixedChunk());
        }
        if (!requests.isEmpty()) {
            this.getContext().scheduleTasks(requests);
        }
    }

    class SrcVertex {
        Source source;
        String name;
        int numTask;
        RoaringBitmap taskCompleted = new RoaringBitmap();
        RoaringBitmap taskWithVMEvent = new RoaringBitmap();
        long numRecord;

        SrcVertex() {
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("vertex ").append(this.name).append(", ");
            sb.append(this.numTask).append(" tasks, ");
            sb.append(this.taskWithVMEvent.getCardinality()).append(" VMEvents, ");
            sb.append("numRecord ").append(this.numRecord).append(", ");
            sb.append("estimated # output records ").append(this.estimateNumRecord());
            return sb.toString();
        }

        public long estimateNumRecord() {
            if (this.taskWithVMEvent.isEmpty()) {
                return 0L;
            }
            return this.numRecord * (long)this.numTask / (long)this.taskWithVMEvent.getCardinality();
        }

        public boolean isChunkCompleted(int chunkId) {
            FairCartesianProductVertexManager.this.grouper.init(this.numTask * FairCartesianProductVertexManager.this.numPartitions, this.source.numChunk);
            int firstRelevantTask = FairCartesianProductVertexManager.this.grouper.getFirstItemInGroup(chunkId) / FairCartesianProductVertexManager.this.numPartitions;
            int lastRelevantTask = FairCartesianProductVertexManager.this.grouper.getLastItemInGroup(chunkId) / FairCartesianProductVertexManager.this.numPartitions;
            for (int relevantTask = firstRelevantTask; relevantTask <= lastRelevantTask; ++relevantTask) {
                if (this.taskCompleted.contains(relevantTask)) continue;
                return false;
            }
            return true;
        }
    }

    static class Source {
        List<SrcVertex> srcVertices = new ArrayList<SrcVertex>();
        int position;
        String name;
        int numChunk;
        long numRecord;

        Source() {
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Source at position ");
            sb.append(this.position);
            if (this.name != null) {
                sb.append(", ");
                sb.append("name ");
                sb.append(this.name);
            }
            sb.append(", num chunk ").append(this.numChunk);
            sb.append(": {");
            for (SrcVertex srcV : this.srcVertices) {
                sb.append("[");
                sb.append(srcV.toString());
                sb.append("], ");
            }
            sb.deleteCharAt(sb.length() - 1);
            sb.setCharAt(sb.length() - 1, '}');
            return sb.toString();
        }

        public long estimateNumRecord() {
            long estimation = 0L;
            for (SrcVertex srcV : this.srcVertices) {
                estimation += srcV.estimateNumRecord();
            }
            return estimation;
        }

        private boolean isChunkCompleted(int chunkId) {
            for (SrcVertex srcV : this.srcVertices) {
                if (srcV.isChunkCompleted(chunkId)) continue;
                return false;
            }
            return true;
        }

        public int getNumTask() {
            int numTask = 0;
            for (SrcVertex srcV : this.srcVertices) {
                numTask += srcV.numTask;
            }
            return numTask;
        }

        public SrcVertex getSrcVertexWithMostOutput() {
            SrcVertex srcVWithMaxOutput = null;
            for (SrcVertex srcV : this.srcVertices) {
                if (srcVWithMaxOutput != null && srcV.numRecord <= srcVWithMaxOutput.numRecord) continue;
                srcVWithMaxOutput = srcV;
            }
            return srcVWithMaxOutput;
        }
    }
}

