/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.workload;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.RewriteCompressedReblock;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.compress.workload.WTreeNode;

public class WorkloadAnalyzer {
    public static Map<Long, WTreeNode> getAllCandidateWorkloads(DMLProgram prog) {
        List<Hop> candidates = WorkloadAnalyzer.getCandidates(prog);
        HashMap<Long, WTreeNode> map = new HashMap<Long, WTreeNode>();
        for (Hop cand : candidates) {
            WTreeNode tree = WorkloadAnalyzer.createWorkloadTree(prog, cand);
            WorkloadAnalyzer.pruneWorkloadTree(tree);
            map.put(cand.getHopID(), tree);
        }
        return map;
    }

    public static List<Hop> getCandidates(DMLProgram prog) {
        ArrayList<Hop> candidates = new ArrayList<Hop>();
        for (StatementBlock sb : prog.getStatementBlocks()) {
            WorkloadAnalyzer.getCandidates(sb, prog, candidates, new HashSet<String>());
        }
        return candidates;
    }

    public static WTreeNode createWorkloadTree(DMLProgram prog, Hop candidate) {
        WTreeNode main = new WTreeNode(WTreeNode.WTNodeType.MAIN);
        HashSet<String> compressed = new HashSet<String>();
        compressed.add(candidate.getName());
        for (StatementBlock sb : prog.getStatementBlocks()) {
            main.addChild(WorkloadAnalyzer.createWorkloadTree(sb, prog, compressed, new HashSet<String>()));
        }
        return main;
    }

    public static boolean pruneWorkloadTree(WTreeNode node) {
        Iterator<WTreeNode> iter = node.getChildNodes().iterator();
        while (iter.hasNext()) {
            if (!WorkloadAnalyzer.pruneWorkloadTree(iter.next())) continue;
            iter.remove();
        }
        return node.getChildNodes().isEmpty() && node.getCompressedOps().isEmpty();
    }

    private static void getCandidates(StatementBlock sb, DMLProgram prog, List<Hop> cands, Set<String> fStack) {
        if (sb instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                WorkloadAnalyzer.getCandidates(csb, prog, cands, fStack);
            }
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            for (StatementBlock csb : wstmt.getBody()) {
                WorkloadAnalyzer.getCandidates(csb, prog, cands, fStack);
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            for (StatementBlock csb : istmt.getIfBody()) {
                WorkloadAnalyzer.getCandidates(csb, prog, cands, fStack);
            }
            for (StatementBlock csb : istmt.getElseBody()) {
                WorkloadAnalyzer.getCandidates(csb, prog, cands, fStack);
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fstmt = (ForStatement)fsb.getStatement(0);
            for (StatementBlock csb : fstmt.getBody()) {
                WorkloadAnalyzer.getCandidates(csb, prog, cands, fStack);
            }
        } else {
            if (sb.getHops() == null) {
                return;
            }
            Hop.resetVisitStatus(sb.getHops());
            for (Hop hop : sb.getHops()) {
                WorkloadAnalyzer.getCandidates(hop, prog, cands, fStack);
            }
            Hop.resetVisitStatus(sb.getHops());
        }
    }

    private static void getCandidates(Hop hop, DMLProgram prog, List<Hop> cands, Set<String> fStack) {
        FunctionOp fop;
        if (hop.isVisited()) {
            return;
        }
        if (RewriteCompressedReblock.satisfiesCompressionCondition(hop)) {
            cands.add(hop);
        }
        for (Hop c : hop.getInput()) {
            WorkloadAnalyzer.getCandidates(c, prog, cands, fStack);
        }
        if (hop instanceof FunctionOp && !fStack.contains((fop = (FunctionOp)hop).getFunctionKey())) {
            fStack.add(fop.getFunctionKey());
            WorkloadAnalyzer.getCandidates(prog.getFunctionStatementBlock(fop.getFunctionKey()), prog, cands, fStack);
            fStack.remove(fop.getFunctionKey());
        }
        hop.setVisited();
    }

    private static WTreeNode createWorkloadTree(StatementBlock sb, DMLProgram prog, Set<String> compressed, Set<String> fStack) {
        WTreeNode node = null;
        if (sb instanceof FunctionStatementBlock) {
            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            node = new WTreeNode(WTreeNode.WTNodeType.FCALL);
            for (StatementBlock csb : fstmt.getBody()) {
                node.addChild(WorkloadAnalyzer.createWorkloadTree(csb, prog, compressed, fStack));
            }
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            node = new WTreeNode(WTreeNode.WTNodeType.WHILE);
            WorkloadAnalyzer.createWorkloadTree(wsb.getPredicateHops(), prog, node, compressed, fStack);
            for (StatementBlock csb : wstmt.getBody()) {
                node.addChild(WorkloadAnalyzer.createWorkloadTree(csb, prog, compressed, fStack));
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            node = new WTreeNode(WTreeNode.WTNodeType.IF);
            WorkloadAnalyzer.createWorkloadTree(isb.getPredicateHops(), prog, node, compressed, fStack);
            for (StatementBlock csb : istmt.getIfBody()) {
                node.addChild(WorkloadAnalyzer.createWorkloadTree(csb, prog, compressed, fStack));
            }
            for (StatementBlock csb : istmt.getElseBody()) {
                node.addChild(WorkloadAnalyzer.createWorkloadTree(csb, prog, compressed, fStack));
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fstmt = (ForStatement)fsb.getStatement(0);
            node = new WTreeNode(sb instanceof ParForStatementBlock ? WTreeNode.WTNodeType.PARFOR : WTreeNode.WTNodeType.FOR);
            WorkloadAnalyzer.createWorkloadTree(fsb.getFromHops(), prog, node, compressed, fStack);
            WorkloadAnalyzer.createWorkloadTree(fsb.getToHops(), prog, node, compressed, fStack);
            WorkloadAnalyzer.createWorkloadTree(fsb.getIncrementHops(), prog, node, compressed, fStack);
            for (StatementBlock csb : fstmt.getBody()) {
                node.addChild(WorkloadAnalyzer.createWorkloadTree(csb, prog, compressed, fStack));
            }
        } else {
            node = new WTreeNode(WTreeNode.WTNodeType.BASIC_BLOCK);
            if (sb.getHops() != null) {
                Hop.resetVisitStatus(sb.getHops());
                HashSet<Long> compressed2 = new HashSet<Long>();
                for (Hop hop : sb.getHops()) {
                    WorkloadAnalyzer.createWorkloadTree(hop, prog, node, compressed, compressed2, fStack);
                }
                for (Hop hop : sb.getHops()) {
                    if (hop instanceof FunctionOp) {
                        FunctionOp fop = (FunctionOp)hop;
                        if (fStack.contains(fop.getFunctionKey())) continue;
                        fStack.add(fop.getFunctionKey());
                        FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionKey());
                        FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
                        HashSet<String> fCompressed = new HashSet<String>();
                        ArrayList<DataIdentifier> fArgs = fstmt.getInputParams();
                        for (int i = 0; i < fArgs.size(); ++i) {
                            if (!compressed2.contains(fop.getInput(i).getHopID())) continue;
                            fCompressed.add(((DataIdentifier)fArgs.get(i)).getName());
                        }
                        node.addChild(WorkloadAnalyzer.createWorkloadTree(fsb, prog, fCompressed, fStack));
                        fStack.remove(fop.getFunctionKey());
                        continue;
                    }
                    if (!HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE)) continue;
                    if (compressed.contains(hop.getName()) && !compressed2.contains(hop.getHopID())) {
                        compressed.remove(hop.getName());
                    }
                    if (compressed.contains(hop.getName()) || !compressed2.contains(hop.getHopID())) continue;
                    compressed.add(hop.getName());
                }
                Hop.resetVisitStatus(sb.getHops());
            }
        }
        node.setLineNumbers(sb.getBeginLine(), sb.getEndLine());
        return node;
    }

    private static void createWorkloadTree(Hop hop, DMLProgram prog, WTreeNode parent, Set<String> compressed, Set<String> fStack) {
        if (hop == null) {
            return;
        }
        hop.resetVisitStatus();
        WorkloadAnalyzer.createWorkloadTree(hop, prog, parent, compressed, new HashSet<Long>(), fStack);
        hop.resetVisitStatus();
    }

    private static void createWorkloadTree(Hop hop, DMLProgram prog, WTreeNode parent, Set<String> compressed, Set<Long> compressed2, Set<String> fStack) {
        if (hop == null || hop.isVisited()) {
            return;
        }
        for (Hop c : hop.getInput()) {
            WorkloadAnalyzer.createWorkloadTree(c, prog, parent, compressed, compressed2, fStack);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.OpOpData.TRANSIENTREAD) && compressed.contains(hop.getName())) {
            compressed2.add(hop.getHopID());
        }
        if (hop.getInput().stream().anyMatch(h -> compressed2.contains(h.getHopID()))) {
            if (!HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.OpOpData.TRANSIENTREAD, Types.OpOpData.TRANSIENTWRITE)) {
                parent.addCompressedOp(hop);
            }
            if (RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop) && hop.getDataType().isMatrix()) {
                compressed2.add(hop.getHopID());
            }
        }
        hop.setVisited();
    }
}

