dify 本地源码启动
Ollama 安装部署
dify 智能体实践

dify 源码分析(一)功能概述
dify 源码分析(二)源码结构
dify 源码分析(三)agent
dify 源码分析(四)tools
dify 源码分析(五)chatflow
dify 源码分析(六)event
dify 源码分析(七)ratelimiter

文章目录

1. 工具调用概述

1.1. provider 继承关系

在这里插入图片描述

1.2. 工具类继承关系

在这里插入图片描述

1.3. 执行逻辑

1.3.1. ToolManager

/api/core/tools/tool_manager.py
ToolManager 管理不同的 *ProviderController

1.3.2. *ProviderController

core/tools/__base/tool_provider.py
core/tools/builtin_tool/provider.py
core/tools/custom_tool/provider.py
core/tools/workflow_as_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/mcp_tool/provider.py
*ProviderController 管理不同的提供者

内置工具 provider
core/tools/builtin_tool/providers/audio/audio.py
core/tools/builtin_tool/providers/code/code.py
core/tools/builtin_tool/providers/time/time.py
core/tools/builtin_tool/providers/webscraper/webscraper.py

1.3.3. *Tool

BuiltinToolProviderController 加载诸工具

2. BaseAgentRunner 调用流程

class CotAgentRunner(BaseAgentRunner, ABC):
    _is_first_iteration = True
    _ignore_observation_providers = ["wenxin"]
    _historic_prompt_messages: list[PromptMessage]
    _agent_scratchpad: list[AgentScratchpadUnit]
    _instruction: str
    _query: str
    _prompt_messages_tools: Sequence[PromptMessageTool]

    def run(
        self,
        message: Message,
        query: str,
        inputs: Mapping[str, str],
    ) -> Generator:
        """
        Run Cot agent application
        """

        app_generate_entity = self.application_generate_entity
        self._repack_app_generate_entity(app_generate_entity)
        self._init_react_state(query)

        trace_manager = app_generate_entity.trace_manager

        # check model mode
        if "Observation" not in app_generate_entity.model_conf.stop:
            if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
                app_generate_entity.model_conf.stop.append("Observation")

        app_config = self.app_config
        assert app_config.agent

        # init instruction
        inputs = inputs or {}
        instruction = app_config.prompt_template.simple_prompt_template or ""
        self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)

        iteration_step = 1
        max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1

        # convert tools into ModelRuntime Tool format
        # 初始化工具集
        # _init_prompt_tools 方法整合所有可用工具(应用配置的工具 + 数据集工具),
        # 生成 tool_instances(工具实例,用于实际调用)和 prompt_messages_tools(模型可识别的工具列表)。
        tool_instances, prompt_messages_tools = self._init_prompt_tools()
        self._prompt_messages_tools = prompt_messages_tools

        function_call_state = True
        llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
        final_answer = ""

        def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
            ......

        model_instance = self.model_instance

        while function_call_state and iteration_step <= max_iteration_steps:
            # continue to run until there is not any tool call
            function_call_state = False

            if iteration_step == max_iteration_steps:
                # the last iteration, remove all tools
                self._prompt_messages_tools = []

            message_file_ids: list[str] = []

            # 1. 创建本轮思考记录(数据库中保存)
            agent_thought_id = self.create_agent_thought(
                message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
            )

            if iteration_step > 1:
                self.queue_manager.publish(
                    QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
                )

            # recalc llm max tokens
            # 2. 调用LLM生成思考或工具调用指令
            prompt_messages = self._organize_prompt_messages()
            self.recalc_llm_max_tokens(self.model_config, prompt_messages)
            # invoke model
            # 模型输出
            chunks = model_instance.invoke_llm(
                prompt_messages=prompt_messages,
                model_parameters=app_generate_entity.model_conf.parameters,
                tools=[],
                stop=app_generate_entity.model_conf.stop,
                stream=True,
                user=self.user_id,
                callbacks=[],
            )

            usage_dict: dict[str, Optional[LLMUsage]] = {}
            # 3. 解析LLM输出,提取思考和工具调用指令
            # 记录本轮思考的细节(思考内容、行动指令等)
            react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
            scratchpad = AgentScratchpadUnit(
                agent_response="",
                thought="",
                action_str="",
                observation="",
                action=None,
            )

            # publish agent thought if it's first iteration
            if iteration_step == 1:
                self.queue_manager.publish(
                    QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
                )

            for chunk in react_chunks:
                if isinstance(chunk, AgentScratchpadUnit.Action): # 若输出是工具调用指令
                    action = chunk
                    # detect action
                    assert scratchpad.agent_response is not None
                    scratchpad.agent_response += json.dumps(chunk.model_dump())
                    scratchpad.action_str = json.dumps(chunk.model_dump())
                    scratchpad.action = action
                else: # 若输出是自然语言思考
                    assert scratchpad.agent_response is not None
                    scratchpad.agent_response += chunk
                    assert scratchpad.thought is not None
                    scratchpad.thought += chunk
                    yield LLMResultChunk(
                        model=self.model_config.model,
                        prompt_messages=prompt_messages,
                        system_fingerprint="",
                        delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
                    )

            assert scratchpad.thought is not None
            scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
            self._agent_scratchpad.append(scratchpad)

            # get llm usage
            ......

            self.save_agent_thought(......)

            # 4. 判断是否需要调用工具
            if not scratchpad.is_final():
                self.queue_manager.publish(
                    QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
                )

            if not scratchpad.action:
                # failed to extract action, return final answer directly
                final_answer = ""
            else:
                # 若代理决定输出最终答案
                if scratchpad.action.action_name.lower() == "final answer":
                    # action is final answer, return final answer directly
                    try:
                        if isinstance(scratchpad.action.action_input, dict):
                            final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
                        elif isinstance(scratchpad.action.action_input, str):
                            final_answer = scratchpad.action.action_input
                        else:
                            final_answer = f"{scratchpad.action.action_input}"
                    except TypeError:
                        final_answer = f"{scratchpad.action.action_input}"
                # 若代理决定调用工具
                else:
                    # 继续下一轮迭代
                    function_call_state = True
                    # action is tool call, invoke tool
                    # 调用工具并获取结果
                    tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
                        action=scratchpad.action,
                        tool_instances=tool_instances,
                        message_file_ids=message_file_ids,
                        trace_manager=trace_manager,
                    )
                    # 记录工具返回结果
                    scratchpad.observation = tool_invoke_response
                    scratchpad.agent_response = tool_invoke_response

                    # 使用了 message_file_ids,这个参数是工具返回的
                    self.save_agent_thought(
                        agent_thought_id=agent_thought_id,
                        tool_name=scratchpad.action.action_name,
                        tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
                        thought=scratchpad.thought or "",
                        observation={scratchpad.action.action_name: tool_invoke_response},
                        tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
                        answer=scratchpad.agent_response,
                        messages_ids=message_file_ids,
                        llm_usage=usage_dict["usage"],
                    )

                    self.queue_manager.publish(
                        QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
                    )

                # update prompt tool message
                for prompt_tool in self._prompt_messages_tools:
                    self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)

            iteration_step += 1

        yield LLMResultChunk(
            model=model_instance.model,
            prompt_messages=prompt_messages,
            delta=LLMResultChunkDelta(
                index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
            ),
            system_fingerprint="",
        )

        # save agent thought
        self.save_agent_thought(......)
        # publish end event
        self.queue_manager.publish(
            QueueMessageEndEvent(
                llm_result=LLMResult(
                    model=model_instance.model,
                    prompt_messages=prompt_messages,
                    message=AssistantPromptMessage(content=final_answer),
                    usage=llm_usage["usage"] or LLMUsage.empty_usage(),
                    system_fingerprint="",
                )
            ),
            PublishFrom.APPLICATION_MANAGER,
        )

2.1. _init_prompt_tools(初始化并转换工具信息)

class BaseAgentRunner(AppRunner):
    ......
    def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
        """
        Init tools
        """
        tool_instances = {}
        prompt_messages_tools = []

        # 将所有工具转换成 提示消息工具
        for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
            try:
                prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
            except Exception:
                # api tool may be deleted
                continue
            # save tool entity
            tool_instances[tool.tool_name] = tool_entity
            # save prompt tool
            prompt_messages_tools.append(prompt_tool)

        # convert dataset tools into ModelRuntime Tool format
        # 数据集工具
        for dataset_tool in self.dataset_tools:
            prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
            # save prompt tool
            prompt_messages_tools.append(prompt_tool)
            # save tool entity
            tool_instances[dataset_tool.entity.identity.name] = dataset_tool

        return tool_instances, prompt_messages_tools
    ......

2.1.1. _convert_tool_to_prompt_message_tool

get_agent_tool_runtime 智能体工具加载
并合并工具参数

class BaseAgentRunner(AppRunner):
    ......
    def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
        """
        convert tool to prompt message tool
        """
        # 转换工具 为 提示词信息工具
        # 此处是智能体工具调用的核心函数
        tool_entity = ToolManager.get_agent_tool_runtime(
            tenant_id=self.tenant_id,
            app_id=self.app_config.app_id,
            agent_tool=tool,
            invoke_from=self.application_generate_entity.invoke_from,
        )
        assert tool_entity.entity.description
        message_tool = PromptMessageTool(
            name=tool.tool_name,
            description=tool_entity.entity.description.llm,
            parameters={
                "type": "object",
                "properties": {},
                "required": [],
            },
        )

        # 合并工具参数
        parameters = tool_entity.get_merged_runtime_parameters()
        for parameter in parameters:
            if parameter.form != ToolParameter.ToolParameterForm.LLM:
                continue

            parameter_type = parameter.type.as_normal_type()
            if parameter.type in {
                ToolParameter.ToolParameterType.SYSTEM_FILES,
                ToolParameter.ToolParameterType.FILE,
                ToolParameter.ToolParameterType.FILES,
            }:
                continue
            enum = []
            if parameter.type == ToolParameter.ToolParameterType.SELECT:
                enum = [option.value for option in parameter.options] if parameter.options else []

            message_tool.parameters["properties"][parameter.name] = (
                {
                    "type": parameter_type,
                    "description": parameter.llm_description or "",
                }
                if parameter.input_schema is None
                else parameter.input_schema
            )

            if len(enum) > 0:
                message_tool.parameters["properties"][parameter.name]["enum"] = enum

            if parameter.required:
                message_tool.parameters["required"].append(parameter.name)

        # 返回 两个数据
        return message_tool, tool_entity
    ......
  • PromptMessageTool 实际为 json
    def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
        ......
        message_tool = PromptMessageTool(
            name=tool.tool_name,
            description=tool_entity.entity.description.llm,
            parameters={
                "type": "object",
                "properties": {},
                "required": [],
            },
        )
        
# 工具名称(name)、描述(description)—— 告诉模型 “这个工具能做什么”;
# 参数定义(parameters)—— 告诉模型 “调用时需要传哪些参数”(如查询的时间范围 start_date、end_date)。
# name="dataset_search",
# description="查询指定时间范围的产品销量数据", 
# parameters={
#    "properties": {
#        "start_date": {
#            "type": "string",
#            "description": "开始日期,格式YYYY-MM-DD"
#        },
#        "end_date": {
#            "type": "string",
#            "description": "结束日期,格式YYYY-MM-DD"
#        }
#    },
#    "required": [
#        "start_date",
#        "end_date"
#    ]
# }
2.1.1.1. get_agent_tool_runtime(智能体工具逻辑)

core/tools/tool_manager.py

class ToolManager:
    ......
    # 智能体工具
    @classmethod
    def get_agent_tool_runtime(
        cls,
        tenant_id: str,
        app_id: str,
        agent_tool: AgentToolEntity, # 工作流该参数不同
        invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
        variable_pool: Optional[VariablePool] = None,
    ) -> Tool:
        """
        get the agent tool runtime
        """
        # 获取代理工具运行时
        tool_entity = cls.get_tool_runtime(
            provider_type=agent_tool.provider_type,
            provider_id=agent_tool.provider_id,
            tool_name=agent_tool.tool_name,
            tenant_id=tenant_id,
            invoke_from=invoke_from,
            tool_invoke_from=ToolInvokeFrom.AGENT,
            credential_id=agent_tool.credential_id,
        )
        runtime_parameters = {}
        # 
        parameters = tool_entity.get_merged_runtime_parameters()
        # 
        runtime_parameters = cls._convert_tool_parameters_type(
            parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
        )
        # decrypt runtime parameters
        encryption_manager = ToolParameterConfigurationManager(
            tenant_id=tenant_id,
            tool_runtime=tool_entity,
            provider_name=agent_tool.provider_id,
            provider_type=agent_tool.provider_type,
            identity_id=f"AGENT.{app_id}",
        )
        runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
        # 工作流没有判断
        if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
            raise ValueError("runtime not found or runtime parameters not found")

        tool_entity.runtime.runtime_parameters.update(runtime_parameters)
        return tool_entity
        
    # 工作流工具
    @classmethod
    def get_workflow_tool_runtime(......)
    ......
2.1.1.2. get_tool_runtime(工具加载,区分类型)
  • 内置工具
  • 接口工具
  • 工作流
  • APP (直接抛出异常了)
  • 插件
  • MCP
    在这里插入图片描述
    core/tools/tool_manager.py
class ToolManager:
    ......
    @classmethod
    def get_tool_runtime(
        cls,
        provider_type: ToolProviderType,
        provider_id: str,
        tool_name: str,
        tenant_id: str,
        invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
        tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
        credential_id: Optional[str] = None,
    ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
        """
        get the tool runtime

        :param provider_type: the type of the provider
        :param provider_id: the id of the provider
        :param tool_name: the name of the tool
        :param tenant_id: the tenant id
        :param invoke_from: invoke from
        :param tool_invoke_from: the tool invoke from
        :param credential_id: the credential id

        :return: the tool
        """
        # 内置工具
        if provider_type == ToolProviderType.BUILT_IN:
            # check if the builtin tool need credentials
            # 检查内置工具是否需要凭据(根据工具id 和 用户id)
            provider_controller = cls.get_builtin_provider(provider_id, tenant_id)

            # 根据工具名返回工具
            builtin_tool = provider_controller.get_tool(tool_name)
            if not builtin_tool:
                raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")

            # 不需要验证 直接返回
            if not provider_controller.need_credentials:
                return cast(
                    BuiltinTool,
                    builtin_tool.fork_tool_runtime(
                        runtime=ToolRuntime(
                            tenant_id=tenant_id,
                            credentials={},
                            invoke_from=invoke_from,
                            tool_invoke_from=tool_invoke_from,
                        )
                    ),
                )

            # 验证工具
            builtin_provider = None
            if isinstance(provider_controller, PluginToolProviderController):
                provider_id_entity = ToolProviderID(provider_id)
                # get specific credentials
                if is_valid_uuid(credential_id):
                    try:
                        builtin_provider_stmt = select(BuiltinToolProvider).where(
                            BuiltinToolProvider.tenant_id == tenant_id,
                            BuiltinToolProvider.id == credential_id,
                        )
                        builtin_provider = db.session.scalar(builtin_provider_stmt)
                    except Exception as e:
                        builtin_provider = None
                        logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
                    # if the provider has been deleted, raise an error
                    if builtin_provider is None:
                        raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")

                # fallback to the default provider
                # 回退到默认提供程序
                if builtin_provider is None:
                    # use the default provider
                    builtin_provider = (
                        db.session.query(BuiltinToolProvider)
                        .where(
                            BuiltinToolProvider.tenant_id == tenant_id,
                            (BuiltinToolProvider.provider == str(provider_id_entity))
                            | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
                        )
                        .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
                        .first()
                    )
                    if builtin_provider is None:
                        raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
            else:
                builtin_provider = (
                    db.session.query(BuiltinToolProvider)
                    .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
                    .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
                    .first()
                )

                if builtin_provider is None:
                    raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")

            encrypter, cache = create_provider_encrypter(
                tenant_id=tenant_id,
                config=[
                    x.to_basic_provider_config()
                    for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
                ],
                cache=ToolProviderCredentialsCache(
                    tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
                ),
            )

            # decrypt the credentials
            decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)

            # check if the credentials is expired
            # 检查凭据是否已过期
            if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
                # TODO: circular import
                from services.tools.builtin_tools_manage_service import BuiltinToolManageService

                # refresh the credentials
                # 刷新凭据
                tool_provider = ToolProviderID(provider_id)
                provider_name = tool_provider.provider_name
                redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
                system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
                oauth_handler = OAuthHandler()
                # refresh the credentials
                # 刷新凭据
                refreshed_credentials = oauth_handler.refresh_credentials(
                    tenant_id=tenant_id,
                    user_id=builtin_provider.user_id,
                    plugin_id=tool_provider.plugin_id,
                    provider=provider_name,
                    redirect_uri=redirect_uri,
                    system_credentials=system_credentials or {},
                    credentials=decrypted_credentials,
                )
                # update the credentials
                # 更新凭据
                builtin_provider.encrypted_credentials = (
                    TypeAdapter(dict[str, Any])
                    .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
                    .decode("utf-8")
                )
                builtin_provider.expires_at = refreshed_credentials.expires_at
                db.session.commit()
                decrypted_credentials = refreshed_credentials.credentials
                cache.delete()

            # 返回工具信息
            return cast(
                BuiltinTool,
                builtin_tool.fork_tool_runtime(
                    runtime=ToolRuntime(
                        tenant_id=tenant_id,
                        credentials=dict(decrypted_credentials),
                        credential_type=CredentialType.of(builtin_provider.credential_type),
                        runtime_parameters={},
                        invoke_from=invoke_from,
                        tool_invoke_from=tool_invoke_from,
                    )
                ),
            )
        # 接口工具
        elif provider_type == ToolProviderType.API:
            api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
            encrypter, _ = create_tool_provider_encrypter(
                tenant_id=tenant_id,
                controller=api_provider,
            )
            return api_provider.get_tool(tool_name).fork_tool_runtime(
                runtime=ToolRuntime(
                    tenant_id=tenant_id,
                    credentials=encrypter.decrypt(credentials),
                    invoke_from=invoke_from,
                    tool_invoke_from=tool_invoke_from,
                )
            )
        # 工作流
        elif provider_type == ToolProviderType.WORKFLOW:
            workflow_provider_stmt = select(WorkflowToolProvider).where(
                WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
            )
            workflow_provider = db.session.scalar(workflow_provider_stmt)

            if workflow_provider is None:
                raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

            controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
            controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
            if controller_tools is None or len(controller_tools) == 0:
                raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

            return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
                runtime=ToolRuntime(
                    tenant_id=tenant_id,
                    credentials={},
                    invoke_from=invoke_from,
                    tool_invoke_from=tool_invoke_from,
                )
            )
        # 
        elif provider_type == ToolProviderType.APP:
            raise NotImplementedError("app provider not implemented")
        # 插件工具
        elif provider_type == ToolProviderType.PLUGIN:
            return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
        # MCP 工具
        elif provider_type == ToolProviderType.MCP:
            return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
        else:
            raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
    ......
2.1.1.3. get_merged_runtime_parameters(尚未分析)
后续分析

2.1.2. _convert_dataset_retriever_tool_to_prompt_message_tool

后续分析

2.2. handle_react_stream_output(模型返回分析)

2.3. _handle_invoke_action(调用工具)

class CotAgentRunner(BaseAgentRunner, ABC):
    _is_first_iteration = True
    _ignore_observation_providers = ["wenxin"]
    _historic_prompt_messages: list[PromptMessage]
    _agent_scratchpad: list[AgentScratchpadUnit]
    _instruction: str
    _query: str
    _prompt_messages_tools: Sequence[PromptMessageTool]

    def run(......) -> Generator:
        """
        Run Cot agent application
        """
        ......
        
    def _handle_invoke_action(
        self,
        action: AgentScratchpadUnit.Action,
        tool_instances: Mapping[str, Tool],
        message_file_ids: list[str],
        trace_manager: Optional[TraceQueueManager] = None,
    ) -> tuple[str, ToolInvokeMeta]:
        """
        handle invoke action
        :param action: action
        :param tool_instances: tool instances
        :param message_file_ids: message file ids
        :param trace_manager: trace manager
        :return: observation, meta
        """
        # action is tool call, invoke tool
        # 获取工具实例
        tool_call_name = action.action_name
        tool_call_args = action.action_input
        tool_instance = tool_instances.get(tool_call_name)

        if not tool_instance:
            answer = f"there is not a tool named {tool_call_name}"
            return answer, ToolInvokeMeta.error_instance(answer)

        # 解析工具参数(如将JSON字符串转为字典)
        if isinstance(tool_call_args, str):
            try:
                tool_call_args = json.loads(tool_call_args)
            except json.JSONDecodeError:
                pass

        # invoke tool
        # 调用工具
        tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
            tool=tool_instance,
            tool_parameters=tool_call_args,
            user_id=self.user_id,
            tenant_id=self.tenant_id,
            message=self.message,
            invoke_from=self.application_generate_entity.invoke_from,
            agent_tool_callback=self.agent_callback,
            trace_manager=trace_manager,
        )

        # publish files
        # 发布工具返回的文件(如图表)
        for message_file_id in message_files:
            # publish message file
            self.queue_manager.publish(
                QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
            )
            # add message file ids
            message_file_ids.append(message_file_id)

        return tool_invoke_response, tool_invoke_meta

2.3.1. ToolEngine.agent_invoke(工具执行)

/api/core/tools/tool_engine.py

class ToolEngine:
    """
    Tool runtime engine take care of the tool executions.
    """

    @staticmethod
    def agent_invoke(
        tool: Tool,
        tool_parameters: Union[str, dict],
        user_id: str,
        tenant_id: str,
        message: Message,
        invoke_from: InvokeFrom,
        agent_tool_callback: DifyAgentCallbackHandler,
        trace_manager: Optional[TraceQueueManager] = None,
        conversation_id: Optional[str] = None,
        app_id: Optional[str] = None,
        message_id: Optional[str] = None,
    ) -> tuple[str, list[str], ToolInvokeMeta]:
        """
        Agent invokes the tool with the given arguments.
        """
        # check if arguments is a string
        # 检查参数
        if isinstance(tool_parameters, str):
            # check if this tool has only one parameter
            parameters = [
                parameter
                for parameter in tool.get_runtime_parameters()
                if parameter.form == ToolParameter.ToolParameterForm.LLM
            ]
            if parameters and len(parameters) == 1:
                tool_parameters = {parameters[0].name: tool_parameters}
            else:
                with contextlib.suppress(Exception):
                    tool_parameters = json.loads(tool_parameters)
                if not isinstance(tool_parameters, dict):
                    raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")

        try:
            # hit the callback handler
            # 点击回调处理程序(开始)
            agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)

            # 工具执行
            messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
            invocation_meta_dict: dict[str, ToolInvokeMeta] = {}

            # 
            def message_callback(
                invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
            ):
                for message in messages:
                    if isinstance(message, ToolInvokeMeta):
                        invocation_meta_dict["meta"] = message
                    else:
                        yield message

            # 解析工具返回数据
            messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
                messages=message_callback(invocation_meta_dict, messages),
                user_id=user_id,
                tenant_id=tenant_id,
                conversation_id=message.conversation_id,
            )

            message_list = list(messages)

            # extract binary data from tool invoke message
            # 从工具调用消息中提取二进制数据
            binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
            # create message file
            # 创建文件
            message_files = ToolEngine._create_message_files(
                tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
            )

            # 
            plain_text = ToolEngine._convert_tool_response_to_str(message_list)

            meta = invocation_meta_dict["meta"]

            # hit the callback handler
            # 点击回调处理程序(结束)
            agent_tool_callback.on_tool_end(
                tool_name=tool.entity.identity.name,
                tool_inputs=tool_parameters,
                tool_outputs=plain_text,
                message_id=message.id,
                trace_manager=trace_manager,
            )

            # transform tool invoke message to get LLM friendly message
            # 把工具调用消息 转换成 对LLM友好的消息,返回
            return plain_text, message_files, meta
        except ToolProviderCredentialValidationError as e:
            ......
        except Exception as e:
            error_response = f"unknown error: {e}"
            agent_tool_callback.on_tool_error(e)

        return error_response, [], ToolInvokeMeta.error_instance(error_response)

    @staticmethod
    def _invoke(
        tool: Tool,
        tool_parameters: dict,
        user_id: str,
        conversation_id: Optional[str] = None,
        app_id: Optional[str] = None,
        message_id: Optional[str] = None,
    ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
        """
        Invoke the tool with the given arguments.
        """
        started_at = datetime.now(UTC)
        meta = ToolInvokeMeta(
            time_cost=0.0,
            error=None,
            tool_config={
                "tool_name": tool.entity.identity.name,
                "tool_provider": tool.entity.identity.provider,
                "tool_provider_type": tool.tool_provider_type().value,
                "tool_parameters": deepcopy(tool.runtime.runtime_parameters),
                "tool_icon": tool.entity.identity.icon,
            },
        )
        try:
            # 工具基类的核心函数
            yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
        except Exception as e:
            meta.error = str(e)
            raise ToolEngineInvokeError(meta)
        finally:
            ended_at = datetime.now(UTC)
            meta.time_cost = (ended_at - started_at).total_seconds()
            yield meta
2.3.1.1. on_tool_start
2.3.1.2. tool.invoke(调用子类的重载函数)
class Tool(ABC):
    """
    The base class of a tool
    """
    ......
    def invoke(
        self,
        user_id: str,
        tool_parameters: dict[str, Any],
        conversation_id: Optional[str] = None,
        app_id: Optional[str] = None,
        message_id: Optional[str] = None,
    ) -> Generator[ToolInvokeMessage]:
        if self.runtime and self.runtime.runtime_parameters:
            tool_parameters.update(self.runtime.runtime_parameters)

        # try parse tool parameters into the correct type
        tool_parameters = self._transform_tool_parameters_type(tool_parameters)

        # 调用子类的重载函数
        result = self._invoke(
            user_id=user_id,
            tool_parameters=tool_parameters,
            conversation_id=conversation_id,
            app_id=app_id,
            message_id=message_id,
        )

        if isinstance(result, ToolInvokeMessage):

            def single_generator() -> Generator[ToolInvokeMessage, None, None]:
                yield result

            return single_generator()
        elif isinstance(result, list):

            def generator() -> Generator[ToolInvokeMessage, None, None]:
                yield from result

            return generator()
        else:
            return result
2.3.1.3. transform_tool_invoke_messages

解析工具返回的数据
文件、图片还支持下载。
此处先默认解析为json数据。

2.3.1.4. ToolEngine._extract_tool_response_binary_and_text

从工具调用消息中提取二进制数据

2.3.1.5. ToolEngine._create_message_files

创建文件

2.3.1.6. ToolEngine._convert_tool_response_to_str
    @staticmethod
    def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
        """
        Handle tool response
        """
        result = ""
        for response in tool_response:
            if response.type == ToolInvokeMessage.MessageType.TEXT:
                result += cast(ToolInvokeMessage.TextMessage, response.message).text
            elif response.type == ToolInvokeMessage.MessageType.LINK:
                result += (
                    f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
                    + " please tell user to check it."
                )
            elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
                result += (
                    "image has been created and sent to user already, "
                    + "you do not need to create it, just tell the user to check it now."
                )
            elif response.type == ToolInvokeMessage.MessageType.JSON:
                result += json.dumps(
                    safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
                    ensure_ascii=False,
                )
            else:
                result += str(response.message)

        return result
2.3.1.7. on_tool_start
2.3.1.8. on_tool_error

3. 工具初始化

3.1. 内置工具(初始化)

class ToolManager:
    ......
    @classmethod
    def get_tool_runtime(
        cls,
        provider_type: ToolProviderType,
        provider_id: str,
        tool_name: str,
        tenant_id: str,
        invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
        tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
        credential_id: Optional[str] = None,
    ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
        """
        get the tool runtime

        :param provider_type: the type of the provider
        :param provider_id: the id of the provider
        :param tool_name: the name of the tool
        :param tenant_id: the tenant id
        :param invoke_from: invoke from
        :param tool_invoke_from: the tool invoke from
        :param credential_id: the credential id

        :return: the tool
        """
        # 内置工具
        if provider_type == ToolProviderType.BUILT_IN:
            # check if the builtin tool need credentials
            # 检查内置工具是否需要凭据(根据工具id 和 用户id)
            provider_controller = cls.get_builtin_provider(provider_id, tenant_id)

            # 根据工具名返回工具
            builtin_tool = provider_controller.get_tool(tool_name)
            if not builtin_tool:
                raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")

            # 不需要验证 直接返回
            if not provider_controller.need_credentials:
                return cast(
                    BuiltinTool,
                    builtin_tool.fork_tool_runtime(
                        runtime=ToolRuntime(
                            tenant_id=tenant_id,
                            credentials={},
                            invoke_from=invoke_from,
                            tool_invoke_from=tool_invoke_from,
                        )
                    ),
                )

            # 验证工具
            builtin_provider = None
            if isinstance(provider_controller, PluginToolProviderController):
                provider_id_entity = ToolProviderID(provider_id)
                # get specific credentials
                if is_valid_uuid(credential_id):
                    try:
                        builtin_provider_stmt = select(BuiltinToolProvider).where(
                            BuiltinToolProvider.tenant_id == tenant_id,
                            BuiltinToolProvider.id == credential_id,
                        )
                        builtin_provider = db.session.scalar(builtin_provider_stmt)
                    except Exception as e:
                        builtin_provider = None
                        logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
                    # if the provider has been deleted, raise an error
                    if builtin_provider is None:
                        raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")

                # fallback to the default provider
                # 回退到默认提供程序
                if builtin_provider is None:
                    # use the default provider
                    builtin_provider = (
                        db.session.query(BuiltinToolProvider)
                        .where(
                            BuiltinToolProvider.tenant_id == tenant_id,
                            (BuiltinToolProvider.provider == str(provider_id_entity))
                            | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
                        )
                        .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
                        .first()
                    )
                    if builtin_provider is None:
                        raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
            else:
                builtin_provider = (
                    db.session.query(BuiltinToolProvider)
                    .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
                    .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
                    .first()
                )

                if builtin_provider is None:
                    raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")

            encrypter, cache = create_provider_encrypter(
                tenant_id=tenant_id,
                config=[
                    x.to_basic_provider_config()
                    for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
                ],
                cache=ToolProviderCredentialsCache(
                    tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
                ),
            )

            # decrypt the credentials
            decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)

            # check if the credentials is expired
            # 检查凭据是否已过期
            if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
                # TODO: circular import
                from services.tools.builtin_tools_manage_service import BuiltinToolManageService

                # refresh the credentials
                # 刷新凭据
                tool_provider = ToolProviderID(provider_id)
                provider_name = tool_provider.provider_name
                redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
                system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
                oauth_handler = OAuthHandler()
                # refresh the credentials
                # 刷新凭据
                refreshed_credentials = oauth_handler.refresh_credentials(
                    tenant_id=tenant_id,
                    user_id=builtin_provider.user_id,
                    plugin_id=tool_provider.plugin_id,
                    provider=provider_name,
                    redirect_uri=redirect_uri,
                    system_credentials=system_credentials or {},
                    credentials=decrypted_credentials,
                )
                # update the credentials
                # 更新凭据
                builtin_provider.encrypted_credentials = (
                    TypeAdapter(dict[str, Any])
                    .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
                    .decode("utf-8")
                )
                builtin_provider.expires_at = refreshed_credentials.expires_at
                db.session.commit()
                decrypted_credentials = refreshed_credentials.credentials
                cache.delete()

            # 返回工具信息
            return cast(
                BuiltinTool,
                builtin_tool.fork_tool_runtime(
                    runtime=ToolRuntime(
                        tenant_id=tenant_id,
                        credentials=dict(decrypted_credentials),
                        credential_type=CredentialType.of(builtin_provider.credential_type),
                        runtime_parameters={},
                        invoke_from=invoke_from,
                        tool_invoke_from=tool_invoke_from,
                    )
                ),
            )
        # 接口工具
        elif provider_type == ToolProviderType.API:
            ......
        # 工作流
        elif provider_type == ToolProviderType.WORKFLOW:
            ......
        # APP
        elif provider_type == ToolProviderType.APP:
            raise NotImplementedError("app provider not implemented")
        # 插件工具
        elif provider_type == ToolProviderType.PLUGIN:
            return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
        # MCP 工具
        elif provider_type == ToolProviderType.MCP:
            return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
        else:
            raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
    ......

3.1.1. get_builtin_provider (获取工具信息)

class ToolManager:
    _builtin_provider_lock = Lock()
    _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}		# 存储信息
    _builtin_providers_loaded = False
    _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
    ......
    # (根据工具id 和 用户id)获取工具信息
    @classmethod
    def get_builtin_provider(
        cls, provider: str, tenant_id: str
    ) -> BuiltinToolProviderController | PluginToolProviderController:
        """
        get the builtin provider

        :param provider: the name of the provider
        :param tenant_id: the id of the tenant
        :return: the provider
        """
        # split provider to
        # 首次判断 cls._hardcoded_providers 是否为空。如果为空则通过加载 py 文件加载工具
        if len(cls._hardcoded_providers) == 0:
            # init the builtin providers
            # 加载硬编码提供者缓存
            cls.load_hardcoded_providers_cache()

        # 如果provider 不在硬编码提供者列表中 则查找软编码??
        # provider 此处为 provider_id
        if provider not in cls._hardcoded_providers:
            # get plugin provider
            plugin_provider = cls.get_plugin_provider(provider, tenant_id)
            if plugin_provider:
                return plugin_provider
        
        # 如果不存在则会异常吧??
        return cls._hardcoded_providers[provider]
    ......
3.1.1.1. load_hardcoded_providers_cache(通过硬编码加载)
class ToolManager:
    _builtin_provider_lock = Lock()
    _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
    _builtin_providers_loaded = False
    _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
    ......
    # 1. 加载硬编码缓存
    @classmethod
    def load_hardcoded_providers_cache(cls):
        for _ in cls.list_hardcoded_providers():
            pass

    # 2. 如果缓存存在则使用缓存,如果缓存不存在则使用py文件加载,并更新缓存
    @classmethod
    def list_hardcoded_providers(cls):
        # use cache first
        # 首先使用缓存 如果尚未加载
        if cls._builtin_providers_loaded:
            yield from list(cls._hardcoded_providers.values())
            return

        with cls._builtin_provider_lock:
            if cls._builtin_providers_loaded:
                yield from list(cls._hardcoded_providers.values())
                return
            # 通过文件路径加载
            yield from cls._list_hardcoded_providers()

    # 3. 通过本地py文件加载
    @classmethod
    def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
        """
        list all the builtin providers
        """
        # 列举所有本地文件中的
        # api/core/tools/builtin_tool/providers
        for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
            # 跳过
            if provider_path.startswith("__"):
                continue

            # 判断是否是文件夹
            if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
                # 跳过
                if provider_path.startswith("__"):
                    continue

                # init provider
                try:
                    # 使用 importlib 加载py文件
                    # 在/api/core/helper/module_import_helper.py
                    provider_class = load_single_subclass_from_source(
                        module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
                        script_path=path.join(
                            path.dirname(path.realpath(__file__)),
                            "builtin_tool",
                            "providers",
                            provider_path,
                            f"{provider_path}.py",
                        ),
                        parent_type=BuiltinToolProviderController,
                    )
                    # 内置工具基类
                    provider: BuiltinToolProviderController = provider_class()
                    cls._hardcoded_providers[provider.entity.identity.name] = provider
                    # 更新工具列表
                    for tool in provider.get_tools():
                        cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
                    yield provider

                except Exception:
                    logger.exception("load builtin provider %s", provider_path)
                    continue
        # set builtin providers loaded
        # 修改加载标志
        cls._builtin_providers_loaded = True
    ......

在这里插入图片描述

3.1.1.1.1. load_single_subclass_from_source(从文件中加载)

从文件中加载,这个需要详细分析

3.1.1.1.2. BuiltinToolProviderController 基类

/api/core/tools/builtin_tool/provider.py

class BuiltinToolProviderController(ToolProviderController):
    tools: list[BuiltinTool]

    def __init__(self, **data: Any) -> None:
        self.tools = []

        # load provider yaml
        # 加载 yaml 文件
        provider = self.__class__.__module__.split(".")[-1]
        yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
        try:
            provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
        except Exception as e:
            raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")

        #### 以下 凭证 和 auth 基本都是空的。暂时忽略
        if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None:
            # set credentials name
            # 设置凭据名称
            for credential_name in provider_yaml["credentials_for_provider"]:
                provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name

        credentials_schema = []
        for credential in provider_yaml.get("credentials_for_provider", {}):
            credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
            credentials_schema.append(credential_dict)

        oauth_schema = None
        if provider_yaml.get("oauth_schema", None) is not None:
            oauth_schema = OAuthSchema(
                client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
                credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
            )
        #### 以上 凭证 和 auth 基本都是空的。暂时忽略

        super().__init__(
            entity=ToolProviderEntity(
                identity=provider_yaml["identity"],
                credentials_schema=credentials_schema, # 暂定为空
                oauth_schema=oauth_schema, # 暂定为空
            ),
        )

        self._load_tools()

    # provider 下还有多个工具,需要逐一解析 yaml 文件加载进来
    def _load_tools(self):
        provider = self.entity.identity.name
        tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
        # get all the yaml files in the tool path
        tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
        tools = []
        for tool_file in tool_files:
            # get tool name
            tool_name = tool_file.split(".")[0]
            tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)

            # get tool class, import the module
            assistant_tool_class: type = load_single_subclass_from_source(
                module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
                script_path=path.join(
                    path.dirname(path.realpath(__file__)),
                    "builtin_tool",
                    "providers",
                    provider,
                    "tools",
                    f"{tool_name}.py",
                ),
                parent_type=BuiltinTool,
            )
            tool["identity"]["provider"] = provider
            tools.append(
                assistant_tool_class(
                    provider=provider,
                    entity=ToolEntity(**tool),
                    runtime=ToolRuntime(tenant_id=""),
                )
            )

        self.tools = tools
    ......

此处也用到了 load_single_subclass_from_source 。
在这里插入图片描述

3.1.1.2. get_plugin_provider(插件工具管理)
class ToolManager:
    _builtin_provider_lock = Lock()
    _hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
    _builtin_providers_loaded = False
    _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
    ......
    @classmethod
    def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
        """
        get the plugin provider
        """
        # check if context is set
        try:
            contexts.plugin_tool_providers.get()
        except LookupError:
            contexts.plugin_tool_providers.set({})
            contexts.plugin_tool_providers_lock.set(Lock())

        plugin_tool_providers = contexts.plugin_tool_providers.get()
        if provider in plugin_tool_providers:
            return plugin_tool_providers[provider]

        with contexts.plugin_tool_providers_lock.get():
            # double check
            plugin_tool_providers = contexts.plugin_tool_providers.get()
            if provider in plugin_tool_providers:
                return plugin_tool_providers[provider]

            # 其他的没看懂,这里是重点
            manager = PluginToolManager()
            provider_entity = manager.fetch_tool_provider(tenant_id, provider)
            if not provider_entity:
                raise ToolProviderNotFoundError(f"plugin provider {provider} not found")

            controller = PluginToolProviderController(
                entity=provider_entity.declaration,
                plugin_id=provider_entity.plugin_id,
                plugin_unique_identifier=provider_entity.plugin_unique_identifier,
                tenant_id=tenant_id,
            )

            plugin_tool_providers[provider] = controller
            return controller
    ......
  • contexts 封装(有点看不懂)

/api/contexts/init.py

from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING

from contexts.wrapper import RecyclableContextVar

if TYPE_CHECKING:
    from core.model_runtime.entities.model_entities import AIModelEntity
    from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
    from core.tools.plugin_tool.provider import PluginToolProviderController
    from core.workflow.entities.variable_pool import VariablePool


"""
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
    ContextVar("plugin_tool_providers")
)

plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))

plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
    ContextVar("plugin_model_providers")
)

plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
    ContextVar("plugin_model_providers_lock")
)

plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))

plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
    ContextVar("plugin_model_schemas")
)
  • RecyclableContextVar

/api/contexts/wrapper.py

from contextvars import ContextVar
from typing import Generic, TypeVar

T = TypeVar("T")


class HiddenValue:
    pass


_default = HiddenValue()


class RecyclableContextVar(Generic[T]):
    """
    RecyclableContextVar is a wrapper around ContextVar
    It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

    NOTE: you need to call `increment_thread_recycles` before requests
    """

    _thread_recycles: ContextVar[int] = ContextVar("thread_recycles")

    @classmethod
    def increment_thread_recycles(cls):
        try:
            recycles = cls._thread_recycles.get()
            cls._thread_recycles.set(recycles + 1)
        except LookupError:
            cls._thread_recycles.set(0)

    def __init__(self, context_var: ContextVar[T]):
        self._context_var = context_var
        self._updates = ContextVar[int](context_var.name + "_updates", default=0)

    def get(self, default: T | HiddenValue = _default) -> T:
        thread_recycles = self._thread_recycles.get(0)
        self_updates = self._updates.get()
        if thread_recycles > self_updates:
            self._updates.set(thread_recycles)

        # check if thread is recycled and should be updated
        if thread_recycles < self_updates:
            return self._context_var.get()
        else:
            # thread_recycles >= self_updates, means current context is invalid
            if isinstance(default, HiddenValue) or default is _default:
                raise LookupError
            else:
                return default

    def set(self, value: T):
        # it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
        # increase it manually
        thread_recycles = self._thread_recycles.get(0)
        self_updates = self._updates.get()
        if thread_recycles > self_updates:
            self._updates.set(thread_recycles)

        if self._updates.get() == self._thread_recycles.get(0):
            # after increment,
            self._updates.set(self._updates.get() + 1)

        # set the context
        self._context_var.set(value)

3.1.2. get_tool(根据工具名返回工具)

/api/core/tools/builtin_tool/provider.py

class BuiltinToolProviderController(ToolProviderController):
    tools: list[BuiltinTool]
    ......
    # 此函数加载 provider 提供的所有工具
    def _load_tools(self):
        provider = self.entity.identity.name
        tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
        # get all the yaml files in the tool path
        tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
        tools = []
        for tool_file in tool_files:
            # get tool name
            tool_name = tool_file.split(".")[0]
            tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)

            # get tool class, import the module
            assistant_tool_class: type = load_single_subclass_from_source(
                module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
                script_path=path.join(
                    path.dirname(path.realpath(__file__)),
                    "builtin_tool",
                    "providers",
                    provider,
                    "tools",
                    f"{tool_name}.py",
                ),
                parent_type=BuiltinTool,
            )

            # 使用 json 初始化 ToolEntity
            # 没有 json 初始化 ToolRuntime
            tool["identity"]["provider"] = provider
            tools.append(
                assistant_tool_class(
                    provider=provider,
                    entity=ToolEntity(**tool),
                    runtime=ToolRuntime(tenant_id=""),
                )
            )
        # 被下面函数调用
        self.tools = tools
        
    # 被下面函数调用
    def _get_builtin_tools(self) -> list[BuiltinTool]:
        """
        returns a list of tools that the provider can provide

        :return: list of tools
        """
        return self.tools

    # 返回 provider 提供的所有工具
    def get_tools(self) -> list[BuiltinTool]:
        """
        returns a list of tools that the provider can provide

        :return: list of tools
        """
        return self._get_builtin_tools()
    ......
3.1.2.1. ToolEntity(使用 json 初始化)

/api/core/tools/entities/tool_entities.py

# 被核心类使用
class ToolParameter(PluginParameter):
    """
    Overrides type
    """

    class ToolParameterType(enum.StrEnum):
        """
        removes TOOLS_SELECTOR from PluginParameterType
        """

        STRING = PluginParameterType.STRING.value
        NUMBER = PluginParameterType.NUMBER.value
        BOOLEAN = PluginParameterType.BOOLEAN.value
        SELECT = PluginParameterType.SELECT.value
        SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
        FILE = PluginParameterType.FILE.value
        FILES = PluginParameterType.FILES.value
        APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
        MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
        ANY = PluginParameterType.ANY.value
        DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value

        # MCP object and array type parameters
        ARRAY = MCPServerParameterType.ARRAY.value
        OBJECT = MCPServerParameterType.OBJECT.value

        # deprecated, should not use.
        SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value

        def as_normal_type(self):
            return as_normal_type(self)

        def cast_value(self, value: Any):
            return cast_parameter_value(self, value)

    class ToolParameterForm(Enum):
        SCHEMA = "schema"  # should be set while adding tool
        FORM = "form"  # should be set before invoking tool
        LLM = "llm"  # will be set by LLM

    type: ToolParameterType = Field(..., description="The type of the parameter")
    human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
    form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
    llm_description: Optional[str] = None
    # MCP object and array type parameters use this field to store the schema
    input_schema: Optional[dict] = None

    @classmethod
    def get_simple_instance(
        cls,
        name: str,
        llm_description: str,
        typ: ToolParameterType,
        required: bool,
        options: Optional[list[str]] = None,
    ) -> "ToolParameter":
        """
        get a simple tool parameter

        :param name: the name of the parameter
        :param llm_description: the description presented to the LLM
        :param typ: the type of the parameter
        :param required: if the parameter is required
        :param options: the options of the parameter
        """
        # convert options to ToolParameterOption
        if options:
            option_objs = [
                PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
                for option in options
            ]
        else:
            option_objs = []

        return cls(
            name=name,
            label=I18nObject(en_US="", zh_Hans=""),
            placeholder=None,
            human_description=I18nObject(en_US="", zh_Hans=""),
            type=typ,
            form=cls.ToolParameterForm.LLM,
            llm_description=llm_description,
            required=required,
            options=option_objs,
        )

    def init_frontend_parameter(self, value: Any):
        return init_frontend_parameter(self, self.type, value)
        
# 被核心类使用
class ToolIdentity(BaseModel):
    author: str = Field(..., description="The author of the tool")
    name: str = Field(..., description="The name of the tool")
    label: I18nObject = Field(..., description="The label of the tool")
    provider: str = Field(..., description="The provider of the tool")
    icon: Optional[str] = None

# 被核心类使用
class ToolDescription(BaseModel):
    human: I18nObject = Field(..., description="The description presented to the user")
    llm: str = Field(..., description="The description presented to the LLM")

# 核心类
class ToolEntity(BaseModel):
    identity: ToolIdentity
    parameters: list[ToolParameter] = Field(default_factory=list)
    description: Optional[ToolDescription] = None
    output_schema: Optional[dict] = None
    has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")

    # pydantic configs
    model_config = ConfigDict(protected_namespaces=())

    @field_validator("parameters", mode="before")
    @classmethod
    def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
        return v or []
3.1.2.2. ToolRuntime(传参初始化)

/api/core/tools/__base/tool_runtime.py

class ToolRuntime(BaseModel):
    """
    Meta data of a tool call processing
    """

    tenant_id: str
    tool_id: Optional[str] = None
    invoke_from: Optional[InvokeFrom] = None
    tool_invoke_from: Optional[ToolInvokeFrom] = None
    credentials: dict[str, Any] = Field(default_factory=dict)
    credential_type: CredentialType = Field(default=CredentialType.API_KEY)
    runtime_parameters: dict[str, Any] = Field(default_factory=dict)


class FakeToolRuntime(ToolRuntime):
    """
    Fake tool runtime for testing
    """

    def __init__(self):
        super().__init__(
            tenant_id="fake_tenant_id",
            tool_id="fake_tool_id",
            invoke_from=InvokeFrom.DEBUGGER,
            tool_invoke_from=ToolInvokeFrom.AGENT,
            credentials={},
            runtime_parameters={},
        )

3.1.3. 验证工具

目前看到的工具凭证基本都没有提供,后续再分析。

3.1.4. 返回工具(至此返回工具的 json 信息)

            ......
            # 不需要验证 直接返回
            if not provider_controller.need_credentials:
                return cast(
                    BuiltinTool,
                    builtin_tool.fork_tool_runtime(
                        runtime=ToolRuntime(
                            tenant_id=tenant_id,
                            credentials={},
                            invoke_from=invoke_from,
                            tool_invoke_from=tool_invoke_from,
                        )
                    ),
                )
            ......
            # 需要验证
            # 区别在于 credentials 是否为空
            return cast(
                BuiltinTool,
                builtin_tool.fork_tool_runtime(
                    runtime=ToolRuntime(
                        tenant_id=tenant_id,
                        credentials=dict(decrypted_credentials),
                        credential_type=CredentialType.of(builtin_provider.credential_type),
                        runtime_parameters={},
                        invoke_from=invoke_from,
                        tool_invoke_from=tool_invoke_from,
                    )
                ),
            )
            ......
3.1.4.1. fork_tool_runtime(返回 json)

entity 定义在父类,会在读取 yaml 文件,创建工具的时传入,实际为 pydantic 结构体。

class Tool(ABC):
    """
    The base class of a tool
    """
    # entity 在读取文件初始化工具的时候 已经提供。
    def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
        self.entity = entity
        self.runtime = runtime
    ......
class BuiltinTool(Tool):
    """
    Builtin tool

    :param meta: the meta data of a tool call processing
    """

    def __init__(self, provider: str, **kwargs):
        super().__init__(**kwargs)
        self.provider = provider

    def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
        """
        fork a new tool with metadata
        :return: the new tool
        """
        return self.__class__(
            entity=self.entity.model_copy(),
            runtime=runtime,
            provider=self.provider,
        )

3.2. 接口工具

class ToolManager:
    ......
    @classmethod
    def get_tool_runtime(
        cls,
        provider_type: ToolProviderType,
        provider_id: str,
        tool_name: str,
        tenant_id: str,
        invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
        tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
        credential_id: Optional[str] = None,
    ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
        """
        get the tool runtime

        :param provider_type: the type of the provider
        :param provider_id: the id of the provider
        :param tool_name: the name of the tool
        :param tenant_id: the tenant id
        :param invoke_from: invoke from
        :param tool_invoke_from: the tool invoke from
        :param credential_id: the credential id

        :return: the tool
        """
        # 内置工具
        if provider_type == ToolProviderType.BUILT_IN:
        ......
        # 接口工具
        elif provider_type == ToolProviderType.API:
            api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
            encrypter, _ = create_tool_provider_encrypter(
                tenant_id=tenant_id,
                controller=api_provider,
            )
            return api_provider.get_tool(tool_name).fork_tool_runtime(
                runtime=ToolRuntime(
                    tenant_id=tenant_id,
                    credentials=encrypter.decrypt(credentials),
                    invoke_from=invoke_from,
                    tool_invoke_from=tool_invoke_from,
                )
            )
        # 工作流
        elif provider_type == ToolProviderType.WORKFLOW:
            ......
        # APP
        elif provider_type == ToolProviderType.APP:
            raise NotImplementedError("app provider not implemented")
        # 插件工具
        elif provider_type == ToolProviderType.PLUGIN:
            return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
        # MCP 工具
        elif provider_type == ToolProviderType.MCP:
            return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
        else:
            raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
    ......

3.3. 工作流

3.4. APP(直接抛出异常)

3.5. 插件

3.6. MCP

Logo

智能硬件社区聚焦AI智能硬件技术生态,汇聚嵌入式AI、物联网硬件开发者,打造交流分享平台,同步全国赛事资讯、开展 OPC 核心人才招募,助力技术落地与开发者成长。

更多推荐