From 1e53d16b3aa655f778e62324d3cfba1bf0252c3f Mon Sep 17 00:00:00 2001 From: Cassianvale Date: Thu, 6 Mar 2025 17:11:15 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- services/__init__.py | 2 + services/ai_analyzer.py | 269 +++++++++++++++++++++++++++++ services/fund_service_async.py | 225 ++++++++++++++++++++++++ services/stock_analyzer_service.py | 178 +++++++++++++++++++ services/stock_data_provider.py | 184 ++++++++++++++++++++ services/stock_scorer.py | 128 ++++++++++++++ services/technical_indicator.py | 187 ++++++++++++++++++++ services/us_stock_service_async.py | 151 ++++++++++++++++ web_server.py | 94 ++++++---- 9 files changed, 1380 insertions(+), 38 deletions(-) create mode 100644 services/__init__.py create mode 100644 services/ai_analyzer.py create mode 100644 services/fund_service_async.py create mode 100644 services/stock_analyzer_service.py create mode 100644 services/stock_data_provider.py create mode 100644 services/stock_scorer.py create mode 100644 services/technical_indicator.py create mode 100644 services/us_stock_service_async.py diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..7267b70 --- /dev/null +++ b/services/__init__.py @@ -0,0 +1,2 @@ +# services包初始化文件 +# 用于组织股票分析服务的各个模块 \ No newline at end of file diff --git a/services/ai_analyzer.py b/services/ai_analyzer.py new file mode 100644 index 0000000..93c063f --- /dev/null +++ b/services/ai_analyzer.py @@ -0,0 +1,269 @@ +import pandas as pd +import numpy as np +import os +import json +import asyncio +import httpx +from typing import Dict, List, Optional, Any, Generator, AsyncGenerator +from dotenv import load_dotenv +from logger import get_logger +from utils.api_utils import APIUtils + +# 获取日志器 +logger = get_logger() + +class AIAnalyzer: + """ + 异步AI分析服务 + 负责调用AI API对股票数据进行分析 + """ + + def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None): + """ + 初始化AI分析服务 + + Args: + custom_api_url: 自定义API URL + custom_api_key: 自定义API密钥 + custom_api_model: 自定义API模型 + custom_api_timeout: 自定义API超时时间 + """ + # 加载环境变量 + load_dotenv() + + # 设置API配置 + self.API_URL = custom_api_url or os.getenv('API_URL') + self.API_KEY = custom_api_key or os.getenv('API_KEY') + self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo') + self.API_TIMEOUT = int(custom_api_timeout or os.getenv('API_TIMEOUT', 60)) + + logger.debug(f"初始化AIAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}, API_TIMEOUT={self.API_TIMEOUT}") + + async def get_ai_analysis(self, df: pd.DataFrame, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]: + """ + 对股票数据进行AI分析 + + Args: + df: 包含技术指标的DataFrame + stock_code: 股票代码 + market_type: 市场类型,默认为'A'股 + stream: 是否使用流式响应 + + Returns: + 异步生成器,生成分析结果字符串 + """ + try: + logger.info(f"开始AI分析 {stock_code}, 流式模式: {stream}") + recent_data = df.tail(14).to_dict('records') + + technical_summary = { + 'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward', + 'volatility': f"{df.iloc[-1]['Volatility']:.2f}%", + 'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing', + 'rsi_level': df.iloc[-1]['RSI'] + } + + # 根据市场类型调整分析提示 + if market_type in ['ETF', 'LOF']: + prompt = f""" + 分析基金 {stock_code}: + + 技术指标概要: + {technical_summary} + + 近14日交易数据: + {recent_data} + + 请分析该基金的技术面状况,包括: + 1. 趋势分析:判断基金当前的趋势方向 + 2. 动量分析:基于RSI和交易量评估基金动量 + 3. 支撑与阻力位:确定关键价格位 + 4. 技术面总结 + 5. 投资建议 + + 将分析结果格式化为JSON,像这样: + {{ + "trend_analysis": "趋势分析结果...", + "momentum_analysis": "动量分析结果...", + "support_resistance": "支撑阻力位分析...", + "technical_summary": "技术面总结...", + "investment_advice": "投资建议..." + }} + """ + else: + prompt = f""" + 分析股票 {stock_code}: + + 技术指标概要: + {technical_summary} + + 近14日交易数据: + {recent_data} + + 请分析该股票的技术面状况,包括: + 1. 趋势分析:当前趋势方向及强度 + 2. 动量分析:基于MACD、RSI等指标 + 3. 支撑与阻力位:关键价格位分析 + 4. 成交量分析:交易量的变化及意义 + 5. 波动性评估:ATR和波动率分析 + 6. 技术面总结 + 7. 投资建议:根据技术分析给出操作建议 + + 将分析结果格式化为JSON,像这样: + {{ + "trend_analysis": "趋势分析结果...", + "momentum_analysis": "动量分析结果...", + "support_resistance": "支撑阻力位分析...", + "volume_analysis": "成交量分析...", + "volatility_assessment": "波动性评估...", + "technical_summary": "技术面总结...", + "investment_advice": "投资建议..." + }} + """ + + # 格式化API URL + api_url = APIUtils.format_api_url(self.API_URL) + + # 准备请求数据 + request_data = { + "model": self.API_MODEL, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.7, + "stream": stream + } + + # 准备请求头 + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.API_KEY}" + } + + # 异步请求API + async with httpx.AsyncClient(timeout=self.API_TIMEOUT) as client: + # 记录请求 + logger.debug(f"发送AI请求: URL={api_url}, MODEL={self.API_MODEL}, STREAM={stream}") + + if stream: + # 流式响应处理 + async with client.stream("POST", api_url, json=request_data, headers=headers) as response: + if response.status_code != 200: + error_text = await response.aread() + error_data = json.loads(error_text) + error_message = error_data.get('error', {}).get('message', '未知错误') + logger.error(f"AI API请求失败: {response.status_code} - {error_message}") + yield json.dumps({"error": f"API请求失败: {error_message}"}) + return + + # 处理流式响应 + buffer = "" + collected_messages = [] + + async for chunk in response.aiter_text(): + if chunk: + chunk_str = chunk.strip() + if chunk_str.startswith("data: "): + chunk_str = chunk_str[6:] # 去除"data: "前缀 + + if chunk_str == "[DONE]": + continue + + try: + # 解析数据块 + chunk_data = json.loads(chunk_str) + delta = chunk_data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + + if content: + buffer += content + # 尝试提取完整的JSON + if buffer.strip().startswith("{") and buffer.strip().endswith("}"): + try: + result_json = json.loads(buffer) + yield json.dumps({ + "stock_code": stock_code, + "analysis": result_json + }) + buffer = "" # 重置缓冲区 + except json.JSONDecodeError: + # JSON不完整,继续收集 + pass + + # 达到一定长度就输出 + if len(buffer) > 100: + yield json.dumps({ + "stock_code": stock_code, + "partial_content": buffer + }) + collected_messages.append(buffer) + buffer = "" + except json.JSONDecodeError: + # 忽略无法解析的块 + continue + + # 处理最后的缓冲区 + if buffer: + yield json.dumps({ + "stock_code": stock_code, + "partial_content": buffer + }) + collected_messages.append(buffer) + + # 尝试从整个内容中提取JSON + full_content = "".join(collected_messages) + + # 如果没有成功解析JSON,返回原始内容 + if not full_content.strip().startswith("{"): + yield json.dumps({ + "stock_code": stock_code, + "raw_analysis": full_content + }) + else: + # 非流式响应处理 + response = await client.post(api_url, json=request_data, headers=headers) + + if response.status_code != 200: + error_data = response.json() + error_message = error_data.get('error', {}).get('message', '未知错误') + logger.error(f"AI API请求失败: {response.status_code} - {error_message}") + yield json.dumps({"error": f"API请求失败: {error_message}"}) + return + + response_data = response.json() + analysis_text = response_data.get("choices", [{}])[0].get("message", {}).get("content", "") + + try: + # 尝试解析JSON + analysis_json = json.loads(analysis_text) + yield json.dumps({ + "stock_code": stock_code, + "analysis": analysis_json + }) + except json.JSONDecodeError: + # 返回原始文本 + yield json.dumps({ + "stock_code": stock_code, + "raw_analysis": analysis_text + }) + + logger.info(f"完成对 {stock_code} 的AI分析") + + except Exception as e: + logger.error(f"AI分析 {stock_code} 时出错: {str(e)}") + logger.exception(e) + yield json.dumps({"error": f"分析出错: {str(e)}"}) + + def _truncate_json_for_logging(self, json_obj, max_length=500): + """ + 截断JSON对象以便记录日志 + + Args: + json_obj: JSON对象 + max_length: 最大长度 + + Returns: + 截断后的字符串 + """ + json_str = json.dumps(json_obj, ensure_ascii=False) + if len(json_str) <= max_length: + return json_str + return json_str[:max_length] + "..." \ No newline at end of file diff --git a/services/fund_service_async.py b/services/fund_service_async.py new file mode 100644 index 0000000..8bafada --- /dev/null +++ b/services/fund_service_async.py @@ -0,0 +1,225 @@ +import asyncio +import pandas as pd +from typing import List, Dict, Any, Optional +from logger import get_logger +from datetime import datetime, timedelta + +# 获取日志器 +logger = get_logger() + +class FundServiceAsync: + """ + 异步基金服务 + 提供基金数据的异步搜索和获取功能 + """ + + def __init__(self): + """初始化异步基金服务""" + logger.debug("初始化FundServiceAsync") + + # 添加缓存 + self._etf_cache = None + self._lof_cache = None + self._cache_timestamp = None + self._cache_duration = timedelta(minutes=30) # 缓存30分钟 + + async def search_funds(self, keyword: str, market_type: str = 'ETF') -> List[Dict[str, Any]]: + """ + 异步搜索基金代码 + + Args: + keyword: 搜索关键词 + market_type: 市场类型,'ETF'或'LOF' + + Returns: + 匹配的基金列表 + """ + try: + logger.info(f"异步搜索基金: {keyword}, 类型: {market_type}") + + # 获取基金数据 + df = await self._get_funds_data(market_type) + + # 模糊匹配搜索(同时匹配代码和名称) + mask = (df['name'].str.contains(keyword, case=False, na=False) | + df['symbol'].str.contains(keyword, case=False, na=False)) + results = df[mask] + + # 格式化返回结果并处理 NaN 值 + formatted_results = [] + for _, row in results.iterrows(): + formatted_results.append({ + 'name': row['name'] if pd.notna(row['name']) else '', + 'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '', + 'price': float(row['price']) if pd.notna(row['price']) else 0.0, + 'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0, + 'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0, + 'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0, + }) + + logger.info(f"基金搜索完成,找到 {len(formatted_results)} 个匹配项") + return formatted_results + + except Exception as e: + error_msg = f"搜索基金代码失败: {str(e)}" + logger.error(error_msg) + logger.exception(e) + raise Exception(error_msg) + + async def _get_funds_data(self, market_type: str = 'ETF') -> pd.DataFrame: + """ + 异步获取基金数据,支持缓存 + + Args: + market_type: 市场类型,'ETF'或'LOF' + + Returns: + 包含基金数据的DataFrame + """ + # 检查缓存是否有效 + now = datetime.now() + cache_valid = ( + self._cache_timestamp is not None and + (now - self._cache_timestamp) < self._cache_duration + ) + + if market_type == 'ETF' and cache_valid and self._etf_cache is not None: + logger.debug("使用ETF缓存数据") + return self._etf_cache + elif market_type == 'LOF' and cache_valid and self._lof_cache is not None: + logger.debug("使用LOF缓存数据") + return self._lof_cache + + # 缓存无效,重新获取数据 + try: + logger.debug(f"从API获取{market_type}数据") + + # 使用线程池执行同步的akshare调用 + if market_type == 'ETF': + df = await asyncio.to_thread(self._get_etf_data) + self._etf_cache = df + else: + df = await asyncio.to_thread(self._get_lof_data) + self._lof_cache = df + + self._cache_timestamp = now + return df + + except Exception as e: + logger.error(f"获取{market_type}数据失败: {str(e)}") + logger.exception(e) + raise + + def _get_etf_data(self) -> pd.DataFrame: + """ + 获取ETF数据(同步方法,将被异步方法调用) + + Returns: + 包含ETF数据的DataFrame + """ + import akshare as ak + + try: + # 获取ETF基金数据 + df = ak.fund_etf_spot_em() + + # 转换列名 + df = df.rename(columns={ + "代码": "symbol", + "名称": "name", + "最新价": "price", + "涨跌额": "price_change", + "涨跌幅": "price_change_percent", + "成交量": "volume", + "流通市值": "market_value", + "总市值": "total_value", + "基金折价率": "discount_rate", + }) + + return df + + except Exception as e: + logger.error(f"获取ETF数据失败: {str(e)}") + logger.exception(e) + raise Exception(f"获取ETF数据失败: {str(e)}") + + def _get_lof_data(self) -> pd.DataFrame: + """ + 获取LOF数据(同步方法,将被异步方法调用) + + Returns: + 包含LOF数据的DataFrame + """ + import akshare as ak + + try: + # 获取LOF基金数据 + df = ak.fund_lof_spot_em() + + # 转换列名 + df = df.rename(columns={ + "代码": "symbol", + "名称": "name", + "最新价": "price", + "涨跌额": "price_change", + "涨跌幅": "price_change_percent", + "成交量": "volume", + "流通市值": "market_value", + "总市值": "total_value", + "基金折价率": "discount_rate", + }) + + return df + + except Exception as e: + logger.error(f"获取LOF数据失败: {str(e)}") + logger.exception(e) + raise Exception(f"获取LOF数据失败: {str(e)}") + + async def get_fund_detail(self, symbol: str, market_type: str = 'ETF') -> Dict[str, Any]: + """ + 异步获取单个基金详细信息 + + Args: + symbol: 基金代码 + market_type: 市场类型,'ETF'或'LOF' + + Returns: + 基金详细信息 + """ + try: + logger.info(f"获取{market_type}基金详情: {symbol}") + + # 获取基金数据 + df = await self._get_funds_data(market_type) + + # 精确匹配基金代码 + result = df[df['symbol'] == symbol] + + if len(result) == 0: + raise Exception(f"未找到基金代码: {symbol}") + + # 获取第一行数据 + row = result.iloc[0] + + # 格式化为字典 + fund_detail = { + 'name': row['name'] if pd.notna(row['name']) else '', + 'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '', + 'price': float(row['price']) if pd.notna(row['price']) else 0.0, + 'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0, + 'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0, + 'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0, + 'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0, + 'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0, + 'discount_rate': float(row['discount_rate'].strip('%'))/100 if pd.notna(row['discount_rate']) else 0.0 + } + + logger.info(f"获取基金详情成功: {symbol}") + return fund_detail + + except Exception as e: + error_msg = f"获取基金详情失败: {str(e)}" + logger.error(error_msg) + logger.exception(e) + raise Exception(error_msg) \ No newline at end of file diff --git a/services/stock_analyzer_service.py b/services/stock_analyzer_service.py new file mode 100644 index 0000000..91df74a --- /dev/null +++ b/services/stock_analyzer_service.py @@ -0,0 +1,178 @@ +import pandas as pd +import numpy as np +import asyncio +import json +from typing import Dict, List, Optional, Tuple, Any, AsyncGenerator +from logger import get_logger +from services.stock_data_provider import StockDataProvider +from services.technical_indicator import TechnicalIndicator +from services.stock_scorer import StockScorer +from services.ai_analyzer import AIAnalyzer + +# 获取日志器 +logger = get_logger() + +class StockAnalyzerService: + """ + 股票分析服务 + 作为门面类协调数据提供、指标计算、评分和AI分析等组件 + """ + + def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None): + """ + 初始化股票分析服务 + + Args: + custom_api_url: 自定义API URL + custom_api_key: 自定义API密钥 + custom_api_model: 自定义API模型 + custom_api_timeout: 自定义API超时时间 + """ + # 初始化各个组件 + self.data_provider = StockDataProvider() + self.indicator = TechnicalIndicator() + self.scorer = StockScorer() + self.ai_analyzer = AIAnalyzer( + custom_api_url=custom_api_url, + custom_api_key=custom_api_key, + custom_api_model=custom_api_model, + custom_api_timeout=custom_api_timeout + ) + + logger.info("初始化StockAnalyzerService完成") + + async def analyze_stock(self, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]: + """ + 分析单只股票 + + Args: + stock_code: 股票代码 + market_type: 市场类型,默认为'A'股 + stream: 是否使用流式响应 + + Returns: + 异步生成器,生成分析结果的JSON字符串 + """ + try: + logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}") + + # 获取股票数据 + df = await self.data_provider.get_stock_data(stock_code, market_type) + + # 计算技术指标 + df_with_indicators = self.indicator.calculate_indicators(df) + + # 计算评分 + score = self.scorer.calculate_score(df_with_indicators) + recommendation = self.scorer.get_recommendation(score) + + # 生成基本分析结果 + basic_result = { + "stock_code": stock_code, + "score": score, + "recommendation": recommendation, + "data_point_count": len(df), + "market_type": market_type + } + + # 输出基本分析结果 + logger.info(f"基本分析结果: {json.dumps(basic_result)}") + yield json.dumps(basic_result) + + # 使用AI进行深入分析 + async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df_with_indicators, stock_code, market_type, stream): + yield analysis_chunk + + logger.info(f"完成股票分析: {stock_code}") + + except Exception as e: + error_msg = f"分析股票 {stock_code} 时出错: {str(e)}" + logger.error(error_msg) + logger.exception(e) + yield json.dumps({"error": error_msg}) + + async def scan_stocks(self, stock_codes: List[str], market_type: str = 'A', min_score: int = 0, stream: bool = False) -> AsyncGenerator[str, None]: + """ + 批量扫描股票 + + Args: + stock_codes: 股票代码列表 + market_type: 市场类型 + min_score: 最低评分阈值 + stream: 是否使用流式响应 + + Returns: + 异步生成器,生成扫描结果的JSON字符串 + """ + try: + logger.info(f"开始批量扫描 {len(stock_codes)} 只股票, 市场: {market_type}") + + # 输出初始状态 + yield json.dumps({ + "status": "scanning", + "total_stocks": len(stock_codes), + "market_type": market_type, + "min_score": min_score + }) + + # 批量获取股票数据 + stock_data_dict = await self.data_provider.get_multiple_stocks_data(stock_codes, market_type) + + # 计算技术指标 + stock_with_indicators = {} + for code, df in stock_data_dict.items(): + try: + stock_with_indicators[code] = self.indicator.calculate_indicators(df) + except Exception as e: + logger.error(f"计算 {code} 技术指标时出错: {str(e)}") + + # 评分股票 + results = self.scorer.batch_score_stocks(stock_with_indicators) + + # 过滤低于最低评分的股票 + filtered_results = [r for r in results if r[1] >= min_score] + + # 输出评分结果 + yield json.dumps({ + "scan_results": [ + { + "stock_code": code, + "score": score, + "recommendation": rec + } for code, score, rec in filtered_results + ], + "total_matched": len(filtered_results), + "total_scanned": len(results) + }) + + # 如果需要进一步分析,对评分较高的股票进行AI分析 + if stream and filtered_results: + top_stocks = filtered_results[:3] # 只分析前3只评分最高的股票 + + for stock_code, score, _ in top_stocks: + df = stock_with_indicators.get(stock_code) + if df is not None: + # 输出正在分析的股票信息 + yield json.dumps({ + "analyzing": stock_code, + "score": score + }) + + # AI分析 + async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df, stock_code, market_type, stream): + yield analysis_chunk + + # 输出扫描完成信息 + yield json.dumps({ + "status": "completed", + "total_scanned": len(results), + "total_matched": len(filtered_results) + }) + + logger.info(f"完成批量扫描 {len(stock_codes)} 只股票, 符合条件: {len(filtered_results)}") + + except Exception as e: + error_msg = f"批量扫描股票时出错: {str(e)}" + logger.error(error_msg) + logger.exception(e) + yield json.dumps({"error": error_msg}) \ No newline at end of file diff --git a/services/stock_data_provider.py b/services/stock_data_provider.py new file mode 100644 index 0000000..48cebb4 --- /dev/null +++ b/services/stock_data_provider.py @@ -0,0 +1,184 @@ +import pandas as pd +import numpy as np +from datetime import datetime, timedelta +import asyncio +import os +from typing import Dict, List, Optional, Tuple, Any +from logger import get_logger + +# 获取日志器 +logger = get_logger() + +class StockDataProvider: + """ + 异步股票数据提供服务 + 负责获取股票、基金等金融产品的历史数据 + """ + + def __init__(self): + """初始化数据提供者服务""" + logger.debug("初始化StockDataProvider") + + async def get_stock_data(self, stock_code: str, market_type: str = 'A', + start_date: Optional[str] = None, + end_date: Optional[str] = None) -> pd.DataFrame: + """ + 异步获取股票或基金数据 + + Args: + stock_code: 股票代码 + market_type: 市场类型,默认为'A'股 + start_date: 开始日期,格式YYYYMMDD,默认为一年前 + end_date: 结束日期,格式YYYYMMDD,默认为今天 + + Returns: + 包含历史数据的DataFrame + """ + # 使用线程池执行同步的akshare调用 + return await asyncio.to_thread( + self._get_stock_data_sync, + stock_code, + market_type, + start_date, + end_date + ) + + def _get_stock_data_sync(self, stock_code: str, market_type: str = 'A', + start_date: Optional[str] = None, + end_date: Optional[str] = None) -> pd.DataFrame: + """ + 同步获取股票数据的实现 + 将被异步方法调用 + """ + import akshare as ak + + if start_date is None: + start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d') + if end_date is None: + end_date = datetime.now().strftime('%Y%m%d') + + try: + # 验证股票代码格式 + if market_type == 'A': + # 上海证券交易所股票代码以6开头 + # 深圳证券交易所股票代码以0或3开头 + # 科创板股票代码以688开头 + # 北京证券交易所股票代码以8开头 + valid_prefixes = ['0', '3', '6', '688', '8'] + valid_format = False + + for prefix in valid_prefixes: + if stock_code.startswith(prefix): + valid_format = True + break + + if not valid_format: + error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头" + logger.error(f"[股票代码格式错误] {error_msg}") + raise ValueError(error_msg) + + logger.debug(f"获取A股数据: {stock_code}") + df = ak.stock_zh_a_hist( + symbol=stock_code, + start_date=start_date, + end_date=end_date, + adjust="qfq" + ) + + elif market_type in ['HK']: + logger.debug(f"获取港股数据: {stock_code}") + df = ak.stock_hk_daily( + symbol=stock_code, + start_date=start_date, + end_date=end_date, + adjust="qfq" + ) + + elif market_type in ['US']: + logger.debug(f"获取美股数据: {stock_code}") + df = ak.stock_us_daily( + symbol=stock_code, + adjust="qfq" + ) + # 过滤日期 + df = df[(df.index >= start_date) & (df.index <= end_date)] + + elif market_type in ['ETF', 'LOF']: + logger.debug(f"获取{market_type}基金数据: {stock_code}") + df = ak.fund_etf_hist_sina( + symbol=stock_code, + start_date=start_date.replace('-', ''), + end_date=end_date.replace('-', '') + ) + + else: + error_msg = f"不支持的市场类型: {market_type}" + logger.error(f"[市场类型错误] {error_msg}") + raise ValueError(error_msg) + + # 标准化列名 + if market_type == 'A': + # 根据实际数据结构调整列名映射 + # 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率'] + df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover'] + elif market_type in ['HK', 'US']: + # 根据实际情况调整 + df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount'] + elif market_type in ['ETF', 'LOF']: + # 基金数据可能有不同的列 + df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount'] + + # 确保日期列是日期类型 + if 'Date' in df.columns: + df['Date'] = pd.to_datetime(df['Date']) + df.set_index('Date', inplace=True) + + # 确保按日期升序排序 + df.sort_index(inplace=True) + + logger.info(f"成功获取{market_type}数据 {stock_code}, 数据点数: {len(df)}") + return df + + except Exception as e: + error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}" + logger.error(error_msg) + logger.exception(e) + raise Exception(error_msg) + + async def get_multiple_stocks_data(self, stock_codes: List[str], + market_type: str = 'A', + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_concurrency: int = 5) -> Dict[str, pd.DataFrame]: + """ + 异步批量获取多只股票数据 + + Args: + stock_codes: 股票代码列表 + market_type: 市场类型,默认为'A'股 + start_date: 开始日期,格式YYYYMMDD + end_date: 结束日期,格式YYYYMMDD + max_concurrency: 最大并发数,默认为5 + + Returns: + 字典,键为股票代码,值为对应的DataFrame + """ + # 使用信号量控制并发数 + semaphore = asyncio.Semaphore(max_concurrency) + + async def get_with_semaphore(code): + async with semaphore: + try: + return code, await self.get_stock_data(code, market_type, start_date, end_date) + except Exception as e: + logger.error(f"获取股票 {code} 数据时出错: {str(e)}") + return code, None + + # 创建异步任务 + tasks = [get_with_semaphore(code) for code in stock_codes] + + # 等待所有任务完成 + results = await asyncio.gather(*tasks) + + # 构建结果字典,过滤掉失败的请求 + return {code: df for code, df in results if df is not None} \ No newline at end of file diff --git a/services/stock_scorer.py b/services/stock_scorer.py new file mode 100644 index 0000000..c631f14 --- /dev/null +++ b/services/stock_scorer.py @@ -0,0 +1,128 @@ +import pandas as pd +import numpy as np +from typing import Dict, Optional, Any, List, Tuple +from logger import get_logger + +# 获取日志器 +logger = get_logger() + +class StockScorer: + """ + 股票评分服务 + 负责根据技术指标计算股票的综合评分 + """ + + def __init__(self): + """初始化股票评分服务""" + logger.debug("初始化StockScorer") + + def calculate_score(self, df: pd.DataFrame) -> int: + """ + 计算股票评分(满分100分) + + Args: + df: 包含技术指标的DataFrame + + Returns: + 股票评分(0-100的整数) + """ + try: + # 使用最新的数据点进行评分 + latest = df.iloc[-1] + + # 初始得分为0 + score = 0 + + # 移动平均线评分(25分) + if latest['MA5'] > latest['MA20'] > latest['MA60']: + # 短期、中期和长期均线呈多头排列 + score += 25 + elif latest['MA5'] > latest['MA20']: + # 短期均线在中期均线之上 + score += 15 + elif latest['Close'] > latest['MA20']: + # 股价在中期均线之上 + score += 10 + + # RSI评分(25分) + rsi = latest['RSI'] + if 45 <= rsi <= 55: + # RSI在中间区域,可能即将爆发 + score += 15 + elif 55 < rsi < 70: + # RSI在强势区域但未超买 + score += 25 + elif 30 < rsi < 45: + # RSI在弱势区域但未超卖 + score += 10 + elif rsi >= 70: + # RSI超买 + score += 5 + elif rsi <= 30: + # RSI超卖 + score += 15 + + # MACD得分(20分) + if latest['MACD'] > latest['Signal']: + score += 20 + + # 成交量得分(30分) + if latest['Volume_Ratio'] > 1.5: + score += 30 + elif latest['Volume_Ratio'] > 1: + score += 15 + + return score + + except Exception as e: + logger.error(f"计算评分时出错: {str(e)}") + logger.exception(e) + raise + + def get_recommendation(self, score: int) -> str: + """ + 根据评分获取投资建议 + + Args: + score: 股票评分(0-100) + + Returns: + 投资建议文本 + """ + if score >= 80: + return "强烈推荐" + elif score >= 70: + return "推荐" + elif score >= 60: + return "谨慎推荐" + elif score >= 40: + return "观望" + elif score >= 20: + return "不推荐" + else: + return "强烈不推荐" + + def batch_score_stocks(self, stock_dfs: Dict[str, pd.DataFrame]) -> List[Tuple[str, int, str]]: + """ + 批量评分多只股票 + + Args: + stock_dfs: 字典,键为股票代码,值为DataFrame + + Returns: + 评分结果列表,每项为(股票代码, 评分, 推荐)的三元组 + """ + results = [] + + for stock_code, df in stock_dfs.items(): + try: + score = self.calculate_score(df) + recommendation = self.get_recommendation(score) + results.append((stock_code, score, recommendation)) + except Exception as e: + logger.error(f"评分股票 {stock_code} 时出错: {str(e)}") + + # 按评分降序排序 + results.sort(key=lambda x: x[1], reverse=True) + + return results \ No newline at end of file diff --git a/services/technical_indicator.py b/services/technical_indicator.py new file mode 100644 index 0000000..0f8aacb --- /dev/null +++ b/services/technical_indicator.py @@ -0,0 +1,187 @@ +import pandas as pd +import numpy as np +from typing import Dict, Optional, Any +from logger import get_logger + +# 获取日志器 +logger = get_logger() + +class TechnicalIndicator: + """ + 技术指标计算服务 + 负责计算常见的股票技术指标 + """ + + def __init__(self, params: Optional[Dict[str, Any]] = None): + """ + 初始化技术指标计算服务 + + Args: + params: 技术指标参数配置 + """ + # 默认参数设置 + self.params = params or { + 'ma_periods': {'short': 5, 'medium': 20, 'long': 60}, + 'rsi_period': 14, + 'bollinger_period': 20, + 'bollinger_std': 2, + 'volume_ma_period': 20, + 'atr_period': 14 + } + + logger.debug(f"初始化TechnicalIndicator,参数: {self.params}") + + def calculate_ema(self, series: pd.Series, period: int) -> pd.Series: + """ + 计算指数移动平均线 + + Args: + series: 价格序列 + period: 周期 + + Returns: + EMA序列 + """ + return series.ewm(span=period, adjust=False).mean() + + def calculate_rsi(self, series: pd.Series, period: int) -> pd.Series: + """ + 计算相对强弱指标(RSI) + + Args: + series: 价格序列 + period: 周期 + + Returns: + RSI序列 + """ + delta = series.diff() + gain = delta.where(delta > 0, 0) + loss = -delta.where(delta < 0, 0) + + avg_gain = gain.rolling(window=period).mean() + avg_loss = loss.rolling(window=period).mean() + + rs = avg_gain / avg_loss + rsi = 100 - (100 / (1 + rs)) + + return rsi + + def calculate_macd(self, series: pd.Series) -> tuple: + """ + 计算MACD指标 + + Args: + series: 价格序列 + + Returns: + (MACD线, 信号线, 柱状图)的元组 + """ + ema12 = self.calculate_ema(series, 12) + ema26 = self.calculate_ema(series, 26) + + macd = ema12 - ema26 + signal = self.calculate_ema(macd, 9) + histogram = macd - signal + + return macd, signal, histogram + + def calculate_bollinger_bands(self, series: pd.Series, period: int, std_dev: float) -> tuple: + """ + 计算布林带 + + Args: + series: 价格序列 + period: 周期 + std_dev: 标准差倍数 + + Returns: + (中轨, 上轨, 下轨)的元组 + """ + middle = series.rolling(window=period).mean() + std = series.rolling(window=period).std() + + upper = middle + std_dev * std + lower = middle - std_dev * std + + return middle, upper, lower + + def calculate_atr(self, df: pd.DataFrame, period: int) -> pd.Series: + """ + 计算平均真实波幅(ATR) + + Args: + df: 包含High, Low, Close列的DataFrame + period: 周期 + + Returns: + ATR序列 + """ + high = df['High'] + low = df['Low'] + close = df['Close'] + + tr1 = high - low + tr2 = abs(high - close.shift()) + tr3 = abs(low - close.shift()) + + tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) + atr = tr.rolling(window=period).mean() + + return atr + + def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame: + """ + 计算所有技术指标 + + Args: + df: 原始价格数据,包含Open, High, Low, Close, Volume列 + + Returns: + 添加了技术指标的DataFrame + """ + try: + # 复制数据框 + result_df = df.copy() + + # 移动平均线 + for name, period in self.params['ma_periods'].items(): + result_df[f'MA{period}'] = result_df['Close'].rolling(window=period).mean() + + # RSI + result_df['RSI'] = self.calculate_rsi(result_df['Close'], self.params['rsi_period']) + + # MACD + macd, signal, histogram = self.calculate_macd(result_df['Close']) + result_df['MACD'] = macd + result_df['Signal'] = signal + result_df['Histogram'] = histogram + + # 布林带 + middle, upper, lower = self.calculate_bollinger_bands( + result_df['Close'], + self.params['bollinger_period'], + self.params['bollinger_std'] + ) + result_df['BB_Middle'] = middle + result_df['BB_Upper'] = upper + result_df['BB_Lower'] = lower + + # 成交量移动平均 + result_df['Volume_MA'] = result_df['Volume'].rolling(window=self.params['volume_ma_period']).mean() + + # 成交量比率 + result_df['Volume_Ratio'] = result_df['Volume'] / result_df['Volume_MA'] + + # ATR + result_df['ATR'] = self.calculate_atr(result_df, self.params['atr_period']) + + # 波动率 (过去20天收盘价的标准差/均值) + result_df['Volatility'] = result_df['Close'].rolling(window=20).std() / result_df['Close'].rolling(window=20).mean() * 100 + + return result_df + + except Exception as e: + logger.error(f"计算技术指标时出错: {str(e)}") + logger.exception(e) + raise \ No newline at end of file diff --git a/services/us_stock_service_async.py b/services/us_stock_service_async.py new file mode 100644 index 0000000..d0a5b0b --- /dev/null +++ b/services/us_stock_service_async.py @@ -0,0 +1,151 @@ +import asyncio +import pandas as pd +from typing import List, Dict, Any, Optional +from logger import get_logger + +# 获取日志器 +logger = get_logger() + +class USStockServiceAsync: + """ + 异步美股服务 + 提供美股数据的异步搜索和获取功能 + """ + + def __init__(self): + """初始化异步美股服务""" + logger.debug("初始化USStockServiceAsync") + + # 可选:添加缓存以减少频繁请求 + self._cache = None + self._cache_timestamp = None + + async def search_us_stocks(self, keyword: str) -> List[Dict[str, Any]]: + """ + 异步搜索美股代码 + + Args: + keyword: 搜索关键词 + + Returns: + 匹配的股票列表 + """ + try: + logger.info(f"异步搜索美股: {keyword}") + + # 使用线程池执行同步的akshare调用 + df = await asyncio.to_thread(self._get_us_stocks_data) + + # 模糊匹配搜索 + mask = df['name'].str.contains(keyword, case=False, na=False) + results = df[mask] + + # 格式化返回结果并处理 NaN 值 + formatted_results = [] + for _, row in results.iterrows(): + formatted_results.append({ + 'name': row['name'] if pd.notna(row['name']) else '', + 'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '', + 'price': float(row['price']) if pd.notna(row['price']) else 0.0, + 'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0 + }) + + logger.info(f"美股搜索完成,找到 {len(formatted_results)} 个匹配项") + return formatted_results + + except Exception as e: + error_msg = f"搜索美股代码失败: {str(e)}" + logger.error(error_msg) + logger.exception(e) + raise Exception(error_msg) + + def _get_us_stocks_data(self) -> pd.DataFrame: + """ + 获取美股数据(同步方法,将被异步方法调用) + + Returns: + 包含美股数据的DataFrame + """ + import akshare as ak + + try: + # 获取美股数据 + df = ak.stock_us_spot_em() + + # 转换列名 + df = df.rename(columns={ + "序号": "index", + "名称": "name", + "最新价": "price", + "涨跌额": "price_change", + "涨跌幅": "price_change_percent", + "开盘价": "open", + "最高价": "high", + "最低价": "low", + "昨收价": "pre_close", + "总市值": "market_value", + "市盈率": "pe_ratio", + "成交量": "volume", + "成交额": "turnover", + "振幅": "amplitude", + "换手率": "turnover_rate", + "代码": "symbol" + }) + + return df + + except Exception as e: + logger.error(f"获取美股数据失败: {str(e)}") + logger.exception(e) + raise Exception(f"获取美股数据失败: {str(e)}") + + async def get_us_stock_detail(self, symbol: str) -> Dict[str, Any]: + """ + 异步获取单个美股详细信息 + + Args: + symbol: 股票代码 + + Returns: + 股票详细信息 + """ + try: + logger.info(f"获取美股详情: {symbol}") + + # 使用线程池执行同步的akshare调用 + df = await asyncio.to_thread(self._get_us_stocks_data) + + # 精确匹配股票代码 + result = df[df['symbol'] == symbol] + + if len(result) == 0: + raise Exception(f"未找到股票代码: {symbol}") + + # 获取第一行数据 + row = result.iloc[0] + + # 格式化为字典 + stock_detail = { + 'name': row['name'] if pd.notna(row['name']) else '', + 'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '', + 'price': float(row['price']) if pd.notna(row['price']) else 0.0, + 'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0, + 'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0, + 'open': float(row['open']) if pd.notna(row['open']) else 0.0, + 'high': float(row['high']) if pd.notna(row['high']) else 0.0, + 'low': float(row['low']) if pd.notna(row['low']) else 0.0, + 'pre_close': float(row['pre_close']) if pd.notna(row['pre_close']) else 0.0, + 'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0, + 'pe_ratio': float(row['pe_ratio']) if pd.notna(row['pe_ratio']) else 0.0, + 'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0, + 'turnover': float(row['turnover']) if pd.notna(row['turnover']) else 0.0 + } + + logger.info(f"获取美股详情成功: {symbol}") + return stock_detail + + except Exception as e: + error_msg = f"获取美股详情失败: {str(e)}" + logger.error(error_msg) + logger.exception(e) + raise Exception(error_msg) \ No newline at end of file diff --git a/web_server.py b/web_server.py index c761334..5550a49 100644 --- a/web_server.py +++ b/web_server.py @@ -3,10 +3,11 @@ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, Red from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field -from typing import List, Optional, Generator -from stock_analyzer import StockAnalyzer -from us_stock_service import USStockService -from fund_service import FundService +from typing import List, Optional, Dict, Any, Generator +from services.stock_analyzer_service import StockAnalyzerService +# 导入新的异步服务 +from services.us_stock_service_async import USStockServiceAsync +from services.fund_service_async import FundServiceAsync import asyncio import threading import os @@ -25,7 +26,7 @@ logger = get_logger() app = FastAPI( title="Stock Scanner API", - description="股票分析API", + description="异步股票分析API", version="1.0.0" ) @@ -43,9 +44,10 @@ frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fronte if os.path.exists(frontend_dist): app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend") -analyzer = StockAnalyzer() -us_stock_service = USStockService() -fund_service = FundService() +# 初始化异步服务 +# StockAnalyzerService 不需要全局初始化,在 /analyze 接口中按需创建 +us_stock_service = USStockServiceAsync() +fund_service = FundServiceAsync() # 定义请求和响应模型 class AnalyzeRequest(BaseModel): @@ -102,7 +104,7 @@ async def analyze(request: AnalyzeRequest): logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}") # 创建新的分析器实例,使用自定义配置 - custom_analyzer = StockAnalyzer( + custom_analyzer = StockAnalyzerService( custom_api_url=custom_api_url, custom_api_key=custom_api_key, custom_api_model=custom_api_model, @@ -126,17 +128,11 @@ async def analyze(request: AnalyzeRequest): logger.debug(f"开始处理股票 {stock_code} 的流式响应") chunk_count = 0 - # 使用线程池执行同步分析 - def run_analysis(): - return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True)) - - # 在线程中执行同步操作 - loop = asyncio.get_event_loop() - chunks = await loop.run_in_executor(None, run_analysis) - - for chunk in chunks: + # 使用异步生成器 + async for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True): chunk_count += 1 yield chunk + '\n' + logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块") else: # 批量分析流式处理 @@ -148,22 +144,16 @@ async def analyze(request: AnalyzeRequest): logger.debug(f"开始处理批量股票的流式响应") chunk_count = 0 - # 使用线程池执行同步分析 - def run_batch_analysis(): - return list(custom_analyzer.scan_stocks( - [code.strip() for code in stock_codes], - min_score=0, - market_type=market_type, - stream=True - )) - - # 在线程中执行同步操作 - loop = asyncio.get_event_loop() - chunks = await loop.run_in_executor(None, run_batch_analysis) - - for chunk in chunks: + # 使用异步生成器 + async for chunk in custom_analyzer.scan_stocks( + [code.strip() for code in stock_codes], + min_score=0, + market_type=market_type, + stream=True + ): chunk_count += 1 yield chunk + '\n' + logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块") logger.info("成功创建流式响应生成器") @@ -181,9 +171,8 @@ async def search_us_stocks(keyword: str = ""): if not keyword: raise HTTPException(status_code=400, detail="请输入搜索关键词") - # 在异步上下文中运行同步的搜索函数 - loop = asyncio.get_event_loop() - results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword) + # 直接使用异步服务的异步方法 + results = await us_stock_service.search_us_stocks(keyword) return {"results": results} except Exception as e: @@ -196,9 +185,8 @@ async def search_funds(keyword: str = "", market_type: str = ""): if not keyword: raise HTTPException(status_code=400, detail="请输入搜索关键词") - # 在异步上下文中运行同步的搜索函数 - loop = asyncio.get_event_loop() - results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type)) + # 直接使用异步服务的异步方法 + results = await fund_service.search_funds(keyword, market_type) return {"results": results} except Exception as e: @@ -273,6 +261,36 @@ async def test_api_connection(request: TestAPIRequest): content={"success": False, "message": f"API 测试连接时出错: {str(e)}"} ) +# 新增 API 端点:获取美股详情 +@app.get("/us_stock_detail/{symbol}") +async def get_us_stock_detail(symbol: str): + try: + if not symbol: + raise HTTPException(status_code=400, detail="请提供股票代码") + + # 使用异步服务获取详情 + detail = await us_stock_service.get_us_stock_detail(symbol) + return detail + + except Exception as e: + logger.error(f"获取美股详情时出错: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + +# 新增 API 端点:获取基金详情 +@app.get("/fund_detail/{symbol}") +async def get_fund_detail(symbol: str, market_type: str = "ETF"): + try: + if not symbol: + raise HTTPException(status_code=400, detail="请提供基金代码") + + # 使用异步服务获取详情 + detail = await fund_service.get_fund_detail(symbol, market_type) + return detail + + except Exception as e: + logger.error(f"获取基金详情时出错: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + if __name__ == '__main__': logger.info("股票分析系统启动") uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True) \ No newline at end of file