/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.agent;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agui.RunFinishedEvent;
import org.opensearch.ml.common.agui.ToolCallResultEvent;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionStreamTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.client.Client;

public class StreamingWrapper {
    @Generated
    private static final Logger log = LogManager.getLogger(StreamingWrapper.class);
    private final TransportChannel channel;
    private boolean isStreaming;
    private Client client;
    private final Map<String, String> parameters;

    public StreamingWrapper(TransportChannel channel, Client client, Map<String, String> parameters) {
        this.channel = channel;
        this.client = client;
        this.parameters = parameters;
        this.isStreaming = channel != null;
    }

    public void fixInteractionRole(List<String> interactions) {
        if (this.isStreaming && !interactions.isEmpty()) {
            try {
                String lastInteraction = interactions.get(interactions.size() - 1);
                Map messageMap = (Map)StringUtils.gson.fromJson(lastInteraction, Map.class);
                if (!messageMap.containsKey("role") && messageMap.containsKey("tool_calls")) {
                    messageMap.put("role", "assistant");
                    interactions.set(interactions.size() - 1, StringUtils.toJson((Object)messageMap));
                }
            }
            catch (Exception e) {
                log.error("Failed to fix assistant message role after parseLLMOutput", (Throwable)e);
            }
        }
    }

    public ActionRequest createPredictionRequest(LLMSpec llm, Map<String, String> parameters, String tenantId) {
        return new MLPredictionTaskRequest(llm.getModelId(), RemoteInferenceMLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)RemoteInferenceInputDataSet.builder().parameters(parameters).build()).build(), !this.isStreaming, null, tenantId);
    }

    public void executeRequest(ActionRequest request, ActionListener<MLTaskResponse> listener) {
        if (this.isStreaming) {
            ((MLPredictionTaskRequest)request).setStreamingChannel(this.channel);
            this.client.execute((ActionType)MLPredictionStreamTaskAction.INSTANCE, request, listener);
            return;
        }
        this.client.execute((ActionType)MLPredictionTaskAction.INSTANCE, request, listener);
    }

    public void sendCompletionChunk(String sessionId, String parentInteractionId) {
        if (!this.isStreaming) {
            return;
        }
        MLTaskResponse completionChunk = this.createStreamChunk("", sessionId, parentInteractionId, true);
        try {
            this.channel.sendResponseBatch((TransportResponse)completionChunk);
        }
        catch (Exception e) {
            log.warn("Failed to send completion chunk: {}", (Object)e.getMessage());
        }
    }

    public void sendFinalResponse(String sessionId, ActionListener<Object> listener, String parentInteractionId, boolean verbose, List<ModelTensors> cotModelTensors, Map<String, Object> additionalInfo, String finalAnswer) {
        if (this.isStreaming) {
            listener.onResponse((Object)"Streaming completed");
        } else {
            MLChatAgentRunner.returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer);
        }
    }

    public void sendToolResponse(String toolOutput, String sessionId, String parentInteractionId) {
        if (this.isStreaming) {
            try {
                MLTaskResponse toolChunk = this.createStreamChunk(toolOutput, sessionId, parentInteractionId, false);
                this.channel.sendResponseBatch((TransportResponse)toolChunk);
            }
            catch (Exception e) {
                log.error("Failed to send tool response chunk", (Throwable)e);
            }
        }
    }

    public void sendBackendToolResult(String toolCallId, String toolResult, String sessionId, String parentInteractionId) {
        try {
            ToolCallResultEvent toolCallResultEvent = new ToolCallResultEvent("msg_" + System.nanoTime(), toolCallId, toolResult);
            MLTaskResponse toolChunk = this.createStreamChunk(toolCallResultEvent.toJsonString(), sessionId, parentInteractionId, false);
            this.channel.sendResponseBatch((TransportResponse)toolChunk);
        }
        catch (Exception e) {
            log.error("Failed to send backend tool AGUI events for toolCallId '{}': {}", (Object)toolCallId, (Object)e.getMessage());
            this.sendToolResponse(toolResult, sessionId, parentInteractionId);
        }
    }

    public void sendRunFinishedAndCloseStream(String sessionId, String parentInteractionId) {
        try {
            Object threadId = this.parameters.get("agui_thread_id");
            Object runId = this.parameters.get("agui_run_id");
            if (threadId == null) {
                log.warn("AG-UI threadId is null, using generated value. This may cause frontend errors.");
                threadId = "thread_" + System.nanoTime();
            }
            if (runId == null) {
                log.warn("AG-UI runId is null, using generated value. This may cause frontend errors.");
                runId = "run_" + System.nanoTime();
            }
            RunFinishedEvent runFinishedEvent = new RunFinishedEvent((String)threadId, (String)runId, null);
            ArrayList<ModelTensor> modelTensors = new ArrayList<ModelTensor>();
            Map<String, Boolean> dataMap = Map.of("content", runFinishedEvent.toJsonString(), "is_last", true);
            modelTensors.add(ModelTensor.builder().name("response").dataAsMap(dataMap).build());
            ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(modelTensors).build())).build();
            this.channel.sendResponseBatch((TransportResponse)new MLTaskResponse((MLOutput)output));
        }
        catch (Exception e) {
            log.error("Failed to send run finished event and close stream", (Throwable)e);
        }
    }

    private MLTaskResponse createStreamChunk(String toolOutput, String sessionId, String parentInteractionId, boolean isLast) {
        List<ModelTensor> tensors = Arrays.asList(ModelTensor.builder().name("response").dataAsMap(Map.of("content", toolOutput, "is_last", isLast)).build(), ModelTensor.builder().name("memory_id").result(sessionId).build(), ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build());
        ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(tensors).build();
        ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build();
        return new MLTaskResponse((MLOutput)output);
    }
}

