/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatementBlock;

public class RewriteForLoopVectorization
extends StatementBlockRewriteRule {
    private static final Types.OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new Types.OpOp2[]{Types.OpOp2.PLUS, Types.OpOp2.MULT, Types.OpOp2.MIN, Types.OpOp2.MAX};
    private static final Types.AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new Types.AggOp[]{Types.AggOp.SUM, Types.AggOp.PROD, Types.AggOp.MIN, Types.AggOp.MAX};

    @Override
    public boolean createsSplitDag() {
        return false;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
        if (sb instanceof ForStatementBlock) {
            StatementBlock csb;
            ForStatementBlock fsb = (ForStatementBlock)sb;
            ForStatement fs = (ForStatement)fsb.getStatement(0);
            Hop from = fsb.getFromHops();
            Hop to = fsb.getToHops();
            Hop incr = fsb.getIncrementHops();
            String iterVar = fsb.getIterPredicate().getIterVar().getName();
            if (!(fs.getBody() == null || fs.getBody().size() != 1 || (csb = fs.getBody().get(0)) instanceof WhileStatementBlock || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock)) {
                sb = RewriteForLoopVectorization.vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar);
                sb = RewriteForLoopVectorization.vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar);
                sb = RewriteForLoopVectorization.vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar);
                sb = RewriteForLoopVectorization.vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar);
            }
        }
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
        return sbs;
    }

    private static StatementBlock vectorizeScalarAggregate(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
        BinaryOp bop;
        Hop root;
        StatementBlock ret = sb;
        if (increment == null || !(increment instanceof LiteralOp) || ((LiteralOp)increment).getDoubleValue() != 1.0) {
            return ret;
        }
        boolean leftScalar = false;
        boolean rightScalar = false;
        boolean rowIx = false;
        if (csb.getHops() != null && csb.getHops().size() == 1 && (root = csb.getHops().get(0)).getDataType() == Types.DataType.SCALAR && root.getInput().get(0) instanceof BinaryOp) {
            IndexingOp ix;
            bop = (BinaryOp)root.getInput().get(0);
            Hop left = bop.getInput().get(0);
            Hop right = bop.getInput().get(1);
            if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && left instanceof DataOp && left.getDataType() == Types.DataType.SCALAR && root.getName().equals(left.getName()) && right instanceof UnaryOp && ((UnaryOp)right).getOp() == Types.OpOp1.CAST_AS_SCALAR && right.getInput().get(0) instanceof IndexingOp) {
                ix = (IndexingOp)right.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    leftScalar = true;
                    rowIx = false;
                }
            } else if (HopRewriteUtils.isValidOp(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && right instanceof DataOp && right.getDataType() == Types.DataType.SCALAR && root.getName().equals(right.getName()) && left instanceof UnaryOp && ((UnaryOp)left).getOp() == Types.OpOp1.CAST_AS_SCALAR && left.getInput().get(0) instanceof IndexingOp) {
                ix = (IndexingOp)left.getInput().get(0);
                if (ix.isRowLowerEqualsUpper() && ix.getInput().get(1) instanceof DataOp && ix.getInput().get(1).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = true;
                } else if (ix.isColLowerEqualsUpper() && ix.getInput().get(3) instanceof DataOp && ix.getInput().get(3).getName().equals(itervar)) {
                    rightScalar = true;
                    rowIx = false;
                }
            }
        }
        if (leftScalar || rightScalar) {
            root = csb.getHops().get(0);
            bop = (BinaryOp)root.getInput().get(0);
            Hop cast = bop.getInput().get(leftScalar ? 1 : 0);
            Hop ix = cast.getInput().get(0);
            int aggOpPos = HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
            Types.AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
            AggUnaryOp newSum = HopRewriteUtils.createAggUnaryOp(ix, aggOp, Types.Direction.RowCol);
            HopRewriteUtils.removeChildReference(cast, ix);
            HopRewriteUtils.removeChildReference(bop, cast);
            HopRewriteUtils.addChildReference(bop, newSum, leftScalar ? 1 : 0);
            int index1 = rowIx ? 1 : 3;
            int index2 = rowIx ? 2 : 4;
            HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index1), from, index1);
            HopRewriteUtils.replaceChildReference(ix, ix.getInput().get(index2), to, index2);
            if (rowIx) {
                ((IndexingOp)ix).setRowLowerEqualsUpper(false);
            } else {
                ((IndexingOp)ix).setColLowerEqualsUpper(false);
            }
            ix.refreshSizeInformation();
            ret = csb;
            LOG.debug((Object)"Applied vectorizeScalarSumForLoop.");
        }
        return ret;
    }

    private static StatementBlock vectorizeElementwiseBinary(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
        LeftIndexingOp lix;
        Hop root;
        StatementBlock ret = sb;
        if (!(increment instanceof LiteralOp) || ((LiteralOp)increment).getDoubleValue() != 1.0) {
            return ret;
        }
        boolean apply = false;
        boolean rowIx = false;
        if (csb.getHops() != null && csb.getHops().size() == 1 && (root = csb.getHops().get(0)).getDataType() == Types.DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            lix = (LeftIndexingOp)root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof BinaryOp && lixrhs.getInput().get(0) instanceof IndexingOp && lixrhs.getInput().get(1) instanceof IndexingOp && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp && lixrhs.getInput().get(1).getInput().get(0) instanceof DataOp) {
                IndexingOp rix0 = (IndexingOp)lixrhs.getInput().get(0);
                IndexingOp rix1 = (IndexingOp)lixrhs.getInput().get(1);
                if (lix.isRowLowerEqualsUpper() && rix0.isRowLowerEqualsUpper() && rix1.isRowLowerEqualsUpper() && lix.getInput().get(2).getName().equals(itervar) && rix0.getInput().get(1).getName().equals(itervar) && rix1.getInput().get(1).getName().equals(itervar)) {
                    apply = true;
                    rowIx = true;
                }
                if (lix.isColLowerEqualsUpper() && rix0.isColLowerEqualsUpper() && rix1.isColLowerEqualsUpper() && lix.getInput().get(4).getName().equals(itervar) && rix0.getInput().get(3).getName().equals(itervar) && rix1.getInput().get(3).getName().equals(itervar)) {
                    apply = true;
                    rowIx = false;
                }
            }
        }
        if (apply) {
            root = csb.getHops().get(0);
            lix = (LeftIndexingOp)root.getInput().get(0);
            BinaryOp bop = (BinaryOp)lix.getInput().get(1);
            IndexingOp rix0 = (IndexingOp)bop.getInput().get(0);
            IndexingOp rix1 = (IndexingOp)bop.getInput().get(1);
            int index1 = rowIx ? 2 : 4;
            int index2 = rowIx ? 3 : 5;
            HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
            HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
            HopRewriteUtils.replaceChildReference(rix0, rix0.getInput().get(index1 - 1), from, index1 - 1);
            HopRewriteUtils.replaceChildReference(rix0, rix0.getInput().get(index2 - 1), to, index2 - 1);
            HopRewriteUtils.replaceChildReference(rix1, rix1.getInput().get(index1 - 1), from, index1 - 1);
            HopRewriteUtils.replaceChildReference(rix1, rix1.getInput().get(index2 - 1), to, index2 - 1);
            RewriteForLoopVectorization.updateLeftAndRightIndexingSizes(rowIx, lix, rix0, rix1);
            bop.refreshSizeInformation();
            lix.refreshSizeInformation();
            ret = csb;
            LOG.debug((Object)"Applied vectorizeElementwiseBinaryForLoop.");
        }
        return ret;
    }

    private static StatementBlock vectorizeElementwiseUnary(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
        LeftIndexingOp lix;
        Hop root;
        StatementBlock ret = sb;
        if (!(increment instanceof LiteralOp) || ((LiteralOp)increment).getDoubleValue() != 1.0) {
            return ret;
        }
        boolean apply = false;
        boolean rowIx = false;
        if (csb.getHops() != null && csb.getHops().size() == 1 && (root = csb.getHops().get(0)).getDataType() == Types.DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            lix = (LeftIndexingOp)root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof UnaryOp && lixrhs.getInput().get(0) instanceof IndexingOp && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp) {
                boolean[] tmp = RewriteForLoopVectorization.checkLeftAndRightIndexing(lix, (IndexingOp)lixrhs.getInput().get(0), itervar);
                apply = tmp[0];
                rowIx = tmp[1];
            }
        }
        if (apply) {
            root = csb.getHops().get(0);
            lix = (LeftIndexingOp)root.getInput().get(0);
            UnaryOp uop = (UnaryOp)lix.getInput().get(1);
            IndexingOp rix = (IndexingOp)uop.getInput().get(0);
            int index1 = rowIx ? 2 : 4;
            int index2 = rowIx ? 3 : 5;
            HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
            HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
            HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
            HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
            RewriteForLoopVectorization.updateLeftAndRightIndexingSizes(rowIx, lix, rix);
            uop.refreshSizeInformation();
            lix.refreshSizeInformation();
            ret = csb;
            LOG.debug((Object)"Applied vectorizeElementwiseUnaryForLoop.");
        }
        return ret;
    }

    private static StatementBlock vectorizeIndexedCopy(StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar) {
        LeftIndexingOp lix;
        Hop root;
        StatementBlock ret = sb;
        if (!(increment instanceof LiteralOp) || ((LiteralOp)increment).getDoubleValue() != 1.0) {
            return ret;
        }
        boolean apply = false;
        boolean rowIx = false;
        if (csb.getHops() != null && csb.getHops().size() == 1 && (root = csb.getHops().get(0)).getDataType() == Types.DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp) {
            lix = (LeftIndexingOp)root.getInput().get(0);
            Hop lixlhs = lix.getInput().get(0);
            Hop lixrhs = lix.getInput().get(1);
            if (lixlhs instanceof DataOp && lixrhs instanceof IndexingOp && lixrhs.getInput().get(0) instanceof DataOp) {
                boolean[] tmp = RewriteForLoopVectorization.checkLeftAndRightIndexing(lix, (IndexingOp)lixrhs, itervar);
                apply = tmp[0];
                rowIx = tmp[1];
            }
        }
        if (apply) {
            root = csb.getHops().get(0);
            lix = (LeftIndexingOp)root.getInput().get(0);
            IndexingOp rix = (IndexingOp)lix.getInput().get(1);
            int index1 = rowIx ? 2 : 4;
            int index2 = rowIx ? 3 : 5;
            HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1);
            HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2);
            HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1 - 1), from, index1 - 1);
            HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2 - 1), to, index2 - 1);
            RewriteForLoopVectorization.updateLeftAndRightIndexingSizes(rowIx, lix, rix);
            ret = csb;
            LOG.debug((Object)"Applied vectorizeIndexedCopy.");
        }
        return ret;
    }

    private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp lix, IndexingOp rix, String itervar) {
        boolean[] ret = new boolean[2];
        if (lix.isRowLowerEqualsUpper() && rix.isRowLowerEqualsUpper() && lix.getInput().get(2).getName().equals(itervar) && rix.getInput().get(1).getName().equals(itervar)) {
            ret[0] = true;
            ret[1] = true;
        }
        if (lix.isColLowerEqualsUpper() && rix.isColLowerEqualsUpper() && lix.getInput().get(4).getName().equals(itervar) && rix.getInput().get(3).getName().equals(itervar)) {
            ret[0] = true;
            ret[1] = false;
        }
        return ret;
    }

    private static void updateLeftAndRightIndexingSizes(boolean rowIx, LeftIndexingOp lix, IndexingOp ... rix) {
        if (rowIx) {
            lix.setRowLowerEqualsUpper(false);
            for (IndexingOp rixi : rix) {
                rixi.setRowLowerEqualsUpper(false);
            }
        } else {
            lix.setColLowerEqualsUpper(false);
            for (IndexingOp rixi : rix) {
                rixi.setColLowerEqualsUpper(false);
            }
        }
        for (IndexingOp rixi : rix) {
            rixi.refreshSizeInformation();
        }
        lix.refreshSizeInformation();
    }
}

