大模型流式输出简谈
·
大模型流式输出是绕不开的一环,本文我将简单写一个示例,带你了解并简单上手
python代码准备
- 本次需要用到的包
要用到StreamingResponse来处理流式输出
import os
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI,Body
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel
-
流式接口编写
-
构建自定义请求体ChatRequest,接收用户输入的字符串
-
构建异步方法
event_generator() -
将
event_generator()传入StreamingResponse中,并设置media_type参数为text/plain
-
class ChatRequest(BaseModel):
user_input: str
@app.post('/chat/stream_resp')
async def chat_stream_resp(request: ChatRequest):
async def event_generator():
async for chunk in chain.astream({'user_input': request.user_input}):
if chunk:
yield chunk.encode('utf-8')
return StreamingResponse(event_generator(), media_type='text/plain')
前端代码准备
- 主要讲解一下javascript中流式输出方法的实现
- 构建POST请求与fastapi后端接口对接,注意fetch无法像axios那样自动解析JSON,所以要将POST请求的请求体做JSON转换
- 为了做流式输出,要利用decoder,先构造reader指定其为
response.body.getReader(),再构建decoder,新建TextDecoder('utf-8')并创建变量buffer用其来做后端响应的接收 - 构建循环,进行解包赋值,拿到信号量和返回值,如果流式输出结束,信号量
done将为true,否则则证明正在进行流式响应,利用decoder对value响应分片,指定stream为true - 接收的流式响应拼接到
buffer并对前端进行返回
<script>
async function startStream() {
const output = document.getElementById('output');
const user_input = document.getElementById('user_input').value;
output.innerHTML = '';
output.innerText = '';
try {
const response = await fetch('http://127.0.0.1:8000/chat/stream_resp', {
method: 'POST',
headers:{
'Content-Type':'application/json'
},
body: JSON.stringify({ //important:发送的是json格式字符串
'user_input': user_input
})
}
);
if (!response.ok) throw new Error('网络响应失败!')
const reader = response.body.getReader()
const decoder = new TextDecoder('utf-8')
let buffer = ''
while (true) {
const {done, value} = await reader.read()
if (done) break;
const chunk = decoder.decode(value, {stream: true})
buffer += chunk;
output.textContent = buffer
output.scrollTop = output.scrollHeight;
}
} catch (err) {
console.log('请求出错', err)
output.textContent += '\n出错了!' + err.message;
}
}
</script>
整体代码
python
import os
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI,Body
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel
load_dotenv()
#定义程序
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#初始化大模型
llm=init_chat_model(
model='glm-4.7',
model_provider='openai',
api_key=os.getenv('zhipu_key'),
base_url=os.getenv('zhipu_base_url')
)
#定义提示词
prompt_template = ChatPromptTemplate(
messages=[
('system','你现在是一个小说家,你会讲小说'),
('human','{user_input}')
]
)
#定义lcel
chain=prompt_template|llm|StrOutputParser()
#important:方案2,构建自定义请求体,更规范
class ChatRequest(BaseModel):
user_input: str
@app.post('/chat/stream_resp')
async def chat_stream_resp(request: ChatRequest):
async def event_generator():
async for chunk in chain.astream({'user_input': request.user_input}):
if chunk:
yield chunk.encode('utf-8')
return StreamingResponse(event_generator(), media_type='text/plain')
if __name__ == '__main__':
uvicorn.run(app, host='127.0.0.1', port=8000)
前端
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Title</title>
<style>
#output {
width: 200px;
height: 300px;
background-color: #c3e6cb;
margin-bottom: 30px;
}
</style>
</head>
<body>
<div id="output"></div>
用户输入:<input id="user_input">
<button onclick="startStream()">发送</button>
<script>
async function startStream() {
const output = document.getElementById('output');
const user_input = document.getElementById('user_input').value;
output.innerHTML = '';
output.innerText = '';
try {
const response = await fetch('http://127.0.0.1:8000/chat/stream_resp', {
method: 'POST',
headers:{
'Content-Type':'application/json'
},
body: JSON.stringify({ //important:发送的是json格式字符串
'user_input': user_input
})
}
);
if (!response.ok) throw new Error('网络响应失败!')
const reader = response.body.getReader()
const decoder = new TextDecoder('utf-8')
let buffer = ''
while (true) {
const {done, value} = await reader.read()
if (done) break;
const chunk = decoder.decode(value, {stream: true})
buffer += chunk;
output.textContent = buffer
output.scrollTop = output.scrollHeight;
}
} catch (err) {
console.log('请求出错', err)
output.textContent += '\n出错了!' + err.message;
}
}
</script>
</body>
</html>
希望对你有帮助
更多推荐


所有评论(0)