/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.JsonArray;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BooleanSupplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeSearchResponse;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

public class GenerativeQAResponseProcessor
extends AbstractProcessor
implements SearchResponseProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(GenerativeQAResponseProcessor.class);
    private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;
    private static final int DEFAULT_PROCESSOR_TIME_IN_SECONDS = 30;
    private final String llmModel;
    private final List<String> contextFields;
    private final String systemPrompt;
    private final String userInstructions;
    private ConversationalMemoryClient memoryClient;
    private Llm llm;
    private final BooleanSupplier featureFlagSupplier;

    protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, Llm llm, String llmModel, List<String> contextFields, String systemPrompt, String userInstructions, BooleanSupplier supplier) {
        super(tag, description, ignoreFailure);
        this.llmModel = llmModel;
        this.contextFields = contextFields;
        this.systemPrompt = systemPrompt;
        this.userInstructions = userInstructions;
        this.llm = llm;
        this.memoryClient = new ConversationalMemoryClient(client);
        this.featureFlagSupplier = supplier;
    }

    public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
        String llmModel;
        log.info("Entering processResponse.");
        if (!this.featureFlagSupplier.getAsBoolean()) {
            throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
        }
        GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
        Integer timeout = params.getTimeout();
        if (timeout == null || timeout == -1) {
            timeout = 30;
        }
        log.info("Timeout for this request: {} seconds.", (Object)timeout);
        String llmQuestion = params.getLlmQuestion();
        String string = llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
        if (llmModel == null) {
            throw new IllegalArgumentException("llm_model cannot be null.");
        }
        String conversationId = params.getConversationId();
        log.info("LLM question: {}, LLM model {}, conversation id: {}", (Object)llmQuestion, (Object)llmModel, (Object)conversationId);
        Instant start = Instant.now();
        Integer interactionSize = params.getInteractionSize();
        if (interactionSize == null || interactionSize == -1) {
            interactionSize = 10;
        }
        log.info("Using interaction size of {}", (Object)interactionSize);
        List<Interaction> chatHistory = conversationId == null ? Collections.emptyList() : this.memoryClient.getInteractions(conversationId, interactionSize);
        log.info("Retrieved chat history. ({})", (Object)this.getDuration(start));
        Integer topN = params.getContextSize();
        if (topN == null) {
            topN = -1;
        }
        List<String> searchResults = this.getSearchResults(response, topN);
        log.info("system_prompt: {}", (Object)this.systemPrompt);
        log.info("user_instructions: {}", (Object)this.userInstructions);
        start = Instant.now();
        ChatCompletionOutput output = this.llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(this.systemPrompt, this.userInstructions, llmModel, llmQuestion, chatHistory, searchResults, timeout));
        log.info("doChatCompletion complete. ({})", (Object)this.getDuration(start));
        String answer = null;
        String errorMessage = null;
        String interactionId = null;
        if (output.isErrorOccurred()) {
            errorMessage = output.getErrors().get(0);
        } else {
            answer = (String)output.getAnswers().get(0);
            if (conversationId != null) {
                start = Instant.now();
                interactionId = this.memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.getPromptTemplate(this.systemPrompt, this.userInstructions), answer, "retrieval_augmented_generation", GenerativeQAResponseProcessor.jsonArrayToString(searchResults));
                log.info("Created a new interaction: {} ({})", (Object)interactionId, (Object)this.getDuration(start));
            }
        }
        return this.insertAnswer(response, answer, errorMessage, interactionId);
    }

    long getDuration(Instant start) {
        return Duration.between(start, Instant.now()).toMillis();
    }

    public String getType() {
        return "retrieval_augmented_generation";
    }

    private SearchResponse insertAnswer(SearchResponse response, String answer, String errorMessage, String interactionId) {
        return new GenerativeSearchResponse(answer, errorMessage, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters(), interactionId);
    }

    private List<String> getSearchResults(SearchResponse response, Integer topN) {
        ArrayList<String> searchResults = new ArrayList<String>();
        SearchHit[] hits = response.getHits().getHits();
        int total = hits.length;
        int end = topN != -1 ? Math.min(topN, total) : total;
        for (int i = 0; i < end; ++i) {
            Map docSourceMap = hits[i].getSourceAsMap();
            for (String contextField : this.contextFields) {
                Object context = docSourceMap.get(contextField);
                if (context == null) {
                    log.error("Context " + contextField + " not found in search hit " + hits[i]);
                    throw new RuntimeException();
                }
                searchResults.add(context.toString());
            }
        }
        return searchResults;
    }

    private static String jsonArrayToString(List<String> listOfStrings) {
        JsonArray array = new JsonArray(listOfStrings.size());
        listOfStrings.forEach(arg_0 -> ((JsonArray)array).add(arg_0));
        return array.toString();
    }

    @Generated
    public void setMemoryClient(ConversationalMemoryClient memoryClient) {
        this.memoryClient = memoryClient;
    }

    @Generated
    public Llm getLlm() {
        return this.llm;
    }

    @Generated
    public void setLlm(Llm llm) {
        this.llm = llm;
    }

    public static final class Factory
    implements Processor.Factory<SearchResponseProcessor> {
        private final Client client;
        private final BooleanSupplier featureFlagSupplier;

        public Factory(Client client, BooleanSupplier supplier) {
            this.client = client;
            this.featureFlagSupplier = supplier;
        }

        public SearchResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) throws Exception {
            if (this.featureFlagSupplier.getAsBoolean()) {
                String modelId = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"model_id");
                String llmModel = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"llm_model");
                List contextFields = ConfigurationUtils.readList((String)"retrieval_augmented_generation", (String)tag, config, (String)"context_field_list");
                if (contextFields.isEmpty()) {
                    throw ConfigurationUtils.newConfigurationException((String)"retrieval_augmented_generation", (String)tag, (String)"context_field_list", (String)"required property can't be empty.");
                }
                String systemPrompt = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"system_prompt");
                String userInstructions = ConfigurationUtils.readOptionalStringProperty((String)"retrieval_augmented_generation", (String)tag, config, (String)"user_instructions");
                log.info("model_id {}, llm_model {}, context_field_list {}, system_prompt {}, user_instructions {}", (Object)modelId, (Object)llmModel, (Object)contextFields, (Object)systemPrompt, (Object)userInstructions);
                return new GenerativeQAResponseProcessor(this.client, tag, description, ignoreFailure, ModelLocator.getLlm(modelId, this.client), llmModel, contextFields, systemPrompt, userInstructions, this.featureFlagSupplier);
            }
            throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG);
        }
    }
}

