/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.matrix.data.Pair;

public class TemplateCell
extends TemplateBase {
    private static final Types.AggOp[] SUPPORTED_AGG = new Types.AggOp[]{Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX};

    public TemplateCell() {
        super(TemplateBase.TemplateType.CELL);
    }

    public TemplateCell(TemplateBase.CloseType ctype) {
        super(TemplateBase.TemplateType.CELL, ctype);
    }

    public TemplateCell(TemplateBase.TemplateType type, TemplateBase.CloseType ctype) {
        super(type, ctype);
    }

    @Override
    public boolean open(Hop hop) {
        return hop.dimsKnown() && TemplateCell.isValidOperation(hop) && (hop.getDim1() != 1L || hop.getDim2() != 1L) || hop instanceof IndexingOp && hop.getInput().get(0).getDim2() >= 0L && (((IndexingOp)hop).isColLowerEqualsUpper() || hop.getDim2() == 1L) || HopRewriteUtils.isDataGenOpWithLiteralInputs(hop, Types.OpOpDG.SEQ) && HopRewriteUtils.hasOnlyUnaryBinaryParents(hop, true) || HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix() || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown();
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return !this.isClosed() && (TemplateCell.isValidOperation(hop) || HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) || HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1() == 1L && hop.getDim2() == 1L && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)) || HopRewriteUtils.isTransposeOperation(hop) && hop.getDim1() == 1L && hop.getDim2() > 1L) || HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix() || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown();
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return !this.isClosed() && (TemplateCell.isValidOperation(hop) || hop instanceof AggBinaryOp && hop.getInput().indexOf(input) == 0 && HopRewriteUtils.isTransposeOperation(input)) || HopRewriteUtils.isDataGenOpWithLiteralInputs(input, Types.OpOpDG.SEQ) && HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false) || HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix() || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown();
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        if (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) || HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1() == 1L && hop.getDim2() == 1L) {
            return TemplateBase.CloseType.CLOSED_VALID;
        }
        if (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp) {
            return TemplateBase.CloseType.CLOSED_INVALID;
        }
        return TemplateBase.CloseType.OPEN_VALID;
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        hop.resetVisitStatus();
        this.rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
        hop.resetVisitStatus();
        Hop[] sinHops = (Hop[])inHops.stream().filter(h -> !h.getDataType().isScalar() || !((CNode)tmp.get(h.getHopID())).isLiteral()).sorted(new HopInputComparator()).toArray(Hop[]::new);
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            inputs.add(tmp.get(in.getHopID()));
        }
        CNode output = tmp.get(hop.getHopID());
        CNodeCell tpl = new CNodeCell(inputs, output);
        tpl.setCellType(TemplateUtils.getCellType(hop));
        tpl.setAggOp(TemplateUtils.getAggOp(hop));
        tpl.setSparseSafe(this.isSparseSafe(Arrays.asList(hop), sinHops[0], Arrays.asList(tpl.getOutput()), Arrays.asList(tpl.getAggOp()), false));
        tpl.setContainsSeq(this.rContainsSeq(tpl.getOutput(), new HashSet<Long>()));
        tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
        tpl.setBeginLine(hop.getBeginLine());
        return new Pair<Hop[], CNodeTpl>(sinHops, tpl);
    }

    protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
        CNode cdata1;
        if (tmp.containsKey(hop.getHopID())) {
            return;
        }
        CPlanMemoTable.MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.CELL);
        if (me != null && me.type.isIn(TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.OUTER)) {
            CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
            tmp.put(hop.getHopID(), cdata);
            inHops.add(hop);
            return;
        }
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop c2 = hop.getInput().get(i);
            if (me != null && me.isPlanRef(i) && !(c2 instanceof DataOp) && (me.type != TemplateBase.TemplateType.MAGG || memo.contains(c2.getHopID(), TemplateBase.TemplateType.CELL))) {
                this.rConstructCplan(c2, memo, tmp, inHops, compileLiterals);
                continue;
            }
            if (me != null && (me.type == TemplateBase.TemplateType.MAGG || me.type == TemplateBase.TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && i == 0) {
                if (c2.getInput().get(0) instanceof DataOp) {
                    tmp.put(c2.getInput().get(0).getHopID(), TemplateUtils.createCNodeData(c2.getInput().get(0), compileLiterals));
                    inHops.add(c2.getInput().get(0));
                    continue;
                }
                this.rConstructCplan(c2.getInput().get(0), memo, tmp, inHops, compileLiterals);
                continue;
            }
            tmp.put(c2.getHopID(), TemplateUtils.createCNodeData(c2, compileLiterals));
            inHops.add(c2);
        }
        CNode out = null;
        if (hop instanceof UnaryOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            String primitiveOpName = ((UnaryOp)hop).getOp().name();
            out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.valueOf(primitiveOpName));
        } else if (hop instanceof BinaryOp) {
            BinaryOp bop = (BinaryOp)hop;
            CNode cdata12 = tmp.get(hop.getInput().get(0).getHopID());
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            String primitiveOpName = bop.getOp().name();
            cdata12 = TemplateUtils.wrapLookupIfNecessary(cdata12, hop.getInput().get(0));
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
            out = new CNodeBinary(cdata12, cdata2, CNodeBinary.BinType.valueOf(primitiveOpName));
        } else if (hop instanceof TernaryOp) {
            TernaryOp top = (TernaryOp)hop;
            CNode cdata13 = tmp.get(hop.getInput().get(0).getHopID());
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
            cdata13 = TemplateUtils.wrapLookupIfNecessary(cdata13, hop.getInput().get(0));
            cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
            cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
            out = new CNodeTernary(cdata13, cdata2, cdata3, CNodeTernary.TernaryType.valueOf(top.getOp().name()));
        } else if (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT)) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
            long c3 = hop.getInput().get(0).getDim2() / hop.getInput().get(1).getDim1();
            CNodeData cdata3 = TemplateUtils.createCNodeData(new LiteralOp(c3), true);
            out = new CNodeTernary(cdata1, cdata2, cdata3, CNodeTernary.TernaryType.valueOf(((DnnOp)hop).getOp().name()));
        } else if (HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS)) {
            String op = ((NaryOp)hop).getOp().name();
            CNode[] inputs = (CNode[])hop.getInput().stream().map(c -> TemplateUtils.wrapLookupIfNecessary((CNode)tmp.get(c.getHopID()), c)).toArray(CNode[]::new);
            out = new CNodeBinary(inputs[0], inputs[1], CNodeBinary.BinType.valueOf(op));
            for (int i = 2; i < hop.getInput().size(); ++i) {
                out = new CNodeBinary(out, inputs[i], CNodeBinary.BinType.valueOf(op));
            }
        } else if (hop instanceof ParameterizedBuiltinOp) {
            cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
            CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
            CNodeTernary.TernaryType ttype = cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN") ? CNodeTernary.TernaryType.REPLACE_NAN : CNodeTernary.TernaryType.REPLACE;
            out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
        } else if (hop instanceof IndexingOp) {
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            out = new CNodeTernary(cdata1, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), CNodeTernary.TernaryType.LOOKUP_RC1);
        } else if (HopRewriteUtils.isDataGenOp(hop, Types.OpOpDG.SEQ)) {
            CNodeData from = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam("from").getHopID()));
            CNodeData to = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam("to").getHopID()));
            CNodeData incr = TemplateUtils.getLiteral(tmp.get(((DataGenOp)hop).getParam("incr").getHopID()));
            if (Double.parseDouble(from.getVarname()) > Double.parseDouble(to.getVarname()) && Double.parseDouble(incr.getVarname()) > 0.0) {
                incr = TemplateUtils.createCNodeData(new LiteralOp("-" + incr.getVarname()), true);
            }
            out = new CNodeBinary(from, incr, CNodeBinary.BinType.SEQ_RIX);
        } else if (HopRewriteUtils.isTransposeOperation(hop)) {
            out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), hop, tmp, compileLiterals);
            if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class)) {
                TemplateUtils.rFlipVectorLookups(out);
            }
            if (out instanceof CNodeData && !inHops.contains(hop.getInput().get(0))) {
                inHops.add(hop.getInput().get(0));
            }
        } else if (hop instanceof AggUnaryOp) {
            out = tmp.get(hop.getInput().get(0).getHopID());
        } else if (hop instanceof AggBinaryOp) {
            if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
                cdata1 = tmp.get(hop.getInput().get(1).getHopID());
                if (TemplateUtils.isColVector(cdata1)) {
                    cdata1 = new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = new CNodeUnary(cdata1, CNodeUnary.UnaryType.POW2);
            } else {
                CNode cdata2;
                cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
                if (cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0))) {
                    inHops.add(hop.getInput().get(0).getInput().get(0));
                }
                if (TemplateUtils.isColVector(cdata1)) {
                    cdata1 = new CNodeUnary(cdata1, CNodeUnary.UnaryType.LOOKUP_R);
                }
                if (TemplateUtils.isColVector(cdata2 = tmp.get(hop.getInput().get(1).getHopID()))) {
                    cdata2 = new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = new CNodeBinary(cdata1, cdata2, CNodeBinary.BinType.MULT);
            }
        }
        if (out == null) {
            throw new HopsException(hop.getHopID() + " " + hop.getOpString());
        }
        tmp.put(hop.getHopID(), out);
    }

    protected static boolean isValidOperation(Hop hop) {
        boolean isTernaryIfElse;
        boolean isBinaryMatrixScalar = false;
        boolean isBinaryMatrixVector = false;
        boolean isBinaryMatrixMatrix = false;
        if (hop instanceof BinaryOp && hop.getDataType().isMatrix() && !((BinaryOp)hop).isOuter()) {
            Hop left = hop.getInput().get(0);
            Hop right = hop.getInput().get(1);
            isBinaryMatrixScalar = left.getDataType().isScalar() || right.getDataType().isScalar();
            isBinaryMatrixVector = hop.dimsKnown() && (left.getDataType().isMatrix() && TemplateUtils.isVectorOrScalar(right) || right.getDataType().isMatrix() && TemplateUtils.isVectorOrScalar(left));
            isBinaryMatrixMatrix = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && left.getDataType().isMatrix() && right.getDataType().isMatrix();
        }
        boolean isTernaryVectorScalarVector = false;
        boolean isTernaryMatrixScalarMatrixDense = false;
        boolean bl = isTernaryIfElse = HopRewriteUtils.isTernary(hop, Types.OpOp3.IFELSE) && hop.getDataType().isMatrix();
        if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, Types.DataType.MATRIX, Types.DataType.SCALAR, Types.DataType.MATRIX)) {
            Hop left = hop.getInput().get(0);
            Hop right = hop.getInput().get(2);
            isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
            isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
        }
        return hop.getDataType() == Types.DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse || hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp() == Types.ParamBuiltinOp.REPLACE);
    }

    protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<Types.AggOp> aggOps, boolean onlySum) {
        boolean ret = true;
        for (int i = 0; i < outputs.size() && ret; ++i) {
            Hop root = roots.get(i) instanceof AggUnaryOp || roots.get(i) instanceof AggBinaryOp ? roots.get(i).getInput().get(0) : roots.get(i);
            ret &= HopRewriteUtils.isBinarySparseSafe(root) && root.getInput().contains(mainInput) || HopRewriteUtils.isBinary(root, Types.OpOp2.DIV) && root.getInput().get(0) == mainInput || TemplateUtils.rIsSparseSafeOnly(outputs.get(i), CNodeBinary.BinType.MULT) && TemplateUtils.rContainsInput(outputs.get(i), mainInput.getHopID());
            if (!onlySum) continue;
            ret &= aggOps.get(i) == Types.AggOp.SUM || aggOps.get(i) == Types.AggOp.SUM_SQ;
        }
        return ret;
    }

    protected boolean rContainsSeq(CNode node, HashSet<Long> memo) {
        if (memo.contains(node.getID())) {
            return false;
        }
        boolean ret = TemplateUtils.isBinary(node, CNodeBinary.BinType.SEQ_RIX);
        for (CNode c : node.getInput()) {
            ret |= this.rContainsSeq(c, memo);
        }
        memo.add(node.getID());
        return ret;
    }

    public static class HopInputComparator
    implements Comparator<Hop> {
        private final Hop _driver;

        public HopInputComparator() {
            this(null);
        }

        public HopInputComparator(Hop driver) {
            this._driver = driver;
        }

        @Override
        public int compare(Hop h1, Hop h2) {
            long ncells2;
            long ncells1;
            long l = h1.isScalar() ? Long.MIN_VALUE : (ncells1 = h1.dimsKnown() ? h1.getLength() : Long.MAX_VALUE);
            long l2 = h2.isScalar() ? Long.MIN_VALUE : (ncells2 = h2.dimsKnown() ? h2.getLength() : Long.MAX_VALUE);
            if (ncells1 > ncells2 || h1 == this._driver) {
                return -1;
            }
            if (ncells1 < ncells2 || h2 == this._driver) {
                return 1;
            }
            if (h1.isScalar() && h2.isScalar()) {
                return Long.compare(h1.getHopID(), h2.getHopID());
            }
            return h1.dimsKnown(true) && h2.dimsKnown(true) && h1.getNnz() != h2.getNnz() && (HopRewriteUtils.isSparse(h1, 1.0) || HopRewriteUtils.isSparse(h2, 1.0)) ? Long.compare(h1.getNnz(), h2.getNnz()) : Long.compare(h1.getHopID(), h2.getHopID());
        }
    }
}

