feat: 支持流式输出
This commit is contained in:
@@ -3,8 +3,14 @@ import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import requests
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple, Generator
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
from logger import get_logger, get_stream_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
stream_logger = get_stream_logger()
|
||||
|
||||
class StockAnalyzer:
|
||||
def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None):
|
||||
@@ -17,6 +23,8 @@ class StockAnalyzer:
|
||||
self.API_KEY = custom_api_key or os.getenv('API_KEY')
|
||||
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||
|
||||
logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}")
|
||||
|
||||
# 配置参数
|
||||
self.params = {
|
||||
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
|
||||
@@ -205,9 +213,10 @@ class StockAnalyzer:
|
||||
print(f"计算评分时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_ai_analysis(self, df, stock_code):
|
||||
def get_ai_analysis(self, df, stock_code, stream=False):
|
||||
"""使用 OpenAI 进行 AI 分析"""
|
||||
try:
|
||||
logger.info(f"开始AI分析股票 {stock_code}, 流式模式: {stream}")
|
||||
recent_data = df.tail(14).to_dict('records')
|
||||
|
||||
technical_summary = {
|
||||
@@ -237,47 +246,202 @@ class StockAnalyzer:
|
||||
请基于技术指标和市场动态进行分析,给出具体数据支持。
|
||||
"""
|
||||
|
||||
# OpenAI API 调用
|
||||
api_urls = [
|
||||
f"{self.API_URL}/chat/completions",
|
||||
f"{self.API_URL}/v1/chat/completions"
|
||||
]
|
||||
logger.debug(f"生成的AI分析提示词: {prompt[:100]}...")
|
||||
|
||||
last_error = None
|
||||
for api_url in api_urls:
|
||||
# 检查API配置
|
||||
if not self.API_URL:
|
||||
error_msg = "API URL未配置,无法进行AI分析"
|
||||
logger.error(error_msg)
|
||||
return error_msg if not stream else (yield json.dumps({"error": error_msg}))
|
||||
|
||||
if not self.API_KEY:
|
||||
error_msg = "API Key未配置,无法进行AI分析"
|
||||
logger.error(error_msg)
|
||||
return error_msg if not stream else (yield json.dumps({"error": error_msg}))
|
||||
|
||||
# 标准化API URL
|
||||
api_url = self.API_URL
|
||||
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
|
||||
if api_url.endswith('/v1'):
|
||||
api_url = f"{api_url}/chat/completions"
|
||||
elif api_url.endswith('/'):
|
||||
api_url = f"{api_url}v1/chat/completions"
|
||||
else:
|
||||
api_url = f"{api_url}/v1/chat/completions"
|
||||
|
||||
logger.debug(f"标准化后的API URL: {api_url}")
|
||||
|
||||
# 构建请求头和请求体
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.API_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}
|
||||
|
||||
# 流式处理设置
|
||||
if stream:
|
||||
logger.debug(f"配置流式参数,使用API URL: {api_url}")
|
||||
payload["stream"] = True # 明确设置stream参数为True
|
||||
|
||||
try:
|
||||
logger.debug(f"发起流式API请求: {api_url}")
|
||||
logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
|
||||
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60, # 增加超时时间
|
||||
stream=True
|
||||
)
|
||||
|
||||
logger.debug(f"API流式响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"成功获取API流式响应,开始处理")
|
||||
yield from self._process_ai_stream(response, stock_code)
|
||||
else:
|
||||
try:
|
||||
error_response = response.json()
|
||||
error_text = json.dumps(error_response, indent=2)
|
||||
except:
|
||||
error_text = response.text[:500] if response.text else "无响应内容"
|
||||
|
||||
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
|
||||
logger.error(error_msg)
|
||||
yield json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"流式API请求异常: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
yield json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||
else:
|
||||
# 非流式处理
|
||||
logger.debug(f"发起非流式API请求: {api_url}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": self.API_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
},
|
||||
timeout=30
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
logger.debug(f"API非流式响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
api_response = response.json()
|
||||
content = api_response['choices'][0]['message']['content']
|
||||
logger.info(f"成功获取AI分析结果,长度: {len(content)}")
|
||||
logger.debug(f"AI分析结果前100字符: {content[:100]}...")
|
||||
return content
|
||||
else:
|
||||
last_error = f"API 错误: {response.status_code} - {response.text}"
|
||||
continue
|
||||
try:
|
||||
error_response = response.json()
|
||||
error_text = json.dumps(error_response, indent=2)
|
||||
except:
|
||||
error_text = response.text[:500] if response.text else "无响应内容"
|
||||
|
||||
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
continue
|
||||
|
||||
print(f"AI 分析暂时无法使用: {last_error}")
|
||||
return f"AI 分析暂时无法使用: {last_error}"
|
||||
error_msg = f"非流式API请求异常: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
return error_msg
|
||||
|
||||
except Exception as e:
|
||||
print(f"AI 分析发生错误: {str(e)}")
|
||||
return f"AI 分析过程中发生错误: {str(e)}"
|
||||
error_msg = f"AI 分析过程中发生错误: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
|
||||
if stream:
|
||||
logger.debug("在流式模式下返回异常信息")
|
||||
error_json = json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||
stream_logger.info(f"流式异常输出: {error_json}")
|
||||
yield error_json
|
||||
else:
|
||||
return error_msg
|
||||
|
||||
def _process_ai_stream(self, response, stock_code) -> Generator[str, None, None]:
|
||||
"""处理AI流式响应"""
|
||||
logger.info(f"开始处理股票 {stock_code} 的AI流式响应")
|
||||
buffer = ""
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
stream_logger.info(f"原始流式行: {line}")
|
||||
|
||||
# 跳过保持连接的空行
|
||||
if line.strip() == '':
|
||||
logger.debug("跳过空行")
|
||||
continue
|
||||
|
||||
# 数据行通常以"data: "开头
|
||||
if line.startswith('data: '):
|
||||
data_content = line[6:] # 移除 "data: " 前缀
|
||||
stream_logger.info(f"数据内容: {data_content}")
|
||||
|
||||
# 检查是否为流的结束
|
||||
if data_content.strip() == '[DONE]':
|
||||
logger.debug("收到流结束标记 [DONE]")
|
||||
break
|
||||
|
||||
try:
|
||||
json_data = json.loads(data_content)
|
||||
logger.debug(f"解析的JSON数据: {json.dumps(json_data)[:100]}...")
|
||||
|
||||
if 'choices' in json_data:
|
||||
delta = json_data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
chunk_count += 1
|
||||
buffer += content
|
||||
logger.debug(f"收到内容片段 #{chunk_count}: {content}")
|
||||
stream_logger.info(f"发送内容片段: {content}")
|
||||
|
||||
# 创建包含AI分析片段的JSON
|
||||
chunk_json = json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"ai_analysis_chunk": content
|
||||
})
|
||||
stream_logger.info(f"流式输出JSON: {chunk_json}")
|
||||
yield chunk_json
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析错误: {str(e)}, 行内容: {data_content}")
|
||||
# 忽略无法解析的JSON
|
||||
pass
|
||||
else:
|
||||
logger.warning(f"收到非'data:'开头的行: {line}")
|
||||
|
||||
logger.info(f"AI流式处理完成,共收到 {chunk_count} 个内容片段,总长度: {len(buffer)}")
|
||||
|
||||
# 如果buffer不为空,最后一次发送完整内容
|
||||
if buffer and not buffer.endswith('\n'):
|
||||
logger.debug("发送换行符")
|
||||
yield json.dumps({"stock_code": stock_code, "ai_analysis_chunk": "\n"})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理AI流式响应时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
yield json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||
|
||||
|
||||
def get_recommendation(self, score):
|
||||
"""根据得分给出建议"""
|
||||
logger.debug(f"根据评分 {score} 生成投资建议")
|
||||
if score >= 80:
|
||||
return '强烈推荐买入'
|
||||
elif score >= 60:
|
||||
@@ -288,18 +452,24 @@ class StockAnalyzer:
|
||||
return '建议卖出'
|
||||
else:
|
||||
return '强烈建议卖出'
|
||||
|
||||
def analyze_stock(self, stock_code, market_type='A'):
|
||||
|
||||
def analyze_stock(self, stock_code, market_type='A', stream=False):
|
||||
"""分析单个股票"""
|
||||
try:
|
||||
logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}, 流式模式: {stream}")
|
||||
|
||||
# 获取股票数据
|
||||
logger.debug(f"获取股票 {stock_code} 数据")
|
||||
df = self.get_stock_data(stock_code, market_type)
|
||||
|
||||
# 计算技术指标
|
||||
logger.debug(f"计算股票 {stock_code} 技术指标")
|
||||
df = self.calculate_indicators(df)
|
||||
|
||||
# 评分系统
|
||||
logger.debug(f"计算股票 {stock_code} 评分")
|
||||
score = self.calculate_score(df)
|
||||
logger.info(f"股票 {stock_code} 评分结果: {score}")
|
||||
|
||||
# 获取最新数据
|
||||
latest = df.iloc[-1]
|
||||
@@ -316,29 +486,92 @@ class StockAnalyzer:
|
||||
'rsi': latest['RSI'],
|
||||
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
|
||||
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
|
||||
'recommendation': self.get_recommendation(score),
|
||||
'ai_analysis': self.get_ai_analysis(df, stock_code)
|
||||
'recommendation': self.get_recommendation(score)
|
||||
}
|
||||
logger.debug(f"生成股票 {stock_code} 基础报告: {json.dumps(report)[:100]}...")
|
||||
|
||||
return report
|
||||
if stream:
|
||||
logger.info(f"以流式模式返回股票 {stock_code} 分析结果")
|
||||
# 先返回基本报告结构
|
||||
base_report = dict(report)
|
||||
base_report['ai_analysis'] = ''
|
||||
base_report_json = json.dumps(base_report)
|
||||
logger.debug(f"基础报告JSON: {base_report_json[:100]}...")
|
||||
stream_logger.info(f"发送基础报告: {base_report_json}")
|
||||
yield base_report_json
|
||||
|
||||
# 然后流式返回AI分析部分
|
||||
logger.debug(f"开始获取股票 {stock_code} 的流式AI分析")
|
||||
ai_chunks_count = 0
|
||||
for ai_chunk in self.get_ai_analysis(df, stock_code, stream=True):
|
||||
ai_chunks_count += 1
|
||||
stream_logger.info(f"股票 {stock_code} 流式块 #{ai_chunks_count}: {ai_chunk}")
|
||||
yield ai_chunk
|
||||
logger.info(f"股票 {stock_code} 流式AI分析完成,共发送 {ai_chunks_count} 个块")
|
||||
else:
|
||||
logger.info(f"以非流式模式返回股票 {stock_code} 分析结果")
|
||||
logger.debug(f"开始获取股票 {stock_code} 的AI分析")
|
||||
report['ai_analysis'] = self.get_ai_analysis(df, stock_code)
|
||||
logger.debug(f"AI分析结果长度: {len(report['ai_analysis'])}")
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
print(f"分析股票时出错: {str(e)}")
|
||||
raise
|
||||
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
|
||||
def scan_market(self, stock_list, min_score=60, market_type='A'):
|
||||
if stream:
|
||||
error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
|
||||
stream_logger.info(f"流式错误输出: {error_json}")
|
||||
yield error_json
|
||||
else:
|
||||
raise
|
||||
|
||||
def scan_market(self, stock_list, min_score=60, market_type='A', stream=False):
|
||||
"""扫描市场,寻找符合条件的股票"""
|
||||
recommendations = []
|
||||
logger.info(f"开始扫描市场,股票数量: {len(stock_list)}, 最低分数: {min_score}, 市场: {market_type}, 流式模式: {stream}")
|
||||
|
||||
for stock_code in stock_list:
|
||||
try:
|
||||
report = self.analyze_stock(stock_code, market_type)
|
||||
if report['score'] >= min_score:
|
||||
recommendations.append(report)
|
||||
except Exception as e:
|
||||
print(f"分析股票 {stock_code} 时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
# 按得分排序
|
||||
recommendations.sort(key=lambda x: x['score'], reverse=True)
|
||||
return recommendations
|
||||
if not stream:
|
||||
recommendations = []
|
||||
|
||||
for stock_code in stock_list:
|
||||
try:
|
||||
logger.debug(f"分析股票: {stock_code}")
|
||||
report = self.analyze_stock(stock_code, market_type)
|
||||
if report['score'] >= min_score:
|
||||
logger.info(f"股票 {stock_code} 评分 {report['score']} >= {min_score},添加到推荐列表")
|
||||
recommendations.append(report)
|
||||
else:
|
||||
logger.debug(f"股票 {stock_code} 评分 {report['score']} < {min_score},不添加到推荐列表")
|
||||
except Exception as e:
|
||||
logger.error(f"分析股票 {stock_code} 时出错: {str(e)}")
|
||||
logger.exception(e)
|
||||
continue
|
||||
|
||||
# 按得分排序
|
||||
recommendations.sort(key=lambda x: x['score'], reverse=True)
|
||||
logger.info(f"扫描完成,找到 {len(recommendations)} 个推荐股票")
|
||||
return recommendations
|
||||
else:
|
||||
# 流式处理每个股票
|
||||
logger.info(f"开始流式扫描 {len(stock_list)} 只股票")
|
||||
stock_count = 0
|
||||
for stock_code in stock_list:
|
||||
stock_count += 1
|
||||
logger.debug(f"流式分析股票 {stock_code} ({stock_count}/{len(stock_list)})")
|
||||
try:
|
||||
# 分析单只股票并获取流式结果
|
||||
chunk_count = 0
|
||||
for chunk in self.analyze_stock(stock_code, market_type, stream=True):
|
||||
chunk_count += 1
|
||||
stream_logger.info(f"股票 {stock_code} 流式块 #{chunk_count}: {chunk}")
|
||||
yield chunk
|
||||
logger.debug(f"股票 {stock_code} 流式分析完成,共 {chunk_count} 个块")
|
||||
except Exception as e:
|
||||
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
|
||||
stream_logger.info(f"流式错误输出: {error_json}")
|
||||
yield error_json
|
||||
logger.info(f"流式扫描完成,处理了 {stock_count} 只股票")
|
||||
|
||||
Reference in New Issue
Block a user