From dd33c170ff9faea0996bbee8abd0b83290a79883 Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 11:38:42 +0800 Subject: [PATCH 1/8] Update requirements.txt --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a79b0b4..7799faf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ +--index-url https://pypi.tuna.tsinghua.edu.cn/simple + # 基础科学计算和数据处理库 numpy==2.1.2 pandas==2.2.2 scipy==1.15.1 # 数据获取和分析库 -akshare==1.15.87 +akshare==1.16.22 tqdm==4.67.1 From f2450d0d61af511a8c77f70613ecf8cf62ec1d13 Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 11:38:45 +0800 Subject: [PATCH 2/8] Update web_server.py --- web_server.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/web_server.py b/web_server.py index 31a6299..cf6beef 100644 --- a/web_server.py +++ b/web_server.py @@ -37,11 +37,21 @@ def analyze(): results = [] for stock_code in stock_codes: - result = analyzer.analyze_stock(stock_code.strip(), market_type) - results.append(result) + try: + result = analyzer.analyze_stock(stock_code.strip(), market_type) + results.append(result) + except Exception as e: + app.logger.error(f"分析股票 {stock_code} 失败: {str(e)}") + app.logger.error(f"详细错误: {traceback.format_exc()}") + results.append({ + 'code': stock_code, + 'error': f"分析失败: {str(e)}" + }) return jsonify({'results': results}) except Exception as e: + app.logger.error(f"处理请求失败: {str(e)}") + app.logger.error(f"详细错误: {traceback.format_exc()}") return jsonify({'error': str(e)}), 500 @app.route('/search_us_stocks', methods=['GET']) From 40787aa85f11103d842ee43c46c8908c35668b3e Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 11:47:28 +0800 Subject: [PATCH 3/8] Update stock_analyzer.py --- stock_analyzer.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/stock_analyzer.py b/stock_analyzer.py index 815f4b2..cbbe1e0 100644 --- a/stock_analyzer.py +++ b/stock_analyzer.py @@ -211,7 +211,7 @@ class StockAnalyzer: raise def get_ai_analysis(self, df, stock_code): - """使用 Gemini 进行 AI 分析""" + """使用 OpenAI 进行 AI 分析""" try: recent_data = df.tail(14).to_dict('records') @@ -242,34 +242,29 @@ class StockAnalyzer: 请基于技术指标和市场动态进行分析,给出具体数据支持。 """ - headers = { - "Authorization": f"Bearer {self.API_KEY}", - "Content-Type": "application/json" - } - - data = { - "model": os.getenv('API_MODEL'), - "messages": [{"role": "user", "content": prompt}] - } - + # OpenAI API 调用 response = requests.post( - f"{self.API_URL}/v1/chat/completions", - headers=headers, - json=data, + f"{self.API_URL}/chat/completions", + headers={ + "Authorization": f"Bearer {self.API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": os.getenv('API_MODEL', 'gpt-3.5-turbo'), + "messages": [{"role": "user", "content": prompt}] + }, timeout=30 ) - print(headers) - print(data) - print(response.json()) - + if response.status_code == 200: return response.json()['choices'][0]['message']['content'] else: - return "AI 分析暂时无法使用" + self.logger.error(f"API 错误: {response.status_code} - {response.text}") + return f"AI 分析暂时无法使用 (HTTP {response.status_code})" except Exception as e: self.logger.error(f"AI 分析发生错误: {str(e)}") - return "AI 分析过程中发生错误" + return f"AI 分析过程中发生错误: {str(e)}" def get_recommendation(self, score): """根据得分给出建议""" From a3474bc201acdd71baada7b6deff16e166484852 Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 12:01:55 +0800 Subject: [PATCH 4/8] Update Dockerfile --- Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index ed99415..1c331fb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# 使用 Python 3.9 作为基础镜像 +# 使用 Python 3.10 作为基础镜像 FROM python:3.10-slim # 设置工作目录 @@ -15,7 +15,6 @@ COPY . /app/ # 安装 Python 依赖 RUN pip install --no-cache-dir -r requirements.txt -RUN pip install akshare --upgrade -i https://pypi.org/simple # 设置环境变量 ENV PYTHONPATH=/app From 476568f328f404c9173a0fb32b1f1986df1f94d8 Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 12:01:58 +0800 Subject: [PATCH 5/8] Update stock_analyzer.py --- stock_analyzer.py | 49 +++++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/stock_analyzer.py b/stock_analyzer.py index 5e86cb9..a88139f 100644 --- a/stock_analyzer.py +++ b/stock_analyzer.py @@ -237,25 +237,40 @@ class StockAnalyzer: """ # OpenAI API 调用 - response = requests.post( + api_urls = [ f"{self.API_URL}/chat/completions", - headers={ - "Authorization": f"Bearer {self.API_KEY}", - "Content-Type": "application/json" - }, - json={ - "model": os.getenv('API_MODEL', 'gpt-3.5-turbo'), - "messages": [{"role": "user", "content": prompt}] - }, - timeout=30 - ) + f"{self.API_URL}/v1/chat/completions" + ] + + last_error = None + for api_url in api_urls: + try: + response = requests.post( + api_url, + headers={ + "Authorization": f"Bearer {self.API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": os.getenv('API_MODEL', 'gpt-3.5-turbo'), + "messages": [{"role": "user", "content": prompt}] + }, + timeout=30 + ) + + if response.status_code == 200: + return response.json()['choices'][0]['message']['content'] + else: + last_error = f"API 错误: {response.status_code} - {response.text}" + continue + + except Exception as e: + last_error = str(e) + continue + + print(f"AI 分析暂时无法使用: {last_error}") + return f"AI 分析暂时无法使用: {last_error}" - if response.status_code == 200: - return response.json()['choices'][0]['message']['content'] - else: - print(f"API 错误: {response.status_code} - {response.text}") - return f"AI 分析暂时无法使用 (HTTP {response.status_code})" - except Exception as e: print(f"AI 分析发生错误: {str(e)}") return f"AI 分析过程中发生错误: {str(e)}" From 5ccab7ab4375bca5b24c4e86776da878a8534eec Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 12:08:19 +0800 Subject: [PATCH 6/8] Update web_server.py --- web_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web_server.py b/web_server.py index ae7aaa2..b6152fe 100644 --- a/web_server.py +++ b/web_server.py @@ -29,8 +29,8 @@ def analyze(): result = analyzer.analyze_stock(stock_code.strip(), market_type) results.append(result) except Exception as e: - app.logger.error(f"分析股票 {stock_code} 失败: {str(e)}") - app.logger.error(f"详细错误: {traceback.format_exc()}") + print(f"分析股票 {stock_code} 失败: {str(e)}") + print(f"详细错误: {traceback.format_exc()}") results.append({ 'code': stock_code, 'error': f"分析失败: {str(e)}" From 17ed403c3e96938092918dd8d16d24b958386e1a Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 13:01:38 +0800 Subject: [PATCH 7/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- stock_analyzer.py | 687 ++++++++++++++-------------- templates/index.html | 1028 ++++++++++++++++++++++++------------------ web_server.py | 85 +++- 3 files changed, 1001 insertions(+), 799 deletions(-) diff --git a/stock_analyzer.py b/stock_analyzer.py index a88139f..105278d 100644 --- a/stock_analyzer.py +++ b/stock_analyzer.py @@ -1,343 +1,344 @@ -import pandas as pd -import numpy as np -from datetime import datetime, timedelta -import os -import requests -from typing import Dict, List, Optional, Tuple -from dotenv import load_dotenv - -class StockAnalyzer: - def __init__(self, initial_cash=1000000): - - # 加载环境变量 - load_dotenv() - - # 设置 Gemini API - self.API_URL = os.getenv('API_URL') - self.API_KEY = os.getenv('API_KEY') - - # 配置参数 - self.params = { - 'ma_periods': {'short': 5, 'medium': 20, 'long': 60}, - 'rsi_period': 14, - 'bollinger_period': 20, - 'bollinger_std': 2, - 'volume_ma_period': 20, - 'atr_period': 14 - } - - - def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ): - """获取股票数据""" - import akshare as ak - - if start_date is None: - start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d') - if end_date is None: - end_date = datetime.now().strftime('%Y%m%d') - - try: - # 根据市场类型获取数据 - if market_type == 'A': - df = ak.stock_zh_a_hist( - symbol=stock_code, - start_date=start_date, - end_date=end_date, - adjust="qfq" - ) - # A股数据列名映射 - elif market_type == 'HK': - df = ak.stock_hk_daily( - symbol=stock_code, - adjust="qfq" - ) - elif market_type == 'US': - df = ak.stock_us_hist( - symbol=stock_code, - start_date=start_date, - end_date=end_date, - adjust="qfq" - ) - # elif market_type == 'CRYPTO': - # df = ak.crypto_js_spot( - # symbol=stock_code - # ) - else: - raise ValueError(f"不支持的市场类型: {market_type}") - - # 重命名列名以匹配分析需求 - df = df.rename(columns={ - "日期": "date", - "开盘": "open", - "收盘": "close", - "最高": "high", - "最低": "low", - "成交量": "volume" - }) - - # 确保日期格式正确 - df['date'] = pd.to_datetime(df['date']) - - # 数据类型转换 - numeric_columns = ['open', 'close', 'high', 'low', 'volume'] - df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce') - - # 删除空值 - df = df.dropna() - - return df.sort_values('date') - - except Exception as e: - raise Exception(f"获取股票数据失败: {str(e)}") - - def calculate_ema(self, series, period): - """计算指数移动平均线""" - return series.ewm(span=period, adjust=False).mean() - - def calculate_rsi(self, series, period): - """计算RSI指标""" - delta = series.diff() - gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() - loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() - rs = gain / loss - return 100 - (100 / (1 + rs)) - - def calculate_macd(self, series): - """计算MACD指标""" - exp1 = series.ewm(span=12, adjust=False).mean() - exp2 = series.ewm(span=26, adjust=False).mean() - macd = exp1 - exp2 - signal = macd.ewm(span=9, adjust=False).mean() - hist = macd - signal - return macd, signal, hist - - def calculate_bollinger_bands(self, series, period, std_dev): - """计算布林带""" - middle = series.rolling(window=period).mean() - std = series.rolling(window=period).std() - upper = middle + (std * std_dev) - lower = middle - (std * std_dev) - return upper, middle, lower - - def calculate_atr(self, df, period): - """计算ATR指标""" - high = df['high'] - low = df['low'] - close = df['close'].shift(1) - - tr1 = high - low - tr2 = abs(high - close) - tr3 = abs(low - close) - - tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) - return tr.rolling(window=period).mean() - - def calculate_indicators(self, df): - """计算技术指标""" - try: - # 计算移动平均线 - df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short']) - df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium']) - df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long']) - - # 计算RSI - df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period']) - - # 计算MACD - df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close']) - - # 计算布林带 - df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands( - df['close'], - self.params['bollinger_period'], - self.params['bollinger_std'] - ) - - # 成交量分析 - df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean() - df['Volume_Ratio'] = df['volume'] / df['Volume_MA'] - - # 计算ATR和波动率 - df['ATR'] = self.calculate_atr(df, self.params['atr_period']) - df['Volatility'] = df['ATR'] / df['close'] * 100 - - # 动量指标 - df['ROC'] = df['close'].pct_change(periods=10) * 100 - - return df - - except Exception as e: - print(f"计算技术指标时出错: {str(e)}") - raise - - def calculate_score(self, df): - """计算股票评分""" - try: - score = 0 - latest = df.iloc[-1] - - # 趋势得分 (30分) - if latest['MA5'] > latest['MA20']: - score += 15 - if latest['MA20'] > latest['MA60']: - score += 15 - - # RSI得分 (20分) - if 30 <= latest['RSI'] <= 70: - score += 20 - elif latest['RSI'] < 30: # 超卖 - score += 15 - - # MACD得分 (20分) - if latest['MACD'] > latest['Signal']: - score += 20 - - # 成交量得分 (30分) - if latest['Volume_Ratio'] > 1.5: - score += 30 - elif latest['Volume_Ratio'] > 1: - score += 15 - - return score - - except Exception as e: - print(f"计算评分时出错: {str(e)}") - raise - - def get_ai_analysis(self, df, stock_code): - """使用 OpenAI 进行 AI 分析""" - try: - recent_data = df.tail(14).to_dict('records') - - technical_summary = { - 'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward', - 'volatility': f"{df.iloc[-1]['Volatility']:.2f}%", - 'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing', - 'rsi_level': df.iloc[-1]['RSI'] - } - - prompt = f""" - 分析股票 {stock_code}: - - 技术指标概要: - {technical_summary} - - 近14日交易数据: - {recent_data} - - 请提供: - 1. 趋势分析(包含支撑位和压力位) - 2. 成交量分析及其含义 - 3. 风险评估(包含波动率分析) - 4. 短期和中期目标价位 - 5. 关键技术位分析 - 6. 具体交易建议(包含止损位) - - 请基于技术指标和市场动态进行分析,给出具体数据支持。 - """ - - # OpenAI API 调用 - api_urls = [ - f"{self.API_URL}/chat/completions", - f"{self.API_URL}/v1/chat/completions" - ] - - last_error = None - for api_url in api_urls: - try: - response = requests.post( - api_url, - headers={ - "Authorization": f"Bearer {self.API_KEY}", - "Content-Type": "application/json" - }, - json={ - "model": os.getenv('API_MODEL', 'gpt-3.5-turbo'), - "messages": [{"role": "user", "content": prompt}] - }, - timeout=30 - ) - - if response.status_code == 200: - return response.json()['choices'][0]['message']['content'] - else: - last_error = f"API 错误: {response.status_code} - {response.text}" - continue - - except Exception as e: - last_error = str(e) - continue - - print(f"AI 分析暂时无法使用: {last_error}") - return f"AI 分析暂时无法使用: {last_error}" - - except Exception as e: - print(f"AI 分析发生错误: {str(e)}") - return f"AI 分析过程中发生错误: {str(e)}" - - def get_recommendation(self, score): - """根据得分给出建议""" - if score >= 80: - return '强烈推荐买入' - elif score >= 60: - return '建议买入' - elif score >= 40: - return '观望' - elif score >= 20: - return '建议卖出' - else: - return '强烈建议卖出' - - def analyze_stock(self, stock_code, market_type='A'): - """分析单个股票""" - try: - # 获取股票数据 - df = self.get_stock_data(stock_code, market_type) - - # 计算技术指标 - df = self.calculate_indicators(df) - - # 评分系统 - score = self.calculate_score(df) - - # 获取最新数据 - latest = df.iloc[-1] - prev = df.iloc[-2] - - # 生成报告(保持原有格式) - report = { - 'stock_code': stock_code, - 'analysis_date': datetime.now().strftime('%Y-%m-%d'), - 'score': score, - 'price': latest['close'], - 'price_change': (latest['close'] - prev['close']) / prev['close'] * 100, - 'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN', - 'rsi': latest['RSI'], - 'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL', - 'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL', - 'recommendation': self.get_recommendation(score), - 'ai_analysis': self.get_ai_analysis(df, stock_code) - } - - return report - - except Exception as e: - print(f"分析股票时出错: {str(e)}") - raise - - def scan_market(self, stock_list, min_score=60, market_type='A'): - """扫描市场,寻找符合条件的股票""" - recommendations = [] - - for stock_code in stock_list: - try: - report = self.analyze_stock(stock_code, market_type) - if report['score'] >= min_score: - recommendations.append(report) - except Exception as e: - print(f"分析股票 {stock_code} 时出错: {str(e)}") - continue - - # 按得分排序 - recommendations.sort(key=lambda x: x['score'], reverse=True) - return recommendations +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +import os +import requests +from typing import Dict, List, Optional, Tuple +from dotenv import load_dotenv + +class StockAnalyzer: + def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None): + + # 加载环境变量 + load_dotenv() + + # 设置 API 配置,优先使用自定义配置,否则使用环境变量 + self.API_URL = custom_api_url or os.getenv('API_URL') + self.API_KEY = custom_api_key or os.getenv('API_KEY') + self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo') + + # 配置参数 + self.params = { + 'ma_periods': {'short': 5, 'medium': 20, 'long': 60}, + 'rsi_period': 14, + 'bollinger_period': 20, + 'bollinger_std': 2, + 'volume_ma_period': 20, + 'atr_period': 14 + } + + + def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ): + """获取股票数据""" + import akshare as ak + + if start_date is None: + start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d') + if end_date is None: + end_date = datetime.now().strftime('%Y%m%d') + + try: + # 根据市场类型获取数据 + if market_type == 'A': + df = ak.stock_zh_a_hist( + symbol=stock_code, + start_date=start_date, + end_date=end_date, + adjust="qfq" + ) + # A股数据列名映射 + elif market_type == 'HK': + df = ak.stock_hk_daily( + symbol=stock_code, + adjust="qfq" + ) + elif market_type == 'US': + df = ak.stock_us_hist( + symbol=stock_code, + start_date=start_date, + end_date=end_date, + adjust="qfq" + ) + # elif market_type == 'CRYPTO': + # df = ak.crypto_js_spot( + # symbol=stock_code + # ) + else: + raise ValueError(f"不支持的市场类型: {market_type}") + + # 重命名列名以匹配分析需求 + df = df.rename(columns={ + "日期": "date", + "开盘": "open", + "收盘": "close", + "最高": "high", + "最低": "low", + "成交量": "volume" + }) + + # 确保日期格式正确 + df['date'] = pd.to_datetime(df['date']) + + # 数据类型转换 + numeric_columns = ['open', 'close', 'high', 'low', 'volume'] + df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce') + + # 删除空值 + df = df.dropna() + + return df.sort_values('date') + + except Exception as e: + raise Exception(f"获取股票数据失败: {str(e)}") + + def calculate_ema(self, series, period): + """计算指数移动平均线""" + return series.ewm(span=period, adjust=False).mean() + + def calculate_rsi(self, series, period): + """计算RSI指标""" + delta = series.diff() + gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() + loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() + rs = gain / loss + return 100 - (100 / (1 + rs)) + + def calculate_macd(self, series): + """计算MACD指标""" + exp1 = series.ewm(span=12, adjust=False).mean() + exp2 = series.ewm(span=26, adjust=False).mean() + macd = exp1 - exp2 + signal = macd.ewm(span=9, adjust=False).mean() + hist = macd - signal + return macd, signal, hist + + def calculate_bollinger_bands(self, series, period, std_dev): + """计算布林带""" + middle = series.rolling(window=period).mean() + std = series.rolling(window=period).std() + upper = middle + (std * std_dev) + lower = middle - (std * std_dev) + return upper, middle, lower + + def calculate_atr(self, df, period): + """计算ATR指标""" + high = df['high'] + low = df['low'] + close = df['close'].shift(1) + + tr1 = high - low + tr2 = abs(high - close) + tr3 = abs(low - close) + + tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) + return tr.rolling(window=period).mean() + + def calculate_indicators(self, df): + """计算技术指标""" + try: + # 计算移动平均线 + df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short']) + df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium']) + df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long']) + + # 计算RSI + df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period']) + + # 计算MACD + df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close']) + + # 计算布林带 + df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands( + df['close'], + self.params['bollinger_period'], + self.params['bollinger_std'] + ) + + # 成交量分析 + df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean() + df['Volume_Ratio'] = df['volume'] / df['Volume_MA'] + + # 计算ATR和波动率 + df['ATR'] = self.calculate_atr(df, self.params['atr_period']) + df['Volatility'] = df['ATR'] / df['close'] * 100 + + # 动量指标 + df['ROC'] = df['close'].pct_change(periods=10) * 100 + + return df + + except Exception as e: + print(f"计算技术指标时出错: {str(e)}") + raise + + def calculate_score(self, df): + """计算股票评分""" + try: + score = 0 + latest = df.iloc[-1] + + # 趋势得分 (30分) + if latest['MA5'] > latest['MA20']: + score += 15 + if latest['MA20'] > latest['MA60']: + score += 15 + + # RSI得分 (20分) + if 30 <= latest['RSI'] <= 70: + score += 20 + elif latest['RSI'] < 30: # 超卖 + score += 15 + + # MACD得分 (20分) + if latest['MACD'] > latest['Signal']: + score += 20 + + # 成交量得分 (30分) + if latest['Volume_Ratio'] > 1.5: + score += 30 + elif latest['Volume_Ratio'] > 1: + score += 15 + + return score + + except Exception as e: + print(f"计算评分时出错: {str(e)}") + raise + + def get_ai_analysis(self, df, stock_code): + """使用 OpenAI 进行 AI 分析""" + try: + recent_data = df.tail(14).to_dict('records') + + technical_summary = { + 'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward', + 'volatility': f"{df.iloc[-1]['Volatility']:.2f}%", + 'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing', + 'rsi_level': df.iloc[-1]['RSI'] + } + + prompt = f""" + 分析股票 {stock_code}: + + 技术指标概要: + {technical_summary} + + 近14日交易数据: + {recent_data} + + 请提供: + 1. 趋势分析(包含支撑位和压力位) + 2. 成交量分析及其含义 + 3. 风险评估(包含波动率分析) + 4. 短期和中期目标价位 + 5. 关键技术位分析 + 6. 具体交易建议(包含止损位) + + 请基于技术指标和市场动态进行分析,给出具体数据支持。 + """ + + # OpenAI API 调用 + api_urls = [ + f"{self.API_URL}/chat/completions", + f"{self.API_URL}/v1/chat/completions" + ] + + last_error = None + for api_url in api_urls: + try: + response = requests.post( + api_url, + headers={ + "Authorization": f"Bearer {self.API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": self.API_MODEL, + "messages": [{"role": "user", "content": prompt}] + }, + timeout=30 + ) + + if response.status_code == 200: + return response.json()['choices'][0]['message']['content'] + else: + last_error = f"API 错误: {response.status_code} - {response.text}" + continue + + except Exception as e: + last_error = str(e) + continue + + print(f"AI 分析暂时无法使用: {last_error}") + return f"AI 分析暂时无法使用: {last_error}" + + except Exception as e: + print(f"AI 分析发生错误: {str(e)}") + return f"AI 分析过程中发生错误: {str(e)}" + + def get_recommendation(self, score): + """根据得分给出建议""" + if score >= 80: + return '强烈推荐买入' + elif score >= 60: + return '建议买入' + elif score >= 40: + return '观望' + elif score >= 20: + return '建议卖出' + else: + return '强烈建议卖出' + + def analyze_stock(self, stock_code, market_type='A'): + """分析单个股票""" + try: + # 获取股票数据 + df = self.get_stock_data(stock_code, market_type) + + # 计算技术指标 + df = self.calculate_indicators(df) + + # 评分系统 + score = self.calculate_score(df) + + # 获取最新数据 + latest = df.iloc[-1] + prev = df.iloc[-2] + + # 生成报告(保持原有格式) + report = { + 'stock_code': stock_code, + 'analysis_date': datetime.now().strftime('%Y-%m-%d'), + 'score': score, + 'price': latest['close'], + 'price_change': (latest['close'] - prev['close']) / prev['close'] * 100, + 'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN', + 'rsi': latest['RSI'], + 'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL', + 'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL', + 'recommendation': self.get_recommendation(score), + 'ai_analysis': self.get_ai_analysis(df, stock_code) + } + + return report + + except Exception as e: + print(f"分析股票时出错: {str(e)}") + raise + + def scan_market(self, stock_list, min_score=60, market_type='A'): + """扫描市场,寻找符合条件的股票""" + recommendations = [] + + for stock_code in stock_list: + try: + report = self.analyze_stock(stock_code, market_type) + if report['score'] >= min_score: + recommendations.append(report) + except Exception as e: + print(f"分析股票 {stock_code} 时出错: {str(e)}") + continue + + # 按得分排序 + recommendations.sort(key=lambda x: x['score'], reverse=True) + return recommendations diff --git a/templates/index.html b/templates/index.html index aee282d..3074210 100644 --- a/templates/index.html +++ b/templates/index.html @@ -1,451 +1,579 @@ - - - - - - 股票分析系统 - - - -
-

股票分析系统

- - - {% if announcement %} -
-
-
-
- - - -
-
-

- {{ announcement }} -

-
-
-
-
- {% endif %} - -
- -
-

股票批量分析

- - -
- - -
- - - - - - -
- - -
- - - - -
- - -
-
-

分析结果

- -
-
-
- - - - - - + + + + + + 股票分析系统 + + + +
+

股票分析系统

+ + + {% if announcement %} +
+
+
+
+ + + +
+
+

+ {{ announcement }} +

+
+
+
+
+ {% endif %} + +
+ +
+

股票批量分析

+ + +
+
+

API配置

+ +
+ + +
+ + +
+ + +
+ + + + + + +
+ + +
+ + + + +
+ + +
+
+

分析结果

+ +
+
+
+ + + + + + + + + \ No newline at end of file diff --git a/web_server.py b/web_server.py index b6152fe..dc54012 100644 --- a/web_server.py +++ b/web_server.py @@ -3,6 +3,8 @@ from stock_analyzer import StockAnalyzer from us_stock_service import USStockService import threading import os +import traceback +import requests app = Flask(__name__) analyzer = StockAnalyzer() @@ -11,7 +13,14 @@ us_stock_service = USStockService() @app.route('/') def index(): announcement = os.getenv('ANNOUNCEMENT_TEXT') or None - return render_template('index.html', announcement=announcement) + # 获取默认API配置信息 + default_api_url = os.getenv('API_URL', '') + default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo') + # 不传递API_KEY到前端,出于安全考虑 + return render_template('index.html', + announcement=announcement, + default_api_url=default_api_url, + default_api_model=default_api_model) @app.route('/analyze', methods=['POST']) def analyze(): @@ -20,13 +29,26 @@ def analyze(): stock_codes = data.get('stock_codes', []) market_type = data.get('market_type', 'A') + # 获取自定义API配置 + custom_api_url = data.get('api_url') + custom_api_key = data.get('api_key') + custom_api_model = data.get('api_model') + + # 创建新的分析器实例,使用自定义配置 + custom_analyzer = StockAnalyzer( + custom_api_url=custom_api_url, + custom_api_key=custom_api_key, + custom_api_model=custom_api_model + ) + if not stock_codes: return jsonify({'error': '请输入代码'}), 400 results = [] for stock_code in stock_codes: try: - result = analyzer.analyze_stock(stock_code.strip(), market_type) + # 使用自定义配置的分析器 + result = custom_analyzer.analyze_stock(stock_code.strip(), market_type) results.append(result) except Exception as e: print(f"分析股票 {stock_code} 失败: {str(e)}") @@ -55,8 +77,59 @@ def search_us_stocks(): print(f"搜索美股代码时出错: {str(e)}") return jsonify({'error': str(e)}), 500 +@app.route('/test_api_connection', methods=['POST']) +def test_api_connection(): + """测试API连接""" + try: + data = request.json + api_url = data.get('api_url') + api_key = data.get('api_key') + api_model = data.get('api_model') + + if not api_url: + return jsonify({'error': '请提供API URL'}), 400 + + if not api_key: + return jsonify({'error': '请提供API Key'}), 400 + + # 构建API URL + test_url = api_url + if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')): + if api_url.endswith('/v1'): + test_url = f"{api_url}/chat/completions" + elif api_url.endswith('/'): + test_url = f"{api_url}chat/completions" + else: + test_url = f"{api_url}/v1/chat/completions" + + # 发送测试请求 + response = requests.post( + test_url, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + }, + json={ + "model": api_model or "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."} + ], + "max_tokens": 20 + }, + timeout=10 + ) + + # 检查响应 + if response.status_code == 200: + return jsonify({'success': True, 'message': '连接成功'}) + else: + error_message = response.json().get('error', {}).get('message', '未知错误') + return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400 + + except requests.exceptions.RequestException as e: + return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400 + except Exception as e: + return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500 + if __name__ == '__main__': - app.run(host='0.0.0.0', port=8888, debug=True) - - - \ No newline at end of file + app.run(host='0.0.0.0', port=8888, debug=True) \ No newline at end of file From 6df78314d6e069cf8344b4a2b17e3c381ead13d3 Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Tue, 4 Mar 2025 15:03:08 +0800 Subject: [PATCH 8/8] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- logger.py | 58 ++++++ stock_analyzer.py | 331 ++++++++++++++++++++++++----- templates/index.html | 485 ++++++++++++++++++++++++++++++++++--------- tests/test_stream.py | 137 ++++++++++++ web_server.py | 83 ++++++-- 5 files changed, 928 insertions(+), 166 deletions(-) create mode 100644 logger.py create mode 100644 tests/test_stream.py diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..8313c52 --- /dev/null +++ b/logger.py @@ -0,0 +1,58 @@ +from loguru import logger +import sys +import os +from datetime import datetime + +# 获取当前时间作为日志文件名的一部分 +current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + +# 创建日志目录 +log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs") +os.makedirs(log_dir, exist_ok=True) + +# 配置日志 +logger.remove() # 移除默认的处理器 + +# 添加标准输出处理器(控制台) +logger.add( + sys.stdout, + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}", + level="DEBUG" +) + +# 添加文件处理器(debug级别) +logger.add( + os.path.join(log_dir, f"debug_{current_time}.log"), + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}", + level="DEBUG", + rotation="100 MB", + retention="1 week" +) + +# 添加文件处理器(error级别) +logger.add( + os.path.join(log_dir, f"error_{current_time}.log"), + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}", + level="ERROR", + rotation="100 MB", + retention="1 month" +) + +# 添加流处理器(用于记录流式输出) +logger.add( + os.path.join(log_dir, f"stream_{current_time}.log"), + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {message}", + filter=lambda record: "STREAM" in record["extra"], + level="INFO" +) + +# 创建专用于流式输出的日志器 +stream_logger = logger.bind(STREAM=True) + +def get_logger(): + """获取通用日志器""" + return logger + +def get_stream_logger(): + """获取流式输出专用日志器""" + return stream_logger diff --git a/stock_analyzer.py b/stock_analyzer.py index 105278d..7d040ca 100644 --- a/stock_analyzer.py +++ b/stock_analyzer.py @@ -3,8 +3,14 @@ import numpy as np from datetime import datetime, timedelta import os import requests -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Generator from dotenv import load_dotenv +import json +from logger import get_logger, get_stream_logger + +# 获取日志器 +logger = get_logger() +stream_logger = get_stream_logger() class StockAnalyzer: def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None): @@ -17,6 +23,8 @@ class StockAnalyzer: self.API_KEY = custom_api_key or os.getenv('API_KEY') self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo') + logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}") + # 配置参数 self.params = { 'ma_periods': {'short': 5, 'medium': 20, 'long': 60}, @@ -205,9 +213,10 @@ class StockAnalyzer: print(f"计算评分时出错: {str(e)}") raise - def get_ai_analysis(self, df, stock_code): + def get_ai_analysis(self, df, stock_code, stream=False): """使用 OpenAI 进行 AI 分析""" try: + logger.info(f"开始AI分析股票 {stock_code}, 流式模式: {stream}") recent_data = df.tail(14).to_dict('records') technical_summary = { @@ -237,47 +246,202 @@ class StockAnalyzer: 请基于技术指标和市场动态进行分析,给出具体数据支持。 """ - # OpenAI API 调用 - api_urls = [ - f"{self.API_URL}/chat/completions", - f"{self.API_URL}/v1/chat/completions" - ] + logger.debug(f"生成的AI分析提示词: {prompt[:100]}...") - last_error = None - for api_url in api_urls: + # 检查API配置 + if not self.API_URL: + error_msg = "API URL未配置,无法进行AI分析" + logger.error(error_msg) + return error_msg if not stream else (yield json.dumps({"error": error_msg})) + + if not self.API_KEY: + error_msg = "API Key未配置,无法进行AI分析" + logger.error(error_msg) + return error_msg if not stream else (yield json.dumps({"error": error_msg})) + + # 标准化API URL + api_url = self.API_URL + if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')): + if api_url.endswith('/v1'): + api_url = f"{api_url}/chat/completions" + elif api_url.endswith('/'): + api_url = f"{api_url}v1/chat/completions" + else: + api_url = f"{api_url}/v1/chat/completions" + + logger.debug(f"标准化后的API URL: {api_url}") + + # 构建请求头和请求体 + headers = { + "Authorization": f"Bearer {self.API_KEY}", + "Content-Type": "application/json" + } + + payload = { + "model": self.API_MODEL, + "messages": [{"role": "user", "content": prompt}] + } + + # 流式处理设置 + if stream: + logger.debug(f"配置流式参数,使用API URL: {api_url}") + payload["stream"] = True # 明确设置stream参数为True + + try: + logger.debug(f"发起流式API请求: {api_url}") + logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}") + + response = requests.post( + api_url, + headers=headers, + json=payload, + timeout=60, # 增加超时时间 + stream=True + ) + + logger.debug(f"API流式响应状态码: {response.status_code}") + + if response.status_code == 200: + logger.info(f"成功获取API流式响应,开始处理") + yield from self._process_ai_stream(response, stock_code) + else: + try: + error_response = response.json() + error_text = json.dumps(error_response, indent=2) + except: + error_text = response.text[:500] if response.text else "无响应内容" + + error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}" + logger.error(error_msg) + yield json.dumps({"stock_code": stock_code, "error": error_msg}) + + except Exception as e: + error_msg = f"流式API请求异常: {str(e)}" + logger.error(error_msg) + logger.exception(e) + yield json.dumps({"stock_code": stock_code, "error": error_msg}) + else: + # 非流式处理 + logger.debug(f"发起非流式API请求: {api_url}") + try: response = requests.post( api_url, - headers={ - "Authorization": f"Bearer {self.API_KEY}", - "Content-Type": "application/json" - }, - json={ - "model": self.API_MODEL, - "messages": [{"role": "user", "content": prompt}] - }, - timeout=30 + headers=headers, + json=payload, + timeout=60 ) + logger.debug(f"API非流式响应状态码: {response.status_code}") + if response.status_code == 200: - return response.json()['choices'][0]['message']['content'] + api_response = response.json() + content = api_response['choices'][0]['message']['content'] + logger.info(f"成功获取AI分析结果,长度: {len(content)}") + logger.debug(f"AI分析结果前100字符: {content[:100]}...") + return content else: - last_error = f"API 错误: {response.status_code} - {response.text}" - continue + try: + error_response = response.json() + error_text = json.dumps(error_response, indent=2) + except: + error_text = response.text[:500] if response.text else "无响应内容" + + error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}" + logger.error(error_msg) + return error_msg except Exception as e: - last_error = str(e) - continue - - print(f"AI 分析暂时无法使用: {last_error}") - return f"AI 分析暂时无法使用: {last_error}" + error_msg = f"非流式API请求异常: {str(e)}" + logger.error(error_msg) + logger.exception(e) + return error_msg except Exception as e: - print(f"AI 分析发生错误: {str(e)}") - return f"AI 分析过程中发生错误: {str(e)}" + error_msg = f"AI 分析过程中发生错误: {str(e)}" + logger.error(error_msg) + logger.exception(e) + if stream: + logger.debug("在流式模式下返回异常信息") + error_json = json.dumps({"stock_code": stock_code, "error": error_msg}) + stream_logger.info(f"流式异常输出: {error_json}") + yield error_json + else: + return error_msg + + def _process_ai_stream(self, response, stock_code) -> Generator[str, None, None]: + """处理AI流式响应""" + logger.info(f"开始处理股票 {stock_code} 的AI流式响应") + buffer = "" + chunk_count = 0 + + try: + for line in response.iter_lines(): + if line: + line = line.decode('utf-8') + stream_logger.info(f"原始流式行: {line}") + + # 跳过保持连接的空行 + if line.strip() == '': + logger.debug("跳过空行") + continue + + # 数据行通常以"data: "开头 + if line.startswith('data: '): + data_content = line[6:] # 移除 "data: " 前缀 + stream_logger.info(f"数据内容: {data_content}") + + # 检查是否为流的结束 + if data_content.strip() == '[DONE]': + logger.debug("收到流结束标记 [DONE]") + break + + try: + json_data = json.loads(data_content) + logger.debug(f"解析的JSON数据: {json.dumps(json_data)[:100]}...") + + if 'choices' in json_data: + delta = json_data['choices'][0].get('delta', {}) + content = delta.get('content', '') + + if content: + chunk_count += 1 + buffer += content + logger.debug(f"收到内容片段 #{chunk_count}: {content}") + stream_logger.info(f"发送内容片段: {content}") + + # 创建包含AI分析片段的JSON + chunk_json = json.dumps({ + "stock_code": stock_code, + "ai_analysis_chunk": content + }) + stream_logger.info(f"流式输出JSON: {chunk_json}") + yield chunk_json + except json.JSONDecodeError as e: + logger.error(f"JSON解析错误: {str(e)}, 行内容: {data_content}") + # 忽略无法解析的JSON + pass + else: + logger.warning(f"收到非'data:'开头的行: {line}") + + logger.info(f"AI流式处理完成,共收到 {chunk_count} 个内容片段,总长度: {len(buffer)}") + + # 如果buffer不为空,最后一次发送完整内容 + if buffer and not buffer.endswith('\n'): + logger.debug("发送换行符") + yield json.dumps({"stock_code": stock_code, "ai_analysis_chunk": "\n"}) + + except Exception as e: + error_msg = f"处理AI流式响应时出错: {str(e)}" + logger.error(error_msg) + logger.exception(e) + yield json.dumps({"stock_code": stock_code, "error": error_msg}) + + def get_recommendation(self, score): """根据得分给出建议""" + logger.debug(f"根据评分 {score} 生成投资建议") if score >= 80: return '强烈推荐买入' elif score >= 60: @@ -288,18 +452,24 @@ class StockAnalyzer: return '建议卖出' else: return '强烈建议卖出' - - def analyze_stock(self, stock_code, market_type='A'): + + def analyze_stock(self, stock_code, market_type='A', stream=False): """分析单个股票""" try: + logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}, 流式模式: {stream}") + # 获取股票数据 + logger.debug(f"获取股票 {stock_code} 数据") df = self.get_stock_data(stock_code, market_type) # 计算技术指标 + logger.debug(f"计算股票 {stock_code} 技术指标") df = self.calculate_indicators(df) # 评分系统 + logger.debug(f"计算股票 {stock_code} 评分") score = self.calculate_score(df) + logger.info(f"股票 {stock_code} 评分结果: {score}") # 获取最新数据 latest = df.iloc[-1] @@ -316,29 +486,92 @@ class StockAnalyzer: 'rsi': latest['RSI'], 'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL', 'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL', - 'recommendation': self.get_recommendation(score), - 'ai_analysis': self.get_ai_analysis(df, stock_code) + 'recommendation': self.get_recommendation(score) } + logger.debug(f"生成股票 {stock_code} 基础报告: {json.dumps(report)[:100]}...") - return report + if stream: + logger.info(f"以流式模式返回股票 {stock_code} 分析结果") + # 先返回基本报告结构 + base_report = dict(report) + base_report['ai_analysis'] = '' + base_report_json = json.dumps(base_report) + logger.debug(f"基础报告JSON: {base_report_json[:100]}...") + stream_logger.info(f"发送基础报告: {base_report_json}") + yield base_report_json + + # 然后流式返回AI分析部分 + logger.debug(f"开始获取股票 {stock_code} 的流式AI分析") + ai_chunks_count = 0 + for ai_chunk in self.get_ai_analysis(df, stock_code, stream=True): + ai_chunks_count += 1 + stream_logger.info(f"股票 {stock_code} 流式块 #{ai_chunks_count}: {ai_chunk}") + yield ai_chunk + logger.info(f"股票 {stock_code} 流式AI分析完成,共发送 {ai_chunks_count} 个块") + else: + logger.info(f"以非流式模式返回股票 {stock_code} 分析结果") + logger.debug(f"开始获取股票 {stock_code} 的AI分析") + report['ai_analysis'] = self.get_ai_analysis(df, stock_code) + logger.debug(f"AI分析结果长度: {len(report['ai_analysis'])}") + return report except Exception as e: - print(f"分析股票时出错: {str(e)}") - raise + error_msg = f"分析股票 {stock_code} 时出错: {str(e)}" + logger.error(error_msg) + logger.exception(e) - def scan_market(self, stock_list, min_score=60, market_type='A'): + if stream: + error_json = json.dumps({'stock_code': stock_code, 'error': error_msg}) + stream_logger.info(f"流式错误输出: {error_json}") + yield error_json + else: + raise + + def scan_market(self, stock_list, min_score=60, market_type='A', stream=False): """扫描市场,寻找符合条件的股票""" - recommendations = [] + logger.info(f"开始扫描市场,股票数量: {len(stock_list)}, 最低分数: {min_score}, 市场: {market_type}, 流式模式: {stream}") - for stock_code in stock_list: - try: - report = self.analyze_stock(stock_code, market_type) - if report['score'] >= min_score: - recommendations.append(report) - except Exception as e: - print(f"分析股票 {stock_code} 时出错: {str(e)}") - continue - - # 按得分排序 - recommendations.sort(key=lambda x: x['score'], reverse=True) - return recommendations + if not stream: + recommendations = [] + + for stock_code in stock_list: + try: + logger.debug(f"分析股票: {stock_code}") + report = self.analyze_stock(stock_code, market_type) + if report['score'] >= min_score: + logger.info(f"股票 {stock_code} 评分 {report['score']} >= {min_score},添加到推荐列表") + recommendations.append(report) + else: + logger.debug(f"股票 {stock_code} 评分 {report['score']} < {min_score},不添加到推荐列表") + except Exception as e: + logger.error(f"分析股票 {stock_code} 时出错: {str(e)}") + logger.exception(e) + continue + + # 按得分排序 + recommendations.sort(key=lambda x: x['score'], reverse=True) + logger.info(f"扫描完成,找到 {len(recommendations)} 个推荐股票") + return recommendations + else: + # 流式处理每个股票 + logger.info(f"开始流式扫描 {len(stock_list)} 只股票") + stock_count = 0 + for stock_code in stock_list: + stock_count += 1 + logger.debug(f"流式分析股票 {stock_code} ({stock_count}/{len(stock_list)})") + try: + # 分析单只股票并获取流式结果 + chunk_count = 0 + for chunk in self.analyze_stock(stock_code, market_type, stream=True): + chunk_count += 1 + stream_logger.info(f"股票 {stock_code} 流式块 #{chunk_count}: {chunk}") + yield chunk + logger.debug(f"股票 {stock_code} 流式分析完成,共 {chunk_count} 个块") + except Exception as e: + error_msg = f"分析股票 {stock_code} 时出错: {str(e)}" + logger.error(error_msg) + logger.exception(e) + error_json = json.dumps({'stock_code': stock_code, 'error': error_msg}) + stream_logger.info(f"流式错误输出: {error_json}") + yield error_json + logger.info(f"流式扫描完成,处理了 {stock_count} 只股票") diff --git a/templates/index.html b/templates/index.html index 3074210..74f9a6b 100644 --- a/templates/index.html +++ b/templates/index.html @@ -328,6 +328,7 @@ @@ -575,5 +773,94 @@ }); }); + \ No newline at end of file diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000..e501081 --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,137 @@ +import os +import requests +import json +from logger import get_logger, get_stream_logger +from dotenv import load_dotenv + +# 获取日志器 +logger = get_logger() +stream_logger = get_stream_logger() + +def test_api_stream(): + """ + 测试API流式响应功能 + """ + # 加载环境变量 + load_dotenv() + + # 获取API配置 + api_url = os.getenv('API_URL') + api_key = os.getenv('API_KEY') + api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo') + + logger.info(f"开始测试API流式响应,API URL: {api_url}, MODEL: {api_model}") + + # 检查API配置 + if not api_url: + logger.error("API URL未配置,无法进行测试") + return + + if not api_key: + logger.error("API Key未配置,无法进行测试") + return + + # 标准化API URL + if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')): + if api_url.endswith('/v1'): + api_url = f"{api_url}/chat/completions" + elif api_url.endswith('/'): + api_url = f"{api_url}v1/chat/completions" + else: + api_url = f"{api_url}/v1/chat/completions" + + logger.debug(f"标准化后的API URL: {api_url}") + + # 构建简单的测试提示 + prompt = "这是一个API流式响应测试。请给出一个简短的股票分析样例。" + + # 构建请求头和请求体 + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": api_model, + "messages": [{"role": "user", "content": prompt}], + "stream": True # 明确设置stream参数为True + } + + logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}") + + try: + logger.info(f"发起流式API请求: {api_url}") + + response = requests.post( + api_url, + headers=headers, + json=payload, + timeout=60, + stream=True + ) + + logger.info(f"API流式响应状态码: {response.status_code}") + logger.debug(f"响应头: {response.headers}") + + if response.status_code == 200: + logger.info("成功获取API流式响应,开始处理") + + buffer = "" + chunk_count = 0 + + for line in response.iter_lines(): + if line: + line_str = line.decode('utf-8') + logger.info(f"原始流式行: {line_str}") + + # 跳过保持连接的空行 + if line_str.strip() == '': + logger.debug("跳过空行") + continue + + # 数据行通常以"data: "开头 + if line_str.startswith('data: '): + data_content = line_str[6:].strip() # 移除 "data: " 前缀并去除前后空格 + logger.info(f"数据内容: {data_content}") + + # 检查是否为流的结束 + if data_content == '[DONE]': + logger.info("收到流结束标记 [DONE]") + break + + try: + # 解析JSON数据 + json_data = json.loads(data_content) + logger.debug(f"JSON结构: {json.dumps(json_data, indent=2)}") + + if 'choices' in json_data: + delta = json_data['choices'][0].get('delta', {}) + content = delta.get('content', '') + + if content: + chunk_count += 1 + buffer += content + logger.info(f"内容片段 #{chunk_count}: {content}") + except json.JSONDecodeError as e: + logger.error(f"JSON解析错误: {e}, 内容: {data_content}") + else: + logger.warning(f"收到非'data:'开头的行: {line_str}") + + logger.info(f"流式处理完成,共收到 {chunk_count} 个内容片段") + logger.info(f"完整内容:\n{buffer}") + + else: + try: + error_response = response.json() + error_text = json.dumps(error_response, indent=2) + except: + error_text = response.text[:500] if response.text else "无响应内容" + + logger.error(f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}") + + except Exception as e: + logger.error(f"测试过程中发生异常: {str(e)}") + logger.exception(e) + +if __name__ == "__main__": + test_api_stream() diff --git a/web_server.py b/web_server.py index dc54012..f1c9f89 100644 --- a/web_server.py +++ b/web_server.py @@ -1,10 +1,15 @@ -from flask import Flask, render_template, request, jsonify +from flask import Flask, render_template, request, jsonify, Response, stream_with_context from stock_analyzer import StockAnalyzer from us_stock_service import USStockService import threading import os import traceback import requests +from logger import get_logger, get_stream_logger + +# 获取日志器 +logger = get_logger() +stream_logger = get_stream_logger() app = Flask(__name__) analyzer = StockAnalyzer() @@ -25,15 +30,20 @@ def index(): @app.route('/analyze', methods=['POST']) def analyze(): try: + logger.info("开始处理分析请求") data = request.json stock_codes = data.get('stock_codes', []) market_type = data.get('market_type', 'A') + logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}") + # 获取自定义API配置 custom_api_url = data.get('api_url') custom_api_key = data.get('api_key') custom_api_model = data.get('api_model') + logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}") + # 创建新的分析器实例,使用自定义配置 custom_analyzer = StockAnalyzer( custom_api_url=custom_api_url, @@ -42,26 +52,50 @@ def analyze(): ) if not stock_codes: + logger.warning("未提供股票代码") return jsonify({'error': '请输入代码'}), 400 + + # 使用流式响应 + def generate(): + if len(stock_codes) == 1: + # 单个股票分析流式处理 + stock_code = stock_codes[0].strip() + logger.info(f"开始单股流式分析: {stock_code}") + + stream_logger.info(f"初始化单股分析流: {stock_code}") + init_message = f'{{"stream_type": "single", "stock_code": "{stock_code}"}}\n' + stream_logger.info(f"发送初始化消息: {init_message}") + yield init_message + + for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True): + stream_logger.info(f"流式输出块: {chunk}") + yield chunk + '\n' + else: + # 批量分析流式处理 + logger.info(f"开始批量流式分析: {stock_codes}") + + stream_logger.info(f"初始化批量分析流: {stock_codes}") + init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes}}}\n' + stream_logger.info(f"发送初始化消息: {init_message}") + yield init_message + + for chunk in custom_analyzer.scan_market( + [code.strip() for code in stock_codes], + min_score=0, + market_type=market_type, + stream=True + ): + stream_logger.info(f"流式输出块: {chunk}") + yield chunk + '\n' + + logger.info("成功创建流式响应生成器") + return Response(stream_with_context(generate()), mimetype='application/json') - results = [] - for stock_code in stock_codes: - try: - # 使用自定义配置的分析器 - result = custom_analyzer.analyze_stock(stock_code.strip(), market_type) - results.append(result) - except Exception as e: - print(f"分析股票 {stock_code} 失败: {str(e)}") - print(f"详细错误: {traceback.format_exc()}") - results.append({ - 'code': stock_code, - 'error': f"分析失败: {str(e)}" - }) - - return jsonify({'results': results}) except Exception as e: - print(f"分析股票时出错: {str(e)}") - return jsonify({'error': str(e)}), 500 + error_msg = f"分析股票时出错: {str(e)}" + logger.error(error_msg) + logger.exception(e) + return jsonify({'error': error_msg}), 500 @app.route('/search_us_stocks', methods=['GET']) def search_us_stocks(): @@ -81,15 +115,20 @@ def search_us_stocks(): def test_api_connection(): """测试API连接""" try: + logger.info("开始测试API连接") data = request.json api_url = data.get('api_url') api_key = data.get('api_key') api_model = data.get('api_model') + logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}") + if not api_url: + logger.warning("未提供API URL") return jsonify({'error': '请提供API URL'}), 400 if not api_key: + logger.warning("未提供API Key") return jsonify({'error': '请提供API Key'}), 400 # 构建API URL @@ -101,6 +140,8 @@ def test_api_connection(): test_url = f"{api_url}chat/completions" else: test_url = f"{api_url}/v1/chat/completions" + + logger.debug(f"完整API测试URL: {test_url}") # 发送测试请求 response = requests.post( @@ -121,15 +162,21 @@ def test_api_connection(): # 检查响应 if response.status_code == 200: + logger.info(f"API连接测试成功: {response.status_code}") return jsonify({'success': True, 'message': '连接成功'}) else: error_message = response.json().get('error', {}).get('message', '未知错误') + logger.warning(f"API连接测试失败: {response.status_code} - {error_message}") return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400 except requests.exceptions.RequestException as e: + logger.error(f"API连接请求错误: {str(e)}") return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400 except Exception as e: + logger.error(f"测试API连接时出错: {str(e)}") + logger.exception(e) return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500 if __name__ == '__main__': + logger.info("股票分析系统启动") app.run(host='0.0.0.0', port=8888, debug=True) \ No newline at end of file