/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.resource.cost;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.resource.cost.CPCostUtils;
import org.apache.sysds.resource.cost.IOCostUtils;
import org.apache.sysds.resource.cost.RDDStats;
import org.apache.sysds.resource.cost.VarStats;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySketchSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryFrameFrameSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryFrameMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CumulativeAggregateSPInstruction;
import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.PMapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
import org.apache.sysds.runtime.instructions.spark.PmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.Tsmm2SPInstruction;
import org.apache.sysds.runtime.instructions.spark.TsmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.ZipmmSPInstruction;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class SparkCostUtils {
    public static double getReblockInstTime(String opcode, VarStats input, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        double readTime = IOCostUtils.getHadoopReadTime(input, executorMetrics);
        long sizeTextFile = OptimizerUtils.estimateSizeTextOutput(input.getM(), input.getN(), input.getNNZ(), (Types.FileFormat)((Object)input.fileInfo[1]));
        RDDStats textRdd = new RDDStats(sizeTextFile, -1);
        double shuffleTime = IOCostUtils.getSparkShuffleTime(textRdd, executorMetrics, false);
        double timeStage1 = readTime + shuffleTime;
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.Reblock, opcode, output, new VarStats[0]);
        double timeStage2 = SparkCostUtils.getCPUTime(nflop, textRdd.numPartitions, executorMetrics, output.rddStats, textRdd);
        return timeStage1 + timeStage2;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static double getRandInstTime(String opcode, int randType, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        if (opcode.equals("sample")) {
            throw new RuntimeException("Spark operation Rand with opcode sample is not supported yet");
        }
        if (opcode.equals("rand") || opcode.equals("frame")) {
            if (randType == 0) {
                return 0.0;
            }
            if (randType == 1) {
                return SparkCostUtils.getCPUTime(nflop *= output.getCells(), output.rddStats.numPartitions, executorMetrics, output.rddStats, new RDDStats[0]);
            }
            if (randType != 2) throw new RuntimeException("Unknown type of random instruction");
            return SparkCostUtils.getCPUTime(nflop *= output.getCells(), output.rddStats.numPartitions, executorMetrics, output.rddStats, new RDDStats[0]);
        }
        if (!opcode.equals("seq")) throw new DMLRuntimeException("Rand operation with opcode '" + opcode + "' is not supported by SystemDS");
        return SparkCostUtils.getCPUTime(nflop *= output.getCells(), output.rddStats.numPartitions, executorMetrics, output.rddStats, new RDDStats[0]);
    }

    public static double getUnaryInstTime(String opcode, VarStats input, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.Unary, opcode, output, input);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
        output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
        return mapTime;
    }

    public static double getAggUnaryInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        double shuffleTime;
        AggBinaryOp.SparkAggType aggType;
        String opcode = inst.getOpcode();
        AggBinaryOp.SparkAggType sparkAggType = aggType = inst instanceof AggregateUnarySPInstruction ? ((AggregateUnarySPInstruction)inst).getAggType() : ((AggregateUnarySketchSPInstruction)inst).getAggType();
        if (inst instanceof CumulativeAggregateSPInstruction) {
            shuffleTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
            output.rddStats.hashPartitioned = true;
        } else if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            output.rddStats.isCollected = true;
            shuffleTime = 0.0;
        } else if (aggType == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
            if (opcode.equals(Opcodes.UAKTRACE.toString())) {
                long diagonalBlockSize = OptimizerUtils.estimatePartitionedSizeExactSparsity((long)input.characteristics.getBlocksize() * input.getM(), (long)input.characteristics.getBlocksize(), (long)input.characteristics.getBlocksize(), input.getNNZ());
                RDDStats filteredRDD = new RDDStats(diagonalBlockSize, input.rddStats.numPartitions);
                shuffleTime = IOCostUtils.getSparkShuffleTime(filteredRDD, executorMetrics, true);
            } else {
                shuffleTime = IOCostUtils.getSparkShuffleTime(input.rddStats, executorMetrics, true);
            }
            output.rddStats.hashPartitioned = true;
            output.rddStats.numPartitions = input.rddStats.numPartitions;
        } else {
            output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
            output.rddStats.numPartitions = input.rddStats.numPartitions;
            shuffleTime = 0.0;
        }
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.AggregateUnary, opcode, output, input);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
        return shuffleTime + mapTime;
    }

    public static double getIndexingInstTime(IndexingSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        RDDStats[] rDDStatsArray;
        double dataTransmissionTime;
        String opcode = inst.getOpcode();
        if (opcode.equals(Opcodes.RIGHT_INDEX.toString())) {
            int blockSize = ConfigurationManager.getBlocksize();
            if (output.getM() <= (long)blockSize && output.getN() <= (long)blockSize) {
                dataTransmissionTime = IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
                output.rddStats.isCollected = true;
            } else {
                dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
            }
        } else {
            dataTransmissionTime = opcode.equals(Opcodes.LEFT_INDEX.toString()) ? IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true) : IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
        }
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.MatrixIndexing, opcode, output, new VarStats[0]);
        if (input2 == null) {
            RDDStats[] rDDStatsArray2 = new RDDStats[1];
            rDDStatsArray = rDDStatsArray2;
            rDDStatsArray2[0] = output.rddStats;
        } else {
            RDDStats[] rDDStatsArray3 = new RDDStats[2];
            rDDStatsArray3[0] = output.rddStats;
            rDDStatsArray = rDDStatsArray3;
            rDDStatsArray3[1] = output.rddStats;
        }
        RDDStats[] objectsToScan = rDDStatsArray;
        double mapTime = SparkCostUtils.getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, objectsToScan);
        return dataTransmissionTime + mapTime;
    }

    public static double getBinaryInstTime(SPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        double dataTransmissionTime;
        SPInstruction.SPType opType = inst.getSPInstructionType();
        String opcode = inst.getOpcode();
        if (opcode.startsWith("map")) {
            opcode = opcode.substring(3);
        }
        if (inst instanceof BinaryMatrixMatrixSPInstruction) {
            if (inst instanceof BinaryMatrixBVectorSPInstruction) {
                dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
                output.rddStats.numPartitions = input1.rddStats.numPartitions;
                output.rddStats.hashPartitioned = input1.rddStats.hashPartitioned;
            } else {
                dataTransmissionTime = IOCostUtils.getSparkShuffleWriteTime(input1.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleWriteTime(input2.rddStats, executorMetrics);
                if (input1.rddStats.hashPartitioned) {
                    output.rddStats.numPartitions = input1.rddStats.numPartitions;
                    dataTransmissionTime = !input2.rddStats.hashPartitioned || input1.rddStats.numPartitions != input2.rddStats.numPartitions ? (dataTransmissionTime += IOCostUtils.getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(input2.rddStats, executorMetrics)) : (dataTransmissionTime += IOCostUtils.getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadStaticTime(input2.rddStats, executorMetrics));
                } else if (input2.rddStats.hashPartitioned) {
                    output.rddStats.numPartitions = input2.rddStats.numPartitions;
                    dataTransmissionTime += IOCostUtils.getSparkShuffleReadStaticTime(input1.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(input2.rddStats, executorMetrics);
                } else {
                    output.rddStats.numPartitions = 2 * output.rddStats.numPartitions;
                    dataTransmissionTime += IOCostUtils.getSparkShuffleReadTime(input1.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(input2.rddStats, executorMetrics);
                }
                output.rddStats.hashPartitioned = true;
            }
        } else if (inst instanceof BinaryMatrixScalarSPInstruction) {
            dataTransmissionTime = 0.0;
            output.rddStats.hashPartitioned = input2.isScalar() ? input1.rddStats.hashPartitioned : input2.rddStats.hashPartitioned;
        } else {
            if (inst instanceof BinaryFrameMatrixSPInstruction || inst instanceof BinaryFrameFrameSPInstruction) {
                throw new RuntimeException("Handling binary instructions for frames not handled yet.");
            }
            throw new RuntimeException("Not supported binary instruction: " + inst);
        }
        long nflop = SparkCostUtils.getInstNFLOP(opType, opcode, output, input1, input2);
        double mapTime = SparkCostUtils.getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getAppendInstTime(AppendSPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        double dataTransmissionTime;
        if (inst instanceof AppendMSPInstruction) {
            dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
            output.rddStats.hashPartitioned = true;
        } else if (inst instanceof AppendRSPInstruction) {
            dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, false);
        } else if (inst instanceof AppendGAlignedSPInstruction) {
            dataTransmissionTime = 0.0;
        } else {
            dataTransmissionTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
            output.rddStats.hashPartitioned = true;
        }
        long nflop = SparkCostUtils.getInstNFLOP(inst.getSPInstructionType(), "append", output, input1, input2);
        double mapTime = SparkCostUtils.getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getReorgInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        double dataTransmissionTime;
        String opcode;
        switch (opcode = inst.getOpcode()) {
            case "rshape": {
                dataTransmissionTime = IOCostUtils.getSparkShuffleTime(input.rddStats, executorMetrics, true);
                output.rddStats.hashPartitioned = true;
                break;
            }
            case "r'": {
                dataTransmissionTime = 0.0;
                output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
                break;
            }
            case "rev": {
                dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
                output.rddStats.hashPartitioned = true;
                break;
            }
            case "rdiag": {
                dataTransmissionTime = 0.0;
                output.rddStats.numPartitions = input.rddStats.numPartitions;
                output.rddStats.hashPartitioned = input.rddStats.hashPartitioned;
                break;
            }
            default: {
                String ixretAsString = InstructionUtils.getInstructionParts(inst.getInstructionString())[4];
                boolean ixret = ixretAsString.equalsIgnoreCase("true");
                int shuffleFactor = ixret ? 2 : 4;
                dataTransmissionTime = IOCostUtils.getSparkShuffleWriteTime(output.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(output.rddStats, executorMetrics);
                dataTransmissionTime *= (double)shuffleFactor;
            }
        }
        long nflop = SparkCostUtils.getInstNFLOP(inst.getSPInstructionType(), opcode, output, new VarStats[0]);
        double mapTime = SparkCostUtils.getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getTSMMInstTime(UnarySPInstruction inst, VarStats input, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        double dataTransmissionTime;
        MMTSJ.MMTSJType type;
        Object opcode = inst.getOpcode();
        if (inst instanceof TsmmSPInstruction) {
            type = ((TsmmSPInstruction)inst).getMMTSJType();
            dataTransmissionTime = IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
            output.rddStats.isCollected = true;
        } else {
            type = ((Tsmm2SPInstruction)inst).getMMTSJType();
            long rowsRange = type == MMTSJ.MMTSJType.LEFT ? input.getM() : input.getM() - (long)input.characteristics.getBlocksize();
            long colsRange = type != MMTSJ.MMTSJType.LEFT ? input.getN() : input.getN() - (long)input.characteristics.getBlocksize();
            VarStats broadcast = new VarStats("tmp1", new MatrixCharacteristics(rowsRange, colsRange));
            broadcast.rddStats = new RDDStats(broadcast);
            dataTransmissionTime = IOCostUtils.getSparkCollectTime(broadcast.rddStats, driverMetrics, executorMetrics);
            dataTransmissionTime += IOCostUtils.getSparkBroadcastTime(broadcast, driverMetrics, executorMetrics);
            dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
        }
        opcode = (String)opcode + (type.isLeft() ? "_left" : "_right");
        long nflop = SparkCostUtils.getInstNFLOP(inst.getSPInstructionType(), (String)opcode, output, input);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getCentralMomentInstTime(CentralMomentSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        RDDStats[] rDDStatsArray;
        CMOperator.AggregateOperationTypes opType = ((CMOperator)inst.getOperator()).getAggOpType();
        String opcode = inst.getOpcode() + "_" + opType.name().toLowerCase();
        double dataTransmissionTime = 0.0;
        if (weights != null) {
            dataTransmissionTime = IOCostUtils.getSparkShuffleWriteTime(weights.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(weights.rddStats, executorMetrics);
        }
        output.rddStats.isCollected = true;
        if (weights == null) {
            RDDStats[] rDDStatsArray2 = new RDDStats[1];
            rDDStatsArray = rDDStatsArray2;
            rDDStatsArray2[0] = input.rddStats;
        } else {
            RDDStats[] rDDStatsArray3 = new RDDStats[2];
            rDDStatsArray3[0] = input.rddStats;
            rDDStatsArray = rDDStatsArray3;
            rDDStatsArray3[1] = weights.rddStats;
        }
        RDDStats[] RDDInputs = rDDStatsArray;
        long nflop = SparkCostUtils.getInstNFLOP(inst.getSPInstructionType(), opcode, output, input);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs);
        return dataTransmissionTime + mapTime;
    }

    public static double getCastInstTime(CastSPInstruction inst, VarStats input, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        double shuffleTime = 0.0;
        if (input.getN() > (long)input.characteristics.getBlocksize()) {
            shuffleTime = IOCostUtils.getSparkShuffleWriteTime(input.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(input.rddStats, executorMetrics);
            output.rddStats.hashPartitioned = true;
        }
        long nflop = SparkCostUtils.getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, input.rddStats);
        return shuffleTime + mapTime;
    }

    public static double getQSortInstTime(QuantileSortSPInstruction inst, VarStats input, VarStats weights, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        RDDStats[] rDDStatsArray;
        Object opcode = inst.getOpcode();
        double shuffleTime = 0.0;
        if (weights != null) {
            opcode = (String)opcode + "_wts";
            shuffleTime += IOCostUtils.getSparkShuffleWriteTime(weights.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(weights.rddStats, executorMetrics);
        }
        shuffleTime += IOCostUtils.getSparkShuffleWriteTime(output.rddStats, executorMetrics) + IOCostUtils.getSparkShuffleReadTime(output.rddStats, executorMetrics);
        output.rddStats.hashPartitioned = true;
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.QSort, (String)opcode, output, input, weights);
        if (weights == null) {
            RDDStats[] rDDStatsArray2 = new RDDStats[1];
            rDDStatsArray = rDDStatsArray2;
            rDDStatsArray2[0] = input.rddStats;
        } else {
            RDDStats[] rDDStatsArray3 = new RDDStats[2];
            rDDStatsArray3[0] = input.rddStats;
            rDDStatsArray = rDDStatsArray3;
            rDDStatsArray3[1] = weights.rddStats;
        }
        RDDStats[] RDDInputs = rDDStatsArray;
        double mapTime = SparkCostUtils.getCPUTime(nflop, input.rddStats.numPartitions, executorMetrics, output.rddStats, RDDInputs);
        return shuffleTime + mapTime;
    }

    public static double getMatMulInstTime(BinarySPInstruction inst, VarStats input1, VarStats input2, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        int numPartitionsForMapping;
        double dataTransmissionTime;
        if (inst instanceof CpmmSPInstruction) {
            CpmmSPInstruction cpmminst = (CpmmSPInstruction)inst;
            AggBinaryOp.SparkAggType aggType = cpmminst.getAggType();
            long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize;
            RDDStats joinedRDD = new RDDStats(joinedSize, -1);
            dataTransmissionTime = IOCostUtils.getSparkShuffleTime(joinedRDD, executorMetrics, true);
            if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
                dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
                output.rddStats.isCollected = true;
            } else {
                dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
                output.rddStats.hashPartitioned = true;
            }
            numPartitionsForMapping = joinedRDD.numPartitions;
        } else if (inst instanceof RmmSPInstruction) {
            long joinedSize = input1.rddStats.distributedSize + input2.rddStats.distributedSize;
            RDDStats joinedRDD = new RDDStats(joinedSize, -1);
            dataTransmissionTime = IOCostUtils.getSparkShuffleTime(joinedRDD, executorMetrics, true);
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, false);
            output.rddStats.hashPartitioned = true;
            numPartitionsForMapping = joinedRDD.numPartitions;
        } else if (inst instanceof MapmmSPInstruction) {
            dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
            MapmmSPInstruction mapmminst = (MapmmSPInstruction)inst;
            AggBinaryOp.SparkAggType aggType = mapmminst.getAggType();
            if (aggType == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
                dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
                output.rddStats.isCollected = true;
            } else {
                dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
                output.rddStats.hashPartitioned = true;
            }
            numPartitionsForMapping = input1.rddStats.numPartitions;
        } else if (inst instanceof PmmSPInstruction) {
            dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
            output.rddStats.numPartitions = input1.rddStats.numPartitions;
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
            output.rddStats.hashPartitioned = true;
            numPartitionsForMapping = input1.rddStats.numPartitions;
        } else if (inst instanceof ZipmmSPInstruction) {
            dataTransmissionTime = IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, false);
            dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics);
            numPartitionsForMapping = input1.rddStats.numPartitions;
            output.rddStats.isCollected = true;
        } else {
            if (inst instanceof PMapmmSPInstruction) {
                throw new RuntimeException("PMapmmSPInstruction instruction is still experimental and not supported yet");
            }
            throw new RuntimeException(inst.getClass().getName() + " instruction is not handled by the current method");
        }
        long nflop = SparkCostUtils.getInstNFLOP(inst.getSPInstructionType(), inst.getOpcode(), output, input1, input2);
        double mapTime = inst instanceof MapmmSPInstruction || inst instanceof PmmSPInstruction ? SparkCostUtils.getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats) : SparkCostUtils.getCPUTime(nflop, numPartitionsForMapping, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getMatMulChainInstTime(MapmmChainSPInstruction inst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        double dataTransmissionTime = 0.0;
        if (input3 != null) {
            dataTransmissionTime += IOCostUtils.getSparkBroadcastTime(input3, driverMetrics, executorMetrics);
        }
        dataTransmissionTime += IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
        output.rddStats.isCollected = true;
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.MAPMMCHAIN, inst.getOpcode(), output, input1, input2);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats);
        return (dataTransmissionTime += IOCostUtils.getSparkCollectTime(output.rddStats, driverMetrics, executorMetrics)) + mapTime;
    }

    public static double getCtableInstTime(CtableSPInstruction tableInst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        double shuffleTime;
        String opcode = tableInst.getOpcode();
        if (opcode.equals(Opcodes.CTABLEEXPAND.toString()) || !input2.isScalar() && input3.isScalar()) {
            shuffleTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
        } else if (input2.isScalar() && input3.isScalar()) {
            shuffleTime = 0.0;
        } else if (input2.isScalar() && !input3.isScalar()) {
            shuffleTime = IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, true);
        } else {
            shuffleTime = IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, true);
            shuffleTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, true);
        }
        output.rddStats.hashPartitioned = true;
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.Ctable, opcode, output, input1, input2, input3);
        double mapTime = SparkCostUtils.getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats, input2.rddStats, input3.rddStats);
        return (shuffleTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true)) + mapTime;
    }

    public static double getParameterizedBuiltinInstTime(ParameterizedBuiltinSPInstruction paramInst, VarStats input1, VarStats input2, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        double dataTransmissionTime;
        String opcode;
        switch (opcode = paramInst.getOpcode()) {
            case "rmempty": {
                dataTransmissionTime = input2.rddStats == null ? IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics) : IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics, true);
                dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
                break;
            }
            case "contains": {
                dataTransmissionTime = input2.isScalar() ? 0.0 : IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics);
                output.rddStats.isCollected = true;
                break;
            }
            case "replace": 
            case "lowertri": 
            case "uppertri": {
                dataTransmissionTime = 0.0;
                break;
            }
            default: {
                throw new RuntimeException("Spark operation ParameterizedBuiltin with opcode " + opcode + " is not supported yet");
            }
        }
        long nflop = SparkCostUtils.getInstNFLOP(paramInst.getSPInstructionType(), opcode, output, input1);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getTernaryInstTime(TernarySPInstruction tInst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOCostUtils.IOMetrics executorMetrics) {
        RDDStats[] inputRddStats = new RDDStats[]{};
        double dataTransmissionTime = 0.0;
        if (!input1.isScalar() && !input2.isScalar()) {
            inputRddStats = new RDDStats[]{input1.rddStats, input2.rddStats};
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics, input1.rddStats.hashPartitioned);
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, input2.rddStats.hashPartitioned);
        } else if (!input1.isScalar() && !input3.isScalar()) {
            inputRddStats = new RDDStats[]{input1.rddStats, input3.rddStats};
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics, input1.rddStats.hashPartitioned);
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, input3.rddStats.hashPartitioned);
        } else if (!input2.isScalar() || !input3.isScalar()) {
            inputRddStats = new RDDStats[]{input2.rddStats, input3.rddStats};
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, input2.rddStats.hashPartitioned);
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, input3.rddStats.hashPartitioned);
        } else if (!(input1.isScalar() || input2.isScalar() || input3.isScalar())) {
            inputRddStats = new RDDStats[]{input1.rddStats, input2.rddStats, input3.rddStats};
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input1.rddStats, executorMetrics, input1.rddStats.hashPartitioned);
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input2.rddStats, executorMetrics, input2.rddStats.hashPartitioned);
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(input3.rddStats, executorMetrics, input3.rddStats.hashPartitioned);
        }
        long nflop = SparkCostUtils.getInstNFLOP(SPInstruction.SPType.Ternary, tInst.getOpcode(), output, input1, input2, input3);
        double mapTime = SparkCostUtils.getCPUTime(nflop, output.rddStats.numPartitions, executorMetrics, output.rddStats, inputRddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getQuaternaryInstTime(QuaternarySPInstruction quatInst, VarStats input1, VarStats input2, VarStats input3, VarStats output, IOCostUtils.IOMetrics driverMetrics, IOCostUtils.IOMetrics executorMetrics) {
        String opcode = quatInst.getOpcode();
        if (opcode.startsWith("red")) {
            throw new RuntimeException("Spark Quaternary reduce-operations are not supported yet");
        }
        double dataTransmissionTime = IOCostUtils.getSparkBroadcastTime(input2, driverMetrics, executorMetrics) + IOCostUtils.getSparkBroadcastTime(input3, driverMetrics, executorMetrics);
        if (opcode.equals("mapwsloss") || opcode.equals("mapwcemm")) {
            output.rddStats.isCollected = true;
        } else if (opcode.equals("mapwdivmm")) {
            dataTransmissionTime += IOCostUtils.getSparkShuffleTime(output.rddStats, executorMetrics, true);
        }
        long nflop = SparkCostUtils.getInstNFLOP(quatInst.getSPInstructionType(), opcode, output, input1);
        double mapTime = SparkCostUtils.getCPUTime(nflop, input1.rddStats.numPartitions, executorMetrics, output.rddStats, input1.rddStats);
        return dataTransmissionTime + mapTime;
    }

    public static double getCPUTime(long nflop, int numPartitions, IOCostUtils.IOMetrics executorMetrics, RDDStats output, RDDStats ... inputs) {
        double memScanTime = 0.0;
        for (RDDStats input : inputs) {
            if (input == null) continue;
            memScanTime += IOCostUtils.getMemReadTime(input, executorMetrics);
        }
        double numWaves = Math.ceil((double)numPartitions / (double)SparkExecutionContext.getDefaultParallelism(false));
        double scaledNFLOP = numWaves * (double)nflop / (double)numPartitions;
        double cpuComputationTime = scaledNFLOP / (double)executorMetrics.cpuFLOPS;
        double memWriteTime = output != null ? IOCostUtils.getMemWriteTime(output, executorMetrics) : 0.0;
        return Math.max(memScanTime, cpuComputationTime) + memWriteTime;
    }

    public static void assignOutputRDDStats(SPInstruction inst, VarStats output, VarStats ... inputs) {
        if (!output.isScalar()) {
            SPInstruction.SPType instType = inst.getSPInstructionType();
            String opcode = inst.getOpcode();
            if (output.getCells() < 0L) {
                SparkCostUtils.inferStats(instType, opcode, output, inputs);
            }
        }
        output.rddStats = new RDDStats(output);
    }

    private static void inferStats(SPInstruction.SPType instType, String opcode, VarStats output, VarStats ... inputs) {
        switch (instType) {
            case Unary: 
            case Builtin: {
                CPCostUtils.inferStats(CPInstruction.CPType.Unary, opcode, output, inputs);
                break;
            }
            case AggregateUnary: 
            case AggregateUnarySketch: {
                CPCostUtils.inferStats(CPInstruction.CPType.AggregateUnary, opcode, output, inputs);
            }
            case MatrixIndexing: {
                CPCostUtils.inferStats(CPInstruction.CPType.MatrixIndexing, opcode, output, inputs);
                break;
            }
            case Reorg: {
                CPCostUtils.inferStats(CPInstruction.CPType.Reorg, opcode, output, inputs);
                break;
            }
            case Binary: {
                CPCostUtils.inferStats(CPInstruction.CPType.Binary, opcode, output, inputs);
                break;
            }
            case CPMM: 
            case RMM: 
            case MAPMM: 
            case PMM: 
            case ZIPMM: {
                CPCostUtils.inferStats(CPInstruction.CPType.AggregateBinary, opcode, output, inputs);
                break;
            }
            case ParameterizedBuiltin: {
                CPCostUtils.inferStats(CPInstruction.CPType.ParameterizedBuiltin, opcode, output, inputs);
                break;
            }
            case Rand: {
                CPCostUtils.inferStats(CPInstruction.CPType.Rand, opcode, output, inputs);
                break;
            }
            case Ctable: {
                CPCostUtils.inferStats(CPInstruction.CPType.Ctable, opcode, output, inputs);
                break;
            }
            default: {
                throw new RuntimeException("Operation of type " + instType + " with opcode '" + opcode + "' has no formula for inferring dimensions");
            }
        }
        if (output.getCells() < 0L) {
            throw new RuntimeException("Operation of type " + instType + " with opcode '" + opcode + "' has incomplete formula for inferring dimensions");
        }
    }

    private static long getInstNFLOP(SPInstruction.SPType instructionType, String opcode, VarStats output, VarStats ... inputs) {
        opcode = opcode.toLowerCase();
        switch (instructionType) {
            case Reblock: {
                if (opcode.startsWith("libsvm")) {
                    return output.getCellsWithSparsity();
                }
                return output.getCells();
            }
            case Unary: 
            case Builtin: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Unary, opcode, output, inputs);
            }
            case AggregateUnary: 
            case AggregateUnarySketch: {
                switch (opcode) {
                    case "uacdr": 
                    case "uacdc": {
                        throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS");
                    }
                }
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.AggregateUnary, opcode, output, inputs);
            }
            case CumsumAggregate: {
                double costs;
                switch (opcode) {
                    case "ucumack+": 
                    case "ucumac*": 
                    case "ucumacmin": 
                    case "ucumacmax": {
                        costs = 1.0;
                        break;
                    }
                    case "ucumac+*": {
                        costs = 2.0;
                        break;
                    }
                    default: {
                        throw new DMLRuntimeException(opcode + " opcode is not implemented by SystemDS");
                    }
                }
                return (long)(costs * (double)inputs[0].getCells() + costs * (double)output.getN());
            }
            case TSMM: 
            case TSMM2: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.MMTSJ, opcode, output, inputs);
            }
            case Reorg: 
            case MatrixReshape: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Reorg, opcode, output, inputs);
            }
            case MatrixIndexing: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.MatrixIndexing, opcode, output, inputs);
            }
            case Cast: {
                return output.getCellsWithSparsity();
            }
            case QSort: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.QSort, opcode, output, inputs);
            }
            case CentralMoment: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.CentralMoment, opcode, output, inputs);
            }
            case UaggOuterChain: 
            case Dnn: {
                throw new RuntimeException("Spark operation type'" + instructionType + "' is not supported yet");
            }
            case Binary: {
                switch (opcode) {
                    case "+*": 
                    case "-*": {
                        throw new RuntimeException("Spark operation with opcode '" + opcode + "' is not supported yet");
                    }
                }
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Binary, opcode, output, inputs);
            }
            case CPMM: 
            case RMM: 
            case MAPMM: 
            case PMM: 
            case ZIPMM: 
            case PMAPMM: {
                return 2L * CPCostUtils.getInstNFLOP(CPInstruction.CPType.AggregateBinary, opcode, output, inputs);
            }
            case MAPMMCHAIN: {
                return 2L * inputs[0].getCells() * inputs[0].getN() + 2L * inputs[0].getM() * inputs[1].getN() + 2L * inputs[0].getCellsWithSparsity() * inputs[1].getN() + inputs[1].getM() * output.getM();
            }
            case BinUaggChain: {
                break;
            }
            case MAppend: 
            case RAppend: 
            case GAppend: 
            case GAlignedAppend: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Append, opcode, output, inputs);
            }
            case BuiltinNary: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.BuiltinNary, opcode, output, inputs);
            }
            case Ctable: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Ctable, opcode, output, inputs);
            }
            case ParameterizedBuiltin: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.ParameterizedBuiltin, opcode, output, inputs);
            }
            case Ternary: {
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Ternary, opcode, output, new VarStats[0]);
            }
            case Quaternary: {
                String opcodeRoot = opcode.substring(3);
                return CPCostUtils.getInstNFLOP(CPInstruction.CPType.Quaternary, opcodeRoot, output, inputs);
            }
            default: {
                throw new DMLRuntimeException("Spark operation type'" + instructionType + "' is not supported by SystemDS");
            }
        }
        throw new RuntimeException();
    }
}

