refactor: 重构代码结构

This commit is contained in:
Cassianvale
2025-03-06 17:11:15 +08:00
parent bcf64f0041
commit 1e53d16b3a
9 changed files with 1380 additions and 38 deletions

2
services/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# services包初始化文件
# 用于组织股票分析服务的各个模块

269
services/ai_analyzer.py Normal file
View File

@@ -0,0 +1,269 @@
import pandas as pd
import numpy as np
import os
import json
import asyncio
import httpx
from typing import Dict, List, Optional, Any, Generator, AsyncGenerator
from dotenv import load_dotenv
from logger import get_logger
from utils.api_utils import APIUtils
# 获取日志器
logger = get_logger()
class AIAnalyzer:
"""
异步AI分析服务
负责调用AI API对股票数据进行分析
"""
def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None):
"""
初始化AI分析服务
Args:
custom_api_url: 自定义API URL
custom_api_key: 自定义API密钥
custom_api_model: 自定义API模型
custom_api_timeout: 自定义API超时时间
"""
# 加载环境变量
load_dotenv()
# 设置API配置
self.API_URL = custom_api_url or os.getenv('API_URL')
self.API_KEY = custom_api_key or os.getenv('API_KEY')
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
self.API_TIMEOUT = int(custom_api_timeout or os.getenv('API_TIMEOUT', 60))
logger.debug(f"初始化AIAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}, API_TIMEOUT={self.API_TIMEOUT}")
async def get_ai_analysis(self, df: pd.DataFrame, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]:
"""
对股票数据进行AI分析
Args:
df: 包含技术指标的DataFrame
stock_code: 股票代码
market_type: 市场类型,默认为'A'
stream: 是否使用流式响应
Returns:
异步生成器,生成分析结果字符串
"""
try:
logger.info(f"开始AI分析 {stock_code}, 流式模式: {stream}")
recent_data = df.tail(14).to_dict('records')
technical_summary = {
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
'rsi_level': df.iloc[-1]['RSI']
}
# 根据市场类型调整分析提示
if market_type in ['ETF', 'LOF']:
prompt = f"""
分析基金 {stock_code}
技术指标概要:
{technical_summary}
近14日交易数据
{recent_data}
请分析该基金的技术面状况,包括:
1. 趋势分析:判断基金当前的趋势方向
2. 动量分析基于RSI和交易量评估基金动量
3. 支撑与阻力位:确定关键价格位
4. 技术面总结
5. 投资建议
将分析结果格式化为JSON像这样
{{
"trend_analysis": "趋势分析结果...",
"momentum_analysis": "动量分析结果...",
"support_resistance": "支撑阻力位分析...",
"technical_summary": "技术面总结...",
"investment_advice": "投资建议..."
}}
"""
else:
prompt = f"""
分析股票 {stock_code}
技术指标概要:
{technical_summary}
近14日交易数据
{recent_data}
请分析该股票的技术面状况,包括:
1. 趋势分析:当前趋势方向及强度
2. 动量分析基于MACD、RSI等指标
3. 支撑与阻力位:关键价格位分析
4. 成交量分析:交易量的变化及意义
5. 波动性评估ATR和波动率分析
6. 技术面总结
7. 投资建议:根据技术分析给出操作建议
将分析结果格式化为JSON像这样
{{
"trend_analysis": "趋势分析结果...",
"momentum_analysis": "动量分析结果...",
"support_resistance": "支撑阻力位分析...",
"volume_analysis": "成交量分析...",
"volatility_assessment": "波动性评估...",
"technical_summary": "技术面总结...",
"investment_advice": "投资建议..."
}}
"""
# 格式化API URL
api_url = APIUtils.format_api_url(self.API_URL)
# 准备请求数据
request_data = {
"model": self.API_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"stream": stream
}
# 准备请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.API_KEY}"
}
# 异步请求API
async with httpx.AsyncClient(timeout=self.API_TIMEOUT) as client:
# 记录请求
logger.debug(f"发送AI请求: URL={api_url}, MODEL={self.API_MODEL}, STREAM={stream}")
if stream:
# 流式响应处理
async with client.stream("POST", api_url, json=request_data, headers=headers) as response:
if response.status_code != 200:
error_text = await response.aread()
error_data = json.loads(error_text)
error_message = error_data.get('error', {}).get('message', '未知错误')
logger.error(f"AI API请求失败: {response.status_code} - {error_message}")
yield json.dumps({"error": f"API请求失败: {error_message}"})
return
# 处理流式响应
buffer = ""
collected_messages = []
async for chunk in response.aiter_text():
if chunk:
chunk_str = chunk.strip()
if chunk_str.startswith("data: "):
chunk_str = chunk_str[6:] # 去除"data: "前缀
if chunk_str == "[DONE]":
continue
try:
# 解析数据块
chunk_data = json.loads(chunk_str)
delta = chunk_data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
buffer += content
# 尝试提取完整的JSON
if buffer.strip().startswith("{") and buffer.strip().endswith("}"):
try:
result_json = json.loads(buffer)
yield json.dumps({
"stock_code": stock_code,
"analysis": result_json
})
buffer = "" # 重置缓冲区
except json.JSONDecodeError:
# JSON不完整继续收集
pass
# 达到一定长度就输出
if len(buffer) > 100:
yield json.dumps({
"stock_code": stock_code,
"partial_content": buffer
})
collected_messages.append(buffer)
buffer = ""
except json.JSONDecodeError:
# 忽略无法解析的块
continue
# 处理最后的缓冲区
if buffer:
yield json.dumps({
"stock_code": stock_code,
"partial_content": buffer
})
collected_messages.append(buffer)
# 尝试从整个内容中提取JSON
full_content = "".join(collected_messages)
# 如果没有成功解析JSON返回原始内容
if not full_content.strip().startswith("{"):
yield json.dumps({
"stock_code": stock_code,
"raw_analysis": full_content
})
else:
# 非流式响应处理
response = await client.post(api_url, json=request_data, headers=headers)
if response.status_code != 200:
error_data = response.json()
error_message = error_data.get('error', {}).get('message', '未知错误')
logger.error(f"AI API请求失败: {response.status_code} - {error_message}")
yield json.dumps({"error": f"API请求失败: {error_message}"})
return
response_data = response.json()
analysis_text = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
try:
# 尝试解析JSON
analysis_json = json.loads(analysis_text)
yield json.dumps({
"stock_code": stock_code,
"analysis": analysis_json
})
except json.JSONDecodeError:
# 返回原始文本
yield json.dumps({
"stock_code": stock_code,
"raw_analysis": analysis_text
})
logger.info(f"完成对 {stock_code} 的AI分析")
except Exception as e:
logger.error(f"AI分析 {stock_code} 时出错: {str(e)}")
logger.exception(e)
yield json.dumps({"error": f"分析出错: {str(e)}"})
def _truncate_json_for_logging(self, json_obj, max_length=500):
"""
截断JSON对象以便记录日志
Args:
json_obj: JSON对象
max_length: 最大长度
Returns:
截断后的字符串
"""
json_str = json.dumps(json_obj, ensure_ascii=False)
if len(json_str) <= max_length:
return json_str
return json_str[:max_length] + "..."

View File

@@ -0,0 +1,225 @@
import asyncio
import pandas as pd
from typing import List, Dict, Any, Optional
from logger import get_logger
from datetime import datetime, timedelta
# 获取日志器
logger = get_logger()
class FundServiceAsync:
"""
异步基金服务
提供基金数据的异步搜索和获取功能
"""
def __init__(self):
"""初始化异步基金服务"""
logger.debug("初始化FundServiceAsync")
# 添加缓存
self._etf_cache = None
self._lof_cache = None
self._cache_timestamp = None
self._cache_duration = timedelta(minutes=30) # 缓存30分钟
async def search_funds(self, keyword: str, market_type: str = 'ETF') -> List[Dict[str, Any]]:
"""
异步搜索基金代码
Args:
keyword: 搜索关键词
market_type: 市场类型,'ETF''LOF'
Returns:
匹配的基金列表
"""
try:
logger.info(f"异步搜索基金: {keyword}, 类型: {market_type}")
# 获取基金数据
df = await self._get_funds_data(market_type)
# 模糊匹配搜索(同时匹配代码和名称)
mask = (df['name'].str.contains(keyword, case=False, na=False) |
df['symbol'].str.contains(keyword, case=False, na=False))
results = df[mask]
# 格式化返回结果并处理 NaN 值
formatted_results = []
for _, row in results.iterrows():
formatted_results.append({
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
})
logger.info(f"基金搜索完成,找到 {len(formatted_results)} 个匹配项")
return formatted_results
except Exception as e:
error_msg = f"搜索基金代码失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
async def _get_funds_data(self, market_type: str = 'ETF') -> pd.DataFrame:
"""
异步获取基金数据,支持缓存
Args:
market_type: 市场类型,'ETF''LOF'
Returns:
包含基金数据的DataFrame
"""
# 检查缓存是否有效
now = datetime.now()
cache_valid = (
self._cache_timestamp is not None and
(now - self._cache_timestamp) < self._cache_duration
)
if market_type == 'ETF' and cache_valid and self._etf_cache is not None:
logger.debug("使用ETF缓存数据")
return self._etf_cache
elif market_type == 'LOF' and cache_valid and self._lof_cache is not None:
logger.debug("使用LOF缓存数据")
return self._lof_cache
# 缓存无效,重新获取数据
try:
logger.debug(f"从API获取{market_type}数据")
# 使用线程池执行同步的akshare调用
if market_type == 'ETF':
df = await asyncio.to_thread(self._get_etf_data)
self._etf_cache = df
else:
df = await asyncio.to_thread(self._get_lof_data)
self._lof_cache = df
self._cache_timestamp = now
return df
except Exception as e:
logger.error(f"获取{market_type}数据失败: {str(e)}")
logger.exception(e)
raise
def _get_etf_data(self) -> pd.DataFrame:
"""
获取ETF数据同步方法将被异步方法调用
Returns:
包含ETF数据的DataFrame
"""
import akshare as ak
try:
# 获取ETF基金数据
df = ak.fund_etf_spot_em()
# 转换列名
df = df.rename(columns={
"代码": "symbol",
"名称": "name",
"最新价": "price",
"涨跌额": "price_change",
"涨跌幅": "price_change_percent",
"成交量": "volume",
"流通市值": "market_value",
"总市值": "total_value",
"基金折价率": "discount_rate",
})
return df
except Exception as e:
logger.error(f"获取ETF数据失败: {str(e)}")
logger.exception(e)
raise Exception(f"获取ETF数据失败: {str(e)}")
def _get_lof_data(self) -> pd.DataFrame:
"""
获取LOF数据同步方法将被异步方法调用
Returns:
包含LOF数据的DataFrame
"""
import akshare as ak
try:
# 获取LOF基金数据
df = ak.fund_lof_spot_em()
# 转换列名
df = df.rename(columns={
"代码": "symbol",
"名称": "name",
"最新价": "price",
"涨跌额": "price_change",
"涨跌幅": "price_change_percent",
"成交量": "volume",
"流通市值": "market_value",
"总市值": "total_value",
"基金折价率": "discount_rate",
})
return df
except Exception as e:
logger.error(f"获取LOF数据失败: {str(e)}")
logger.exception(e)
raise Exception(f"获取LOF数据失败: {str(e)}")
async def get_fund_detail(self, symbol: str, market_type: str = 'ETF') -> Dict[str, Any]:
"""
异步获取单个基金详细信息
Args:
symbol: 基金代码
market_type: 市场类型,'ETF''LOF'
Returns:
基金详细信息
"""
try:
logger.info(f"获取{market_type}基金详情: {symbol}")
# 获取基金数据
df = await self._get_funds_data(market_type)
# 精确匹配基金代码
result = df[df['symbol'] == symbol]
if len(result) == 0:
raise Exception(f"未找到基金代码: {symbol}")
# 获取第一行数据
row = result.iloc[0]
# 格式化为字典
fund_detail = {
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0,
'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0,
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
'discount_rate': float(row['discount_rate'].strip('%'))/100 if pd.notna(row['discount_rate']) else 0.0
}
logger.info(f"获取基金详情成功: {symbol}")
return fund_detail
except Exception as e:
error_msg = f"获取基金详情失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)

View File

@@ -0,0 +1,178 @@
import pandas as pd
import numpy as np
import asyncio
import json
from typing import Dict, List, Optional, Tuple, Any, AsyncGenerator
from logger import get_logger
from services.stock_data_provider import StockDataProvider
from services.technical_indicator import TechnicalIndicator
from services.stock_scorer import StockScorer
from services.ai_analyzer import AIAnalyzer
# 获取日志器
logger = get_logger()
class StockAnalyzerService:
"""
股票分析服务
作为门面类协调数据提供、指标计算、评分和AI分析等组件
"""
def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None):
"""
初始化股票分析服务
Args:
custom_api_url: 自定义API URL
custom_api_key: 自定义API密钥
custom_api_model: 自定义API模型
custom_api_timeout: 自定义API超时时间
"""
# 初始化各个组件
self.data_provider = StockDataProvider()
self.indicator = TechnicalIndicator()
self.scorer = StockScorer()
self.ai_analyzer = AIAnalyzer(
custom_api_url=custom_api_url,
custom_api_key=custom_api_key,
custom_api_model=custom_api_model,
custom_api_timeout=custom_api_timeout
)
logger.info("初始化StockAnalyzerService完成")
async def analyze_stock(self, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]:
"""
分析单只股票
Args:
stock_code: 股票代码
market_type: 市场类型,默认为'A'
stream: 是否使用流式响应
Returns:
异步生成器生成分析结果的JSON字符串
"""
try:
logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}")
# 获取股票数据
df = await self.data_provider.get_stock_data(stock_code, market_type)
# 计算技术指标
df_with_indicators = self.indicator.calculate_indicators(df)
# 计算评分
score = self.scorer.calculate_score(df_with_indicators)
recommendation = self.scorer.get_recommendation(score)
# 生成基本分析结果
basic_result = {
"stock_code": stock_code,
"score": score,
"recommendation": recommendation,
"data_point_count": len(df),
"market_type": market_type
}
# 输出基本分析结果
logger.info(f"基本分析结果: {json.dumps(basic_result)}")
yield json.dumps(basic_result)
# 使用AI进行深入分析
async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df_with_indicators, stock_code, market_type, stream):
yield analysis_chunk
logger.info(f"完成股票分析: {stock_code}")
except Exception as e:
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield json.dumps({"error": error_msg})
async def scan_stocks(self, stock_codes: List[str], market_type: str = 'A', min_score: int = 0, stream: bool = False) -> AsyncGenerator[str, None]:
"""
批量扫描股票
Args:
stock_codes: 股票代码列表
market_type: 市场类型
min_score: 最低评分阈值
stream: 是否使用流式响应
Returns:
异步生成器生成扫描结果的JSON字符串
"""
try:
logger.info(f"开始批量扫描 {len(stock_codes)} 只股票, 市场: {market_type}")
# 输出初始状态
yield json.dumps({
"status": "scanning",
"total_stocks": len(stock_codes),
"market_type": market_type,
"min_score": min_score
})
# 批量获取股票数据
stock_data_dict = await self.data_provider.get_multiple_stocks_data(stock_codes, market_type)
# 计算技术指标
stock_with_indicators = {}
for code, df in stock_data_dict.items():
try:
stock_with_indicators[code] = self.indicator.calculate_indicators(df)
except Exception as e:
logger.error(f"计算 {code} 技术指标时出错: {str(e)}")
# 评分股票
results = self.scorer.batch_score_stocks(stock_with_indicators)
# 过滤低于最低评分的股票
filtered_results = [r for r in results if r[1] >= min_score]
# 输出评分结果
yield json.dumps({
"scan_results": [
{
"stock_code": code,
"score": score,
"recommendation": rec
} for code, score, rec in filtered_results
],
"total_matched": len(filtered_results),
"total_scanned": len(results)
})
# 如果需要进一步分析对评分较高的股票进行AI分析
if stream and filtered_results:
top_stocks = filtered_results[:3] # 只分析前3只评分最高的股票
for stock_code, score, _ in top_stocks:
df = stock_with_indicators.get(stock_code)
if df is not None:
# 输出正在分析的股票信息
yield json.dumps({
"analyzing": stock_code,
"score": score
})
# AI分析
async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df, stock_code, market_type, stream):
yield analysis_chunk
# 输出扫描完成信息
yield json.dumps({
"status": "completed",
"total_scanned": len(results),
"total_matched": len(filtered_results)
})
logger.info(f"完成批量扫描 {len(stock_codes)} 只股票, 符合条件: {len(filtered_results)}")
except Exception as e:
error_msg = f"批量扫描股票时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield json.dumps({"error": error_msg})

View File

@@ -0,0 +1,184 @@
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import asyncio
import os
from typing import Dict, List, Optional, Tuple, Any
from logger import get_logger
# 获取日志器
logger = get_logger()
class StockDataProvider:
"""
异步股票数据提供服务
负责获取股票、基金等金融产品的历史数据
"""
def __init__(self):
"""初始化数据提供者服务"""
logger.debug("初始化StockDataProvider")
async def get_stock_data(self, stock_code: str, market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
异步获取股票或基金数据
Args:
stock_code: 股票代码
market_type: 市场类型,默认为'A'
start_date: 开始日期格式YYYYMMDD默认为一年前
end_date: 结束日期格式YYYYMMDD默认为今天
Returns:
包含历史数据的DataFrame
"""
# 使用线程池执行同步的akshare调用
return await asyncio.to_thread(
self._get_stock_data_sync,
stock_code,
market_type,
start_date,
end_date
)
def _get_stock_data_sync(self, stock_code: str, market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
同步获取股票数据的实现
将被异步方法调用
"""
import akshare as ak
if start_date is None:
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
# 验证股票代码格式
if market_type == 'A':
# 上海证券交易所股票代码以6开头
# 深圳证券交易所股票代码以0或3开头
# 科创板股票代码以688开头
# 北京证券交易所股票代码以8开头
valid_prefixes = ['0', '3', '6', '688', '8']
valid_format = False
for prefix in valid_prefixes:
if stock_code.startswith(prefix):
valid_format = True
break
if not valid_format:
error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头"
logger.error(f"[股票代码格式错误] {error_msg}")
raise ValueError(error_msg)
logger.debug(f"获取A股数据: {stock_code}")
df = ak.stock_zh_a_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
elif market_type in ['HK']:
logger.debug(f"获取港股数据: {stock_code}")
df = ak.stock_hk_daily(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
elif market_type in ['US']:
logger.debug(f"获取美股数据: {stock_code}")
df = ak.stock_us_daily(
symbol=stock_code,
adjust="qfq"
)
# 过滤日期
df = df[(df.index >= start_date) & (df.index <= end_date)]
elif market_type in ['ETF', 'LOF']:
logger.debug(f"获取{market_type}基金数据: {stock_code}")
df = ak.fund_etf_hist_sina(
symbol=stock_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
else:
error_msg = f"不支持的市场类型: {market_type}"
logger.error(f"[市场类型错误] {error_msg}")
raise ValueError(error_msg)
# 标准化列名
if market_type == 'A':
# 根据实际数据结构调整列名映射
# 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover']
elif market_type in ['HK', 'US']:
# 根据实际情况调整
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
elif market_type in ['ETF', 'LOF']:
# 基金数据可能有不同的列
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
# 确保日期列是日期类型
if 'Date' in df.columns:
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)
# 确保按日期升序排序
df.sort_index(inplace=True)
logger.info(f"成功获取{market_type}数据 {stock_code}, 数据点数: {len(df)}")
return df
except Exception as e:
error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
async def get_multiple_stocks_data(self, stock_codes: List[str],
market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_concurrency: int = 5) -> Dict[str, pd.DataFrame]:
"""
异步批量获取多只股票数据
Args:
stock_codes: 股票代码列表
market_type: 市场类型,默认为'A'
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
max_concurrency: 最大并发数默认为5
Returns:
字典键为股票代码值为对应的DataFrame
"""
# 使用信号量控制并发数
semaphore = asyncio.Semaphore(max_concurrency)
async def get_with_semaphore(code):
async with semaphore:
try:
return code, await self.get_stock_data(code, market_type, start_date, end_date)
except Exception as e:
logger.error(f"获取股票 {code} 数据时出错: {str(e)}")
return code, None
# 创建异步任务
tasks = [get_with_semaphore(code) for code in stock_codes]
# 等待所有任务完成
results = await asyncio.gather(*tasks)
# 构建结果字典,过滤掉失败的请求
return {code: df for code, df in results if df is not None}

128
services/stock_scorer.py Normal file
View File

@@ -0,0 +1,128 @@
import pandas as pd
import numpy as np
from typing import Dict, Optional, Any, List, Tuple
from logger import get_logger
# 获取日志器
logger = get_logger()
class StockScorer:
"""
股票评分服务
负责根据技术指标计算股票的综合评分
"""
def __init__(self):
"""初始化股票评分服务"""
logger.debug("初始化StockScorer")
def calculate_score(self, df: pd.DataFrame) -> int:
"""
计算股票评分满分100分
Args:
df: 包含技术指标的DataFrame
Returns:
股票评分0-100的整数
"""
try:
# 使用最新的数据点进行评分
latest = df.iloc[-1]
# 初始得分为0
score = 0
# 移动平均线评分25分
if latest['MA5'] > latest['MA20'] > latest['MA60']:
# 短期、中期和长期均线呈多头排列
score += 25
elif latest['MA5'] > latest['MA20']:
# 短期均线在中期均线之上
score += 15
elif latest['Close'] > latest['MA20']:
# 股价在中期均线之上
score += 10
# RSI评分25分
rsi = latest['RSI']
if 45 <= rsi <= 55:
# RSI在中间区域可能即将爆发
score += 15
elif 55 < rsi < 70:
# RSI在强势区域但未超买
score += 25
elif 30 < rsi < 45:
# RSI在弱势区域但未超卖
score += 10
elif rsi >= 70:
# RSI超买
score += 5
elif rsi <= 30:
# RSI超卖
score += 15
# MACD得分20分
if latest['MACD'] > latest['Signal']:
score += 20
# 成交量得分30分
if latest['Volume_Ratio'] > 1.5:
score += 30
elif latest['Volume_Ratio'] > 1:
score += 15
return score
except Exception as e:
logger.error(f"计算评分时出错: {str(e)}")
logger.exception(e)
raise
def get_recommendation(self, score: int) -> str:
"""
根据评分获取投资建议
Args:
score: 股票评分0-100
Returns:
投资建议文本
"""
if score >= 80:
return "强烈推荐"
elif score >= 70:
return "推荐"
elif score >= 60:
return "谨慎推荐"
elif score >= 40:
return "观望"
elif score >= 20:
return "不推荐"
else:
return "强烈不推荐"
def batch_score_stocks(self, stock_dfs: Dict[str, pd.DataFrame]) -> List[Tuple[str, int, str]]:
"""
批量评分多只股票
Args:
stock_dfs: 字典键为股票代码值为DataFrame
Returns:
评分结果列表,每项为(股票代码, 评分, 推荐)的三元组
"""
results = []
for stock_code, df in stock_dfs.items():
try:
score = self.calculate_score(df)
recommendation = self.get_recommendation(score)
results.append((stock_code, score, recommendation))
except Exception as e:
logger.error(f"评分股票 {stock_code} 时出错: {str(e)}")
# 按评分降序排序
results.sort(key=lambda x: x[1], reverse=True)
return results

View File

@@ -0,0 +1,187 @@
import pandas as pd
import numpy as np
from typing import Dict, Optional, Any
from logger import get_logger
# 获取日志器
logger = get_logger()
class TechnicalIndicator:
"""
技术指标计算服务
负责计算常见的股票技术指标
"""
def __init__(self, params: Optional[Dict[str, Any]] = None):
"""
初始化技术指标计算服务
Args:
params: 技术指标参数配置
"""
# 默认参数设置
self.params = params or {
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
'rsi_period': 14,
'bollinger_period': 20,
'bollinger_std': 2,
'volume_ma_period': 20,
'atr_period': 14
}
logger.debug(f"初始化TechnicalIndicator参数: {self.params}")
def calculate_ema(self, series: pd.Series, period: int) -> pd.Series:
"""
计算指数移动平均线
Args:
series: 价格序列
period: 周期
Returns:
EMA序列
"""
return series.ewm(span=period, adjust=False).mean()
def calculate_rsi(self, series: pd.Series, period: int) -> pd.Series:
"""
计算相对强弱指标(RSI)
Args:
series: 价格序列
period: 周期
Returns:
RSI序列
"""
delta = series.diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=period).mean()
avg_loss = loss.rolling(window=period).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
def calculate_macd(self, series: pd.Series) -> tuple:
"""
计算MACD指标
Args:
series: 价格序列
Returns:
(MACD线, 信号线, 柱状图)的元组
"""
ema12 = self.calculate_ema(series, 12)
ema26 = self.calculate_ema(series, 26)
macd = ema12 - ema26
signal = self.calculate_ema(macd, 9)
histogram = macd - signal
return macd, signal, histogram
def calculate_bollinger_bands(self, series: pd.Series, period: int, std_dev: float) -> tuple:
"""
计算布林带
Args:
series: 价格序列
period: 周期
std_dev: 标准差倍数
Returns:
(中轨, 上轨, 下轨)的元组
"""
middle = series.rolling(window=period).mean()
std = series.rolling(window=period).std()
upper = middle + std_dev * std
lower = middle - std_dev * std
return middle, upper, lower
def calculate_atr(self, df: pd.DataFrame, period: int) -> pd.Series:
"""
计算平均真实波幅(ATR)
Args:
df: 包含High, Low, Close列的DataFrame
period: 周期
Returns:
ATR序列
"""
high = df['High']
low = df['Low']
close = df['Close']
tr1 = high - low
tr2 = abs(high - close.shift())
tr3 = abs(low - close.shift())
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
atr = tr.rolling(window=period).mean()
return atr
def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""
计算所有技术指标
Args:
df: 原始价格数据包含Open, High, Low, Close, Volume列
Returns:
添加了技术指标的DataFrame
"""
try:
# 复制数据框
result_df = df.copy()
# 移动平均线
for name, period in self.params['ma_periods'].items():
result_df[f'MA{period}'] = result_df['Close'].rolling(window=period).mean()
# RSI
result_df['RSI'] = self.calculate_rsi(result_df['Close'], self.params['rsi_period'])
# MACD
macd, signal, histogram = self.calculate_macd(result_df['Close'])
result_df['MACD'] = macd
result_df['Signal'] = signal
result_df['Histogram'] = histogram
# 布林带
middle, upper, lower = self.calculate_bollinger_bands(
result_df['Close'],
self.params['bollinger_period'],
self.params['bollinger_std']
)
result_df['BB_Middle'] = middle
result_df['BB_Upper'] = upper
result_df['BB_Lower'] = lower
# 成交量移动平均
result_df['Volume_MA'] = result_df['Volume'].rolling(window=self.params['volume_ma_period']).mean()
# 成交量比率
result_df['Volume_Ratio'] = result_df['Volume'] / result_df['Volume_MA']
# ATR
result_df['ATR'] = self.calculate_atr(result_df, self.params['atr_period'])
# 波动率 (过去20天收盘价的标准差/均值)
result_df['Volatility'] = result_df['Close'].rolling(window=20).std() / result_df['Close'].rolling(window=20).mean() * 100
return result_df
except Exception as e:
logger.error(f"计算技术指标时出错: {str(e)}")
logger.exception(e)
raise

View File

@@ -0,0 +1,151 @@
import asyncio
import pandas as pd
from typing import List, Dict, Any, Optional
from logger import get_logger
# 获取日志器
logger = get_logger()
class USStockServiceAsync:
"""
异步美股服务
提供美股数据的异步搜索和获取功能
"""
def __init__(self):
"""初始化异步美股服务"""
logger.debug("初始化USStockServiceAsync")
# 可选:添加缓存以减少频繁请求
self._cache = None
self._cache_timestamp = None
async def search_us_stocks(self, keyword: str) -> List[Dict[str, Any]]:
"""
异步搜索美股代码
Args:
keyword: 搜索关键词
Returns:
匹配的股票列表
"""
try:
logger.info(f"异步搜索美股: {keyword}")
# 使用线程池执行同步的akshare调用
df = await asyncio.to_thread(self._get_us_stocks_data)
# 模糊匹配搜索
mask = df['name'].str.contains(keyword, case=False, na=False)
results = df[mask]
# 格式化返回结果并处理 NaN 值
formatted_results = []
for _, row in results.iterrows():
formatted_results.append({
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0
})
logger.info(f"美股搜索完成,找到 {len(formatted_results)} 个匹配项")
return formatted_results
except Exception as e:
error_msg = f"搜索美股代码失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
def _get_us_stocks_data(self) -> pd.DataFrame:
"""
获取美股数据(同步方法,将被异步方法调用)
Returns:
包含美股数据的DataFrame
"""
import akshare as ak
try:
# 获取美股数据
df = ak.stock_us_spot_em()
# 转换列名
df = df.rename(columns={
"序号": "index",
"名称": "name",
"最新价": "price",
"涨跌额": "price_change",
"涨跌幅": "price_change_percent",
"开盘价": "open",
"最高价": "high",
"最低价": "low",
"昨收价": "pre_close",
"总市值": "market_value",
"市盈率": "pe_ratio",
"成交量": "volume",
"成交额": "turnover",
"振幅": "amplitude",
"换手率": "turnover_rate",
"代码": "symbol"
})
return df
except Exception as e:
logger.error(f"获取美股数据失败: {str(e)}")
logger.exception(e)
raise Exception(f"获取美股数据失败: {str(e)}")
async def get_us_stock_detail(self, symbol: str) -> Dict[str, Any]:
"""
异步获取单个美股详细信息
Args:
symbol: 股票代码
Returns:
股票详细信息
"""
try:
logger.info(f"获取美股详情: {symbol}")
# 使用线程池执行同步的akshare调用
df = await asyncio.to_thread(self._get_us_stocks_data)
# 精确匹配股票代码
result = df[df['symbol'] == symbol]
if len(result) == 0:
raise Exception(f"未找到股票代码: {symbol}")
# 获取第一行数据
row = result.iloc[0]
# 格式化为字典
stock_detail = {
'name': row['name'] if pd.notna(row['name']) else '',
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0,
'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0,
'open': float(row['open']) if pd.notna(row['open']) else 0.0,
'high': float(row['high']) if pd.notna(row['high']) else 0.0,
'low': float(row['low']) if pd.notna(row['low']) else 0.0,
'pre_close': float(row['pre_close']) if pd.notna(row['pre_close']) else 0.0,
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
'pe_ratio': float(row['pe_ratio']) if pd.notna(row['pe_ratio']) else 0.0,
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
'turnover': float(row['turnover']) if pd.notna(row['turnover']) else 0.0
}
logger.info(f"获取美股详情成功: {symbol}")
return stock_detail
except Exception as e:
error_msg = f"获取美股详情失败: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)

View File

@@ -3,10 +3,11 @@ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, Red
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional, Generator from typing import List, Optional, Dict, Any, Generator
from stock_analyzer import StockAnalyzer from services.stock_analyzer_service import StockAnalyzerService
from us_stock_service import USStockService # 导入新的异步服务
from fund_service import FundService from services.us_stock_service_async import USStockServiceAsync
from services.fund_service_async import FundServiceAsync
import asyncio import asyncio
import threading import threading
import os import os
@@ -25,7 +26,7 @@ logger = get_logger()
app = FastAPI( app = FastAPI(
title="Stock Scanner API", title="Stock Scanner API",
description="股票分析API", description="异步股票分析API",
version="1.0.0" version="1.0.0"
) )
@@ -43,9 +44,10 @@ frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fronte
if os.path.exists(frontend_dist): if os.path.exists(frontend_dist):
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend") app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
analyzer = StockAnalyzer() # 初始化异步服务
us_stock_service = USStockService() # StockAnalyzerService 不需要全局初始化,在 /analyze 接口中按需创建
fund_service = FundService() us_stock_service = USStockServiceAsync()
fund_service = FundServiceAsync()
# 定义请求和响应模型 # 定义请求和响应模型
class AnalyzeRequest(BaseModel): class AnalyzeRequest(BaseModel):
@@ -102,7 +104,7 @@ async def analyze(request: AnalyzeRequest):
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}") logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}")
# 创建新的分析器实例,使用自定义配置 # 创建新的分析器实例,使用自定义配置
custom_analyzer = StockAnalyzer( custom_analyzer = StockAnalyzerService(
custom_api_url=custom_api_url, custom_api_url=custom_api_url,
custom_api_key=custom_api_key, custom_api_key=custom_api_key,
custom_api_model=custom_api_model, custom_api_model=custom_api_model,
@@ -126,17 +128,11 @@ async def analyze(request: AnalyzeRequest):
logger.debug(f"开始处理股票 {stock_code} 的流式响应") logger.debug(f"开始处理股票 {stock_code} 的流式响应")
chunk_count = 0 chunk_count = 0
# 使用线程池执行同步分析 # 使用异步生成器
def run_analysis(): async for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True))
# 在线程中执行同步操作
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, run_analysis)
for chunk in chunks:
chunk_count += 1 chunk_count += 1
yield chunk + '\n' yield chunk + '\n'
logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块") logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块")
else: else:
# 批量分析流式处理 # 批量分析流式处理
@@ -148,22 +144,16 @@ async def analyze(request: AnalyzeRequest):
logger.debug(f"开始处理批量股票的流式响应") logger.debug(f"开始处理批量股票的流式响应")
chunk_count = 0 chunk_count = 0
# 使用线程池执行同步分析 # 使用异步生成器
def run_batch_analysis(): async for chunk in custom_analyzer.scan_stocks(
return list(custom_analyzer.scan_stocks( [code.strip() for code in stock_codes],
[code.strip() for code in stock_codes], min_score=0,
min_score=0, market_type=market_type,
market_type=market_type, stream=True
stream=True ):
))
# 在线程中执行同步操作
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, run_batch_analysis)
for chunk in chunks:
chunk_count += 1 chunk_count += 1
yield chunk + '\n' yield chunk + '\n'
logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块") logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块")
logger.info("成功创建流式响应生成器") logger.info("成功创建流式响应生成器")
@@ -181,9 +171,8 @@ async def search_us_stocks(keyword: str = ""):
if not keyword: if not keyword:
raise HTTPException(status_code=400, detail="请输入搜索关键词") raise HTTPException(status_code=400, detail="请输入搜索关键词")
# 在异步上下文中运行同步的搜索函数 # 直接使用异步服务的异步方法
loop = asyncio.get_event_loop() results = await us_stock_service.search_us_stocks(keyword)
results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword)
return {"results": results} return {"results": results}
except Exception as e: except Exception as e:
@@ -196,9 +185,8 @@ async def search_funds(keyword: str = "", market_type: str = ""):
if not keyword: if not keyword:
raise HTTPException(status_code=400, detail="请输入搜索关键词") raise HTTPException(status_code=400, detail="请输入搜索关键词")
# 在异步上下文中运行同步的搜索函数 # 直接使用异步服务的异步方法
loop = asyncio.get_event_loop() results = await fund_service.search_funds(keyword, market_type)
results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type))
return {"results": results} return {"results": results}
except Exception as e: except Exception as e:
@@ -273,6 +261,36 @@ async def test_api_connection(request: TestAPIRequest):
content={"success": False, "message": f"API 测试连接时出错: {str(e)}"} content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
) )
# 新增 API 端点:获取美股详情
@app.get("/us_stock_detail/{symbol}")
async def get_us_stock_detail(symbol: str):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供股票代码")
# 使用异步服务获取详情
detail = await us_stock_service.get_us_stock_detail(symbol)
return detail
except Exception as e:
logger.error(f"获取美股详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 新增 API 端点:获取基金详情
@app.get("/fund_detail/{symbol}")
async def get_fund_detail(symbol: str, market_type: str = "ETF"):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供基金代码")
# 使用异步服务获取详情
detail = await fund_service.get_fund_detail(symbol, market_type)
return detail
except Exception as e:
logger.error(f"获取基金详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__': if __name__ == '__main__':
logger.info("股票分析系统启动") logger.info("股票分析系统启动")
uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True) uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)