Merge branch 'pr-1' into dev
# Conflicts: # stock_analyzer.py
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
# 使用 Python 3.9 作为基础镜像
|
# 使用 Python 3.10 作为基础镜像
|
||||||
FROM python:3.10-slim
|
FROM python:3.10-slim
|
||||||
|
|
||||||
# 设置工作目录
|
# 设置工作目录
|
||||||
@@ -15,7 +15,6 @@ COPY . /app/
|
|||||||
|
|
||||||
# 安装 Python 依赖
|
# 安装 Python 依赖
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
RUN pip install akshare --upgrade -i https://pypi.org/simple
|
|
||||||
|
|
||||||
# 设置环境变量
|
# 设置环境变量
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
|
|||||||
58
logger.py
Normal file
58
logger.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from loguru import logger
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 获取当前时间作为日志文件名的一部分
|
||||||
|
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# 创建日志目录
|
||||||
|
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logger.remove() # 移除默认的处理器
|
||||||
|
|
||||||
|
# 添加标准输出处理器(控制台)
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||||
|
level="DEBUG"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加文件处理器(debug级别)
|
||||||
|
logger.add(
|
||||||
|
os.path.join(log_dir, f"debug_{current_time}.log"),
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
|
||||||
|
level="DEBUG",
|
||||||
|
rotation="100 MB",
|
||||||
|
retention="1 week"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加文件处理器(error级别)
|
||||||
|
logger.add(
|
||||||
|
os.path.join(log_dir, f"error_{current_time}.log"),
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
|
||||||
|
level="ERROR",
|
||||||
|
rotation="100 MB",
|
||||||
|
retention="1 month"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加流处理器(用于记录流式输出)
|
||||||
|
logger.add(
|
||||||
|
os.path.join(log_dir, f"stream_{current_time}.log"),
|
||||||
|
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {message}",
|
||||||
|
filter=lambda record: "STREAM" in record["extra"],
|
||||||
|
level="INFO"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建专用于流式输出的日志器
|
||||||
|
stream_logger = logger.bind(STREAM=True)
|
||||||
|
|
||||||
|
def get_logger():
|
||||||
|
"""获取通用日志器"""
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def get_stream_logger():
|
||||||
|
"""获取流式输出专用日志器"""
|
||||||
|
return stream_logger
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
|
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
# 基础科学计算和数据处理库
|
# 基础科学计算和数据处理库
|
||||||
numpy==2.1.2
|
numpy==2.1.2
|
||||||
pandas==2.2.2
|
pandas==2.2.2
|
||||||
scipy==1.15.1
|
scipy==1.15.1
|
||||||
|
|
||||||
# 数据获取和分析库
|
# 数据获取和分析库
|
||||||
akshare
|
akshare==1.16.22
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,340 +1,583 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple, Generator
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import json
|
||||||
class StockAnalyzer:
|
from logger import get_logger, get_stream_logger
|
||||||
def __init__(self, initial_cash=1000000):
|
|
||||||
|
# 获取日志器
|
||||||
# 加载环境变量
|
logger = get_logger()
|
||||||
load_dotenv()
|
stream_logger = get_stream_logger()
|
||||||
|
|
||||||
# 设置 API
|
class StockAnalyzer:
|
||||||
self.API_URL = os.getenv('API_URL')
|
def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=60):
|
||||||
self.API_KEY = os.getenv('API_KEY')
|
|
||||||
self.API_TIMEOUT = int(os.getenv('API_TIMEOUT', '60'))
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
# 配置参数
|
|
||||||
self.params = {
|
# 设置 API 配置,优先使用自定义配置,否则使用环境变量
|
||||||
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
|
self.API_URL = custom_api_url or os.getenv('API_URL')
|
||||||
'rsi_period': 14,
|
self.API_KEY = custom_api_key or os.getenv('API_KEY')
|
||||||
'bollinger_period': 20,
|
self.API_TIMEOUT = custom_api_timeout or int(os.getenv('API_TIMEOUT', '60'))
|
||||||
'bollinger_std': 2,
|
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||||
'volume_ma_period': 20,
|
|
||||||
'atr_period': 14
|
logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}")
|
||||||
}
|
|
||||||
|
# 配置参数
|
||||||
|
self.params = {
|
||||||
def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
|
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
|
||||||
"""获取股票数据"""
|
'rsi_period': 14,
|
||||||
import akshare as ak
|
'bollinger_period': 20,
|
||||||
|
'bollinger_std': 2,
|
||||||
if start_date is None:
|
'volume_ma_period': 20,
|
||||||
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
|
'atr_period': 14
|
||||||
if end_date is None:
|
}
|
||||||
end_date = datetime.now().strftime('%Y%m%d')
|
|
||||||
|
|
||||||
try:
|
def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
|
||||||
# 根据市场类型获取数据
|
"""获取股票数据"""
|
||||||
if market_type == 'A':
|
import akshare as ak
|
||||||
df = ak.stock_zh_a_hist(
|
|
||||||
symbol=stock_code,
|
if start_date is None:
|
||||||
start_date=start_date,
|
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
|
||||||
end_date=end_date,
|
if end_date is None:
|
||||||
adjust="qfq"
|
end_date = datetime.now().strftime('%Y%m%d')
|
||||||
)
|
|
||||||
# A股数据列名映射
|
try:
|
||||||
elif market_type == 'HK':
|
# 根据市场类型获取数据
|
||||||
df = ak.stock_hk_daily(
|
if market_type == 'A':
|
||||||
symbol=stock_code,
|
df = ak.stock_zh_a_hist(
|
||||||
adjust="qfq"
|
symbol=stock_code,
|
||||||
)
|
start_date=start_date,
|
||||||
elif market_type == 'US':
|
end_date=end_date,
|
||||||
df = ak.stock_us_hist(
|
adjust="qfq"
|
||||||
symbol=stock_code,
|
)
|
||||||
start_date=start_date,
|
# A股数据列名映射
|
||||||
end_date=end_date,
|
elif market_type == 'HK':
|
||||||
adjust="qfq"
|
df = ak.stock_hk_daily(
|
||||||
)
|
symbol=stock_code,
|
||||||
# elif market_type == 'CRYPTO':
|
adjust="qfq"
|
||||||
# df = ak.crypto_js_spot(
|
)
|
||||||
# symbol=stock_code
|
elif market_type == 'US':
|
||||||
# )
|
df = ak.stock_us_hist(
|
||||||
else:
|
symbol=stock_code,
|
||||||
raise ValueError(f"不支持的市场类型: {market_type}")
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
# 重命名列名以匹配分析需求
|
adjust="qfq"
|
||||||
df = df.rename(columns={
|
)
|
||||||
"日期": "date",
|
# elif market_type == 'CRYPTO':
|
||||||
"开盘": "open",
|
# df = ak.crypto_js_spot(
|
||||||
"收盘": "close",
|
# symbol=stock_code
|
||||||
"最高": "high",
|
# )
|
||||||
"最低": "low",
|
else:
|
||||||
"成交量": "volume"
|
raise ValueError(f"不支持的市场类型: {market_type}")
|
||||||
})
|
|
||||||
|
# 重命名列名以匹配分析需求
|
||||||
# 确保日期格式正确
|
df = df.rename(columns={
|
||||||
df['date'] = pd.to_datetime(df['date'])
|
"日期": "date",
|
||||||
|
"开盘": "open",
|
||||||
# 数据类型转换
|
"收盘": "close",
|
||||||
numeric_columns = ['open', 'close', 'high', 'low', 'volume']
|
"最高": "high",
|
||||||
df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
|
"最低": "low",
|
||||||
|
"成交量": "volume"
|
||||||
# 删除空值
|
})
|
||||||
df = df.dropna()
|
|
||||||
|
# 确保日期格式正确
|
||||||
return df.sort_values('date')
|
df['date'] = pd.to_datetime(df['date'])
|
||||||
|
|
||||||
except Exception as e:
|
# 数据类型转换
|
||||||
raise Exception(f"获取股票数据失败: {str(e)}")
|
numeric_columns = ['open', 'close', 'high', 'low', 'volume']
|
||||||
|
df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
|
||||||
def calculate_ema(self, series, period):
|
|
||||||
"""计算指数移动平均线"""
|
# 删除空值
|
||||||
return series.ewm(span=period, adjust=False).mean()
|
df = df.dropna()
|
||||||
|
|
||||||
def calculate_rsi(self, series, period):
|
return df.sort_values('date')
|
||||||
"""计算RSI指标"""
|
|
||||||
delta = series.diff()
|
except Exception as e:
|
||||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
raise Exception(f"获取股票数据失败: {str(e)}")
|
||||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
|
||||||
rs = gain / loss
|
def calculate_ema(self, series, period):
|
||||||
return 100 - (100 / (1 + rs))
|
"""计算指数移动平均线"""
|
||||||
|
return series.ewm(span=period, adjust=False).mean()
|
||||||
def calculate_macd(self, series):
|
|
||||||
"""计算MACD指标"""
|
def calculate_rsi(self, series, period):
|
||||||
exp1 = series.ewm(span=12, adjust=False).mean()
|
"""计算RSI指标"""
|
||||||
exp2 = series.ewm(span=26, adjust=False).mean()
|
delta = series.diff()
|
||||||
macd = exp1 - exp2
|
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||||
signal = macd.ewm(span=9, adjust=False).mean()
|
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||||
hist = macd - signal
|
rs = gain / loss
|
||||||
return macd, signal, hist
|
return 100 - (100 / (1 + rs))
|
||||||
|
|
||||||
def calculate_bollinger_bands(self, series, period, std_dev):
|
def calculate_macd(self, series):
|
||||||
"""计算布林带"""
|
"""计算MACD指标"""
|
||||||
middle = series.rolling(window=period).mean()
|
exp1 = series.ewm(span=12, adjust=False).mean()
|
||||||
std = series.rolling(window=period).std()
|
exp2 = series.ewm(span=26, adjust=False).mean()
|
||||||
upper = middle + (std * std_dev)
|
macd = exp1 - exp2
|
||||||
lower = middle - (std * std_dev)
|
signal = macd.ewm(span=9, adjust=False).mean()
|
||||||
return upper, middle, lower
|
hist = macd - signal
|
||||||
|
return macd, signal, hist
|
||||||
def calculate_atr(self, df, period):
|
|
||||||
"""计算ATR指标"""
|
def calculate_bollinger_bands(self, series, period, std_dev):
|
||||||
high = df['high']
|
"""计算布林带"""
|
||||||
low = df['low']
|
middle = series.rolling(window=period).mean()
|
||||||
close = df['close'].shift(1)
|
std = series.rolling(window=period).std()
|
||||||
|
upper = middle + (std * std_dev)
|
||||||
tr1 = high - low
|
lower = middle - (std * std_dev)
|
||||||
tr2 = abs(high - close)
|
return upper, middle, lower
|
||||||
tr3 = abs(low - close)
|
|
||||||
|
def calculate_atr(self, df, period):
|
||||||
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
"""计算ATR指标"""
|
||||||
return tr.rolling(window=period).mean()
|
high = df['high']
|
||||||
|
low = df['low']
|
||||||
def calculate_indicators(self, df):
|
close = df['close'].shift(1)
|
||||||
"""计算技术指标"""
|
|
||||||
try:
|
tr1 = high - low
|
||||||
# 计算移动平均线
|
tr2 = abs(high - close)
|
||||||
df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
|
tr3 = abs(low - close)
|
||||||
df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
|
|
||||||
df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
|
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||||||
|
return tr.rolling(window=period).mean()
|
||||||
# 计算RSI
|
|
||||||
df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
|
def calculate_indicators(self, df):
|
||||||
|
"""计算技术指标"""
|
||||||
# 计算MACD
|
try:
|
||||||
df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
|
# 计算移动平均线
|
||||||
|
df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
|
||||||
# 计算布林带
|
df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
|
||||||
df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
|
df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
|
||||||
df['close'],
|
|
||||||
self.params['bollinger_period'],
|
# 计算RSI
|
||||||
self.params['bollinger_std']
|
df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
|
||||||
)
|
|
||||||
|
# 计算MACD
|
||||||
# 成交量分析
|
df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
|
||||||
df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
|
|
||||||
df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
|
# 计算布林带
|
||||||
|
df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
|
||||||
# 计算ATR和波动率
|
df['close'],
|
||||||
df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
|
self.params['bollinger_period'],
|
||||||
df['Volatility'] = df['ATR'] / df['close'] * 100
|
self.params['bollinger_std']
|
||||||
|
)
|
||||||
# 动量指标
|
|
||||||
df['ROC'] = df['close'].pct_change(periods=10) * 100
|
# 成交量分析
|
||||||
|
df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
|
||||||
return df
|
df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
|
||||||
|
|
||||||
except Exception as e:
|
# 计算ATR和波动率
|
||||||
print(f"计算技术指标时出错: {str(e)}")
|
df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
|
||||||
raise
|
df['Volatility'] = df['ATR'] / df['close'] * 100
|
||||||
|
|
||||||
def calculate_score(self, df):
|
# 动量指标
|
||||||
"""计算股票评分"""
|
df['ROC'] = df['close'].pct_change(periods=10) * 100
|
||||||
try:
|
|
||||||
score = 0
|
return df
|
||||||
latest = df.iloc[-1]
|
|
||||||
|
except Exception as e:
|
||||||
# 趋势得分 (30分)
|
print(f"计算技术指标时出错: {str(e)}")
|
||||||
if latest['MA5'] > latest['MA20']:
|
raise
|
||||||
score += 15
|
|
||||||
if latest['MA20'] > latest['MA60']:
|
def calculate_score(self, df):
|
||||||
score += 15
|
"""计算股票评分"""
|
||||||
|
try:
|
||||||
# RSI得分 (20分)
|
score = 0
|
||||||
if 30 <= latest['RSI'] <= 70:
|
latest = df.iloc[-1]
|
||||||
score += 20
|
|
||||||
elif latest['RSI'] < 30: # 超卖
|
# 趋势得分 (30分)
|
||||||
score += 15
|
if latest['MA5'] > latest['MA20']:
|
||||||
|
score += 15
|
||||||
# MACD得分 (20分)
|
if latest['MA20'] > latest['MA60']:
|
||||||
if latest['MACD'] > latest['Signal']:
|
score += 15
|
||||||
score += 20
|
|
||||||
|
# RSI得分 (20分)
|
||||||
# 成交量得分 (30分)
|
if 30 <= latest['RSI'] <= 70:
|
||||||
if latest['Volume_Ratio'] > 1.5:
|
score += 20
|
||||||
score += 30
|
elif latest['RSI'] < 30: # 超卖
|
||||||
elif latest['Volume_Ratio'] > 1:
|
score += 15
|
||||||
score += 15
|
|
||||||
|
# MACD得分 (20分)
|
||||||
return score
|
if latest['MACD'] > latest['Signal']:
|
||||||
|
score += 20
|
||||||
except Exception as e:
|
|
||||||
print(f"计算评分时出错: {str(e)}")
|
# 成交量得分 (30分)
|
||||||
raise
|
if latest['Volume_Ratio'] > 1.5:
|
||||||
|
score += 30
|
||||||
def get_ai_analysis(self, df, stock_code):
|
elif latest['Volume_Ratio'] > 1:
|
||||||
"""使用 Gemini 进行 AI 分析"""
|
score += 15
|
||||||
try:
|
|
||||||
recent_data = df.tail(14).to_dict('records')
|
return score
|
||||||
|
|
||||||
technical_summary = {
|
except Exception as e:
|
||||||
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
|
print(f"计算评分时出错: {str(e)}")
|
||||||
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
|
raise
|
||||||
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
|
|
||||||
'rsi_level': df.iloc[-1]['RSI']
|
def get_ai_analysis(self, df, stock_code, stream=False):
|
||||||
}
|
"""使用 OpenAI 进行 AI 分析"""
|
||||||
|
try:
|
||||||
prompt = f"""
|
logger.info(f"开始AI分析股票 {stock_code}, 流式模式: {stream}")
|
||||||
分析股票 {stock_code}:
|
recent_data = df.tail(14).to_dict('records')
|
||||||
|
|
||||||
技术指标概要:
|
technical_summary = {
|
||||||
{technical_summary}
|
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
|
||||||
|
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
|
||||||
近14日交易数据:
|
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
|
||||||
{recent_data}
|
'rsi_level': df.iloc[-1]['RSI']
|
||||||
|
}
|
||||||
请提供:
|
|
||||||
1. 趋势分析(包含支撑位和压力位)
|
prompt = f"""
|
||||||
2. 成交量分析及其含义
|
分析股票 {stock_code}:
|
||||||
3. 风险评估(包含波动率分析)
|
|
||||||
4. 短期和中期目标价位
|
技术指标概要:
|
||||||
5. 关键技术位分析
|
{technical_summary}
|
||||||
6. 具体交易建议(包含止损位)
|
|
||||||
|
近14日交易数据:
|
||||||
请基于技术指标和市场动态进行分析,给出具体数据支持。
|
{recent_data}
|
||||||
"""
|
|
||||||
|
请提供:
|
||||||
headers = {
|
1. 趋势分析(包含支撑位和压力位)
|
||||||
"Authorization": f"Bearer {self.API_KEY}",
|
2. 成交量分析及其含义
|
||||||
"Content-Type": "application/json"
|
3. 风险评估(包含波动率分析)
|
||||||
}
|
4. 短期和中期目标价位
|
||||||
|
5. 关键技术位分析
|
||||||
data = {
|
6. 具体交易建议(包含止损位)
|
||||||
"model": os.getenv('API_MODEL'),
|
|
||||||
"messages": [{"role": "user", "content": prompt}]
|
请基于技术指标和市场动态进行分析,给出具体数据支持。
|
||||||
}
|
"""
|
||||||
|
|
||||||
if self.API_URL.endswith('/'):
|
logger.debug(f"生成的AI分析提示词: {prompt[:100]}...")
|
||||||
api_url = f"{self.API_URL}chat/completions"
|
|
||||||
else:
|
# 检查API配置
|
||||||
api_url = f"{self.API_URL}/v1/chat/completions"
|
if not self.API_URL:
|
||||||
|
error_msg = "API URL未配置,无法进行AI分析"
|
||||||
response = requests.post(
|
logger.error(error_msg)
|
||||||
api_url,
|
return error_msg if not stream else (yield json.dumps({"error": error_msg}))
|
||||||
headers=headers,
|
|
||||||
json=data,
|
if not self.API_KEY:
|
||||||
timeout=self.API_TIMEOUT
|
error_msg = "API Key未配置,无法进行AI分析"
|
||||||
)
|
logger.error(error_msg)
|
||||||
|
return error_msg if not stream else (yield json.dumps({"error": error_msg}))
|
||||||
print(api_url)
|
|
||||||
print(data)
|
# 标准化API URL
|
||||||
print(response.json())
|
if self.API_URL.endswith('/'):
|
||||||
|
api_url = f"{self.API_URL}chat/completions"
|
||||||
if response.status_code == 200:
|
else:
|
||||||
return response.json()['choices'][0]['message']['content']
|
api_url = f"{self.API_URL}/v1/chat/completions"
|
||||||
else:
|
# 标准化API URL
|
||||||
return "AI 分析暂时无法使用"
|
# api_url = self.API_URL
|
||||||
|
# if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
|
||||||
except Exception as e:
|
# if api_url.endswith('/v1'):
|
||||||
print(f"AI 分析发生错误: {str(e)}")
|
# api_url = f"{api_url}/chat/completions"
|
||||||
return "AI 分析过程中发生错误"
|
# elif api_url.endswith('/'):
|
||||||
|
# api_url = f"{api_url}v1/chat/completions"
|
||||||
def get_recommendation(self, score):
|
# else:
|
||||||
"""根据得分给出建议"""
|
# api_url = f"{api_url}/v1/chat/completions"
|
||||||
if score >= 80:
|
|
||||||
return '强烈推荐买入'
|
logger.debug(f"标准化后的API URL: {api_url}")
|
||||||
elif score >= 60:
|
|
||||||
return '建议买入'
|
# 构建请求头和请求体
|
||||||
elif score >= 40:
|
headers = {
|
||||||
return '观望'
|
"Authorization": f"Bearer {self.API_KEY}",
|
||||||
elif score >= 20:
|
"Content-Type": "application/json"
|
||||||
return '建议卖出'
|
}
|
||||||
else:
|
|
||||||
return '强烈建议卖出'
|
payload = {
|
||||||
|
"model": self.API_MODEL,
|
||||||
def analyze_stock(self, stock_code, market_type='A'):
|
"messages": [{"role": "user", "content": prompt}]
|
||||||
"""分析单个股票"""
|
}
|
||||||
try:
|
|
||||||
# 获取股票数据
|
# 流式处理设置
|
||||||
df = self.get_stock_data(stock_code, market_type)
|
if stream:
|
||||||
|
logger.debug(f"配置流式参数,使用API URL: {api_url}")
|
||||||
# 计算技术指标
|
payload["stream"] = True # 明确设置stream参数为True
|
||||||
df = self.calculate_indicators(df)
|
|
||||||
|
try:
|
||||||
# 评分系统
|
logger.debug(f"发起流式API请求: {api_url}")
|
||||||
score = self.calculate_score(df)
|
logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
|
||||||
|
|
||||||
# 获取最新数据
|
response = requests.post(
|
||||||
latest = df.iloc[-1]
|
api_url,
|
||||||
prev = df.iloc[-2]
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
# 生成报告(保持原有格式)
|
timeout=60, # 增加超时时间
|
||||||
report = {
|
stream=True
|
||||||
'stock_code': stock_code,
|
)
|
||||||
'analysis_date': datetime.now().strftime('%Y-%m-%d'),
|
|
||||||
'score': score,
|
logger.debug(f"API流式响应状态码: {response.status_code}")
|
||||||
'price': latest['close'],
|
|
||||||
'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
|
if response.status_code == 200:
|
||||||
'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
|
logger.info(f"成功获取API流式响应,开始处理")
|
||||||
'rsi': latest['RSI'],
|
yield from self._process_ai_stream(response, stock_code)
|
||||||
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
|
else:
|
||||||
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
|
try:
|
||||||
'recommendation': self.get_recommendation(score),
|
error_response = response.json()
|
||||||
'ai_analysis': self.get_ai_analysis(df, stock_code)
|
error_text = json.dumps(error_response, indent=2)
|
||||||
}
|
except:
|
||||||
|
error_text = response.text[:500] if response.text else "无响应内容"
|
||||||
return report
|
|
||||||
|
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
|
||||||
except Exception as e:
|
logger.error(error_msg)
|
||||||
print(f"分析股票时出错: {str(e)}")
|
yield json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||||
raise
|
|
||||||
|
except Exception as e:
|
||||||
def scan_market(self, stock_list, min_score=60, market_type='A'):
|
error_msg = f"流式API请求异常: {str(e)}"
|
||||||
"""扫描市场,寻找符合条件的股票"""
|
logger.error(error_msg)
|
||||||
recommendations = []
|
logger.exception(e)
|
||||||
|
yield json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||||
for stock_code in stock_list:
|
else:
|
||||||
try:
|
# 非流式处理
|
||||||
report = self.analyze_stock(stock_code, market_type)
|
logger.debug(f"发起非流式API请求: {api_url}")
|
||||||
if report['score'] >= min_score:
|
|
||||||
recommendations.append(report)
|
try:
|
||||||
except Exception as e:
|
response = requests.post(
|
||||||
print(f"分析股票 {stock_code} 时出错: {str(e)}")
|
api_url,
|
||||||
continue
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
# 按得分排序
|
timeout=60
|
||||||
recommendations.sort(key=lambda x: x['score'], reverse=True)
|
)
|
||||||
return recommendations
|
|
||||||
|
logger.debug(f"API非流式响应状态码: {response.status_code}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
api_response = response.json()
|
||||||
|
content = api_response['choices'][0]['message']['content']
|
||||||
|
logger.info(f"成功获取AI分析结果,长度: {len(content)}")
|
||||||
|
logger.debug(f"AI分析结果前100字符: {content[:100]}...")
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
error_response = response.json()
|
||||||
|
error_text = json.dumps(error_response, indent=2)
|
||||||
|
except:
|
||||||
|
error_text = response.text[:500] if response.text else "无响应内容"
|
||||||
|
|
||||||
|
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"非流式API请求异常: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.exception(e)
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"AI 分析过程中发生错误: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.exception(e)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
logger.debug("在流式模式下返回异常信息")
|
||||||
|
error_json = json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||||
|
stream_logger.info(f"流式异常输出: {error_json}")
|
||||||
|
yield error_json
|
||||||
|
else:
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
def _process_ai_stream(self, response, stock_code) -> Generator[str, None, None]:
|
||||||
|
"""处理AI流式响应"""
|
||||||
|
logger.info(f"开始处理股票 {stock_code} 的AI流式响应")
|
||||||
|
buffer = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
stream_logger.info(f"原始流式行: {line}")
|
||||||
|
|
||||||
|
# 跳过保持连接的空行
|
||||||
|
if line.strip() == '':
|
||||||
|
logger.debug("跳过空行")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 数据行通常以"data: "开头
|
||||||
|
if line.startswith('data: '):
|
||||||
|
data_content = line[6:] # 移除 "data: " 前缀
|
||||||
|
stream_logger.info(f"数据内容: {data_content}")
|
||||||
|
|
||||||
|
# 检查是否为流的结束
|
||||||
|
if data_content.strip() == '[DONE]':
|
||||||
|
logger.debug("收到流结束标记 [DONE]")
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
json_data = json.loads(data_content)
|
||||||
|
logger.debug(f"解析的JSON数据: {json.dumps(json_data)[:100]}...")
|
||||||
|
|
||||||
|
if 'choices' in json_data:
|
||||||
|
delta = json_data['choices'][0].get('delta', {})
|
||||||
|
content = delta.get('content', '')
|
||||||
|
|
||||||
|
if content:
|
||||||
|
chunk_count += 1
|
||||||
|
buffer += content
|
||||||
|
logger.debug(f"收到内容片段 #{chunk_count}: {content}")
|
||||||
|
stream_logger.info(f"发送内容片段: {content}")
|
||||||
|
|
||||||
|
# 创建包含AI分析片段的JSON
|
||||||
|
chunk_json = json.dumps({
|
||||||
|
"stock_code": stock_code,
|
||||||
|
"ai_analysis_chunk": content
|
||||||
|
})
|
||||||
|
stream_logger.info(f"流式输出JSON: {chunk_json}")
|
||||||
|
yield chunk_json
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON解析错误: {str(e)}, 行内容: {data_content}")
|
||||||
|
# 忽略无法解析的JSON
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(f"收到非'data:'开头的行: {line}")
|
||||||
|
|
||||||
|
logger.info(f"AI流式处理完成,共收到 {chunk_count} 个内容片段,总长度: {len(buffer)}")
|
||||||
|
|
||||||
|
# 如果buffer不为空,最后一次发送完整内容
|
||||||
|
if buffer and not buffer.endswith('\n'):
|
||||||
|
logger.debug("发送换行符")
|
||||||
|
yield json.dumps({"stock_code": stock_code, "ai_analysis_chunk": "\n"})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"处理AI流式响应时出错: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.exception(e)
|
||||||
|
yield json.dumps({"stock_code": stock_code, "error": error_msg})
|
||||||
|
|
||||||
|
|
||||||
|
def get_recommendation(self, score):
|
||||||
|
"""根据得分给出建议"""
|
||||||
|
logger.debug(f"根据评分 {score} 生成投资建议")
|
||||||
|
if score >= 80:
|
||||||
|
return '强烈推荐买入'
|
||||||
|
elif score >= 60:
|
||||||
|
return '建议买入'
|
||||||
|
elif score >= 40:
|
||||||
|
return '观望'
|
||||||
|
elif score >= 20:
|
||||||
|
return '建议卖出'
|
||||||
|
else:
|
||||||
|
return '强烈建议卖出'
|
||||||
|
|
||||||
|
def analyze_stock(self, stock_code, market_type='A', stream=False):
|
||||||
|
"""分析单个股票"""
|
||||||
|
try:
|
||||||
|
logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}, 流式模式: {stream}")
|
||||||
|
|
||||||
|
# 获取股票数据
|
||||||
|
logger.debug(f"获取股票 {stock_code} 数据")
|
||||||
|
df = self.get_stock_data(stock_code, market_type)
|
||||||
|
|
||||||
|
# 计算技术指标
|
||||||
|
logger.debug(f"计算股票 {stock_code} 技术指标")
|
||||||
|
df = self.calculate_indicators(df)
|
||||||
|
|
||||||
|
# 评分系统
|
||||||
|
logger.debug(f"计算股票 {stock_code} 评分")
|
||||||
|
score = self.calculate_score(df)
|
||||||
|
logger.info(f"股票 {stock_code} 评分结果: {score}")
|
||||||
|
|
||||||
|
# 获取最新数据
|
||||||
|
latest = df.iloc[-1]
|
||||||
|
prev = df.iloc[-2]
|
||||||
|
|
||||||
|
# 生成报告(保持原有格式)
|
||||||
|
report = {
|
||||||
|
'stock_code': stock_code,
|
||||||
|
'analysis_date': datetime.now().strftime('%Y-%m-%d'),
|
||||||
|
'score': score,
|
||||||
|
'price': latest['close'],
|
||||||
|
'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
|
||||||
|
'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
|
||||||
|
'rsi': latest['RSI'],
|
||||||
|
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
|
||||||
|
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
|
||||||
|
'recommendation': self.get_recommendation(score)
|
||||||
|
}
|
||||||
|
logger.debug(f"生成股票 {stock_code} 基础报告: {json.dumps(report)[:100]}...")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
logger.info(f"以流式模式返回股票 {stock_code} 分析结果")
|
||||||
|
# 先返回基本报告结构
|
||||||
|
base_report = dict(report)
|
||||||
|
base_report['ai_analysis'] = ''
|
||||||
|
base_report_json = json.dumps(base_report)
|
||||||
|
logger.debug(f"基础报告JSON: {base_report_json[:100]}...")
|
||||||
|
stream_logger.info(f"发送基础报告: {base_report_json}")
|
||||||
|
yield base_report_json
|
||||||
|
|
||||||
|
# 然后流式返回AI分析部分
|
||||||
|
logger.debug(f"开始获取股票 {stock_code} 的流式AI分析")
|
||||||
|
ai_chunks_count = 0
|
||||||
|
for ai_chunk in self.get_ai_analysis(df, stock_code, stream=True):
|
||||||
|
ai_chunks_count += 1
|
||||||
|
stream_logger.info(f"股票 {stock_code} 流式块 #{ai_chunks_count}: {ai_chunk}")
|
||||||
|
yield ai_chunk
|
||||||
|
logger.info(f"股票 {stock_code} 流式AI分析完成,共发送 {ai_chunks_count} 个块")
|
||||||
|
else:
|
||||||
|
logger.info(f"以非流式模式返回股票 {stock_code} 分析结果")
|
||||||
|
logger.debug(f"开始获取股票 {stock_code} 的AI分析")
|
||||||
|
report['ai_analysis'] = self.get_ai_analysis(df, stock_code)
|
||||||
|
logger.debug(f"AI分析结果长度: {len(report['ai_analysis'])}")
|
||||||
|
return report
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.exception(e)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
|
||||||
|
stream_logger.info(f"流式错误输出: {error_json}")
|
||||||
|
yield error_json
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def scan_market(self, stock_list, min_score=60, market_type='A', stream=False):
|
||||||
|
"""扫描市场,寻找符合条件的股票"""
|
||||||
|
logger.info(f"开始扫描市场,股票数量: {len(stock_list)}, 最低分数: {min_score}, 市场: {market_type}, 流式模式: {stream}")
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
recommendations = []
|
||||||
|
|
||||||
|
for stock_code in stock_list:
|
||||||
|
try:
|
||||||
|
logger.debug(f"分析股票: {stock_code}")
|
||||||
|
report = self.analyze_stock(stock_code, market_type)
|
||||||
|
if report['score'] >= min_score:
|
||||||
|
logger.info(f"股票 {stock_code} 评分 {report['score']} >= {min_score},添加到推荐列表")
|
||||||
|
recommendations.append(report)
|
||||||
|
else:
|
||||||
|
logger.debug(f"股票 {stock_code} 评分 {report['score']} < {min_score},不添加到推荐列表")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"分析股票 {stock_code} 时出错: {str(e)}")
|
||||||
|
logger.exception(e)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 按得分排序
|
||||||
|
recommendations.sort(key=lambda x: x['score'], reverse=True)
|
||||||
|
logger.info(f"扫描完成,找到 {len(recommendations)} 个推荐股票")
|
||||||
|
return recommendations
|
||||||
|
else:
|
||||||
|
# 流式处理每个股票
|
||||||
|
logger.info(f"开始流式扫描 {len(stock_list)} 只股票")
|
||||||
|
stock_count = 0
|
||||||
|
for stock_code in stock_list:
|
||||||
|
stock_count += 1
|
||||||
|
logger.debug(f"流式分析股票 {stock_code} ({stock_count}/{len(stock_list)})")
|
||||||
|
try:
|
||||||
|
# 分析单只股票并获取流式结果
|
||||||
|
chunk_count = 0
|
||||||
|
for chunk in self.analyze_stock(stock_code, market_type, stream=True):
|
||||||
|
chunk_count += 1
|
||||||
|
stream_logger.info(f"股票 {stock_code} 流式块 #{chunk_count}: {chunk}")
|
||||||
|
yield chunk
|
||||||
|
logger.debug(f"股票 {stock_code} 流式分析完成,共 {chunk_count} 个块")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.exception(e)
|
||||||
|
error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
|
||||||
|
stream_logger.info(f"流式错误输出: {error_json}")
|
||||||
|
yield error_json
|
||||||
|
logger.info(f"流式扫描完成,处理了 {stock_count} 只股票")
|
||||||
|
|||||||
1323
templates/index.html
1323
templates/index.html
File diff suppressed because it is too large
Load Diff
137
tests/test_stream.py
Normal file
137
tests/test_stream.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from logger import get_logger, get_stream_logger
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 获取日志器
|
||||||
|
logger = get_logger()
|
||||||
|
stream_logger = get_stream_logger()
|
||||||
|
|
||||||
|
def test_api_stream():
|
||||||
|
"""
|
||||||
|
测试API流式响应功能
|
||||||
|
"""
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 获取API配置
|
||||||
|
api_url = os.getenv('API_URL')
|
||||||
|
api_key = os.getenv('API_KEY')
|
||||||
|
api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||||
|
|
||||||
|
logger.info(f"开始测试API流式响应,API URL: {api_url}, MODEL: {api_model}")
|
||||||
|
|
||||||
|
# 检查API配置
|
||||||
|
if not api_url:
|
||||||
|
logger.error("API URL未配置,无法进行测试")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
logger.error("API Key未配置,无法进行测试")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 标准化API URL
|
||||||
|
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
|
||||||
|
if api_url.endswith('/v1'):
|
||||||
|
api_url = f"{api_url}/chat/completions"
|
||||||
|
elif api_url.endswith('/'):
|
||||||
|
api_url = f"{api_url}v1/chat/completions"
|
||||||
|
else:
|
||||||
|
api_url = f"{api_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
logger.debug(f"标准化后的API URL: {api_url}")
|
||||||
|
|
||||||
|
# 构建简单的测试提示
|
||||||
|
prompt = "这是一个API流式响应测试。请给出一个简短的股票分析样例。"
|
||||||
|
|
||||||
|
# 构建请求头和请求体
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": api_model,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"stream": True # 明确设置stream参数为True
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"发起流式API请求: {api_url}")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
api_url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=60,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"API流式响应状态码: {response.status_code}")
|
||||||
|
logger.debug(f"响应头: {response.headers}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info("成功获取API流式响应,开始处理")
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line_str = line.decode('utf-8')
|
||||||
|
logger.info(f"原始流式行: {line_str}")
|
||||||
|
|
||||||
|
# 跳过保持连接的空行
|
||||||
|
if line_str.strip() == '':
|
||||||
|
logger.debug("跳过空行")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 数据行通常以"data: "开头
|
||||||
|
if line_str.startswith('data: '):
|
||||||
|
data_content = line_str[6:].strip() # 移除 "data: " 前缀并去除前后空格
|
||||||
|
logger.info(f"数据内容: {data_content}")
|
||||||
|
|
||||||
|
# 检查是否为流的结束
|
||||||
|
if data_content == '[DONE]':
|
||||||
|
logger.info("收到流结束标记 [DONE]")
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析JSON数据
|
||||||
|
json_data = json.loads(data_content)
|
||||||
|
logger.debug(f"JSON结构: {json.dumps(json_data, indent=2)}")
|
||||||
|
|
||||||
|
if 'choices' in json_data:
|
||||||
|
delta = json_data['choices'][0].get('delta', {})
|
||||||
|
content = delta.get('content', '')
|
||||||
|
|
||||||
|
if content:
|
||||||
|
chunk_count += 1
|
||||||
|
buffer += content
|
||||||
|
logger.info(f"内容片段 #{chunk_count}: {content}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON解析错误: {e}, 内容: {data_content}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"收到非'data:'开头的行: {line_str}")
|
||||||
|
|
||||||
|
logger.info(f"流式处理完成,共收到 {chunk_count} 个内容片段")
|
||||||
|
logger.info(f"完整内容:\n{buffer}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
error_response = response.json()
|
||||||
|
error_text = json.dumps(error_response, indent=2)
|
||||||
|
except:
|
||||||
|
error_text = response.text[:500] if response.text else "无响应内容"
|
||||||
|
|
||||||
|
logger.error(f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"测试过程中发生异常: {str(e)}")
|
||||||
|
logger.exception(e)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_api_stream()
|
||||||
158
web_server.py
158
web_server.py
@@ -1,8 +1,15 @@
|
|||||||
from flask import Flask, render_template, request, jsonify
|
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
|
||||||
from stock_analyzer import StockAnalyzer
|
from stock_analyzer import StockAnalyzer
|
||||||
from us_stock_service import USStockService
|
from us_stock_service import USStockService
|
||||||
import threading
|
import threading
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
|
import requests
|
||||||
|
from logger import get_logger, get_stream_logger
|
||||||
|
|
||||||
|
# 获取日志器
|
||||||
|
logger = get_logger()
|
||||||
|
stream_logger = get_stream_logger()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
analyzer = StockAnalyzer()
|
analyzer = StockAnalyzer()
|
||||||
@@ -11,27 +18,86 @@ us_stock_service = USStockService()
|
|||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
|
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
|
||||||
return render_template('index.html', announcement=announcement)
|
# 获取默认API配置信息
|
||||||
|
default_api_url = os.getenv('API_URL', '')
|
||||||
|
default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||||
|
# 不传递API_KEY到前端,出于安全考虑
|
||||||
|
return render_template('index.html',
|
||||||
|
announcement=announcement,
|
||||||
|
default_api_url=default_api_url,
|
||||||
|
default_api_model=default_api_model)
|
||||||
|
|
||||||
@app.route('/analyze', methods=['POST'])
|
@app.route('/analyze', methods=['POST'])
|
||||||
def analyze():
|
def analyze():
|
||||||
try:
|
try:
|
||||||
|
logger.info("开始处理分析请求")
|
||||||
data = request.json
|
data = request.json
|
||||||
stock_codes = data.get('stock_codes', [])
|
stock_codes = data.get('stock_codes', [])
|
||||||
market_type = data.get('market_type', 'A')
|
market_type = data.get('market_type', 'A')
|
||||||
|
|
||||||
|
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
|
||||||
|
|
||||||
|
# 获取自定义API配置
|
||||||
|
custom_api_url = data.get('api_url')
|
||||||
|
custom_api_key = data.get('api_key')
|
||||||
|
custom_api_model = data.get('api_model')
|
||||||
|
custom_api_timeout = data.get('api_timeout', 60)
|
||||||
|
|
||||||
|
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}")
|
||||||
|
|
||||||
|
# 创建新的分析器实例,使用自定义配置
|
||||||
|
custom_analyzer = StockAnalyzer(
|
||||||
|
custom_api_url=custom_api_url,
|
||||||
|
custom_api_key=custom_api_key,
|
||||||
|
custom_api_model=custom_api_model,
|
||||||
|
custom_api_timeout= custom_api_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
if not stock_codes:
|
if not stock_codes:
|
||||||
|
logger.warning("未提供股票代码")
|
||||||
return jsonify({'error': '请输入代码'}), 400
|
return jsonify({'error': '请输入代码'}), 400
|
||||||
|
|
||||||
|
# 使用流式响应
|
||||||
|
def generate():
|
||||||
|
if len(stock_codes) == 1:
|
||||||
|
# 单个股票分析流式处理
|
||||||
|
stock_code = stock_codes[0].strip()
|
||||||
|
logger.info(f"开始单股流式分析: {stock_code}")
|
||||||
|
|
||||||
|
stream_logger.info(f"初始化单股分析流: {stock_code}")
|
||||||
|
init_message = f'{{"stream_type": "single", "stock_code": "{stock_code}"}}\n'
|
||||||
|
stream_logger.info(f"发送初始化消息: {init_message}")
|
||||||
|
yield init_message
|
||||||
|
|
||||||
|
for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
|
||||||
|
stream_logger.info(f"流式输出块: {chunk}")
|
||||||
|
yield chunk + '\n'
|
||||||
|
else:
|
||||||
|
# 批量分析流式处理
|
||||||
|
logger.info(f"开始批量流式分析: {stock_codes}")
|
||||||
|
|
||||||
|
stream_logger.info(f"初始化批量分析流: {stock_codes}")
|
||||||
|
init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes}}}\n'
|
||||||
|
stream_logger.info(f"发送初始化消息: {init_message}")
|
||||||
|
yield init_message
|
||||||
|
|
||||||
|
for chunk in custom_analyzer.scan_market(
|
||||||
|
[code.strip() for code in stock_codes],
|
||||||
|
min_score=0,
|
||||||
|
market_type=market_type,
|
||||||
|
stream=True
|
||||||
|
):
|
||||||
|
stream_logger.info(f"流式输出块: {chunk}")
|
||||||
|
yield chunk + '\n'
|
||||||
|
|
||||||
|
logger.info("成功创建流式响应生成器")
|
||||||
|
return Response(stream_with_context(generate()), mimetype='application/json')
|
||||||
|
|
||||||
results = []
|
|
||||||
for stock_code in stock_codes:
|
|
||||||
result = analyzer.analyze_stock(stock_code.strip(), market_type)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return jsonify({'results': results})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"分析股票时出错: {str(e)}")
|
error_msg = f"分析股票时出错: {str(e)}"
|
||||||
return jsonify({'error': str(e)}), 500
|
logger.error(error_msg)
|
||||||
|
logger.exception(e)
|
||||||
|
return jsonify({'error': error_msg}), 500
|
||||||
|
|
||||||
@app.route('/search_us_stocks', methods=['GET'])
|
@app.route('/search_us_stocks', methods=['GET'])
|
||||||
def search_us_stocks():
|
def search_us_stocks():
|
||||||
@@ -47,8 +113,72 @@ def search_us_stocks():
|
|||||||
print(f"搜索美股代码时出错: {str(e)}")
|
print(f"搜索美股代码时出错: {str(e)}")
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/test_api_connection', methods=['POST'])
|
||||||
|
def test_api_connection():
|
||||||
|
"""测试API连接"""
|
||||||
|
try:
|
||||||
|
logger.info("开始测试API连接")
|
||||||
|
data = request.json
|
||||||
|
api_url = data.get('api_url')
|
||||||
|
api_key = data.get('api_key')
|
||||||
|
api_model = data.get('api_model')
|
||||||
|
|
||||||
|
logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}")
|
||||||
|
|
||||||
|
if not api_url:
|
||||||
|
logger.warning("未提供API URL")
|
||||||
|
return jsonify({'error': '请提供API URL'}), 400
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
logger.warning("未提供API Key")
|
||||||
|
return jsonify({'error': '请提供API Key'}), 400
|
||||||
|
|
||||||
|
# 构建API URL
|
||||||
|
test_url = api_url
|
||||||
|
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
|
||||||
|
if api_url.endswith('/v1'):
|
||||||
|
test_url = f"{api_url}/chat/completions"
|
||||||
|
elif api_url.endswith('/'):
|
||||||
|
test_url = f"{api_url}chat/completions"
|
||||||
|
else:
|
||||||
|
test_url = f"{api_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
logger.debug(f"完整API测试URL: {test_url}")
|
||||||
|
|
||||||
|
# 发送测试请求
|
||||||
|
response = requests.post(
|
||||||
|
test_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": api_model or "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
|
||||||
|
],
|
||||||
|
"max_tokens": 20
|
||||||
|
},
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info(f"API连接测试成功: {response.status_code}")
|
||||||
|
return jsonify({'success': True, 'message': '连接成功'})
|
||||||
|
else:
|
||||||
|
error_message = response.json().get('error', {}).get('message', '未知错误')
|
||||||
|
logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
|
||||||
|
return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"API连接请求错误: {str(e)}")
|
||||||
|
return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"测试API连接时出错: {str(e)}")
|
||||||
|
logger.exception(e)
|
||||||
|
return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=8888, debug=True)
|
logger.info("股票分析系统启动")
|
||||||
|
app.run(host='0.0.0.0', port=8888, debug=True)
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user