Loading

1

import json
import logging
from collections.abc import Generator
from datetime import datetime
from typing import Any, Union

from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueMessageEndEvent
from core.ai_provider.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.ai_provider.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
from core.ai_provider.entities.model_entities import ModelType, ModelFeature
from models.model import Message
from services.ai_provider_service import AIProviderService as ModelProviderService

logger = logging.getLogger(__name__)


class FunctionCallAgentRunner(BaseAgentRunner):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
        """
        Run FunctionCall agent application
        """
        try:
            response = ''
            prompt_messages = []
            model_instance = self.model_instance
            self.inputs = kwargs.get('inputs', {})
            run_st_0 = datetime.now()

            self.query = query

            self._update_timestamp("开始运行")
            logger.info(
                f"开始运行 FunctionCallAgentRunner,查询: {query},开始时间:{self._timestamp_new.strftime('%Y-%m-%d %H:%M:%S')}")
            logger.info(f"知识库配置:{self.kb_config()}")
            self.tool_call_results = []
            app_generate_entity = self.application_generate_entity
            app_config = self.app_config

            # convert tools into ModelRuntime Tool format
            tool_instances, prompt_messages_tools = self._init_prompt_tools()

            # 缓存工具ID、名称map
            for tool_id, tool in tool_instances.items():
                # 缓存工具名称
                if tool_id not in self.tool_id_name_map:
                    if tool_id.startswith("dataset_"):
                        self.tool_id_name_map[tool_id] = tool.retrival_tool.display_name
                    else:
                        self.tool_id_name_map[tool_id] = tool.identity.name
                # 缓存工具meta
                if tool_id not in self.tool_id_meta_map:
                    _meta = {}
                    if tool_id.startswith("dataset_"):
                        try:
                            _identity = tool.identity
                            _description = tool.description
                            _meta = {
                                'provider': _identity.provider,
                                'name': tool.retrival_tool.display_name,
                                'description': _description.llm or _description.human.zh_Hans or _description.human.en_US if _description else "",
                            }
                            logger.info(f"知识库工具元数据:{_meta}")
                        except Exception as e:
                            logger.error(f"获取工具元数据失败:{e}")
                    else:
                        try:
                            _identity = tool.identity
                            _parameters = tool.parameters
                            _description = tool.description
                            _meta = {
                                'provider': _identity.provider,
                                'name': _identity.label.zh_Hans or _identity.label.en_US or _identity.name if _identity else "",
                                'description': _description.llm or _description.human.zh_Hans or _description.human.en_US if _description else "",
                            }
                        except Exception as e:
                            logger.error(f"获取工具元数据失败:{e}")

                    self.tool_id_meta_map[tool_id] = _meta

            # 记录所有 dataset_ 开头的工具
            dataset_tools = [(tool_id, tool.retrival_tool.display_name, tool) for tool_id, tool in
                             tool_instances.items() if tool_id.startswith('dataset_')]

            is_first_call = True
            has_yield = False

            iteration_step = 1
            max_iteration_steps = 3  # 固定为3次迭代

            function_call_state = True
            final_answer = ''

            # 跟踪工具调用状态
            dataset_tool_called = False
            has_tool_been_called = False

            trace_manager = app_generate_entity.trace_manager

            if self.kb_config().rewrite_config.enabled:
                self.full_question()

            # 准备工作完成,开始推理
            self._sync_status(tool_name="大模型思考",
                              status="开始",
                              msg="",
                              tool_input={"query": query},
                              answer="",
                              thought="大模型正在思考决策",
                              llm_usage=self.llm_usage,
                              observation={"query": query,
                                           "answer": "",
                                           "usage": self.llm_usage.to_dict()})

            # 如果没有工具,直接使用model_instance推理最终答案
            if not tool_instances:
                prompt_messages, tool_rsp_list = self._organize_prompt_messages()
                self.recalc_llm_max_tokens(self.model_config, prompt_messages)

                # Add detailed logging before LLM invocation
                self._log_messages_before_llm(prompt_messages=prompt_messages, tools=None)
                self._update_timestamp("无工具,调用LLM前")

                chunks = model_instance.invoke_llm(
                    prompt_messages=prompt_messages,
                    model_parameters=app_generate_entity.model_conf.parameters,
                    tools=None,
                    stop=app_generate_entity.model_conf.stop,
                    stream=self.stream_tool_call,
                    user=self.user_id,
                    callbacks=[],
                    model_schema=self.application_generate_entity.model_conf.model_schema,
                )

                response = ''

                if self.stream_tool_call:
                    for chunk in chunks:
                        if chunk.delta.message and chunk.delta.message.content:
                            if isinstance(chunk.delta.message.content, list):
                                for content in chunk.delta.message.content:
                                    response += content.data
                            else:
                                response += chunk.delta.message.content

                        if chunk.delta.usage:
                            usage = chunk.delta.usage
                            usage.provider = model_instance.provider
                            usage.model = model_instance.model
                            usage.features = ModelProviderService().get_model_features(tenant_id=self.tenant_id,
                                                                                       provider=model_instance.provider,
                                                                                       model=model_instance.model)
                            self.increase_usage(usage)

                        has_yield = True
                        yield chunk
                else:
                    result = chunks
                    if result.message and result.message.content:
                        if isinstance(result.message.content, list):
                            for content in result.message.content:
                                response += content.data
                        else:
                            response += result.message.content

                    if result.usage:
                        usage = result.usage
                        usage.provider = model_instance.provider
                        usage.model = model_instance.model
                        usage.features = ModelProviderService().get_model_features(tenant_id=self.tenant_id,
                                                                                   provider=model_instance.provider,
                                                                                   model=model_instance.model)
                        self.increase_usage(usage)

                    has_yield = True
                    yield LLMResultChunk(
                        model=model_instance.model,
                        prompt_messages=result.prompt_messages,
                        system_fingerprint=result.system_fingerprint,
                        delta=LLMResultChunkDelta(
                            index=0,
                            message=result.message,
                            usage=result.usage,
                        )
                    )

                self.update_db_variables(self.variables_pool, self.db_variables_pool)

                self._update_timestamp("[DONE] 无工具,调用LLM完成")
                self._sync_status(tool_name="大模型思考",
                                  status="完成",
                                  msg="",
                                  tool_input={"query": query},
                                  answer="",
                                  thought="大模型完成思考决策",
                                  llm_usage=self.llm_usage,
                                  observation={"query": query,
                                               "answer": response,
                                               "messages": [{
                                                   "ROLE": msg.role.value.upper(),
                                                   "CONTENT": msg.read_content_str()
                                               } for msg in prompt_messages],
                                               "usage": self.llm_usage.to_dict()})

                # 发布结束事件
                self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
                    model=model_instance.model,
                    prompt_messages=prompt_messages,
                    message=AssistantPromptMessage(
                        content=response
                    ),
                    usage=self.llm_usage,
                    quote=self.quote,
                    system_fingerprint=''
                )), PublishFrom.APPLICATION_MANAGER)

                return

            while function_call_state and iteration_step <= max_iteration_steps:
                function_call_state = False
                message_file_ids = []
                tool_calls = []

                logger.info(f"是否第一次调用:{is_first_call}\n"
                            f"是否强制调用dataset工具:{self.kb_force_retrieve(dataset_tools)}\n"
                            f"是否执行yield:{not is_first_call or (is_first_call and not self.kb_force_retrieve(dataset_tools))}")

                # 检查是否需要强制调用dataset工具
                if iteration_step == 2 and not dataset_tool_called and dataset_tools:
                    logger.info(
                        f"应用配置了dataset工具,但是第一次迭代LLM没有调用任何工具,第 {iteration_step} 回合迭代是否调用看是否有开启强制调用")
                    # 如果应用配置强制调用知识库,则第一次函数调用没有工具调用后,第二次迭代强制调用所有配置的知识库工具
                    force_retrieve = self.kb_force_retrieve(dataset_tools)
                    if force_retrieve:
                        # 如果有知识库工具,根据规则选择是否进行查询优化
                        for tool_call_id, tool_call_name, tool_call_args in tool_calls:
                            if tool_call_id.startswith('dataset_'):
                                tool_call_args['query'] = self.query

                        tool_calls = [(tool_id, display_name, {'query': self.query,
                                                               'keyword_extraction_config': self.kb_config().keyword_extraction_config.to_dict()})
                                      for tool_id, display_name, tool in dataset_tools]

                    # 处理工具调用
                    has_tool_been_called, dataset_tool_called, message_file_ids, tool_responses = self.tools_invoke_parallel(
                        has_tool_been_called=has_tool_been_called,
                        dataset_tool_called=dataset_tool_called,
                        tool_instances=tool_instances,
                        tool_calls=tool_calls,
                    )

                    logger.info(
                        f"FunctionCallAgentRunner 调用dataset工具,耗时:{(self._timestamp_new - self._timestamp_old).total_seconds()} 秒")

                    function_call_state = True
                    is_first_call = False
                    iteration_step += 1
                    continue

                # 选择合适的模型实例
                stream_tool_call = self.stream_tool_call
                model_schema = app_generate_entity.model_conf.model_schema

                if iteration_step == 1 and not has_tool_been_called:
                    # 第一次迭代使用tool call模型
                    from core.ai_provider.ai_model_manager import ModelManager
                    try:
                        model_manager = ModelManager()
                        model_instance = model_manager.get_default_model_instance(
                            tenant_id=app_config.tenant_id,
                            model_type=ModelType.LLM_TOOL_CALL,
                        )
                        model_schema = model_instance.model_type_instance.get_model_schema(model_instance.model,
                                                                                           model_instance.credentials)
                        stream_tool_call = ModelFeature.STREAM_TOOL_CALL in model_schema.features

                        logger.info(f"使用默认的工具调用模型: {model_instance.provider} - {model_instance.model}")
                    except Exception as e:
                        logger.warning(f"Failed to get default function call model, using configured model. Error: {e}")
                        model_instance = self.model_instance

                else:
                    # 其他迭代使用应用配置的模型
                    model_instance = self.model_instance
                    logger.info(
                        f"使用应用配置的模型:{iteration_step}: {model_instance.provider} - {model_instance.model}")

                # 重新计算 LLM 最大 tokens
                prompt_messages, tool_rsp_list = self._organize_prompt_messages()
                logger.info(f"Organized prompt messages for iteration {iteration_step}")
                self.recalc_llm_max_tokens(self.model_config, prompt_messages)

                # Add detailed logging before LLM invocation
                self._log_messages_before_llm(prompt_messages=prompt_messages, tools=prompt_messages_tools)

                # 确定当前迭代使用的工具
                current_tools = []
                if not has_tool_been_called:
                    # 如果还没有工具被调用,提供所有工具
                    current_tools = prompt_messages_tools
                    logger.info(f"Iteration {iteration_step}: Providing all tools ({len(current_tools)} tools)")
                else:
                    # 如果已经有工具被调用,清空工具列表
                    logger.info(
                        f"Iteration {iteration_step}: Tools have been called. Clearing tool list for final reasoning.")

                # Add more detailed logging about the request
                logger.info(f"Model: {model_instance.provider} - {model_instance.model}")
                logger.info(f"Tools count: {len(current_tools)}")

                # Check if prompt messages have proper format
                for i, msg in enumerate(prompt_messages):
                    if hasattr(msg, 'content') and msg.content is None:
                        logger.warning(f"Message {i} has None content, type: {type(msg)}")
                        # Fix None content
                        msg.content = ''

                self._update_timestamp(f"调用LLM前")
                logger.info(f"工具 22:{current_tools}")


                chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
                    prompt_messages=prompt_messages,
                    model_parameters=app_generate_entity.model_conf.parameters,
                    tools=current_tools,
                    stop=app_generate_entity.model_conf.stop,
                    stream=stream_tool_call,
                    user=self.user_id,
                    callbacks=[],
                    model_schema=model_schema
                )

                tool_calls: list[tuple[str, str, dict[str, Any]]] = []

                # save full response
                response = ''

                current_llm_usage = None

                if stream_tool_call:
                    is_first_chunk = True
                    for chunk in chunks:
                        if is_first_chunk:
                            is_first_chunk = False
                        # check if there is any tool call
                        if self.check_tool_calls(chunk):
                            # 只有在未调用过工具的情况下才设置function_call_state为True
                            if not has_tool_been_called:
                                function_call_state = True
                                tool_calls.extend(self.extract_tool_calls(chunk))

                        if chunk.delta.message and chunk.delta.message.content:
                            if isinstance(chunk.delta.message.content, list):
                                for content in chunk.delta.message.content:
                                    response += content.data
                            else:
                                response += chunk.delta.message.content

                        if chunk.delta.usage:
                            usage = chunk.delta.usage
                            usage.provider = model_instance.provider
                            usage.model = model_instance.model
                            usage.features = ModelProviderService().get_model_features(tenant_id=self.tenant_id,
                                                                                       provider=model_instance.provider,
                                                                                       model=model_instance.model)
                            self.increase_usage(usage)
                            current_llm_usage = usage

                        if not is_first_call or (is_first_call and not self.kb_force_retrieve(dataset_tools)):
                            has_yield = True
                            yield chunk
                else:
                    result: LLMResult = chunks
                    # check if there is any tool call
                    if self.check_blocking_tool_calls(result):
                        # 只有在未调用过工具的情况下才设置function_call_state为True
                        if not has_tool_been_called:
                            function_call_state = True
                            tool_calls.extend(self.extract_blocking_tool_calls(result))

                    if result.usage:
                        usage = result.usage
                        usage.provider = model_instance.provider
                        usage.model = model_instance.model
                        usage.features = ModelProviderService().get_model_features(tenant_id=self.tenant_id,
                                                                                   provider=model_instance.provider,
                                                                                   model=model_instance.model)
                        self.increase_usage(usage)
                        current_llm_usage = usage

                    if result.message and result.message.content:
                        if isinstance(result.message.content, list):
                            for content in result.message.content:
                                response += content.data
                        else:
                            response += result.message.content

                    if not result.message.content:
                        result.message.content = ''

                    if not is_first_call or (is_first_call and not self.kb_force_retrieve(dataset_tools)):
                        has_yield = True
                        yield LLMResultChunk(
                            model=model_instance.model,
                            prompt_messages=result.prompt_messages,
                            system_fingerprint=result.system_fingerprint,
                            delta=LLMResultChunkDelta(
                                index=0,
                                message=result.message,
                                usage=result.usage,
                            )
                        )

                assistant_message = AssistantPromptMessage(
                    content='',
                    tool_calls=[]
                )

                logger.info(f"工具调用迭代 {iteration_step} - 工具调用数量:{len(tool_calls)} 函数:{tool_calls}")

                # 第一次调用如果LLM有触发dataset tool的调用,此时根据配置决定是否强制调用全部知识库工具
                if iteration_step == 1 and len(tool_calls) > 0:
                    _has_kb_called = False
                    kb_args = {"query": "", "kb_keywords": ""}
                    for tool_call_id, _, tool_args in tool_calls:
                        if tool_call_id.startswith('dataset_'):
                            _has_kb_called = True
                            kb_args['query'] = kb_args['query'] + f" {tool_args.get('query', '')}"

                    if _has_kb_called:
                        # 有工具被调用时才需要关键词提取
                        if self.need_keyword_extraction():
                            kb_args['keyword_extraction_config'] = self.kb_config().keyword_extraction_config.to_dict()
                            # 更新工具调用参数
                            for i, (tool_call_id, tool_call_name, tool_call_args) in enumerate(tool_calls):
                                if tool_call_id.startswith('dataset_'):
                                    # 更新工具调用参数
                                    tool_calls[i] = (tool_call_id, tool_call_name, {**tool_call_args, **kb_args})

                        if self.kb_force_retrieve(dataset_tools):
                            # 手动强制调用全部知识库
                            tool_call_ids = [tool_call[0] for tool_call in tool_calls]
                            # 将dataset_tools中不在tool_calls中的工具手动加到tool_calls中
                            for tool_id, display_name, tool in dataset_tools:
                                if tool_id not in tool_call_ids:
                                    # 直接添加到tool_calls中
                                    tool_calls.append((tool_id, display_name, kb_args))

                if tool_calls:
                    assistant_message.tool_calls = [
                        AssistantPromptMessage.ToolCall(
                            id=tool_call[0],
                            type='function',
                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
                                name=tool_call[1],
                                arguments=json.dumps(tool_call[2], ensure_ascii=False)
                            )
                        ) for tool_call in tool_calls
                    ]
                else:
                    assistant_message.content = response

                self._current_thoughts.append(assistant_message)

                # 在这里添加检查 - 第一次迭代结束后,检查是否有dataset工具被调用
                logger.info(
                    f"Iteration {iteration_step} - Dataset tool called: {dataset_tool_called}")
                if iteration_step == 1 and not dataset_tool_called and self.kb_force_retrieve(dataset_tools):
                    function_call_state = True

                self._update_timestamp(f"[DONE] 调用LLM完成")
                # call tools
                has_tool_been_called, dataset_tool_called, message_file_ids, tool_responses = self.tools_invoke_parallel(
                    has_tool_been_called=has_tool_been_called,
                    dataset_tool_called=dataset_tool_called,
                    tool_instances=tool_instances,
                    tool_calls=tool_calls,
                )
                # tool_responses = []
                # threads = []
                # # 创建一个列表来存储每个线程的结果
                # thread_results = [None] * len(tool_calls)
                # app = current_app._get_current_object()
                # # 捕获当前请求上下文信息(如果存在)
                # request_context = None
                # if has_request_context():
                #     # 捕获当前请求的相关信息
                #     request_context = {
                #         'path': request.path,
                #         'base_url': request.base_url,
                #         'method': request.method,
                #         'headers': dict(request.headers),
                #     }
                #
                #     # 捕获g对象中的数据
                #     g_data = {}
                #     for key in dir(g):
                #         if not key.startswith('_') and hasattr(g, key):
                #             try:
                #                 g_data[key] = getattr(g, key)
                #             except AttributeError:
                #                 pass  # 忽略无法获取的属性
                #
                #     if g_data:
                #         request_context['g_data'] = g_data
                #
                # # 递归调用工具
                # for tool_index, (tool_call_id, tool_call_name, tool_call_args) in enumerate(tool_calls):
                #     print(f"工具调用参数:{tool_call_args}")
                #     # 标记已调用过工具
                #     has_tool_been_called = True
                #
                #     # 检查是否是dataset工具
                #     if tool_call_id.startswith('dataset_'):
                #         dataset_tool_called = True
                #         logger.info(f"Dataset tool called: {tool_call_name}")
                #
                #     tool_instance = tool_instances.get(tool_call_id)
                #     if not tool_instance:
                #         logger.warning(f"工具 {tool_call_name} 不存在,无法调用")
                #     else:
                #         # invoke tool
                #         logger.info(f"Invoking tool: {tool_call_name}")
                #         if hasattr(tool_instance, 'retrival_tool'):
                #             tool_display_name = tool_instance.retrival_tool.display_name
                #         else:
                #             tool_display_name = tool_instance.identity.name
                #
                #         # 创建并启动线程,传递Flask应用实例和请求上下文
                #         thread = threading.Thread(
                #             target=self.invoke_tool_thread,
                #             args=(tool_call_id, tool_call_name, tool_call_args, tool_instance,
                #                   tool_display_name, thread_results, tool_index, app, request_context)
                #         )
                #         threads.append(thread)
                #         thread.start()
                #         # # 同步工具调用进度
                #         # self._sync_status(status="开始",
                #         #                   msg="",
                #         #                   tool_name=tool_display_name,
                #         #                   tool_input=tool_call_args,
                #         #                   thought=f"[{tool_display_name}] 检索中",
                #         #                   tool_invoke_meta=None,
                #         #                   observation=None,
                #         #                   llm_usage=current_llm_usage,
                #         #                   answer='',
                #         #                   messages_ids=message_file_ids, )
                #         #
                #         # self._update_timestamp(f"调用工具:{tool_call_name} 前")
                #         #
                #         # tool_invoke_response, message_files, tool_invoke_meta, tool_invoke_messages = ToolEngine.agent_invoke(
                #         #     tool=tool_instance,
                #         #     tool_parameters=tool_call_args,
                #         #     user_id=self.user_id,
                #         #     tenant_id=self.tenant_id,
                #         #     message=self.message,
                #         #     invoke_from=self.application_generate_entity.invoke_from,
                #         #     agent_tool_callback=self.agent_callback,
                #         #     trace_manager=trace_manager,
                #         #     kb_keywords=self.kb_keywords,
                #         # )
                #         # self._update_timestamp(f"[DONE] 调用工具:{tool_call_name} 完成")
                #         # # 添加工具调用后的日志
                #         # logger.info(f"After invoking tool {tool_call_name}:")
                #         # logger.info(f"Tool response type: {type(tool_invoke_response)}")
                #         # logger.info(f"Tool response is None: {tool_invoke_response is None}")
                #         #
                #         # # 对强制调用的dataset tools进行缓存 - 使用辅助方法
                #         # self._cache_tool_result(tool_id=tool_call_id,
                #         #                         tool_name=tool_call_name,
                #         #                         tool_args=tool_call_args,
                #         #                         tool_response=tool_invoke_response,
                #         #                         tool_meta=tool_invoke_meta.to_dict() | self.tool_id_meta_map[
                #         #                             tool_call_id],
                #         #                         tool_invoke_messages=tool_invoke_messages)
                #         #
                #         # self._sync_status(status="完成",
                #         #                   msg="",
                #         #                   tool_name=tool_display_name,
                #         #                   tool_input=tool_call_args,
                #         #                   thought="",
                #         #                   tool_invoke_meta={tool_display_name: tool_invoke_meta},
                #         #                   observation={
                #         #                       **tool_call_args,
                #         #                       'quote': self._read_tool_response(tool_call_id),
                #         #                       'duration': f"{(self._timestamp_new - self._timestamp_old).total_seconds()} 秒"
                #         #                   },
                #         #                   llm_usage=current_llm_usage,
                #         #                   answer='',
                #         #                   messages_ids=message_file_ids, )
                #         #
                #         # # publish files
                #         # for message_file_id, save_as in message_files:
                #         #     if save_as:
                #         #         self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id,
                #         #                                      name=save_as)
                #         #     # add message file ids
                #         #     message_file_ids.append(message_file_id)
                #         #
                #         # tool_response = {
                #         #     "tool_call_id": tool_call_id,
                #         #     "tool_call_name": tool_call_name,
                #         #     "tool_response": tool_invoke_response,
                #         #     "meta": tool_invoke_meta.to_dict()
                #         # }
                #         #
                #         # # Sanitize tool response if needed
                #         # if tool_invoke_response is not None:
                #         #     # Ensure tool response is a string
                #         #     if not isinstance(tool_invoke_response, str):
                #         #         logger.warning(
                #         #             f"Tool response is not a string, converting: {type(tool_invoke_response)}")
                #         #         try:
                #         #             tool_response["tool_response"] = str(tool_invoke_response)
                #         #         except Exception as e:
                #         #             logger.error(f"Failed to convert tool response to string: {e}")
                #         #             tool_response[
                #         #                 "tool_response"] = f"Error converting response: {type(tool_invoke_response)}"
                #         #
                #         #     # Truncate very long responses if needed
                #         #     if len(tool_response["tool_response"]) > 10000:
                #         #         logger.warning(
                #         #             f"Tool response is very long ({len(tool_response['tool_response'])} chars), truncating")
                #         #         tool_response["tool_response"] = tool_response["tool_response"][
                #         #                                          :10000] + "... [truncated]"
                #         #
                #         #     # Ensure the response is valid for use in ToolPromptMessage
                #         #     if not tool_response["tool_call_id"]:
                #         #         logger.warning(f"Tool response missing tool_call_id, generating one")
                #         #         tool_response[
                #         #             "tool_call_id"] = f"call_{tool_response['tool_call_name']}_{hash(str(tool_response['tool_response']))}"
                #         #
                #         # tool_responses.append(tool_response)
                #         # if tool_response['tool_response'] is not None:
                #         #     logger.info(
                #         #         f"Adding tool response for {tool_response['tool_call_name']} to current thoughts")
                #         #     try:
                #         #         # Make sure content is a string
                #         #         content = str(tool_response['tool_response'])
                #         #         # Create the tool message with sanitized values
                #         #         tool_message = ToolPromptMessage(
                #         #             content=content,
                #         #             tool_call_id=tool_response['tool_call_id'],
                #         #             name=tool_response['tool_call_name'],
                #         #         )
                #         #         # Validate the tool message
                #         #         if not hasattr(tool_message, 'content') or tool_message.content is None:
                #         #             logger.warning("Tool message has None content, setting to empty string")
                #         #             tool_message.content = ""
                #         #         if not hasattr(tool_message, 'tool_call_id') or not tool_message.tool_call_id:
                #         #             logger.warning("Tool message has empty tool_call_id, generating one")
                #         #             tool_message.tool_call_id = f"call_{tool_message.name}_{hash(str(content))}"
                #         #
                #         #         self._current_thoughts.append(tool_message)
                #         #     except Exception as e:
                #         #         logger.error(f"Failed to create ToolPromptMessage: {e}")
                #         #         # Create a simplified tool message as fallback
                #         #         fallback_message = ToolPromptMessage(
                #         #             content="Error processing tool response",
                #         #             tool_call_id=f"error_{tool_response['tool_call_name']}",
                #         #             name=tool_response['tool_call_name'],
                #         #         )
                #         #         self._current_thoughts.append(fallback_message)
                #         # else:
                #         #     logger.warning(f"Tool {tool_response['tool_call_name']} returned None response")
                #
                # # 等待所有线程完成
                # for thread in threads:
                #     thread.join()
                #
                # # 处理所有线程的结果
                # for result in thread_results:
                #     if result is None:
                #         continue
                #
                #     tool_call_id = result["tool_call_id"]
                #     tool_call_name = result["tool_call_name"]
                #     tool_invoke_response = result["tool_response"]
                #     tool_invoke_meta = result["meta"]
                #     thread_message_files = result.get("message_files", [])
                #
                #     # 处理文件
                #     for message_file_id, save_as in thread_message_files:
                #         if save_as:
                #             self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id,
                #                                          name=save_as)
                #         with self._message_files_lock:
                #             message_file_ids.append(message_file_id)
                #
                #     # 添加到工具响应列表
                #     tool_response = {
                #         "tool_call_id": tool_call_id,
                #         "tool_call_name": tool_call_name,
                #         "tool_response": tool_invoke_response,
                #         "meta": tool_invoke_meta
                #     }
                #
                #     tool_responses.append(tool_response)
                #
                #     # 如果工具响应不为None,添加到当前思考中
                #     if tool_invoke_response is not None and not result.get("error", False):
                #         with self._thoughts_lock:  # 保护_current_thoughts的修改
                #             self._current_thoughts.append(
                #                 ToolPromptMessage(
                #                     content=tool_invoke_response,
                #                     tool_call_id=tool_call_id,
                #                     name=tool_call_name,
                #                 )
                #             )

                # # 更新状态显示所有工具调用完成
                # self._sync_status(
                #     status="完成",
                #     msg="所有工具调用已完成",
                #     tool_name="",
                #     tool_input="",
                #     thought="正在整合工具调用结果...",
                #     tool_invoke_meta=None,
                #     observation=None,
                #     answer="",
                #     messages_ids=message_file_ids
                # )
                #
                # logger.info(
                #     f"FunctionCallAgentRunner 调用dataset工具,耗时:{(self._timestamp_new - self._timestamp_old).total_seconds()} 秒")

                logger.info(f"完成工具调用,有结果的工具有:{len(tool_responses)}个")

                if len(tool_responses) > 0:
                    # 工具调用后,设置function_call_state为True以继续下一次迭代
                    function_call_state = True
                    # 第一次调用已经结束
                    is_first_call = False

                # update prompt tool
                for prompt_tool in prompt_messages_tools:
                    self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)

                iteration_step += 1
                # 如果没有调用工具,也标记第一次调用已经结束
                if not function_call_state:
                    is_first_call = False

            # 如果final_answer为空(可能是因为只有第一次调用),确保返回最后一次响应
            if not final_answer.strip() and response:
                final_answer = response
                logger.info(f"Using last response as final answer: {final_answer[:100]}..." if len(
                    final_answer) > 100 else final_answer)

            self.update_db_variables(self.variables_pool, self.db_variables_pool)

            _run_duration = (datetime.now() - run_st_0).total_seconds()
            self._sync_status(
                tool_name="大模型思考",
                status="完成",
                msg="",
                tool_input={"query": query},
                thought="大模型完成思考决策",
                answer="",
                llm_usage=self.llm_usage,
                observation={
                    "query": query,
                    "answer": final_answer,
                    "messages": [{
                        "ROLE": msg.role.value.upper(),
                        "CONTENT": msg.read_content_str()
                    } for msg in prompt_messages],
                    "duration": f"{_run_duration} 秒",
                    "usage": self.llm_usage.to_dict()}
            )

            if not has_yield:
                yield LLMResultChunk(
                    model=self.model_instance.model,
                    prompt_messages=prompt_messages,
                    system_fingerprint='',
                    delta=LLMResultChunkDelta(
                        index=0,
                        message=AssistantPromptMessage(content=final_answer),
                        usage=self.llm_usage,
                    )
                )

            # publish end event
            self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
                model=model_instance.model,
                prompt_messages=prompt_messages,
                message=AssistantPromptMessage(
                    content=final_answer
                ),
                usage=self.llm_usage,
                quote=self.quote,
                system_fingerprint=''
            )), PublishFrom.APPLICATION_MANAGER)
        except Exception as e:
            logger.exception(f"Error in FunctionCallAgentRunner.run: {e}")
            # Create an error message to return to the user
            error_message = f"An error occurred while processing your request: {str(e)}"

            # Generate an error response
            yield LLMResultChunk(
                model="error",
                prompt_messages=[],
                system_fingerprint="",
                delta=LLMResultChunkDelta(
                    index=0,
                    message=AssistantPromptMessage(content=error_message),
                    usage=LLMUsage.empty_usage(),
                )
            )

            # Publish an end event with the error
            self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
                model="error",
                prompt_messages=[],
                message=AssistantPromptMessage(content=error_message),
                usage=LLMUsage.empty_usage(),
                quote=self.quote,
                system_fingerprint=""
            )), PublishFrom.APPLICATION_MANAGER)

    def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
        """
        Check if there is any tool call in llm result chunk
        """
        if llm_result_chunk.delta.message.tool_calls:
            return True
        return False

    def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
        """
        Check if there is any blocking tool call in llm result
        """
        if llm_result.message.tool_calls:
            return True
        return False

    def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[
        None, list[tuple[str, str, dict[str, Any]]]]:
        """
        Extract tool calls from llm result chunk

        Returns:
            List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
        """
        tool_calls = []
        for prompt_message in llm_result_chunk.delta.message.tool_calls:
            args = {}
            if prompt_message.function.arguments != '':
                try:
                    args = json.loads(prompt_message.function.arguments)
                except json.JSONDecodeError as e:
                    logger.error(f"Failed to parse tool arguments: {e}, arguments: {prompt_message.function.arguments}")
                    args = {"error": "Failed to parse arguments"}

            tool_id = prompt_message.function.name
            tool_name = prompt_message.function.name
            if tool_id.startswith('dataset_'):
                tool_name = self.tool_id_name_map.get(tool_id, tool_id)

            tool_calls.append((tool_id, tool_name, args))

        return tool_calls

    def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
        """
        Extract blocking tool calls from llm result

        Returns:
            List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
        """
        tool_calls = []
        for prompt_message in llm_result.message.tool_calls:
            args = {}
            if prompt_message.function.arguments != '':
                args = json.loads(prompt_message.function.arguments)

            tool_id = prompt_message.function.name
            tool_name = prompt_message.function.name
            if tool_id.startswith('dataset_'):
                tool_name = self.tool_id_name_map.get(tool_id, tool_id)

            tool_calls.append((tool_id, tool_name, args))

        return tool_calls

    # def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
    #     """
    #     Organize user query
    #     """
    #     prompt_messages.append(UserPromptMessage(content=query))
    #
    #     return prompt_messages

    # def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
    #     """
    #     As for now, gpt supports both fc and vision at the first iteration.
    #     We need to remove the image messages from the prompt messages at the first iteration.
    #     """
    #     prompt_messages = deepcopy(prompt_messages)
    #     logger.info("Clearing image messages from prompt messages")
    #
    #     for i, prompt_message in enumerate(prompt_messages):
    #         if isinstance(prompt_message, UserPromptMessage):
    #             if isinstance(prompt_message.content, list):
    #                 logger.info(f"Converting list content to string for message {i}")
    #                 text_content = '\n'.join([
    #                     content.data if content.type == PromptMessageContentType.TEXT else
    #                     '[image]' if content.type == PromptMessageContentType.IMAGE else
    #                     '[file]'
    #                     for content in prompt_message.content
    #                 ])
    #                 prompt_message.content = text_content
    #                 logger.info(f"Converted content: {text_content[:100]}...")
    #
    #     return prompt_messages

posted @ 2025-07-28 11:09  踩坑大王  阅读(12)  评论(0)    收藏  举报