/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;

public abstract class TextEmbeddingModel
extends DLModel {
    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLInputDataset inputDataSet = mlInput.getInputDataset();
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet)inputDataSet;
        ModelResultFilter resultFilter = textDocsInput.getResultFilter();
        for (String doc : textDocsInput.getDocs()) {
            Input input = new Input();
            input.add(doc);
            Output output = (Output)this.getPredictor().predict((Object)input);
            tensorOutputs.add(this.parseModelTensorOutput(output, resultFilter));
        }
        return new ModelTensorOutput(tensorOutputs);
    }

    @Override
    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
        Integer modelMaxLength;
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        String warmUpSentence = "warm up sentence";
        if (modelConfig != null && (modelMaxLength = textEmbeddingModelConfig.getModelMaxLength()) != null) {
            warmUpSentence = "sentence ".repeat(modelMaxLength);
        }
        Input input = new Input();
        input.add(warmUpSentence);
        predictor.predict((Object)input);
    }

    @Override
    public Map<String, Object> getArguments(MLModelConfig modelConfig) {
        HashMap<String, Object> arguments = new HashMap<String, Object>();
        if (modelConfig == null) {
            return arguments;
        }
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        Integer modelMaxLength = textEmbeddingModelConfig.getModelMaxLength();
        if (modelMaxLength != null) {
            arguments.put("modelMaxLength", modelMaxLength);
        }
        return arguments;
    }
}

