feat: 支持流式输出
This commit is contained in:
137
tests/test_stream.py
Normal file
137
tests/test_stream.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
from logger import get_logger, get_stream_logger
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
stream_logger = get_stream_logger()
|
||||
|
||||
def test_api_stream():
|
||||
"""
|
||||
测试API流式响应功能
|
||||
"""
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 获取API配置
|
||||
api_url = os.getenv('API_URL')
|
||||
api_key = os.getenv('API_KEY')
|
||||
api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||
|
||||
logger.info(f"开始测试API流式响应,API URL: {api_url}, MODEL: {api_model}")
|
||||
|
||||
# 检查API配置
|
||||
if not api_url:
|
||||
logger.error("API URL未配置,无法进行测试")
|
||||
return
|
||||
|
||||
if not api_key:
|
||||
logger.error("API Key未配置,无法进行测试")
|
||||
return
|
||||
|
||||
# 标准化API URL
|
||||
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
|
||||
if api_url.endswith('/v1'):
|
||||
api_url = f"{api_url}/chat/completions"
|
||||
elif api_url.endswith('/'):
|
||||
api_url = f"{api_url}v1/chat/completions"
|
||||
else:
|
||||
api_url = f"{api_url}/v1/chat/completions"
|
||||
|
||||
logger.debug(f"标准化后的API URL: {api_url}")
|
||||
|
||||
# 构建简单的测试提示
|
||||
prompt = "这是一个API流式响应测试。请给出一个简短的股票分析样例。"
|
||||
|
||||
# 构建请求头和请求体
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": api_model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": True # 明确设置stream参数为True
|
||||
}
|
||||
|
||||
logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
|
||||
|
||||
try:
|
||||
logger.info(f"发起流式API请求: {api_url}")
|
||||
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60,
|
||||
stream=True
|
||||
)
|
||||
|
||||
logger.info(f"API流式响应状态码: {response.status_code}")
|
||||
logger.debug(f"响应头: {response.headers}")
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info("成功获取API流式响应,开始处理")
|
||||
|
||||
buffer = ""
|
||||
chunk_count = 0
|
||||
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode('utf-8')
|
||||
logger.info(f"原始流式行: {line_str}")
|
||||
|
||||
# 跳过保持连接的空行
|
||||
if line_str.strip() == '':
|
||||
logger.debug("跳过空行")
|
||||
continue
|
||||
|
||||
# 数据行通常以"data: "开头
|
||||
if line_str.startswith('data: '):
|
||||
data_content = line_str[6:].strip() # 移除 "data: " 前缀并去除前后空格
|
||||
logger.info(f"数据内容: {data_content}")
|
||||
|
||||
# 检查是否为流的结束
|
||||
if data_content == '[DONE]':
|
||||
logger.info("收到流结束标记 [DONE]")
|
||||
break
|
||||
|
||||
try:
|
||||
# 解析JSON数据
|
||||
json_data = json.loads(data_content)
|
||||
logger.debug(f"JSON结构: {json.dumps(json_data, indent=2)}")
|
||||
|
||||
if 'choices' in json_data:
|
||||
delta = json_data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
chunk_count += 1
|
||||
buffer += content
|
||||
logger.info(f"内容片段 #{chunk_count}: {content}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {e}, 内容: {data_content}")
|
||||
else:
|
||||
logger.warning(f"收到非'data:'开头的行: {line_str}")
|
||||
|
||||
logger.info(f"流式处理完成,共收到 {chunk_count} 个内容片段")
|
||||
logger.info(f"完整内容:\n{buffer}")
|
||||
|
||||
else:
|
||||
try:
|
||||
error_response = response.json()
|
||||
error_text = json.dumps(error_response, indent=2)
|
||||
except:
|
||||
error_text = response.text[:500] if response.text else "无响应内容"
|
||||
|
||||
logger.error(f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试过程中发生异常: {str(e)}")
|
||||
logger.exception(e)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_api_stream()
|
||||
Reference in New Issue
Block a user