/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Translator;
import java.io.File;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.commons.io.FileUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.utils.ZipUtils;

public abstract class DLModelExecute
implements MLExecutable {
    @Generated
    private static final Logger log = LogManager.getLogger(DLModelExecute.class);
    public static final String MODEL_ZIP_FILE = "model_zip_file";
    public static final String MODEL_HELPER = "model_helper";
    public static final String ML_ENGINE = "ml_engine";
    protected ModelHelper modelHelper;
    protected MLEngine mlEngine;
    protected String modelId;
    protected Predictor<float[][], Output>[] predictors;
    protected ZooModel[] models;
    protected Device[] devices;
    protected AtomicInteger nextDevice = new AtomicInteger(0);

    @Override
    public abstract org.opensearch.ml.common.output.Output execute(Input var1) throws ExecuteException;

    protected Predictor<float[][], Output> getPredictor() {
        int currentDevice = this.nextDevice.getAndIncrement();
        if (currentDevice > this.devices.length - 1) {
            this.nextDevice.set((currentDevice %= this.devices.length) + 1);
        }
        return this.predictors[currentDevice];
    }

    @Override
    public void initModel(MLModel model, Map<String, Object> params) {
        if (Objects.requireNonNull(model.getModelFormat()) != MLModelFormat.TORCH_SCRIPT) {
            throw new IllegalArgumentException("unsupported engine");
        }
        String engine = "PyTorch";
        File modelZipFile = (File)params.get(MODEL_ZIP_FILE);
        this.modelHelper = (ModelHelper)params.get(MODEL_HELPER);
        this.mlEngine = (MLEngine)params.get(ML_ENGINE);
        if (modelZipFile == null) {
            throw new IllegalArgumentException("model file is null");
        }
        if (this.modelHelper == null) {
            throw new IllegalArgumentException("model helper is null");
        }
        if (this.mlEngine == null) {
            throw new IllegalArgumentException("ML engine is null");
        }
        this.modelId = model.getModelId();
        if (this.modelId == null) {
            throw new IllegalArgumentException("model id is null");
        }
        if (model.getAlgorithm() != FunctionName.METRICS_CORRELATION) {
            throw new IllegalArgumentException("wrong function name");
        }
        this.loadModel(modelZipFile, this.modelId, model.getName(), model.getVersion(), engine);
    }

    @Override
    public void close() {
        if (this.modelHelper != null && this.modelId != null) {
            this.modelHelper.deleteFileCache(this.modelId);
            if (this.predictors != null) {
                this.closePredictors(this.predictors);
                this.predictors = null;
            }
            if (this.models != null) {
                this.closeModels(this.models);
                this.models = null;
            }
        }
    }

    public abstract Translator getTranslator();

    private void loadModel(File modelZipFile, String modelId, String modelName, String version, String engine) {
        try {
            ArrayList predictorList = new ArrayList();
            ArrayList modelList = new ArrayList();
            AccessController.doPrivileged(() -> {
                ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
                try {
                    System.setProperty("PYTORCH_PRECXX11", "true");
                    System.setProperty("DJL_CACHE_DIR", this.mlEngine.getMlCachePath().toAbsolutePath().toString());
                    System.setProperty("java.library.path", this.mlEngine.getMlCachePath().toAbsolutePath().toString());
                    System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
                    System.setProperty("ai.djl.pytorch.num_threads", "1");
                    Thread.currentThread().setContextClassLoader(Model.class.getClassLoader());
                    Path modelPath = this.mlEngine.getModelCachePath(modelId, modelName, version);
                    File pathFile = new File(modelPath.toUri());
                    if (pathFile.exists()) {
                        FileUtils.deleteDirectory((File)pathFile);
                    }
                    ZipUtils.unzip(modelZipFile, modelPath);
                    boolean findModelFile = false;
                    for (File file : Objects.requireNonNull(pathFile.listFiles())) {
                        String name = file.getName();
                        if (!name.endsWith(".pt") && !name.endsWith(".onnx")) continue;
                        if (findModelFile) {
                            throw new IllegalArgumentException("found multiple models");
                        }
                        findModelFile = true;
                        int dotIndex = name.lastIndexOf(".");
                        String suffix = name.substring(dotIndex);
                        String targetModelFileName = modelPath.getFileName().toString();
                        if (targetModelFileName.equals(name.substring(0, dotIndex))) continue;
                        file.renameTo(new File(modelPath.resolve(targetModelFileName + suffix).toUri()));
                    }
                    this.devices = Engine.getEngine((String)engine).getDevices();
                    for (int i = 0; i < this.devices.length; ++i) {
                        log.debug("Deploy model {} on device {}: {}", (Object)modelId, (Object)i, (Object)this.devices[i]);
                        Criteria.Builder criteriaBuilder = Criteria.builder().setTypes(ai.djl.modality.Input.class, Output.class).optApplication(Application.UNDEFINED).optEngine(engine).optDevice(this.devices[i]).optModelPath(modelPath);
                        Translator translator = this.getTranslator();
                        if (translator != null) {
                            criteriaBuilder.optTranslator(translator);
                        }
                        Criteria criteria = criteriaBuilder.build();
                        ZooModel model = criteria.loadModel();
                        Predictor predictor = model.newPredictor();
                        predictorList.add(predictor);
                        modelList.add(model);
                    }
                    if (predictorList.size() > 0) {
                        this.predictors = predictorList.toArray(new Predictor[0]);
                        predictorList.clear();
                    }
                    if (modelList.size() > 0) {
                        this.models = modelList.toArray(new ZooModel[0]);
                        modelList.clear();
                    }
                    log.info("Model {} is successfully deployed on {} devices", (Object)modelId, (Object)this.devices.length);
                    Void void_ = null;
                    return void_;
                }
                catch (Throwable e) {
                    String errorMessage = "Failed to deploy model " + modelId;
                    log.error(errorMessage, e);
                    this.close();
                    if (predictorList.size() > 0) {
                        this.closePredictors(predictorList.toArray(new Predictor[0]));
                        predictorList.clear();
                    }
                    if (modelList.size() > 0) {
                        this.closeModels(modelList.toArray(new ZooModel[0]));
                        modelList.clear();
                    }
                    throw new MLException(errorMessage, e);
                }
                finally {
                    org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly(this.mlEngine.getDeployModelPath(modelId));
                    Thread.currentThread().setContextClassLoader(contextClassLoader);
                }
            });
        }
        catch (PrivilegedActionException e) {
            String errorMsg = "Failed to deploy model " + modelId;
            log.error(errorMsg, (Throwable)e);
            throw new MLException(errorMsg, (Throwable)e);
        }
    }

    protected void closePredictors(Predictor[] predictors) {
        log.debug("will close {} predictor for model {}", (Object)predictors.length, (Object)this.modelId);
        for (Predictor predictor : predictors) {
            predictor.close();
        }
    }

    protected void closeModels(ZooModel[] models) {
        log.debug("will close {} zoo model for model {}", (Object)models.length, (Object)this.modelId);
        for (ZooModel model : models) {
            model.close();
        }
    }
}

