diff --git a/Dockerfile b/Dockerfile
index ed99415..1c331fb 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-# 使用 Python 3.9 作为基础镜像
+# 使用 Python 3.10 作为基础镜像
FROM python:3.10-slim
# 设置工作目录
@@ -15,7 +15,6 @@ COPY . /app/
# 安装 Python 依赖
RUN pip install --no-cache-dir -r requirements.txt
-RUN pip install akshare --upgrade -i https://pypi.org/simple
# 设置环境变量
ENV PYTHONPATH=/app
diff --git a/logger.py b/logger.py
new file mode 100644
index 0000000..8313c52
--- /dev/null
+++ b/logger.py
@@ -0,0 +1,58 @@
+from loguru import logger
+import sys
+import os
+from datetime import datetime
+
+# 获取当前时间作为日志文件名的一部分
+current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
+
+# 创建日志目录
+log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
+os.makedirs(log_dir, exist_ok=True)
+
+# 配置日志
+logger.remove() # 移除默认的处理器
+
+# 添加标准输出处理器(控制台)
+logger.add(
+ sys.stdout,
+ format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
+ level="DEBUG"
+)
+
+# 添加文件处理器(debug级别)
+logger.add(
+ os.path.join(log_dir, f"debug_{current_time}.log"),
+ format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
+ level="DEBUG",
+ rotation="100 MB",
+ retention="1 week"
+)
+
+# 添加文件处理器(error级别)
+logger.add(
+ os.path.join(log_dir, f"error_{current_time}.log"),
+ format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
+ level="ERROR",
+ rotation="100 MB",
+ retention="1 month"
+)
+
+# 添加流处理器(用于记录流式输出)
+logger.add(
+ os.path.join(log_dir, f"stream_{current_time}.log"),
+ format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {message}",
+ filter=lambda record: "STREAM" in record["extra"],
+ level="INFO"
+)
+
+# 创建专用于流式输出的日志器
+stream_logger = logger.bind(STREAM=True)
+
+def get_logger():
+ """获取通用日志器"""
+ return logger
+
+def get_stream_logger():
+ """获取流式输出专用日志器"""
+ return stream_logger
diff --git a/requirements.txt b/requirements.txt
index 55a489d..7799faf 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,12 @@
+--index-url https://pypi.tuna.tsinghua.edu.cn/simple
+
# 基础科学计算和数据处理库
numpy==2.1.2
pandas==2.2.2
scipy==1.15.1
# 数据获取和分析库
-akshare
+akshare==1.16.22
tqdm==4.67.1
diff --git a/stock_analyzer.py b/stock_analyzer.py
index 066185b..ea35826 100644
--- a/stock_analyzer.py
+++ b/stock_analyzer.py
@@ -1,340 +1,583 @@
-import pandas as pd
-import numpy as np
-from datetime import datetime, timedelta
-import os
-import requests
-from typing import Dict, List, Optional, Tuple
-from dotenv import load_dotenv
-
-class StockAnalyzer:
- def __init__(self, initial_cash=1000000):
-
- # 加载环境变量
- load_dotenv()
-
- # 设置 API
- self.API_URL = os.getenv('API_URL')
- self.API_KEY = os.getenv('API_KEY')
- self.API_TIMEOUT = int(os.getenv('API_TIMEOUT', '60'))
-
- # 配置参数
- self.params = {
- 'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
- 'rsi_period': 14,
- 'bollinger_period': 20,
- 'bollinger_std': 2,
- 'volume_ma_period': 20,
- 'atr_period': 14
- }
-
-
- def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
- """获取股票数据"""
- import akshare as ak
-
- if start_date is None:
- start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
- if end_date is None:
- end_date = datetime.now().strftime('%Y%m%d')
-
- try:
- # 根据市场类型获取数据
- if market_type == 'A':
- df = ak.stock_zh_a_hist(
- symbol=stock_code,
- start_date=start_date,
- end_date=end_date,
- adjust="qfq"
- )
- # A股数据列名映射
- elif market_type == 'HK':
- df = ak.stock_hk_daily(
- symbol=stock_code,
- adjust="qfq"
- )
- elif market_type == 'US':
- df = ak.stock_us_hist(
- symbol=stock_code,
- start_date=start_date,
- end_date=end_date,
- adjust="qfq"
- )
- # elif market_type == 'CRYPTO':
- # df = ak.crypto_js_spot(
- # symbol=stock_code
- # )
- else:
- raise ValueError(f"不支持的市场类型: {market_type}")
-
- # 重命名列名以匹配分析需求
- df = df.rename(columns={
- "日期": "date",
- "开盘": "open",
- "收盘": "close",
- "最高": "high",
- "最低": "low",
- "成交量": "volume"
- })
-
- # 确保日期格式正确
- df['date'] = pd.to_datetime(df['date'])
-
- # 数据类型转换
- numeric_columns = ['open', 'close', 'high', 'low', 'volume']
- df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
-
- # 删除空值
- df = df.dropna()
-
- return df.sort_values('date')
-
- except Exception as e:
- raise Exception(f"获取股票数据失败: {str(e)}")
-
- def calculate_ema(self, series, period):
- """计算指数移动平均线"""
- return series.ewm(span=period, adjust=False).mean()
-
- def calculate_rsi(self, series, period):
- """计算RSI指标"""
- delta = series.diff()
- gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
- loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
- rs = gain / loss
- return 100 - (100 / (1 + rs))
-
- def calculate_macd(self, series):
- """计算MACD指标"""
- exp1 = series.ewm(span=12, adjust=False).mean()
- exp2 = series.ewm(span=26, adjust=False).mean()
- macd = exp1 - exp2
- signal = macd.ewm(span=9, adjust=False).mean()
- hist = macd - signal
- return macd, signal, hist
-
- def calculate_bollinger_bands(self, series, period, std_dev):
- """计算布林带"""
- middle = series.rolling(window=period).mean()
- std = series.rolling(window=period).std()
- upper = middle + (std * std_dev)
- lower = middle - (std * std_dev)
- return upper, middle, lower
-
- def calculate_atr(self, df, period):
- """计算ATR指标"""
- high = df['high']
- low = df['low']
- close = df['close'].shift(1)
-
- tr1 = high - low
- tr2 = abs(high - close)
- tr3 = abs(low - close)
-
- tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
- return tr.rolling(window=period).mean()
-
- def calculate_indicators(self, df):
- """计算技术指标"""
- try:
- # 计算移动平均线
- df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
- df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
- df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
-
- # 计算RSI
- df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
-
- # 计算MACD
- df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
-
- # 计算布林带
- df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
- df['close'],
- self.params['bollinger_period'],
- self.params['bollinger_std']
- )
-
- # 成交量分析
- df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
- df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
-
- # 计算ATR和波动率
- df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
- df['Volatility'] = df['ATR'] / df['close'] * 100
-
- # 动量指标
- df['ROC'] = df['close'].pct_change(periods=10) * 100
-
- return df
-
- except Exception as e:
- print(f"计算技术指标时出错: {str(e)}")
- raise
-
- def calculate_score(self, df):
- """计算股票评分"""
- try:
- score = 0
- latest = df.iloc[-1]
-
- # 趋势得分 (30分)
- if latest['MA5'] > latest['MA20']:
- score += 15
- if latest['MA20'] > latest['MA60']:
- score += 15
-
- # RSI得分 (20分)
- if 30 <= latest['RSI'] <= 70:
- score += 20
- elif latest['RSI'] < 30: # 超卖
- score += 15
-
- # MACD得分 (20分)
- if latest['MACD'] > latest['Signal']:
- score += 20
-
- # 成交量得分 (30分)
- if latest['Volume_Ratio'] > 1.5:
- score += 30
- elif latest['Volume_Ratio'] > 1:
- score += 15
-
- return score
-
- except Exception as e:
- print(f"计算评分时出错: {str(e)}")
- raise
-
- def get_ai_analysis(self, df, stock_code):
- """使用 Gemini 进行 AI 分析"""
- try:
- recent_data = df.tail(14).to_dict('records')
-
- technical_summary = {
- 'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
- 'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
- 'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
- 'rsi_level': df.iloc[-1]['RSI']
- }
-
- prompt = f"""
- 分析股票 {stock_code}:
-
- 技术指标概要:
- {technical_summary}
-
- 近14日交易数据:
- {recent_data}
-
- 请提供:
- 1. 趋势分析(包含支撑位和压力位)
- 2. 成交量分析及其含义
- 3. 风险评估(包含波动率分析)
- 4. 短期和中期目标价位
- 5. 关键技术位分析
- 6. 具体交易建议(包含止损位)
-
- 请基于技术指标和市场动态进行分析,给出具体数据支持。
- """
-
- headers = {
- "Authorization": f"Bearer {self.API_KEY}",
- "Content-Type": "application/json"
- }
-
- data = {
- "model": os.getenv('API_MODEL'),
- "messages": [{"role": "user", "content": prompt}]
- }
-
- if self.API_URL.endswith('/'):
- api_url = f"{self.API_URL}chat/completions"
- else:
- api_url = f"{self.API_URL}/v1/chat/completions"
-
- response = requests.post(
- api_url,
- headers=headers,
- json=data,
- timeout=self.API_TIMEOUT
- )
-
- print(api_url)
- print(data)
- print(response.json())
-
- if response.status_code == 200:
- return response.json()['choices'][0]['message']['content']
- else:
- return "AI 分析暂时无法使用"
-
- except Exception as e:
- print(f"AI 分析发生错误: {str(e)}")
- return "AI 分析过程中发生错误"
-
- def get_recommendation(self, score):
- """根据得分给出建议"""
- if score >= 80:
- return '强烈推荐买入'
- elif score >= 60:
- return '建议买入'
- elif score >= 40:
- return '观望'
- elif score >= 20:
- return '建议卖出'
- else:
- return '强烈建议卖出'
-
- def analyze_stock(self, stock_code, market_type='A'):
- """分析单个股票"""
- try:
- # 获取股票数据
- df = self.get_stock_data(stock_code, market_type)
-
- # 计算技术指标
- df = self.calculate_indicators(df)
-
- # 评分系统
- score = self.calculate_score(df)
-
- # 获取最新数据
- latest = df.iloc[-1]
- prev = df.iloc[-2]
-
- # 生成报告(保持原有格式)
- report = {
- 'stock_code': stock_code,
- 'analysis_date': datetime.now().strftime('%Y-%m-%d'),
- 'score': score,
- 'price': latest['close'],
- 'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
- 'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
- 'rsi': latest['RSI'],
- 'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
- 'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
- 'recommendation': self.get_recommendation(score),
- 'ai_analysis': self.get_ai_analysis(df, stock_code)
- }
-
- return report
-
- except Exception as e:
- print(f"分析股票时出错: {str(e)}")
- raise
-
- def scan_market(self, stock_list, min_score=60, market_type='A'):
- """扫描市场,寻找符合条件的股票"""
- recommendations = []
-
- for stock_code in stock_list:
- try:
- report = self.analyze_stock(stock_code, market_type)
- if report['score'] >= min_score:
- recommendations.append(report)
- except Exception as e:
- print(f"分析股票 {stock_code} 时出错: {str(e)}")
- continue
-
- # 按得分排序
- recommendations.sort(key=lambda x: x['score'], reverse=True)
- return recommendations
+import pandas as pd
+import numpy as np
+from datetime import datetime, timedelta
+import os
+import requests
+from typing import Dict, List, Optional, Tuple, Generator
+from dotenv import load_dotenv
+import json
+from logger import get_logger, get_stream_logger
+
+# 获取日志器
+logger = get_logger()
+stream_logger = get_stream_logger()
+
+class StockAnalyzer:
+ def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=60):
+
+ # 加载环境变量
+ load_dotenv()
+
+ # 设置 API 配置,优先使用自定义配置,否则使用环境变量
+ self.API_URL = custom_api_url or os.getenv('API_URL')
+ self.API_KEY = custom_api_key or os.getenv('API_KEY')
+ self.API_TIMEOUT = custom_api_timeout or int(os.getenv('API_TIMEOUT', '60'))
+ self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
+
+ logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}")
+
+ # 配置参数
+ self.params = {
+ 'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
+ 'rsi_period': 14,
+ 'bollinger_period': 20,
+ 'bollinger_std': 2,
+ 'volume_ma_period': 20,
+ 'atr_period': 14
+ }
+
+
+ def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
+ """获取股票数据"""
+ import akshare as ak
+
+ if start_date is None:
+ start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
+ if end_date is None:
+ end_date = datetime.now().strftime('%Y%m%d')
+
+ try:
+ # 根据市场类型获取数据
+ if market_type == 'A':
+ df = ak.stock_zh_a_hist(
+ symbol=stock_code,
+ start_date=start_date,
+ end_date=end_date,
+ adjust="qfq"
+ )
+ # A股数据列名映射
+ elif market_type == 'HK':
+ df = ak.stock_hk_daily(
+ symbol=stock_code,
+ adjust="qfq"
+ )
+ elif market_type == 'US':
+ df = ak.stock_us_hist(
+ symbol=stock_code,
+ start_date=start_date,
+ end_date=end_date,
+ adjust="qfq"
+ )
+ # elif market_type == 'CRYPTO':
+ # df = ak.crypto_js_spot(
+ # symbol=stock_code
+ # )
+ else:
+ raise ValueError(f"不支持的市场类型: {market_type}")
+
+ # 重命名列名以匹配分析需求
+ df = df.rename(columns={
+ "日期": "date",
+ "开盘": "open",
+ "收盘": "close",
+ "最高": "high",
+ "最低": "low",
+ "成交量": "volume"
+ })
+
+ # 确保日期格式正确
+ df['date'] = pd.to_datetime(df['date'])
+
+ # 数据类型转换
+ numeric_columns = ['open', 'close', 'high', 'low', 'volume']
+ df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
+
+ # 删除空值
+ df = df.dropna()
+
+ return df.sort_values('date')
+
+ except Exception as e:
+ raise Exception(f"获取股票数据失败: {str(e)}")
+
+ def calculate_ema(self, series, period):
+ """计算指数移动平均线"""
+ return series.ewm(span=period, adjust=False).mean()
+
+ def calculate_rsi(self, series, period):
+ """计算RSI指标"""
+ delta = series.diff()
+ gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
+ loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
+ rs = gain / loss
+ return 100 - (100 / (1 + rs))
+
+ def calculate_macd(self, series):
+ """计算MACD指标"""
+ exp1 = series.ewm(span=12, adjust=False).mean()
+ exp2 = series.ewm(span=26, adjust=False).mean()
+ macd = exp1 - exp2
+ signal = macd.ewm(span=9, adjust=False).mean()
+ hist = macd - signal
+ return macd, signal, hist
+
+ def calculate_bollinger_bands(self, series, period, std_dev):
+ """计算布林带"""
+ middle = series.rolling(window=period).mean()
+ std = series.rolling(window=period).std()
+ upper = middle + (std * std_dev)
+ lower = middle - (std * std_dev)
+ return upper, middle, lower
+
+ def calculate_atr(self, df, period):
+ """计算ATR指标"""
+ high = df['high']
+ low = df['low']
+ close = df['close'].shift(1)
+
+ tr1 = high - low
+ tr2 = abs(high - close)
+ tr3 = abs(low - close)
+
+ tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
+ return tr.rolling(window=period).mean()
+
+ def calculate_indicators(self, df):
+ """计算技术指标"""
+ try:
+ # 计算移动平均线
+ df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
+ df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
+ df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
+
+ # 计算RSI
+ df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
+
+ # 计算MACD
+ df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
+
+ # 计算布林带
+ df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
+ df['close'],
+ self.params['bollinger_period'],
+ self.params['bollinger_std']
+ )
+
+ # 成交量分析
+ df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
+ df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
+
+ # 计算ATR和波动率
+ df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
+ df['Volatility'] = df['ATR'] / df['close'] * 100
+
+ # 动量指标
+ df['ROC'] = df['close'].pct_change(periods=10) * 100
+
+ return df
+
+ except Exception as e:
+ print(f"计算技术指标时出错: {str(e)}")
+ raise
+
+ def calculate_score(self, df):
+ """计算股票评分"""
+ try:
+ score = 0
+ latest = df.iloc[-1]
+
+ # 趋势得分 (30分)
+ if latest['MA5'] > latest['MA20']:
+ score += 15
+ if latest['MA20'] > latest['MA60']:
+ score += 15
+
+ # RSI得分 (20分)
+ if 30 <= latest['RSI'] <= 70:
+ score += 20
+ elif latest['RSI'] < 30: # 超卖
+ score += 15
+
+ # MACD得分 (20分)
+ if latest['MACD'] > latest['Signal']:
+ score += 20
+
+ # 成交量得分 (30分)
+ if latest['Volume_Ratio'] > 1.5:
+ score += 30
+ elif latest['Volume_Ratio'] > 1:
+ score += 15
+
+ return score
+
+ except Exception as e:
+ print(f"计算评分时出错: {str(e)}")
+ raise
+
+ def get_ai_analysis(self, df, stock_code, stream=False):
+ """使用 OpenAI 进行 AI 分析"""
+ try:
+ logger.info(f"开始AI分析股票 {stock_code}, 流式模式: {stream}")
+ recent_data = df.tail(14).to_dict('records')
+
+ technical_summary = {
+ 'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
+ 'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
+ 'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
+ 'rsi_level': df.iloc[-1]['RSI']
+ }
+
+ prompt = f"""
+ 分析股票 {stock_code}:
+
+ 技术指标概要:
+ {technical_summary}
+
+ 近14日交易数据:
+ {recent_data}
+
+ 请提供:
+ 1. 趋势分析(包含支撑位和压力位)
+ 2. 成交量分析及其含义
+ 3. 风险评估(包含波动率分析)
+ 4. 短期和中期目标价位
+ 5. 关键技术位分析
+ 6. 具体交易建议(包含止损位)
+
+ 请基于技术指标和市场动态进行分析,给出具体数据支持。
+ """
+
+ logger.debug(f"生成的AI分析提示词: {prompt[:100]}...")
+
+ # 检查API配置
+ if not self.API_URL:
+ error_msg = "API URL未配置,无法进行AI分析"
+ logger.error(error_msg)
+ return error_msg if not stream else (yield json.dumps({"error": error_msg}))
+
+ if not self.API_KEY:
+ error_msg = "API Key未配置,无法进行AI分析"
+ logger.error(error_msg)
+ return error_msg if not stream else (yield json.dumps({"error": error_msg}))
+
+ # 标准化API URL
+ if self.API_URL.endswith('/'):
+ api_url = f"{self.API_URL}chat/completions"
+ else:
+ api_url = f"{self.API_URL}/v1/chat/completions"
+ # 标准化API URL
+ # api_url = self.API_URL
+ # if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
+ # if api_url.endswith('/v1'):
+ # api_url = f"{api_url}/chat/completions"
+ # elif api_url.endswith('/'):
+ # api_url = f"{api_url}v1/chat/completions"
+ # else:
+ # api_url = f"{api_url}/v1/chat/completions"
+
+ logger.debug(f"标准化后的API URL: {api_url}")
+
+ # 构建请求头和请求体
+ headers = {
+ "Authorization": f"Bearer {self.API_KEY}",
+ "Content-Type": "application/json"
+ }
+
+ payload = {
+ "model": self.API_MODEL,
+ "messages": [{"role": "user", "content": prompt}]
+ }
+
+ # 流式处理设置
+ if stream:
+ logger.debug(f"配置流式参数,使用API URL: {api_url}")
+ payload["stream"] = True # 明确设置stream参数为True
+
+ try:
+ logger.debug(f"发起流式API请求: {api_url}")
+ logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
+
+ response = requests.post(
+ api_url,
+ headers=headers,
+ json=payload,
+ timeout=60, # 增加超时时间
+ stream=True
+ )
+
+ logger.debug(f"API流式响应状态码: {response.status_code}")
+
+ if response.status_code == 200:
+ logger.info(f"成功获取API流式响应,开始处理")
+ yield from self._process_ai_stream(response, stock_code)
+ else:
+ try:
+ error_response = response.json()
+ error_text = json.dumps(error_response, indent=2)
+ except:
+ error_text = response.text[:500] if response.text else "无响应内容"
+
+ error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
+ logger.error(error_msg)
+ yield json.dumps({"stock_code": stock_code, "error": error_msg})
+
+ except Exception as e:
+ error_msg = f"流式API请求异常: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+ yield json.dumps({"stock_code": stock_code, "error": error_msg})
+ else:
+ # 非流式处理
+ logger.debug(f"发起非流式API请求: {api_url}")
+
+ try:
+ response = requests.post(
+ api_url,
+ headers=headers,
+ json=payload,
+ timeout=60
+ )
+
+ logger.debug(f"API非流式响应状态码: {response.status_code}")
+
+ if response.status_code == 200:
+ api_response = response.json()
+ content = api_response['choices'][0]['message']['content']
+ logger.info(f"成功获取AI分析结果,长度: {len(content)}")
+ logger.debug(f"AI分析结果前100字符: {content[:100]}...")
+ return content
+ else:
+ try:
+ error_response = response.json()
+ error_text = json.dumps(error_response, indent=2)
+ except:
+ error_text = response.text[:500] if response.text else "无响应内容"
+
+ error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
+ logger.error(error_msg)
+ return error_msg
+
+ except Exception as e:
+ error_msg = f"非流式API请求异常: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+ return error_msg
+
+ except Exception as e:
+ error_msg = f"AI 分析过程中发生错误: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+
+ if stream:
+ logger.debug("在流式模式下返回异常信息")
+ error_json = json.dumps({"stock_code": stock_code, "error": error_msg})
+ stream_logger.info(f"流式异常输出: {error_json}")
+ yield error_json
+ else:
+ return error_msg
+
+ def _process_ai_stream(self, response, stock_code) -> Generator[str, None, None]:
+ """处理AI流式响应"""
+ logger.info(f"开始处理股票 {stock_code} 的AI流式响应")
+ buffer = ""
+ chunk_count = 0
+
+ try:
+ for line in response.iter_lines():
+ if line:
+ line = line.decode('utf-8')
+ stream_logger.info(f"原始流式行: {line}")
+
+ # 跳过保持连接的空行
+ if line.strip() == '':
+ logger.debug("跳过空行")
+ continue
+
+ # 数据行通常以"data: "开头
+ if line.startswith('data: '):
+ data_content = line[6:] # 移除 "data: " 前缀
+ stream_logger.info(f"数据内容: {data_content}")
+
+ # 检查是否为流的结束
+ if data_content.strip() == '[DONE]':
+ logger.debug("收到流结束标记 [DONE]")
+ break
+
+ try:
+ json_data = json.loads(data_content)
+ logger.debug(f"解析的JSON数据: {json.dumps(json_data)[:100]}...")
+
+ if 'choices' in json_data:
+ delta = json_data['choices'][0].get('delta', {})
+ content = delta.get('content', '')
+
+ if content:
+ chunk_count += 1
+ buffer += content
+ logger.debug(f"收到内容片段 #{chunk_count}: {content}")
+ stream_logger.info(f"发送内容片段: {content}")
+
+ # 创建包含AI分析片段的JSON
+ chunk_json = json.dumps({
+ "stock_code": stock_code,
+ "ai_analysis_chunk": content
+ })
+ stream_logger.info(f"流式输出JSON: {chunk_json}")
+ yield chunk_json
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON解析错误: {str(e)}, 行内容: {data_content}")
+ # 忽略无法解析的JSON
+ pass
+ else:
+ logger.warning(f"收到非'data:'开头的行: {line}")
+
+ logger.info(f"AI流式处理完成,共收到 {chunk_count} 个内容片段,总长度: {len(buffer)}")
+
+ # 如果buffer不为空,最后一次发送完整内容
+ if buffer and not buffer.endswith('\n'):
+ logger.debug("发送换行符")
+ yield json.dumps({"stock_code": stock_code, "ai_analysis_chunk": "\n"})
+
+ except Exception as e:
+ error_msg = f"处理AI流式响应时出错: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+ yield json.dumps({"stock_code": stock_code, "error": error_msg})
+
+
+ def get_recommendation(self, score):
+ """根据得分给出建议"""
+ logger.debug(f"根据评分 {score} 生成投资建议")
+ if score >= 80:
+ return '强烈推荐买入'
+ elif score >= 60:
+ return '建议买入'
+ elif score >= 40:
+ return '观望'
+ elif score >= 20:
+ return '建议卖出'
+ else:
+ return '强烈建议卖出'
+
+ def analyze_stock(self, stock_code, market_type='A', stream=False):
+ """分析单个股票"""
+ try:
+ logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}, 流式模式: {stream}")
+
+ # 获取股票数据
+ logger.debug(f"获取股票 {stock_code} 数据")
+ df = self.get_stock_data(stock_code, market_type)
+
+ # 计算技术指标
+ logger.debug(f"计算股票 {stock_code} 技术指标")
+ df = self.calculate_indicators(df)
+
+ # 评分系统
+ logger.debug(f"计算股票 {stock_code} 评分")
+ score = self.calculate_score(df)
+ logger.info(f"股票 {stock_code} 评分结果: {score}")
+
+ # 获取最新数据
+ latest = df.iloc[-1]
+ prev = df.iloc[-2]
+
+ # 生成报告(保持原有格式)
+ report = {
+ 'stock_code': stock_code,
+ 'analysis_date': datetime.now().strftime('%Y-%m-%d'),
+ 'score': score,
+ 'price': latest['close'],
+ 'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
+ 'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
+ 'rsi': latest['RSI'],
+ 'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
+ 'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
+ 'recommendation': self.get_recommendation(score)
+ }
+ logger.debug(f"生成股票 {stock_code} 基础报告: {json.dumps(report)[:100]}...")
+
+ if stream:
+ logger.info(f"以流式模式返回股票 {stock_code} 分析结果")
+ # 先返回基本报告结构
+ base_report = dict(report)
+ base_report['ai_analysis'] = ''
+ base_report_json = json.dumps(base_report)
+ logger.debug(f"基础报告JSON: {base_report_json[:100]}...")
+ stream_logger.info(f"发送基础报告: {base_report_json}")
+ yield base_report_json
+
+ # 然后流式返回AI分析部分
+ logger.debug(f"开始获取股票 {stock_code} 的流式AI分析")
+ ai_chunks_count = 0
+ for ai_chunk in self.get_ai_analysis(df, stock_code, stream=True):
+ ai_chunks_count += 1
+ stream_logger.info(f"股票 {stock_code} 流式块 #{ai_chunks_count}: {ai_chunk}")
+ yield ai_chunk
+ logger.info(f"股票 {stock_code} 流式AI分析完成,共发送 {ai_chunks_count} 个块")
+ else:
+ logger.info(f"以非流式模式返回股票 {stock_code} 分析结果")
+ logger.debug(f"开始获取股票 {stock_code} 的AI分析")
+ report['ai_analysis'] = self.get_ai_analysis(df, stock_code)
+ logger.debug(f"AI分析结果长度: {len(report['ai_analysis'])}")
+ return report
+
+ except Exception as e:
+ error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+
+ if stream:
+ error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
+ stream_logger.info(f"流式错误输出: {error_json}")
+ yield error_json
+ else:
+ raise
+
+ def scan_market(self, stock_list, min_score=60, market_type='A', stream=False):
+ """扫描市场,寻找符合条件的股票"""
+ logger.info(f"开始扫描市场,股票数量: {len(stock_list)}, 最低分数: {min_score}, 市场: {market_type}, 流式模式: {stream}")
+
+ if not stream:
+ recommendations = []
+
+ for stock_code in stock_list:
+ try:
+ logger.debug(f"分析股票: {stock_code}")
+ report = self.analyze_stock(stock_code, market_type)
+ if report['score'] >= min_score:
+ logger.info(f"股票 {stock_code} 评分 {report['score']} >= {min_score},添加到推荐列表")
+ recommendations.append(report)
+ else:
+ logger.debug(f"股票 {stock_code} 评分 {report['score']} < {min_score},不添加到推荐列表")
+ except Exception as e:
+ logger.error(f"分析股票 {stock_code} 时出错: {str(e)}")
+ logger.exception(e)
+ continue
+
+ # 按得分排序
+ recommendations.sort(key=lambda x: x['score'], reverse=True)
+ logger.info(f"扫描完成,找到 {len(recommendations)} 个推荐股票")
+ return recommendations
+ else:
+ # 流式处理每个股票
+ logger.info(f"开始流式扫描 {len(stock_list)} 只股票")
+ stock_count = 0
+ for stock_code in stock_list:
+ stock_count += 1
+ logger.debug(f"流式分析股票 {stock_code} ({stock_count}/{len(stock_list)})")
+ try:
+ # 分析单只股票并获取流式结果
+ chunk_count = 0
+ for chunk in self.analyze_stock(stock_code, market_type, stream=True):
+ chunk_count += 1
+ stream_logger.info(f"股票 {stock_code} 流式块 #{chunk_count}: {chunk}")
+ yield chunk
+ logger.debug(f"股票 {stock_code} 流式分析完成,共 {chunk_count} 个块")
+ except Exception as e:
+ error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+ error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
+ stream_logger.info(f"流式错误输出: {error_json}")
+ yield error_json
+ logger.info(f"流式扫描完成,处理了 {stock_count} 只股票")
diff --git a/templates/index.html b/templates/index.html
index aee282d..11045cf 100644
--- a/templates/index.html
+++ b/templates/index.html
@@ -1,451 +1,874 @@
-
-
-
-
-
- 股票分析系统
-
-
-
-
-
股票分析系统
-
-
- {% if announcement %}
-
-
-
-
-
-
- {{ announcement }}
-
-
-
-
-
- {% endif %}
-
-
-
-
-
股票批量分析
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
分析结果
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
股票分析系统
+
+
+
+
+
股票分析系统
+
+
+ {% if announcement %}
+
+
+
+
+
+
+ {{ announcement }}
+
+
+
+
+
+ {% endif %}
+
+
+
+
+
股票批量分析
+
+
+
+
+
API配置
+
+
+
+
+
+
+
+
+
如不填写,将使用系统默认配置
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
分析结果
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/test_stream.py b/tests/test_stream.py
new file mode 100644
index 0000000..e501081
--- /dev/null
+++ b/tests/test_stream.py
@@ -0,0 +1,137 @@
+import os
+import requests
+import json
+from logger import get_logger, get_stream_logger
+from dotenv import load_dotenv
+
+# 获取日志器
+logger = get_logger()
+stream_logger = get_stream_logger()
+
+def test_api_stream():
+ """
+ 测试API流式响应功能
+ """
+ # 加载环境变量
+ load_dotenv()
+
+ # 获取API配置
+ api_url = os.getenv('API_URL')
+ api_key = os.getenv('API_KEY')
+ api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
+
+ logger.info(f"开始测试API流式响应,API URL: {api_url}, MODEL: {api_model}")
+
+ # 检查API配置
+ if not api_url:
+ logger.error("API URL未配置,无法进行测试")
+ return
+
+ if not api_key:
+ logger.error("API Key未配置,无法进行测试")
+ return
+
+ # 标准化API URL
+ if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
+ if api_url.endswith('/v1'):
+ api_url = f"{api_url}/chat/completions"
+ elif api_url.endswith('/'):
+ api_url = f"{api_url}v1/chat/completions"
+ else:
+ api_url = f"{api_url}/v1/chat/completions"
+
+ logger.debug(f"标准化后的API URL: {api_url}")
+
+ # 构建简单的测试提示
+ prompt = "这是一个API流式响应测试。请给出一个简短的股票分析样例。"
+
+ # 构建请求头和请求体
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json"
+ }
+
+ payload = {
+ "model": api_model,
+ "messages": [{"role": "user", "content": prompt}],
+ "stream": True # 明确设置stream参数为True
+ }
+
+ logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
+
+ try:
+ logger.info(f"发起流式API请求: {api_url}")
+
+ response = requests.post(
+ api_url,
+ headers=headers,
+ json=payload,
+ timeout=60,
+ stream=True
+ )
+
+ logger.info(f"API流式响应状态码: {response.status_code}")
+ logger.debug(f"响应头: {response.headers}")
+
+ if response.status_code == 200:
+ logger.info("成功获取API流式响应,开始处理")
+
+ buffer = ""
+ chunk_count = 0
+
+ for line in response.iter_lines():
+ if line:
+ line_str = line.decode('utf-8')
+ logger.info(f"原始流式行: {line_str}")
+
+ # 跳过保持连接的空行
+ if line_str.strip() == '':
+ logger.debug("跳过空行")
+ continue
+
+ # 数据行通常以"data: "开头
+ if line_str.startswith('data: '):
+ data_content = line_str[6:].strip() # 移除 "data: " 前缀并去除前后空格
+ logger.info(f"数据内容: {data_content}")
+
+ # 检查是否为流的结束
+ if data_content == '[DONE]':
+ logger.info("收到流结束标记 [DONE]")
+ break
+
+ try:
+ # 解析JSON数据
+ json_data = json.loads(data_content)
+ logger.debug(f"JSON结构: {json.dumps(json_data, indent=2)}")
+
+ if 'choices' in json_data:
+ delta = json_data['choices'][0].get('delta', {})
+ content = delta.get('content', '')
+
+ if content:
+ chunk_count += 1
+ buffer += content
+ logger.info(f"内容片段 #{chunk_count}: {content}")
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON解析错误: {e}, 内容: {data_content}")
+ else:
+ logger.warning(f"收到非'data:'开头的行: {line_str}")
+
+ logger.info(f"流式处理完成,共收到 {chunk_count} 个内容片段")
+ logger.info(f"完整内容:\n{buffer}")
+
+ else:
+ try:
+ error_response = response.json()
+ error_text = json.dumps(error_response, indent=2)
+ except:
+ error_text = response.text[:500] if response.text else "无响应内容"
+
+ logger.error(f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}")
+
+ except Exception as e:
+ logger.error(f"测试过程中发生异常: {str(e)}")
+ logger.exception(e)
+
+if __name__ == "__main__":
+ test_api_stream()
diff --git a/web_server.py b/web_server.py
index a62105f..2c77137 100644
--- a/web_server.py
+++ b/web_server.py
@@ -1,8 +1,15 @@
-from flask import Flask, render_template, request, jsonify
+from flask import Flask, render_template, request, jsonify, Response, stream_with_context
from stock_analyzer import StockAnalyzer
from us_stock_service import USStockService
import threading
import os
+import traceback
+import requests
+from logger import get_logger, get_stream_logger
+
+# 获取日志器
+logger = get_logger()
+stream_logger = get_stream_logger()
app = Flask(__name__)
analyzer = StockAnalyzer()
@@ -11,27 +18,86 @@ us_stock_service = USStockService()
@app.route('/')
def index():
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
- return render_template('index.html', announcement=announcement)
+ # 获取默认API配置信息
+ default_api_url = os.getenv('API_URL', '')
+ default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
+ # 不传递API_KEY到前端,出于安全考虑
+ return render_template('index.html',
+ announcement=announcement,
+ default_api_url=default_api_url,
+ default_api_model=default_api_model)
@app.route('/analyze', methods=['POST'])
def analyze():
try:
+ logger.info("开始处理分析请求")
data = request.json
stock_codes = data.get('stock_codes', [])
market_type = data.get('market_type', 'A')
+ logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
+
+ # 获取自定义API配置
+ custom_api_url = data.get('api_url')
+ custom_api_key = data.get('api_key')
+ custom_api_model = data.get('api_model')
+ custom_api_timeout = data.get('api_timeout', 60)
+
+ logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}")
+
+ # 创建新的分析器实例,使用自定义配置
+ custom_analyzer = StockAnalyzer(
+ custom_api_url=custom_api_url,
+ custom_api_key=custom_api_key,
+ custom_api_model=custom_api_model,
+ custom_api_timeout= custom_api_timeout,
+ )
+
if not stock_codes:
+ logger.warning("未提供股票代码")
return jsonify({'error': '请输入代码'}), 400
+
+ # 使用流式响应
+ def generate():
+ if len(stock_codes) == 1:
+ # 单个股票分析流式处理
+ stock_code = stock_codes[0].strip()
+ logger.info(f"开始单股流式分析: {stock_code}")
+
+ stream_logger.info(f"初始化单股分析流: {stock_code}")
+ init_message = f'{{"stream_type": "single", "stock_code": "{stock_code}"}}\n'
+ stream_logger.info(f"发送初始化消息: {init_message}")
+ yield init_message
+
+ for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
+ stream_logger.info(f"流式输出块: {chunk}")
+ yield chunk + '\n'
+ else:
+ # 批量分析流式处理
+ logger.info(f"开始批量流式分析: {stock_codes}")
+
+ stream_logger.info(f"初始化批量分析流: {stock_codes}")
+ init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes}}}\n'
+ stream_logger.info(f"发送初始化消息: {init_message}")
+ yield init_message
+
+ for chunk in custom_analyzer.scan_market(
+ [code.strip() for code in stock_codes],
+ min_score=0,
+ market_type=market_type,
+ stream=True
+ ):
+ stream_logger.info(f"流式输出块: {chunk}")
+ yield chunk + '\n'
+
+ logger.info("成功创建流式响应生成器")
+ return Response(stream_with_context(generate()), mimetype='application/json')
- results = []
- for stock_code in stock_codes:
- result = analyzer.analyze_stock(stock_code.strip(), market_type)
- results.append(result)
-
- return jsonify({'results': results})
except Exception as e:
- print(f"分析股票时出错: {str(e)}")
- return jsonify({'error': str(e)}), 500
+ error_msg = f"分析股票时出错: {str(e)}"
+ logger.error(error_msg)
+ logger.exception(e)
+ return jsonify({'error': error_msg}), 500
@app.route('/search_us_stocks', methods=['GET'])
def search_us_stocks():
@@ -47,8 +113,72 @@ def search_us_stocks():
print(f"搜索美股代码时出错: {str(e)}")
return jsonify({'error': str(e)}), 500
+@app.route('/test_api_connection', methods=['POST'])
+def test_api_connection():
+ """测试API连接"""
+ try:
+ logger.info("开始测试API连接")
+ data = request.json
+ api_url = data.get('api_url')
+ api_key = data.get('api_key')
+ api_model = data.get('api_model')
+
+ logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}")
+
+ if not api_url:
+ logger.warning("未提供API URL")
+ return jsonify({'error': '请提供API URL'}), 400
+
+ if not api_key:
+ logger.warning("未提供API Key")
+ return jsonify({'error': '请提供API Key'}), 400
+
+ # 构建API URL
+ test_url = api_url
+ if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
+ if api_url.endswith('/v1'):
+ test_url = f"{api_url}/chat/completions"
+ elif api_url.endswith('/'):
+ test_url = f"{api_url}chat/completions"
+ else:
+ test_url = f"{api_url}/v1/chat/completions"
+
+ logger.debug(f"完整API测试URL: {test_url}")
+
+ # 发送测试请求
+ response = requests.post(
+ test_url,
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json"
+ },
+ json={
+ "model": api_model or "gpt-3.5-turbo",
+ "messages": [
+ {"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
+ ],
+ "max_tokens": 20
+ },
+ timeout=10
+ )
+
+ # 检查响应
+ if response.status_code == 200:
+ logger.info(f"API连接测试成功: {response.status_code}")
+ return jsonify({'success': True, 'message': '连接成功'})
+ else:
+ error_message = response.json().get('error', {}).get('message', '未知错误')
+ logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
+ return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400
+
+ except requests.exceptions.RequestException as e:
+ logger.error(f"API连接请求错误: {str(e)}")
+ return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
+ except Exception as e:
+ logger.error(f"测试API连接时出错: {str(e)}")
+ logger.exception(e)
+ return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500
+
if __name__ == '__main__':
- app.run(host='0.0.0.0', port=8888, debug=True)
-
-
-
\ No newline at end of file
+ logger.info("股票分析系统启动")
+ app.run(host='0.0.0.0', port=8888, debug=True)
\ No newline at end of file