/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import java.time.Instant;
import java.util.UUID;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.breaker.MLCircuitBreakerService;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportResponseHandler;

public class MLTrainingTaskRunner
extends MLTaskRunner<MLTrainingTaskRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTrainingTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLIndicesHandler mlIndicesHandler;
    private final MLInputDatasetHandler mlInputDatasetHandler;

    public MLTrainingTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLIndicesHandler mlIndicesHandler, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService) {
        super(mlTaskManager, mlStats, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlIndicesHandler = mlIndicesHandler;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/train";
    }

    @Override
    protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLTaskResponse::new);
    }

    @Override
    protected void executeTask(MLTrainingTaskRequest request, ActionListener<MLTaskResponse> listener) {
        MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        MLTask mlTask = MLTask.builder().taskType(MLTaskType.TRAINING).inputType(inputDataType).functionName(request.getMlInput().getFunctionName()).state(MLTaskState.CREATED).workerNode(this.clusterService.localNode().getId()).createTime(now).lastUpdateTime(now).async(request.isAsync()).build();
        if (request.isAsync()) {
            this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListener.wrap(r -> {
                String taskId = r.getId();
                mlTask.setTaskId(taskId);
                listener.onResponse((Object)new MLTaskResponse((MLOutput)new MLTrainingOutput(null, taskId, mlTask.getState().name())));
                ActionListener internalListener = ActionListener.wrap(res -> {
                    String modelId = ((MLTrainingOutput)res.getOutput()).getModelId();
                    log.info("ML model trained successfully, task id: {}, model id: {}", (Object)taskId, (Object)modelId);
                    mlTask.setModelId(modelId);
                    this.handleAsyncMLTaskComplete(mlTask);
                }, ex -> {
                    log.error("Failed to train ML model for task " + taskId);
                    this.handleAsyncMLTaskFailure(mlTask, (Exception)ex);
                });
                this.startTrainingTask(mlTask, request.getMlInput(), (ActionListener<MLTaskResponse>)internalListener);
            }, e -> {
                log.error("Failed to create ML task", (Throwable)e);
                listener.onFailure(e);
            }));
        } else {
            mlTask.setTaskId(UUID.randomUUID().toString());
            this.startTrainingTask(mlTask, request.getMlInput(), listener);
        }
    }

    private void startTrainingTask(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener) {
        ActionListener<MLTaskResponse> internalListener = this.wrappedCleanupListener(listener, mlTask.getTaskId());
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        this.mlTaskManager.add(mlTask);
        try {
            if (mlInput.getInputDataset().getInputDataType().equals((Object)MLInputDataType.SEARCH_QUERY)) {
                ActionListener dataFrameActionListener = ActionListener.wrap(dataFrame -> this.train(mlTask, mlInput.toBuilder().inputDataset((MLInputDataset)new DataFrameInputDataset(dataFrame)).build(), internalListener), e -> {
                    log.error("Failed to generate DataFrame from search query", (Throwable)e);
                    internalListener.onFailure(e);
                });
                this.mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), (ActionListener<DataFrame>)new ThreadedActionListener(log, this.threadPool, "OPENSEARCH_ML_TASK_THREAD_POOL", dataFrameActionListener, false));
            } else {
                this.threadPool.executor("OPENSEARCH_ML_TASK_THREAD_POOL").execute(() -> this.train(mlTask, mlInput, internalListener));
            }
        }
        catch (Exception e2) {
            log.error("Failed to train " + mlInput.getAlgorithm(), (Throwable)e2);
            internalListener.onFailure(e2);
        }
    }

    private void train(MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
        ActionListener listener = ActionListener.wrap(r -> actionListener.onResponse(r), e -> {
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), ActionName.TRAIN, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_FAILURE_COUNT).increment();
            actionListener.onFailure(e);
        });
        try {
            this.mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
            Model model = MLEngine.train((Input)mlInput);
            this.mlIndicesHandler.initModelIndexIfAbsent((ActionListener<Boolean>)ActionListener.wrap(indexCreated -> {
                if (!indexCreated.booleanValue()) {
                    listener.onFailure((Exception)new RuntimeException("No response to create ML task index"));
                    return;
                }
                MLModel mlModel = new MLModel(mlInput.getAlgorithm(), model);
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    ActionListener indexResponseListener = ActionListener.wrap(r -> {
                        log.info("Model data indexing done, result:{}, model id: {}", (Object)r.getResult(), (Object)r.getId());
                        this.mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_MODEL_COUNT).increment();
                        String returnedTaskId = mlTask.isAsync() ? mlTask.getTaskId() : null;
                        MLTrainingOutput output = new MLTrainingOutput(r.getId(), returnedTaskId, MLTaskState.COMPLETED.name());
                        listener.onResponse((Object)MLTaskResponse.builder().output((MLOutput)output).build());
                    }, e -> listener.onFailure(e));
                    IndexRequest indexRequest = new IndexRequest(".plugins-ml-model");
                    indexRequest.source(mlModel.toXContent(XContentBuilder.builder((XContent)XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                    indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                    this.client.index(indexRequest, ActionListener.runBefore((ActionListener)indexResponseListener, () -> context.restore()));
                }
                catch (Exception e2) {
                    log.error("Failed to save ML model", (Throwable)e2);
                    listener.onFailure(e2);
                }
            }, e -> {
                log.error("Failed to init ML model index", (Throwable)e);
                listener.onFailure(e);
            }));
        }
        catch (Exception e2) {
            log.error("Failed to train " + mlInput.getAlgorithm(), (Throwable)e2);
            listener.onFailure(e2);
        }
    }
}

