dify 源码分析(四)tools
以下是workflow 执行流程总结读取配置信息初始化 trace_manager,用于跟踪任务初始化 application_generate_entity ,用于存放运行所需要的信息初始化 queue_manager,通过队列传输线程中的结果启动线程,调用 _generate_workera.根据节点信息构建 graphb.依次运行各节点,把结果包成对应 Event 并放进 queue_man
dify 本地源码启动
Ollama 安装部署
dify 智能体实践
dify 源码分析(一)功能概述
dify 源码分析(二)源码结构
dify 源码分析(三)agent
dify 源码分析(四)tools
dify 源码分析(五)chatflow
dify 源码分析(六)event
dify 源码分析(七)ratelimiter
文章目录
- 1. 工具调用概述
- 2. BaseAgentRunner 调用流程
- 3. 工具初始化
1. 工具调用概述
1.1. provider 继承关系

1.2. 工具类继承关系

1.3. 执行逻辑
1.3.1. ToolManager
/api/core/tools/tool_manager.py
ToolManager 管理不同的 *ProviderController
1.3.2. *ProviderController
core/tools/__base/tool_provider.py
core/tools/builtin_tool/provider.py
core/tools/custom_tool/provider.py
core/tools/workflow_as_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/mcp_tool/provider.py
*ProviderController 管理不同的提供者
内置工具 provider
core/tools/builtin_tool/providers/audio/audio.py
core/tools/builtin_tool/providers/code/code.py
core/tools/builtin_tool/providers/time/time.py
core/tools/builtin_tool/providers/webscraper/webscraper.py
1.3.3. *Tool
BuiltinToolProviderController 加载诸工具
2. BaseAgentRunner 调用流程
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
message: Message,
query: str,
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# convert tools into ModelRuntime Tool format
# 初始化工具集
# _init_prompt_tools 方法整合所有可用工具(应用配置的工具 + 数据集工具),
# 生成 tool_instances(工具实例,用于实际调用)和 prompt_messages_tools(模型可识别的工具列表)。
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
......
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids: list[str] = []
# 1. 创建本轮思考记录(数据库中保存)
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
# 2. 调用LLM生成思考或工具调用指令
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
# 模型输出
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)
usage_dict: dict[str, Optional[LLMUsage]] = {}
# 3. 解析LLM输出,提取思考和工具调用指令
# 记录本轮思考的细节(思考内容、行动指令等)
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action): # 若输出是工具调用指令
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else: # 若输出是自然语言思考
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# get llm usage
......
self.save_agent_thought(......)
# 4. 判断是否需要调用工具
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
else:
# 若代理决定输出最终答案
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
# 若代理决定调用工具
else:
# 继续下一轮迭代
function_call_state = True
# action is tool call, invoke tool
# 调用工具并获取结果
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids,
trace_manager=trace_manager,
)
# 记录工具返回结果
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
# 使用了 message_file_ids,这个参数是工具返回的
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(......)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
2.1. _init_prompt_tools(初始化并转换工具信息)
class BaseAgentRunner(AppRunner):
......
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
tool_instances = {}
prompt_messages_tools = []
# 将所有工具转换成 提示消息工具
for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
try:
prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
except Exception:
# api tool may be deleted
continue
# save tool entity
tool_instances[tool.tool_name] = tool_entity
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# convert dataset tools into ModelRuntime Tool format
# 数据集工具
for dataset_tool in self.dataset_tools:
prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
# save prompt tool
prompt_messages_tools.append(prompt_tool)
# save tool entity
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
return tool_instances, prompt_messages_tools
......
2.1.1. _convert_tool_to_prompt_message_tool
get_agent_tool_runtime 智能体工具加载
并合并工具参数
class BaseAgentRunner(AppRunner):
......
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
"""
# 转换工具 为 提示词信息工具
# 此处是智能体工具调用的核心函数
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
)
assert tool_entity.entity.description
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.entity.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
# 合并工具参数
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
# 返回 两个数据
return message_tool, tool_entity
......
- PromptMessageTool 实际为 json
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
......
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.entity.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
# 工具名称(name)、描述(description)—— 告诉模型 “这个工具能做什么”;
# 参数定义(parameters)—— 告诉模型 “调用时需要传哪些参数”(如查询的时间范围 start_date、end_date)。
# name="dataset_search",
# description="查询指定时间范围的产品销量数据",
# parameters={
# "properties": {
# "start_date": {
# "type": "string",
# "description": "开始日期,格式YYYY-MM-DD"
# },
# "end_date": {
# "type": "string",
# "description": "结束日期,格式YYYY-MM-DD"
# }
# },
# "required": [
# "start_date",
# "end_date"
# ]
# }
2.1.1.1. get_agent_tool_runtime(智能体工具逻辑)
core/tools/tool_manager.py
class ToolManager:
......
# 智能体工具
@classmethod
def get_agent_tool_runtime(
cls,
tenant_id: str,
app_id: str,
agent_tool: AgentToolEntity, # 工作流该参数不同
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Optional[VariablePool] = None,
) -> Tool:
"""
get the agent tool runtime
"""
# 获取代理工具运行时
tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type,
provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
#
parameters = tool_entity.get_merged_runtime_parameters()
#
runtime_parameters = cls._convert_tool_parameters_type(
parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
)
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
identity_id=f"AGENT.{app_id}",
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
# 工作流没有判断
if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None:
raise ValueError("runtime not found or runtime parameters not found")
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
# 工作流工具
@classmethod
def get_workflow_tool_runtime(......)
......
2.1.1.2. get_tool_runtime(工具加载,区分类型)
- 内置工具
- 接口工具
- 工作流
- APP (直接抛出异常了)
- 插件
- MCP

core/tools/tool_manager.py
class ToolManager:
......
@classmethod
def get_tool_runtime(
cls,
provider_type: ToolProviderType,
provider_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param tool_name: the name of the tool
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
# 内置工具
if provider_type == ToolProviderType.BUILT_IN:
# check if the builtin tool need credentials
# 检查内置工具是否需要凭据(根据工具id 和 用户id)
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
# 根据工具名返回工具
builtin_tool = provider_controller.get_tool(tool_name)
if not builtin_tool:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
# 不需要验证 直接返回
if not provider_controller.need_credentials:
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
# 验证工具
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
# 回退到默认提供程序
if builtin_provider is None:
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
# 检查凭据是否已过期
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
# 刷新凭据
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
# 刷新凭据
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
# 更新凭据
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
# 返回工具信息
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
# 接口工具
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=api_provider,
)
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
# 工作流
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
workflow_provider = db.session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
#
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
# 插件工具
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
# MCP 工具
elif provider_type == ToolProviderType.MCP:
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
......
2.1.1.3. get_merged_runtime_parameters(尚未分析)
后续分析
2.1.2. _convert_dataset_retriever_tool_to_prompt_message_tool
后续分析
2.2. handle_react_stream_output(模型返回分析)
…
2.3. _handle_invoke_action(调用工具)
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(......) -> Generator:
"""
Run Cot agent application
"""
......
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool
# 获取工具实例
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
# 解析工具参数(如将JSON字符串转为字典)
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
# 调用工具
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
)
# publish files
# 发布工具返回的文件(如图表)
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta
2.3.1. ToolEngine.agent_invoke(工具执行)
/api/core/tools/tool_engine.py
class ToolEngine:
"""
Tool runtime engine take care of the tool executions.
"""
@staticmethod
def agent_invoke(
tool: Tool,
tool_parameters: Union[str, dict],
user_id: str,
tenant_id: str,
message: Message,
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
# 检查参数
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter
for parameter in tool.get_runtime_parameters()
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
with contextlib.suppress(Exception):
tool_parameters = json.loads(tool_parameters)
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
try:
# hit the callback handler
# 点击回调处理程序(开始)
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
# 工具执行
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
#
def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict["meta"] = message
else:
yield message
# 解析工具返回数据
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=message_callback(invocation_meta_dict, messages),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id,
)
message_list = list(messages)
# extract binary data from tool invoke message
# 从工具调用消息中提取二进制数据
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
# create message file
# 创建文件
message_files = ToolEngine._create_message_files(
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
)
#
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
meta = invocation_meta_dict["meta"]
# hit the callback handler
# 点击回调处理程序(结束)
agent_tool_callback.on_tool_end(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text,
message_id=message.id,
trace_manager=trace_manager,
)
# transform tool invoke message to get LLM friendly message
# 把工具调用消息 转换成 对LLM友好的消息,返回
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
......
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict,
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
"""
Invoke the tool with the given arguments.
"""
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
error=None,
tool_config={
"tool_name": tool.entity.identity.name,
"tool_provider": tool.entity.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_icon": tool.entity.identity.icon,
},
)
try:
# 工具基类的核心函数
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds()
yield meta
2.3.1.1. on_tool_start
2.3.1.2. tool.invoke(调用子类的重载函数)
class Tool(ABC):
"""
The base class of a tool
"""
......
def invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage]:
if self.runtime and self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
# 调用子类的重载函数
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
if isinstance(result, ToolInvokeMessage):
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
yield result
return single_generator()
elif isinstance(result, list):
def generator() -> Generator[ToolInvokeMessage, None, None]:
yield from result
return generator()
else:
return result
2.3.1.3. transform_tool_invoke_messages
解析工具返回的数据
文件、图片还支持下载。
此处先默认解析为json数据。
2.3.1.4. ToolEngine._extract_tool_response_binary_and_text
从工具调用消息中提取二进制数据
2.3.1.5. ToolEngine._create_message_files
创建文件
2.3.1.6. ToolEngine._convert_tool_response_to_str
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
result = ""
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += cast(ToolInvokeMessage.TextMessage, response.message).text
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += (
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
result += (
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
result += json.dumps(
safe_json_value(cast(ToolInvokeMessage.JsonMessage, response.message).json_object),
ensure_ascii=False,
)
else:
result += str(response.message)
return result
2.3.1.7. on_tool_start
2.3.1.8. on_tool_error
3. 工具初始化
3.1. 内置工具(初始化)
class ToolManager:
......
@classmethod
def get_tool_runtime(
cls,
provider_type: ToolProviderType,
provider_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param tool_name: the name of the tool
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
# 内置工具
if provider_type == ToolProviderType.BUILT_IN:
# check if the builtin tool need credentials
# 检查内置工具是否需要凭据(根据工具id 和 用户id)
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
# 根据工具名返回工具
builtin_tool = provider_controller.get_tool(tool_name)
if not builtin_tool:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
# 不需要验证 直接返回
if not provider_controller.need_credentials:
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
# 验证工具
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
# 回退到默认提供程序
if builtin_provider is None:
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
# 检查凭据是否已过期
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
# 刷新凭据
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
# 刷新凭据
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
# 更新凭据
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
# 返回工具信息
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
# 接口工具
elif provider_type == ToolProviderType.API:
......
# 工作流
elif provider_type == ToolProviderType.WORKFLOW:
......
# APP
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
# 插件工具
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
# MCP 工具
elif provider_type == ToolProviderType.MCP:
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
......
3.1.1. get_builtin_provider (获取工具信息)
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {} # 存储信息
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
......
# (根据工具id 和 用户id)获取工具信息
@classmethod
def get_builtin_provider(
cls, provider: str, tenant_id: str
) -> BuiltinToolProviderController | PluginToolProviderController:
"""
get the builtin provider
:param provider: the name of the provider
:param tenant_id: the id of the tenant
:return: the provider
"""
# split provider to
# 首次判断 cls._hardcoded_providers 是否为空。如果为空则通过加载 py 文件加载工具
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
# 加载硬编码提供者缓存
cls.load_hardcoded_providers_cache()
# 如果provider 不在硬编码提供者列表中 则查找软编码??
# provider 此处为 provider_id
if provider not in cls._hardcoded_providers:
# get plugin provider
plugin_provider = cls.get_plugin_provider(provider, tenant_id)
if plugin_provider:
return plugin_provider
# 如果不存在则会异常吧??
return cls._hardcoded_providers[provider]
......
3.1.1.1. load_hardcoded_providers_cache(通过硬编码加载)
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
......
# 1. 加载硬编码缓存
@classmethod
def load_hardcoded_providers_cache(cls):
for _ in cls.list_hardcoded_providers():
pass
# 2. 如果缓存存在则使用缓存,如果缓存不存在则使用py文件加载,并更新缓存
@classmethod
def list_hardcoded_providers(cls):
# use cache first
# 首先使用缓存 如果尚未加载
if cls._builtin_providers_loaded:
yield from list(cls._hardcoded_providers.values())
return
with cls._builtin_provider_lock:
if cls._builtin_providers_loaded:
yield from list(cls._hardcoded_providers.values())
return
# 通过文件路径加载
yield from cls._list_hardcoded_providers()
# 3. 通过本地py文件加载
@classmethod
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
"""
list all the builtin providers
"""
# 列举所有本地文件中的
# api/core/tools/builtin_tool/providers
for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
# 跳过
if provider_path.startswith("__"):
continue
# 判断是否是文件夹
if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
# 跳过
if provider_path.startswith("__"):
continue
# init provider
try:
# 使用 importlib 加载py文件
# 在/api/core/helper/module_import_helper.py
provider_class = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider_path,
f"{provider_path}.py",
),
parent_type=BuiltinToolProviderController,
)
# 内置工具基类
provider: BuiltinToolProviderController = provider_class()
cls._hardcoded_providers[provider.entity.identity.name] = provider
# 更新工具列表
for tool in provider.get_tools():
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
yield provider
except Exception:
logger.exception("load builtin provider %s", provider_path)
continue
# set builtin providers loaded
# 修改加载标志
cls._builtin_providers_loaded = True
......

3.1.1.1.1. load_single_subclass_from_source(从文件中加载)
从文件中加载,这个需要详细分析
3.1.1.1.2. BuiltinToolProviderController 基类
/api/core/tools/builtin_tool/provider.py
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool]
def __init__(self, **data: Any) -> None:
self.tools = []
# load provider yaml
# 加载 yaml 文件
provider = self.__class__.__module__.split(".")[-1]
yaml_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, f"{provider}.yaml")
try:
provider_yaml = load_yaml_file(yaml_path, ignore_error=False)
except Exception as e:
raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}")
#### 以下 凭证 和 auth 基本都是空的。暂时忽略
if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None:
# set credentials name
# 设置凭据名称
for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
credentials_schema = []
for credential in provider_yaml.get("credentials_for_provider", {}):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
#### 以上 凭证 和 auth 基本都是空的。暂时忽略
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema, # 暂定为空
oauth_schema=oauth_schema, # 暂定为空
),
)
self._load_tools()
# provider 下还有多个工具,需要逐一解析 yaml 文件加载进来
def _load_tools(self):
provider = self.entity.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
# get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module
assistant_tool_class: type = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider,
"tools",
f"{tool_name}.py",
),
parent_type=BuiltinTool,
)
tool["identity"]["provider"] = provider
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool),
runtime=ToolRuntime(tenant_id=""),
)
)
self.tools = tools
......
此处也用到了 load_single_subclass_from_source 。
3.1.1.2. get_plugin_provider(插件工具管理)
class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
......
@classmethod
def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
"""
get the plugin provider
"""
# check if context is set
try:
contexts.plugin_tool_providers.get()
except LookupError:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
# 其他的没看懂,这里是重点
manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
controller = PluginToolProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
plugin_tool_providers[provider] = controller
return controller
......
- contexts 封装(有点看不懂)
/api/contexts/init.py
from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
"""
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
- RecyclableContextVar
/api/contexts/wrapper.py
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
pass
_default = HiddenValue()
class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
NOTE: you need to call `increment_thread_recycles` before requests
"""
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
@classmethod
def increment_thread_recycles(cls):
try:
recycles = cls._thread_recycles.get()
cls._thread_recycles.set(recycles + 1)
except LookupError:
cls._thread_recycles.set(0)
def __init__(self, context_var: ContextVar[T]):
self._context_var = context_var
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
def get(self, default: T | HiddenValue = _default) -> T:
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)
# check if thread is recycled and should be updated
if thread_recycles < self_updates:
return self._context_var.get()
else:
# thread_recycles >= self_updates, means current context is invalid
if isinstance(default, HiddenValue) or default is _default:
raise LookupError
else:
return default
def set(self, value: T):
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
# increase it manually
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)
if self._updates.get() == self._thread_recycles.get(0):
# after increment,
self._updates.set(self._updates.get() + 1)
# set the context
self._context_var.set(value)
3.1.2. get_tool(根据工具名返回工具)
/api/core/tools/builtin_tool/provider.py
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool]
......
# 此函数加载 provider 提供的所有工具
def _load_tools(self):
provider = self.entity.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
# get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
tools = []
for tool_file in tool_files:
# get tool name
tool_name = tool_file.split(".")[0]
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module
assistant_tool_class: type = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider,
"tools",
f"{tool_name}.py",
),
parent_type=BuiltinTool,
)
# 使用 json 初始化 ToolEntity
# 没有 json 初始化 ToolRuntime
tool["identity"]["provider"] = provider
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool),
runtime=ToolRuntime(tenant_id=""),
)
)
# 被下面函数调用
self.tools = tools
# 被下面函数调用
def _get_builtin_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self.tools
# 返回 provider 提供的所有工具
def get_tools(self) -> list[BuiltinTool]:
"""
returns a list of tools that the provider can provide
:return: list of tools
"""
return self._get_builtin_tools()
......
3.1.2.1. ToolEntity(使用 json 初始化)
/api/core/tools/entities/tool_entities.py
# 被核心类使用
class ToolParameter(PluginParameter):
"""
Overrides type
"""
class ToolParameterType(enum.StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
ANY = PluginParameterType.ANY.value
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
# MCP object and array type parameters
ARRAY = MCPServerParameterType.ARRAY.value
OBJECT = MCPServerParameterType.OBJECT.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
FORM = "form" # should be set before invoking tool
LLM = "llm" # will be set by LLM
type: ToolParameterType = Field(..., description="The type of the parameter")
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
# MCP object and array type parameters use this field to store the schema
input_schema: Optional[dict] = None
@classmethod
def get_simple_instance(
cls,
name: str,
llm_description: str,
typ: ToolParameterType,
required: bool,
options: Optional[list[str]] = None,
) -> "ToolParameter":
"""
get a simple tool parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
human_description=I18nObject(en_US="", zh_Hans=""),
type=typ,
form=cls.ToolParameterForm.LLM,
llm_description=llm_description,
required=required,
options=option_objs,
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
# 被核心类使用
class ToolIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
provider: str = Field(..., description="The provider of the tool")
icon: Optional[str] = None
# 被核心类使用
class ToolDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
# 核心类
class ToolEntity(BaseModel):
identity: ToolIdentity
parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None
output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
3.1.2.2. ToolRuntime(传参初始化)
/api/core/tools/__base/tool_runtime.py
class ToolRuntime(BaseModel):
"""
Meta data of a tool call processing
"""
tenant_id: str
tool_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeToolRuntime(ToolRuntime):
"""
Fake tool runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
tool_id="fake_tool_id",
invoke_from=InvokeFrom.DEBUGGER,
tool_invoke_from=ToolInvokeFrom.AGENT,
credentials={},
runtime_parameters={},
)
3.1.3. 验证工具
目前看到的工具凭证基本都没有提供,后续再分析。
3.1.4. 返回工具(至此返回工具的 json 信息)
......
# 不需要验证 直接返回
if not provider_controller.need_credentials:
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
......
# 需要验证
# 区别在于 credentials 是否为空
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
......
3.1.4.1. fork_tool_runtime(返回 json)
entity 定义在父类,会在读取 yaml 文件,创建工具的时传入,实际为 pydantic 结构体。
class Tool(ABC):
"""
The base class of a tool
"""
# entity 在读取文件初始化工具的时候 已经提供。
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
self.entity = entity
self.runtime = runtime
......
class BuiltinTool(Tool):
"""
Builtin tool
:param meta: the meta data of a tool call processing
"""
def __init__(self, provider: str, **kwargs):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
"""
fork a new tool with metadata
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
provider=self.provider,
)
3.2. 接口工具
class ToolManager:
......
@classmethod
def get_tool_runtime(
cls,
provider_type: ToolProviderType,
provider_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param tool_name: the name of the tool
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
# 内置工具
if provider_type == ToolProviderType.BUILT_IN:
......
# 接口工具
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
controller=api_provider,
)
return api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
# 工作流
elif provider_type == ToolProviderType.WORKFLOW:
......
# APP
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
# 插件工具
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
# MCP 工具
elif provider_type == ToolProviderType.MCP:
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
......
3.3. 工作流
3.4. APP(直接抛出异常)
3.5. 插件
3.6. MCP
更多推荐




所有评论(0)