/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.rest.RestStatus;

public class MLTaskManager {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskManager.class);
    private final Map<String, MLTaskCache> taskCaches;
    public static final int MAX_ML_TASK_PER_NODE = 10;
    private final Client client;
    private final MLIndicesHandler mlIndicesHandler;

    public MLTaskManager(Client client, MLIndicesHandler mlIndicesHandler) {
        this.client = client;
        this.mlIndicesHandler = mlIndicesHandler;
        this.taskCaches = new ConcurrentHashMap<String, MLTaskCache>();
    }

    public synchronized void add(MLTask mlTask) {
        String taskId = mlTask.getTaskId();
        if (this.contains(taskId)) {
            throw new IllegalArgumentException("Duplicate taskId");
        }
        this.taskCaches.put(taskId, new MLTaskCache(mlTask));
        log.info("add ML task to cache " + taskId);
    }

    public synchronized void updateTaskState(String taskId, MLTaskState state, boolean isAsyncTask) {
        this.updateTaskStateAndError(taskId, state, null, isAsyncTask);
    }

    public synchronized void updateTaskError(String taskId, String error, boolean isAsyncTask) {
        this.updateTaskStateAndError(taskId, null, error, isAsyncTask);
    }

    public synchronized void updateTaskStateAndError(String taskId, MLTaskState state, String error, boolean isAsyncTask) {
        if (!this.contains(taskId)) {
            throw new IllegalArgumentException("Task not found");
        }
        MLTask task = this.get(taskId);
        task.setState(state);
        task.setError(error);
        if (isAsyncTask) {
            HashMap<String, Object> updatedFields = new HashMap<String, Object>();
            if (state != null) {
                updatedFields.put("state", state.name());
            }
            if (error != null) {
                updatedFields.put("error", error);
            }
            this.updateMLTask(taskId, updatedFields, 0L);
        }
    }

    public boolean contains(String taskId) {
        return this.taskCaches.containsKey(taskId);
    }

    public void remove(String taskId) {
        if (this.contains(taskId)) {
            this.taskCaches.remove(taskId);
            log.info("remove ML task from cache " + taskId);
        }
    }

    public MLTask get(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId).getMlTask();
        }
        return null;
    }

    public int getRunningTaskCount() {
        int res = 0;
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getState() == null || mlTask.getState() != MLTaskState.RUNNING) continue;
            ++res;
        }
        return res;
    }

    public void clear() {
        this.taskCaches.clear();
    }

    public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
        this.mlIndicesHandler.initMLTaskIndex((ActionListener<Boolean>)ActionListener.wrap(indexCreated -> {
            if (!indexCreated.booleanValue()) {
                listener.onFailure((Exception)new RuntimeException("No response to create ML task index"));
                return;
            }
            IndexRequest request = new IndexRequest(".plugins-ml-task");
            try (XContentBuilder builder = XContentFactory.jsonBuilder();
                 ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                request.source(mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                this.client.index(request, ActionListener.runBefore((ActionListener)listener, () -> context.restore()));
            }
            catch (Exception e) {
                log.error("Failed to create AD task for " + mlTask.getFunctionName() + ", " + mlTask.getTaskType(), (Throwable)e);
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to create ML index", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    public void updateMLTask(String taskId, Map<String, Object> updatedFields, long timeoutInMillis) {
        this.updateMLTask(taskId, updatedFields, (ActionListener<UpdateResponse>)ActionListener.wrap(response -> {
            if (response.status() == RestStatus.OK) {
                log.debug("Updated ML task successfully: {}, task id: {}", (Object)response.status(), (Object)taskId);
            } else {
                log.error("Failed to update ML task {}, status: {}", (Object)taskId, (Object)response.status());
            }
        }, e -> log.error("Failed to update ML task: " + taskId, (Throwable)e)), timeoutInMillis);
    }

    public void updateMLTask(String taskId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener, long timeoutInMillis) {
        if (!this.taskCaches.containsKey(taskId)) {
            listener.onFailure((Exception)new RuntimeException("Can't find task"));
            return;
        }
        Semaphore semaphore = this.taskCaches.get(taskId).getUpdateTaskIndexSemaphore();
        try {
            if (semaphore != null && !semaphore.tryAcquire(timeoutInMillis, TimeUnit.MILLISECONDS)) {
                listener.onFailure((Exception)new RuntimeException("Other updating request not finished yet"));
                return;
            }
        }
        catch (InterruptedException e) {
            log.error("Failed to acquire semaphore for ML task " + taskId, (Throwable)e);
            listener.onFailure((Exception)e);
            return;
        }
        try {
            if (updatedFields == null || updatedFields.size() == 0) {
                listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                return;
            }
            UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", taskId);
            HashMap<String, Object> updatedContent = new HashMap<String, Object>();
            updatedContent.putAll(updatedFields);
            updatedContent.put("last_update_time", Instant.now().toEpochMilli());
            updateRequest.doc(updatedContent);
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            ActionListener actionListener = semaphore == null ? listener : ActionListener.runAfter(listener, () -> semaphore.release());
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.client.update(updateRequest, ActionListener.runBefore((ActionListener)actionListener, () -> context.restore()));
            }
            catch (Exception e) {
                actionListener.onFailure(e);
            }
        }
        catch (Exception e) {
            semaphore.release();
            log.error("Failed to update ML task " + taskId, (Throwable)e);
            listener.onFailure(e);
        }
    }
}

