refactor: 重构代码结构

This commit is contained in:
Cassianvale
2025-03-06 17:11:15 +08:00
parent bcf64f0041
commit 1e53d16b3a
9 changed files with 1380 additions and 38 deletions

2
services/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# services包初始化文件
# 用于组织股票分析服务的各个模块

269
services/ai_analyzer.py Normal file
View File

@@ -0,0 +1,269 @@
import pandas as pd
import numpy as np
import os
import json
import asyncio
import httpx
from typing import Dict, List, Optional, Any, Generator, AsyncGenerator
from dotenv import load_dotenv
from logger import get_logger
from utils.api_utils import APIUtils
# 获取日志器
logger = get_logger()
class AIAnalyzer:
"""
异步AI分析服务
负责调用AI API对股票数据进行分析
"""
def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None):
"""
初始化AI分析服务
Args:
custom_api_url: 自定义API URL
custom_api_key: 自定义API密钥
custom_api_model: 自定义API模型
custom_api_timeout: 自定义API超时时间
"""
# 加载环境变量
load_dotenv()
# 设置API配置
self.API_URL = custom_api_url or os.getenv('API_URL')
self.API_KEY = custom_api_key or os.getenv('API_KEY')
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
self.API_TIMEOUT = int(custom_api_timeout or os.getenv('API_TIMEOUT', 60))
logger.debug(f"初始化AIAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}, API_TIMEOUT={self.API_TIMEOUT}")
async def get_ai_analysis(self, df: pd.DataFrame, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]:
"""
对股票数据进行AI分析
Args:
df: 包含技术指标的DataFrame
stock_code: 股票代码
market_type: 市场类型,默认为'A'
stream: 是否使用流式响应
Returns:
异步生成器,生成分析结果字符串
"""
try:
logger.info(f"开始AI分析 {stock_code}, 流式模式: {stream}")
recent_data = df.tail(14).to_dict('records')
technical_summary = {
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
'rsi_level': df.iloc[-1]['RSI']
}
# 根据市场类型调整分析提示
if market_type in ['ETF', 'LOF']:
prompt = f"""
分析基金 {stock_code}
技术指标概要:
{technical_summary}
近14日交易数据
{recent_data}
请分析该基金的技术面状况,包括:
1. 趋势分析:判断基金当前的趋势方向
2. 动量分析基于RSI和交易量评估基金动量
3. 支撑与阻力位:确定关键价格位
4. 技术面总结
5. 投资建议
将分析结果格式化为JSON像这样
{{
"trend_analysis": "趋势分析结果...",
"momentum_analysis": "动量分析结果...",
"support_resistance": "支撑阻力位分析...",
"technical_summary": "技术面总结...",
"investment_advice": "投资建议..."
}}
"""
else:
prompt = f"""
分析股票 {stock_code}
技术指标概要:
{technical_summary}
近14日交易数据
{recent_data}
请分析该股票的技术面状况,包括:
1. 趋势分析:当前趋势方向及强度
2. 动量分析基于MACD、RSI等指标
3. 支撑与阻力位:关键价格位分析
4. 成交量分析:交易量的变化及意义
5. 波动性评估ATR和波动率分析
6. 技术面总结
7. 投资建议:根据技术分析给出操作建议
将分析结果格式化为JSON像这样
{{
"trend_analysis": "趋势分析结果...",
"momentum_analysis": "动量分析结果...",
"support_resistance": "支撑阻力位分析...",
"volume_analysis": "成交量分析...",
"volatility_assessment": "波动性评估...",
"technical_summary": "技术面总结...",
"investment_advice": "投资建议..."
}}
"""
# 格式化API URL
api_url = APIUtils.format_api_url(self.API_URL)
# 准备请求数据
request_data = {
"model": self.API_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"stream": stream
}
# 准备请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.API_KEY}"
}
# 异步请求API
async with httpx.AsyncClient(timeout=self.API_TIMEOUT) as client:
# 记录请求
logger.debug(f"发送AI请求: URL={api_url}, MODEL={self.API_MODEL}, STREAM={stream}")
if stream:
# 流式响应处理
async with client.stream("POST", api_url, json=request_data, headers=headers) as response:
if response.status_code != 200:
error_text = await response.aread()
error_data = json.loads(error_text)
error_message = error_data.get('error', {}).get('message', '未知错误')
logger.error(f"AI API请求失败: {response.status_code} - {error_message}")
yield json.dumps({"error": f"API请求失败: {error_message}"})
return
# 处理流式响应
buffer = ""
collected_messages = []
async for chunk in response.aiter_text():
if chunk:
chunk_str = chunk.strip()
if chunk_str.startswith("data: "):
chunk_str = chunk_str[6:] # 去除"data: "前缀
if chunk_str == "[DONE]":
continue
try:
# 解析数据块
chunk_data = json.loads(chunk_str)
delta = chunk_data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
buffer += content
# 尝试提取完整的JSON
if buffer.strip().startswith("{") and buffer.strip().endswith("}"):
try:
result_json = json.loads(buffer)
yield json.dumps({
"stock_code": stock_code,
"analysis": result_json
})
buffer = "" # 重置缓冲区
except json.JSONDecodeError:
# JSON不完整继续收集
pass
# 达到一定长度就输出
if len(buffer) > 100:
yield json.dumps({
"stock_code": stock_code,
"partial_content": buffer
})
collected_messages.append(buffer)
buffer = ""
except json.JSONDecodeError:
# 忽略无法解析的块
continue
# 处理最后的缓冲区
if buffer:
yield json.dumps({
"stock_code": stock_code,
"partial_content": buffer
})
collected_messages.append(buffer)
# 尝试从整个内容中提取JSON
full_content = "".join(collected_messages)
# 如果没有成功解析JSON返回原始内容
if not full_content.strip().startswith("{"):
yield json.dumps({
"stock_code": stock_code,
"raw_analysis": full_content
})
else:
# 非流式响应处理
response = await client.post(api_url, json=request_data, headers=headers)
if response.status_code != 200:
error_data = response.json()
error_message = error_data.get('error', {}).get('message', '未知错误')
logger.error(f"AI API请求失败: {response.status_code} - {error_message}")
yield json.dumps({"error": f"API请求失败: {error_message}"})
return
response_data = response.json()
analysis_text = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
try:
# 尝试解析JSON
analysis_json = json.loads(analysis_text)
yield json.dumps({
"stock_code": stock_code,
"analysis": analysis_json
})
except json.JSONDecodeError:
# 返回原始文本
yield json.dumps({
"stock_code": stock_code,
"raw_analysis": analysis_text
})
logger.info(f"完成对 {stock_code} 的AI分析")
except Exception as e:
logger.error(f"AI分析 {stock_code} 时出错: {str(e)}")
logger.exception(e)
yield json.dumps({"error": f"分析出错: {str(e)}"})
def _truncate_json_for_logging(self, json_obj, max_length=500):
"""
截断JSON对象以便记录日志
Args:
json_obj: JSON对象
max_length: 最大长度
Returns:
截断后的字符串
"""
json_str = json.dumps(json_obj, ensure_ascii=False)
if len(json_str) <= max_length:
return json_str
return json_str[:max_length] + "..."

View File

@@ -0,0 +1,225 @@
import asyncio
import pandas as pd
from typing import List, Dict, Any, Optional
from logger import get_logger
from datetime import datetime, timedelta
# 获取日志器
logger = get_logger()
class FundServiceAsync:
"""
异步基金服务
提供基金数据的异步搜索和获取功能
"""
def __init__(self):
"""初始化异步基金服务"""
logger.debug("初始化FundServiceAsync")
# 添加缓存
self._etf_cache = None
self._lof_cache = None
self._cache_timestamp = None
self._cache_duration = timedelta(minutes=30) # 缓存30分钟
async def search_funds(self, keyword: str, market_type: str = 'ETF') -> List[Dict[str, Any]]:
"""
异步搜索基金代码
Args:
keyword: 搜索关键词
market_type: 市场类型,'ETF''LOF'
Returns:
匹配的基金列表
"""
try:
logger.info(f"异步搜索基金: {keyword}, 类型: {market_type}")
# 获取基金数据
df = await self._get_funds_data(market_type)
# 模糊匹配搜索(同时匹配代码和名称)
mask = (df['name'].str.contains(keyword, case=False, na=False) |
df['symbol'].str.contains(keyword, case=False, na=False))
results = df[mask]
# 格式化返回结果并处理 NaN 值
formatted_results = []
for _, row in results.iterrows():
formatted_results.append({
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
})
logger.info(f"基金搜索完成,找到 {len(formatted_results)} 个匹配项")
return formatted_results
except Exception as e:
error_msg = f"搜索基金代码失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
async def _get_funds_data(self, market_type: str = 'ETF') -> pd.DataFrame:
"""
异步获取基金数据,支持缓存
Args:
market_type: 市场类型,'ETF''LOF'
Returns:
包含基金数据的DataFrame
"""
# 检查缓存是否有效
now = datetime.now()
cache_valid = (
self._cache_timestamp is not None and
(now - self._cache_timestamp) < self._cache_duration
)
if market_type == 'ETF' and cache_valid and self._etf_cache is not None:
logger.debug("使用ETF缓存数据")
return self._etf_cache
elif market_type == 'LOF' and cache_valid and self._lof_cache is not None:
logger.debug("使用LOF缓存数据")
return self._lof_cache
# 缓存无效,重新获取数据
try:
logger.debug(f"从API获取{market_type}数据")
# 使用线程池执行同步的akshare调用
if market_type == 'ETF':
df = await asyncio.to_thread(self._get_etf_data)
self._etf_cache = df
else:
df = await asyncio.to_thread(self._get_lof_data)
self._lof_cache = df
self._cache_timestamp = now
return df
except Exception as e:
logger.error(f"获取{market_type}数据失败: {str(e)}")
logger.exception(e)
raise
def _get_etf_data(self) -> pd.DataFrame:
"""
获取ETF数据同步方法将被异步方法调用
Returns:
包含ETF数据的DataFrame
"""
import akshare as ak
try:
# 获取ETF基金数据
df = ak.fund_etf_spot_em()
# 转换列名
df = df.rename(columns={
"代码": "symbol",
"名称": "name",
"最新价": "price",
"涨跌额": "price_change",
"涨跌幅": "price_change_percent",
"成交量": "volume",
"流通市值": "market_value",
"总市值": "total_value",
"基金折价率": "discount_rate",
})
return df
except Exception as e:
logger.error(f"获取ETF数据失败: {str(e)}")
logger.exception(e)
raise Exception(f"获取ETF数据失败: {str(e)}")
def _get_lof_data(self) -> pd.DataFrame:
"""
获取LOF数据同步方法将被异步方法调用
Returns:
包含LOF数据的DataFrame
"""
import akshare as ak
try:
# 获取LOF基金数据
df = ak.fund_lof_spot_em()
# 转换列名
df = df.rename(columns={
"代码": "symbol",
"名称": "name",
"最新价": "price",
"涨跌额": "price_change",
"涨跌幅": "price_change_percent",
"成交量": "volume",
"流通市值": "market_value",
"总市值": "total_value",
"基金折价率": "discount_rate",
})
return df
except Exception as e:
logger.error(f"获取LOF数据失败: {str(e)}")
logger.exception(e)
raise Exception(f"获取LOF数据失败: {str(e)}")
async def get_fund_detail(self, symbol: str, market_type: str = 'ETF') -> Dict[str, Any]:
"""
异步获取单个基金详细信息
Args:
symbol: 基金代码
market_type: 市场类型,'ETF''LOF'
Returns:
基金详细信息
"""
try:
logger.info(f"获取{market_type}基金详情: {symbol}")
# 获取基金数据
df = await self._get_funds_data(market_type)
# 精确匹配基金代码
result = df[df['symbol'] == symbol]
if len(result) == 0:
raise Exception(f"未找到基金代码: {symbol}")
# 获取第一行数据
row = result.iloc[0]
# 格式化为字典
fund_detail = {
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0,
'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0,
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
'discount_rate': float(row['discount_rate'].strip('%'))/100 if pd.notna(row['discount_rate']) else 0.0
}
logger.info(f"获取基金详情成功: {symbol}")
return fund_detail
except Exception as e:
error_msg = f"获取基金详情失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)

View File

@@ -0,0 +1,178 @@
import pandas as pd
import numpy as np
import asyncio
import json
from typing import Dict, List, Optional, Tuple, Any, AsyncGenerator
from logger import get_logger
from services.stock_data_provider import StockDataProvider
from services.technical_indicator import TechnicalIndicator
from services.stock_scorer import StockScorer
from services.ai_analyzer import AIAnalyzer
# 获取日志器
logger = get_logger()
class StockAnalyzerService:
"""
股票分析服务
作为门面类协调数据提供、指标计算、评分和AI分析等组件
"""
def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None):
"""
初始化股票分析服务
Args:
custom_api_url: 自定义API URL
custom_api_key: 自定义API密钥
custom_api_model: 自定义API模型
custom_api_timeout: 自定义API超时时间
"""
# 初始化各个组件
self.data_provider = StockDataProvider()
self.indicator = TechnicalIndicator()
self.scorer = StockScorer()
self.ai_analyzer = AIAnalyzer(
custom_api_url=custom_api_url,
custom_api_key=custom_api_key,
custom_api_model=custom_api_model,
custom_api_timeout=custom_api_timeout
)
logger.info("初始化StockAnalyzerService完成")
async def analyze_stock(self, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]:
"""
分析单只股票
Args:
stock_code: 股票代码
market_type: 市场类型,默认为'A'
stream: 是否使用流式响应
Returns:
异步生成器生成分析结果的JSON字符串
"""
try:
logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}")
# 获取股票数据
df = await self.data_provider.get_stock_data(stock_code, market_type)
# 计算技术指标
df_with_indicators = self.indicator.calculate_indicators(df)
# 计算评分
score = self.scorer.calculate_score(df_with_indicators)
recommendation = self.scorer.get_recommendation(score)
# 生成基本分析结果
basic_result = {
"stock_code": stock_code,
"score": score,
"recommendation": recommendation,
"data_point_count": len(df),
"market_type": market_type
}
# 输出基本分析结果
logger.info(f"基本分析结果: {json.dumps(basic_result)}")
yield json.dumps(basic_result)
# 使用AI进行深入分析
async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df_with_indicators, stock_code, market_type, stream):
yield analysis_chunk
logger.info(f"完成股票分析: {stock_code}")
except Exception as e:
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield json.dumps({"error": error_msg})
async def scan_stocks(self, stock_codes: List[str], market_type: str = 'A', min_score: int = 0, stream: bool = False) -> AsyncGenerator[str, None]:
"""
批量扫描股票
Args:
stock_codes: 股票代码列表
market_type: 市场类型
min_score: 最低评分阈值
stream: 是否使用流式响应
Returns:
异步生成器生成扫描结果的JSON字符串
"""
try:
logger.info(f"开始批量扫描 {len(stock_codes)} 只股票, 市场: {market_type}")
# 输出初始状态
yield json.dumps({
"status": "scanning",
"total_stocks": len(stock_codes),
"market_type": market_type,
"min_score": min_score
})
# 批量获取股票数据
stock_data_dict = await self.data_provider.get_multiple_stocks_data(stock_codes, market_type)
# 计算技术指标
stock_with_indicators = {}
for code, df in stock_data_dict.items():
try:
stock_with_indicators[code] = self.indicator.calculate_indicators(df)
except Exception as e:
logger.error(f"计算 {code} 技术指标时出错: {str(e)}")
# 评分股票
results = self.scorer.batch_score_stocks(stock_with_indicators)
# 过滤低于最低评分的股票
filtered_results = [r for r in results if r[1] >= min_score]
# 输出评分结果
yield json.dumps({
"scan_results": [
{
"stock_code": code,
"score": score,
"recommendation": rec
} for code, score, rec in filtered_results
],
"total_matched": len(filtered_results),
"total_scanned": len(results)
})
# 如果需要进一步分析对评分较高的股票进行AI分析
if stream and filtered_results:
top_stocks = filtered_results[:3] # 只分析前3只评分最高的股票
for stock_code, score, _ in top_stocks:
df = stock_with_indicators.get(stock_code)
if df is not None:
# 输出正在分析的股票信息
yield json.dumps({
"analyzing": stock_code,
"score": score
})
# AI分析
async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df, stock_code, market_type, stream):
yield analysis_chunk
# 输出扫描完成信息
yield json.dumps({
"status": "completed",
"total_scanned": len(results),
"total_matched": len(filtered_results)
})
logger.info(f"完成批量扫描 {len(stock_codes)} 只股票, 符合条件: {len(filtered_results)}")
except Exception as e:
error_msg = f"批量扫描股票时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield json.dumps({"error": error_msg})

View File

@@ -0,0 +1,184 @@
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import asyncio
import os
from typing import Dict, List, Optional, Tuple, Any
from logger import get_logger
# 获取日志器
logger = get_logger()
class StockDataProvider:
"""
异步股票数据提供服务
负责获取股票、基金等金融产品的历史数据
"""
def __init__(self):
"""初始化数据提供者服务"""
logger.debug("初始化StockDataProvider")
async def get_stock_data(self, stock_code: str, market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
异步获取股票或基金数据
Args:
stock_code: 股票代码
market_type: 市场类型,默认为'A'
start_date: 开始日期格式YYYYMMDD默认为一年前
end_date: 结束日期格式YYYYMMDD默认为今天
Returns:
包含历史数据的DataFrame
"""
# 使用线程池执行同步的akshare调用
return await asyncio.to_thread(
self._get_stock_data_sync,
stock_code,
market_type,
start_date,
end_date
)
def _get_stock_data_sync(self, stock_code: str, market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
同步获取股票数据的实现
将被异步方法调用
"""
import akshare as ak
if start_date is None:
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
# 验证股票代码格式
if market_type == 'A':
# 上海证券交易所股票代码以6开头
# 深圳证券交易所股票代码以0或3开头
# 科创板股票代码以688开头
# 北京证券交易所股票代码以8开头
valid_prefixes = ['0', '3', '6', '688', '8']
valid_format = False
for prefix in valid_prefixes:
if stock_code.startswith(prefix):
valid_format = True
break
if not valid_format:
error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头"
logger.error(f"[股票代码格式错误] {error_msg}")
raise ValueError(error_msg)
logger.debug(f"获取A股数据: {stock_code}")
df = ak.stock_zh_a_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
elif market_type in ['HK']:
logger.debug(f"获取港股数据: {stock_code}")
df = ak.stock_hk_daily(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
elif market_type in ['US']:
logger.debug(f"获取美股数据: {stock_code}")
df = ak.stock_us_daily(
symbol=stock_code,
adjust="qfq"
)
# 过滤日期
df = df[(df.index >= start_date) & (df.index <= end_date)]
elif market_type in ['ETF', 'LOF']:
logger.debug(f"获取{market_type}基金数据: {stock_code}")
df = ak.fund_etf_hist_sina(
symbol=stock_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
else:
error_msg = f"不支持的市场类型: {market_type}"
logger.error(f"[市场类型错误] {error_msg}")
raise ValueError(error_msg)
# 标准化列名
if market_type == 'A':
# 根据实际数据结构调整列名映射
# 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover']
elif market_type in ['HK', 'US']:
# 根据实际情况调整
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
elif market_type in ['ETF', 'LOF']:
# 基金数据可能有不同的列
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
# 确保日期列是日期类型
if 'Date' in df.columns:
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)
# 确保按日期升序排序
df.sort_index(inplace=True)
logger.info(f"成功获取{market_type}数据 {stock_code}, 数据点数: {len(df)}")
return df
except Exception as e:
error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
async def get_multiple_stocks_data(self, stock_codes: List[str],
market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_concurrency: int = 5) -> Dict[str, pd.DataFrame]:
"""
异步批量获取多只股票数据
Args:
stock_codes: 股票代码列表
market_type: 市场类型,默认为'A'
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
max_concurrency: 最大并发数默认为5
Returns:
字典键为股票代码值为对应的DataFrame
"""
# 使用信号量控制并发数
semaphore = asyncio.Semaphore(max_concurrency)
async def get_with_semaphore(code):
async with semaphore:
try:
return code, await self.get_stock_data(code, market_type, start_date, end_date)
except Exception as e:
logger.error(f"获取股票 {code} 数据时出错: {str(e)}")
return code, None
# 创建异步任务
tasks = [get_with_semaphore(code) for code in stock_codes]
# 等待所有任务完成
results = await asyncio.gather(*tasks)
# 构建结果字典,过滤掉失败的请求
return {code: df for code, df in results if df is not None}

128
services/stock_scorer.py Normal file
View File

@@ -0,0 +1,128 @@
import pandas as pd
import numpy as np
from typing import Dict, Optional, Any, List, Tuple
from logger import get_logger
# 获取日志器
logger = get_logger()
class StockScorer:
"""
股票评分服务
负责根据技术指标计算股票的综合评分
"""
def __init__(self):
"""初始化股票评分服务"""
logger.debug("初始化StockScorer")
def calculate_score(self, df: pd.DataFrame) -> int:
"""
计算股票评分满分100分
Args:
df: 包含技术指标的DataFrame
Returns:
股票评分0-100的整数
"""
try:
# 使用最新的数据点进行评分
latest = df.iloc[-1]
# 初始得分为0
score = 0
# 移动平均线评分25分
if latest['MA5'] > latest['MA20'] > latest['MA60']:
# 短期、中期和长期均线呈多头排列
score += 25
elif latest['MA5'] > latest['MA20']:
# 短期均线在中期均线之上
score += 15
elif latest['Close'] > latest['MA20']:
# 股价在中期均线之上
score += 10
# RSI评分25分
rsi = latest['RSI']
if 45 <= rsi <= 55:
# RSI在中间区域可能即将爆发
score += 15
elif 55 < rsi < 70:
# RSI在强势区域但未超买
score += 25
elif 30 < rsi < 45:
# RSI在弱势区域但未超卖
score += 10
elif rsi >= 70:
# RSI超买
score += 5
elif rsi <= 30:
# RSI超卖
score += 15
# MACD得分20分
if latest['MACD'] > latest['Signal']:
score += 20
# 成交量得分30分
if latest['Volume_Ratio'] > 1.5:
score += 30
elif latest['Volume_Ratio'] > 1:
score += 15
return score
except Exception as e:
logger.error(f"计算评分时出错: {str(e)}")
logger.exception(e)
raise
def get_recommendation(self, score: int) -> str:
"""
根据评分获取投资建议
Args:
score: 股票评分0-100
Returns:
投资建议文本
"""
if score >= 80:
return "强烈推荐"
elif score >= 70:
return "推荐"
elif score >= 60:
return "谨慎推荐"
elif score >= 40:
return "观望"
elif score >= 20:
return "不推荐"
else:
return "强烈不推荐"
def batch_score_stocks(self, stock_dfs: Dict[str, pd.DataFrame]) -> List[Tuple[str, int, str]]:
"""
批量评分多只股票
Args:
stock_dfs: 字典键为股票代码值为DataFrame
Returns:
评分结果列表,每项为(股票代码, 评分, 推荐)的三元组
"""
results = []
for stock_code, df in stock_dfs.items():
try:
score = self.calculate_score(df)
recommendation = self.get_recommendation(score)
results.append((stock_code, score, recommendation))
except Exception as e:
logger.error(f"评分股票 {stock_code} 时出错: {str(e)}")
# 按评分降序排序
results.sort(key=lambda x: x[1], reverse=True)
return results

View File

@@ -0,0 +1,187 @@
import pandas as pd
import numpy as np
from typing import Dict, Optional, Any
from logger import get_logger
# 获取日志器
logger = get_logger()
class TechnicalIndicator:
"""
技术指标计算服务
负责计算常见的股票技术指标
"""
def __init__(self, params: Optional[Dict[str, Any]] = None):
"""
初始化技术指标计算服务
Args:
params: 技术指标参数配置
"""
# 默认参数设置
self.params = params or {
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
'rsi_period': 14,
'bollinger_period': 20,
'bollinger_std': 2,
'volume_ma_period': 20,
'atr_period': 14
}
logger.debug(f"初始化TechnicalIndicator参数: {self.params}")
def calculate_ema(self, series: pd.Series, period: int) -> pd.Series:
"""
计算指数移动平均线
Args:
series: 价格序列
period: 周期
Returns:
EMA序列
"""
return series.ewm(span=period, adjust=False).mean()
def calculate_rsi(self, series: pd.Series, period: int) -> pd.Series:
"""
计算相对强弱指标(RSI)
Args:
series: 价格序列
period: 周期
Returns:
RSI序列
"""
delta = series.diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=period).mean()
avg_loss = loss.rolling(window=period).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
def calculate_macd(self, series: pd.Series) -> tuple:
"""
计算MACD指标
Args:
series: 价格序列
Returns:
(MACD线, 信号线, 柱状图)的元组
"""
ema12 = self.calculate_ema(series, 12)
ema26 = self.calculate_ema(series, 26)
macd = ema12 - ema26
signal = self.calculate_ema(macd, 9)
histogram = macd - signal
return macd, signal, histogram
def calculate_bollinger_bands(self, series: pd.Series, period: int, std_dev: float) -> tuple:
"""
计算布林带
Args:
series: 价格序列
period: 周期
std_dev: 标准差倍数
Returns:
(中轨, 上轨, 下轨)的元组
"""
middle = series.rolling(window=period).mean()
std = series.rolling(window=period).std()
upper = middle + std_dev * std
lower = middle - std_dev * std
return middle, upper, lower
def calculate_atr(self, df: pd.DataFrame, period: int) -> pd.Series:
"""
计算平均真实波幅(ATR)
Args:
df: 包含High, Low, Close列的DataFrame
period: 周期
Returns:
ATR序列
"""
high = df['High']
low = df['Low']
close = df['Close']
tr1 = high - low
tr2 = abs(high - close.shift())
tr3 = abs(low - close.shift())
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
atr = tr.rolling(window=period).mean()
return atr
def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""
计算所有技术指标
Args:
df: 原始价格数据包含Open, High, Low, Close, Volume列
Returns:
添加了技术指标的DataFrame
"""
try:
# 复制数据框
result_df = df.copy()
# 移动平均线
for name, period in self.params['ma_periods'].items():
result_df[f'MA{period}'] = result_df['Close'].rolling(window=period).mean()
# RSI
result_df['RSI'] = self.calculate_rsi(result_df['Close'], self.params['rsi_period'])
# MACD
macd, signal, histogram = self.calculate_macd(result_df['Close'])
result_df['MACD'] = macd
result_df['Signal'] = signal
result_df['Histogram'] = histogram
# 布林带
middle, upper, lower = self.calculate_bollinger_bands(
result_df['Close'],
self.params['bollinger_period'],
self.params['bollinger_std']
)
result_df['BB_Middle'] = middle
result_df['BB_Upper'] = upper
result_df['BB_Lower'] = lower
# 成交量移动平均
result_df['Volume_MA'] = result_df['Volume'].rolling(window=self.params['volume_ma_period']).mean()
# 成交量比率
result_df['Volume_Ratio'] = result_df['Volume'] / result_df['Volume_MA']
# ATR
result_df['ATR'] = self.calculate_atr(result_df, self.params['atr_period'])
# 波动率 (过去20天收盘价的标准差/均值)
result_df['Volatility'] = result_df['Close'].rolling(window=20).std() / result_df['Close'].rolling(window=20).mean() * 100
return result_df
except Exception as e:
logger.error(f"计算技术指标时出错: {str(e)}")
logger.exception(e)
raise

View File

@@ -0,0 +1,151 @@
import asyncio
import pandas as pd
from typing import List, Dict, Any, Optional
from logger import get_logger
# 获取日志器
logger = get_logger()
class USStockServiceAsync:
"""
异步美股服务
提供美股数据的异步搜索和获取功能
"""
def __init__(self):
"""初始化异步美股服务"""
logger.debug("初始化USStockServiceAsync")
# 可选:添加缓存以减少频繁请求
self._cache = None
self._cache_timestamp = None
async def search_us_stocks(self, keyword: str) -> List[Dict[str, Any]]:
"""
异步搜索美股代码
Args:
keyword: 搜索关键词
Returns:
匹配的股票列表
"""
try:
logger.info(f"异步搜索美股: {keyword}")
# 使用线程池执行同步的akshare调用
df = await asyncio.to_thread(self._get_us_stocks_data)
# 模糊匹配搜索
mask = df['name'].str.contains(keyword, case=False, na=False)
results = df[mask]
# 格式化返回结果并处理 NaN 值
formatted_results = []
for _, row in results.iterrows():
formatted_results.append({
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0
})
logger.info(f"美股搜索完成,找到 {len(formatted_results)} 个匹配项")
return formatted_results
except Exception as e:
error_msg = f"搜索美股代码失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
def _get_us_stocks_data(self) -> pd.DataFrame:
"""
获取美股数据(同步方法,将被异步方法调用)
Returns:
包含美股数据的DataFrame
"""
import akshare as ak
try:
# 获取美股数据
df = ak.stock_us_spot_em()
# 转换列名
df = df.rename(columns={
"序号": "index",
"名称": "name",
"最新价": "price",
"涨跌额": "price_change",
"涨跌幅": "price_change_percent",
"开盘价": "open",
"最高价": "high",
"最低价": "low",
"昨收价": "pre_close",
"总市值": "market_value",
"市盈率": "pe_ratio",
"成交量": "volume",
"成交额": "turnover",
"振幅": "amplitude",
"换手率": "turnover_rate",
"代码": "symbol"
})
return df
except Exception as e:
logger.error(f"获取美股数据失败: {str(e)}")
logger.exception(e)
raise Exception(f"获取美股数据失败: {str(e)}")
async def get_us_stock_detail(self, symbol: str) -> Dict[str, Any]:
"""
异步获取单个美股详细信息
Args:
symbol: 股票代码
Returns:
股票详细信息
"""
try:
logger.info(f"获取美股详情: {symbol}")
# 使用线程池执行同步的akshare调用
df = await asyncio.to_thread(self._get_us_stocks_data)
# 精确匹配股票代码
result = df[df['symbol'] == symbol]
if len(result) == 0:
raise Exception(f"未找到股票代码: {symbol}")
# 获取第一行数据
row = result.iloc[0]
# 格式化为字典
stock_detail = {
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0,
'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0,
'open': float(row['open']) if pd.notna(row['open']) else 0.0,
'high': float(row['high']) if pd.notna(row['high']) else 0.0,
'low': float(row['low']) if pd.notna(row['low']) else 0.0,
'pre_close': float(row['pre_close']) if pd.notna(row['pre_close']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
'pe_ratio': float(row['pe_ratio']) if pd.notna(row['pe_ratio']) else 0.0,
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
'turnover': float(row['turnover']) if pd.notna(row['turnover']) else 0.0
}
logger.info(f"获取美股详情成功: {symbol}")
return stock_detail
except Exception as e:
error_msg = f"获取美股详情失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)