refactor: 重构代码结构
This commit is contained in:
2
services/__init__.py
Normal file
2
services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# services包初始化文件
|
||||
# 用于组织股票分析服务的各个模块
|
||||
269
services/ai_analyzer.py
Normal file
269
services/ai_analyzer.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import httpx
|
||||
from typing import Dict, List, Optional, Any, Generator, AsyncGenerator
|
||||
from dotenv import load_dotenv
|
||||
from logger import get_logger
|
||||
from utils.api_utils import APIUtils
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class AIAnalyzer:
|
||||
"""
|
||||
异步AI分析服务
|
||||
负责调用AI API对股票数据进行分析
|
||||
"""
|
||||
|
||||
def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None):
|
||||
"""
|
||||
初始化AI分析服务
|
||||
|
||||
Args:
|
||||
custom_api_url: 自定义API URL
|
||||
custom_api_key: 自定义API密钥
|
||||
custom_api_model: 自定义API模型
|
||||
custom_api_timeout: 自定义API超时时间
|
||||
"""
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 设置API配置
|
||||
self.API_URL = custom_api_url or os.getenv('API_URL')
|
||||
self.API_KEY = custom_api_key or os.getenv('API_KEY')
|
||||
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||
self.API_TIMEOUT = int(custom_api_timeout or os.getenv('API_TIMEOUT', 60))
|
||||
|
||||
logger.debug(f"初始化AIAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}, API_TIMEOUT={self.API_TIMEOUT}")
|
||||
|
||||
async def get_ai_analysis(self, df: pd.DataFrame, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
对股票数据进行AI分析
|
||||
|
||||
Args:
|
||||
df: 包含技术指标的DataFrame
|
||||
stock_code: 股票代码
|
||||
market_type: 市场类型,默认为'A'股
|
||||
stream: 是否使用流式响应
|
||||
|
||||
Returns:
|
||||
异步生成器,生成分析结果字符串
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始AI分析 {stock_code}, 流式模式: {stream}")
|
||||
recent_data = df.tail(14).to_dict('records')
|
||||
|
||||
technical_summary = {
|
||||
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
|
||||
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
|
||||
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
|
||||
'rsi_level': df.iloc[-1]['RSI']
|
||||
}
|
||||
|
||||
# 根据市场类型调整分析提示
|
||||
if market_type in ['ETF', 'LOF']:
|
||||
prompt = f"""
|
||||
分析基金 {stock_code}:
|
||||
|
||||
技术指标概要:
|
||||
{technical_summary}
|
||||
|
||||
近14日交易数据:
|
||||
{recent_data}
|
||||
|
||||
请分析该基金的技术面状况,包括:
|
||||
1. 趋势分析:判断基金当前的趋势方向
|
||||
2. 动量分析:基于RSI和交易量评估基金动量
|
||||
3. 支撑与阻力位:确定关键价格位
|
||||
4. 技术面总结
|
||||
5. 投资建议
|
||||
|
||||
将分析结果格式化为JSON,像这样:
|
||||
{{
|
||||
"trend_analysis": "趋势分析结果...",
|
||||
"momentum_analysis": "动量分析结果...",
|
||||
"support_resistance": "支撑阻力位分析...",
|
||||
"technical_summary": "技术面总结...",
|
||||
"investment_advice": "投资建议..."
|
||||
}}
|
||||
"""
|
||||
else:
|
||||
prompt = f"""
|
||||
分析股票 {stock_code}:
|
||||
|
||||
技术指标概要:
|
||||
{technical_summary}
|
||||
|
||||
近14日交易数据:
|
||||
{recent_data}
|
||||
|
||||
请分析该股票的技术面状况,包括:
|
||||
1. 趋势分析:当前趋势方向及强度
|
||||
2. 动量分析:基于MACD、RSI等指标
|
||||
3. 支撑与阻力位:关键价格位分析
|
||||
4. 成交量分析:交易量的变化及意义
|
||||
5. 波动性评估:ATR和波动率分析
|
||||
6. 技术面总结
|
||||
7. 投资建议:根据技术分析给出操作建议
|
||||
|
||||
将分析结果格式化为JSON,像这样:
|
||||
{{
|
||||
"trend_analysis": "趋势分析结果...",
|
||||
"momentum_analysis": "动量分析结果...",
|
||||
"support_resistance": "支撑阻力位分析...",
|
||||
"volume_analysis": "成交量分析...",
|
||||
"volatility_assessment": "波动性评估...",
|
||||
"technical_summary": "技术面总结...",
|
||||
"investment_advice": "投资建议..."
|
||||
}}
|
||||
"""
|
||||
|
||||
# 格式化API URL
|
||||
api_url = APIUtils.format_api_url(self.API_URL)
|
||||
|
||||
# 准备请求数据
|
||||
request_data = {
|
||||
"model": self.API_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.7,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.API_KEY}"
|
||||
}
|
||||
|
||||
# 异步请求API
|
||||
async with httpx.AsyncClient(timeout=self.API_TIMEOUT) as client:
|
||||
# 记录请求
|
||||
logger.debug(f"发送AI请求: URL={api_url}, MODEL={self.API_MODEL}, STREAM={stream}")
|
||||
|
||||
if stream:
|
||||
# 流式响应处理
|
||||
async with client.stream("POST", api_url, json=request_data, headers=headers) as response:
|
||||
if response.status_code != 200:
|
||||
error_text = await response.aread()
|
||||
error_data = json.loads(error_text)
|
||||
error_message = error_data.get('error', {}).get('message', '未知错误')
|
||||
logger.error(f"AI API请求失败: {response.status_code} - {error_message}")
|
||||
yield json.dumps({"error": f"API请求失败: {error_message}"})
|
||||
return
|
||||
|
||||
# 处理流式响应
|
||||
buffer = ""
|
||||
collected_messages = []
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
if chunk:
|
||||
chunk_str = chunk.strip()
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[6:] # 去除"data: "前缀
|
||||
|
||||
if chunk_str == "[DONE]":
|
||||
continue
|
||||
|
||||
try:
|
||||
# 解析数据块
|
||||
chunk_data = json.loads(chunk_str)
|
||||
delta = chunk_data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
if content:
|
||||
buffer += content
|
||||
# 尝试提取完整的JSON
|
||||
if buffer.strip().startswith("{") and buffer.strip().endswith("}"):
|
||||
try:
|
||||
result_json = json.loads(buffer)
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"analysis": result_json
|
||||
})
|
||||
buffer = "" # 重置缓冲区
|
||||
except json.JSONDecodeError:
|
||||
# JSON不完整,继续收集
|
||||
pass
|
||||
|
||||
# 达到一定长度就输出
|
||||
if len(buffer) > 100:
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"partial_content": buffer
|
||||
})
|
||||
collected_messages.append(buffer)
|
||||
buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# 忽略无法解析的块
|
||||
continue
|
||||
|
||||
# 处理最后的缓冲区
|
||||
if buffer:
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"partial_content": buffer
|
||||
})
|
||||
collected_messages.append(buffer)
|
||||
|
||||
# 尝试从整个内容中提取JSON
|
||||
full_content = "".join(collected_messages)
|
||||
|
||||
# 如果没有成功解析JSON,返回原始内容
|
||||
if not full_content.strip().startswith("{"):
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"raw_analysis": full_content
|
||||
})
|
||||
else:
|
||||
# 非流式响应处理
|
||||
response = await client.post(api_url, json=request_data, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json()
|
||||
error_message = error_data.get('error', {}).get('message', '未知错误')
|
||||
logger.error(f"AI API请求失败: {response.status_code} - {error_message}")
|
||||
yield json.dumps({"error": f"API请求失败: {error_message}"})
|
||||
return
|
||||
|
||||
response_data = response.json()
|
||||
analysis_text = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
try:
|
||||
# 尝试解析JSON
|
||||
analysis_json = json.loads(analysis_text)
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"analysis": analysis_json
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
# 返回原始文本
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"raw_analysis": analysis_text
|
||||
})
|
||||
|
||||
logger.info(f"完成对 {stock_code} 的AI分析")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI分析 {stock_code} 时出错: {str(e)}")
|
||||
logger.exception(e)
|
||||
yield json.dumps({"error": f"分析出错: {str(e)}"})
|
||||
|
||||
def _truncate_json_for_logging(self, json_obj, max_length=500):
|
||||
"""
|
||||
截断JSON对象以便记录日志
|
||||
|
||||
Args:
|
||||
json_obj: JSON对象
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
截断后的字符串
|
||||
"""
|
||||
json_str = json.dumps(json_obj, ensure_ascii=False)
|
||||
if len(json_str) <= max_length:
|
||||
return json_str
|
||||
return json_str[:max_length] + "..."
|
||||
225
services/fund_service_async.py
Normal file
225
services/fund_service_async.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any, Optional
|
||||
from logger import get_logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class FundServiceAsync:
|
||||
"""
|
||||
异步基金服务
|
||||
提供基金数据的异步搜索和获取功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化异步基金服务"""
|
||||
logger.debug("初始化FundServiceAsync")
|
||||
|
||||
# 添加缓存
|
||||
self._etf_cache = None
|
||||
self._lof_cache = None
|
||||
self._cache_timestamp = None
|
||||
self._cache_duration = timedelta(minutes=30) # 缓存30分钟
|
||||
|
||||
async def search_funds(self, keyword: str, market_type: str = 'ETF') -> List[Dict[str, Any]]:
|
||||
"""
|
||||
异步搜索基金代码
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
market_type: 市场类型,'ETF'或'LOF'
|
||||
|
||||
Returns:
|
||||
匹配的基金列表
|
||||
"""
|
||||
try:
|
||||
logger.info(f"异步搜索基金: {keyword}, 类型: {market_type}")
|
||||
|
||||
# 获取基金数据
|
||||
df = await self._get_funds_data(market_type)
|
||||
|
||||
# 模糊匹配搜索(同时匹配代码和名称)
|
||||
mask = (df['name'].str.contains(keyword, case=False, na=False) |
|
||||
df['symbol'].str.contains(keyword, case=False, na=False))
|
||||
results = df[mask]
|
||||
|
||||
# 格式化返回结果并处理 NaN 值
|
||||
formatted_results = []
|
||||
for _, row in results.iterrows():
|
||||
formatted_results.append({
|
||||
'name': row['name'] if pd.notna(row['name']) else '',
|
||||
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
|
||||
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
|
||||
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
|
||||
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
|
||||
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
|
||||
})
|
||||
|
||||
logger.info(f"基金搜索完成,找到 {len(formatted_results)} 个匹配项")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"搜索基金代码失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def _get_funds_data(self, market_type: str = 'ETF') -> pd.DataFrame:
|
||||
"""
|
||||
异步获取基金数据,支持缓存
|
||||
|
||||
Args:
|
||||
market_type: 市场类型,'ETF'或'LOF'
|
||||
|
||||
Returns:
|
||||
包含基金数据的DataFrame
|
||||
"""
|
||||
# 检查缓存是否有效
|
||||
now = datetime.now()
|
||||
cache_valid = (
|
||||
self._cache_timestamp is not None and
|
||||
(now - self._cache_timestamp) < self._cache_duration
|
||||
)
|
||||
|
||||
if market_type == 'ETF' and cache_valid and self._etf_cache is not None:
|
||||
logger.debug("使用ETF缓存数据")
|
||||
return self._etf_cache
|
||||
elif market_type == 'LOF' and cache_valid and self._lof_cache is not None:
|
||||
logger.debug("使用LOF缓存数据")
|
||||
return self._lof_cache
|
||||
|
||||
# 缓存无效,重新获取数据
|
||||
try:
|
||||
logger.debug(f"从API获取{market_type}数据")
|
||||
|
||||
# 使用线程池执行同步的akshare调用
|
||||
if market_type == 'ETF':
|
||||
df = await asyncio.to_thread(self._get_etf_data)
|
||||
self._etf_cache = df
|
||||
else:
|
||||
df = await asyncio.to_thread(self._get_lof_data)
|
||||
self._lof_cache = df
|
||||
|
||||
self._cache_timestamp = now
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取{market_type}数据失败: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
def _get_etf_data(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取ETF数据(同步方法,将被异步方法调用)
|
||||
|
||||
Returns:
|
||||
包含ETF数据的DataFrame
|
||||
"""
|
||||
import akshare as ak
|
||||
|
||||
try:
|
||||
# 获取ETF基金数据
|
||||
df = ak.fund_etf_spot_em()
|
||||
|
||||
# 转换列名
|
||||
df = df.rename(columns={
|
||||
"代码": "symbol",
|
||||
"名称": "name",
|
||||
"最新价": "price",
|
||||
"涨跌额": "price_change",
|
||||
"涨跌幅": "price_change_percent",
|
||||
"成交量": "volume",
|
||||
"流通市值": "market_value",
|
||||
"总市值": "total_value",
|
||||
"基金折价率": "discount_rate",
|
||||
})
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取ETF数据失败: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise Exception(f"获取ETF数据失败: {str(e)}")
|
||||
|
||||
def _get_lof_data(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取LOF数据(同步方法,将被异步方法调用)
|
||||
|
||||
Returns:
|
||||
包含LOF数据的DataFrame
|
||||
"""
|
||||
import akshare as ak
|
||||
|
||||
try:
|
||||
# 获取LOF基金数据
|
||||
df = ak.fund_lof_spot_em()
|
||||
|
||||
# 转换列名
|
||||
df = df.rename(columns={
|
||||
"代码": "symbol",
|
||||
"名称": "name",
|
||||
"最新价": "price",
|
||||
"涨跌额": "price_change",
|
||||
"涨跌幅": "price_change_percent",
|
||||
"成交量": "volume",
|
||||
"流通市值": "market_value",
|
||||
"总市值": "total_value",
|
||||
"基金折价率": "discount_rate",
|
||||
})
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取LOF数据失败: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise Exception(f"获取LOF数据失败: {str(e)}")
|
||||
|
||||
async def get_fund_detail(self, symbol: str, market_type: str = 'ETF') -> Dict[str, Any]:
|
||||
"""
|
||||
异步获取单个基金详细信息
|
||||
|
||||
Args:
|
||||
symbol: 基金代码
|
||||
market_type: 市场类型,'ETF'或'LOF'
|
||||
|
||||
Returns:
|
||||
基金详细信息
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取{market_type}基金详情: {symbol}")
|
||||
|
||||
# 获取基金数据
|
||||
df = await self._get_funds_data(market_type)
|
||||
|
||||
# 精确匹配基金代码
|
||||
result = df[df['symbol'] == symbol]
|
||||
|
||||
if len(result) == 0:
|
||||
raise Exception(f"未找到基金代码: {symbol}")
|
||||
|
||||
# 获取第一行数据
|
||||
row = result.iloc[0]
|
||||
|
||||
# 格式化为字典
|
||||
fund_detail = {
|
||||
'name': row['name'] if pd.notna(row['name']) else '',
|
||||
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
|
||||
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
|
||||
'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0,
|
||||
'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0,
|
||||
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
|
||||
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
|
||||
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
|
||||
'discount_rate': float(row['discount_rate'].strip('%'))/100 if pd.notna(row['discount_rate']) else 0.0
|
||||
}
|
||||
|
||||
logger.info(f"获取基金详情成功: {symbol}")
|
||||
return fund_detail
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取基金详情失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
178
services/stock_analyzer_service.py
Normal file
178
services/stock_analyzer_service.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, Any, AsyncGenerator
|
||||
from logger import get_logger
|
||||
from services.stock_data_provider import StockDataProvider
|
||||
from services.technical_indicator import TechnicalIndicator
|
||||
from services.stock_scorer import StockScorer
|
||||
from services.ai_analyzer import AIAnalyzer
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class StockAnalyzerService:
|
||||
"""
|
||||
股票分析服务
|
||||
作为门面类协调数据提供、指标计算、评分和AI分析等组件
|
||||
"""
|
||||
|
||||
def __init__(self, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None):
|
||||
"""
|
||||
初始化股票分析服务
|
||||
|
||||
Args:
|
||||
custom_api_url: 自定义API URL
|
||||
custom_api_key: 自定义API密钥
|
||||
custom_api_model: 自定义API模型
|
||||
custom_api_timeout: 自定义API超时时间
|
||||
"""
|
||||
# 初始化各个组件
|
||||
self.data_provider = StockDataProvider()
|
||||
self.indicator = TechnicalIndicator()
|
||||
self.scorer = StockScorer()
|
||||
self.ai_analyzer = AIAnalyzer(
|
||||
custom_api_url=custom_api_url,
|
||||
custom_api_key=custom_api_key,
|
||||
custom_api_model=custom_api_model,
|
||||
custom_api_timeout=custom_api_timeout
|
||||
)
|
||||
|
||||
logger.info("初始化StockAnalyzerService完成")
|
||||
|
||||
async def analyze_stock(self, stock_code: str, market_type: str = 'A', stream: bool = False) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
分析单只股票
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
market_type: 市场类型,默认为'A'股
|
||||
stream: 是否使用流式响应
|
||||
|
||||
Returns:
|
||||
异步生成器,生成分析结果的JSON字符串
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}")
|
||||
|
||||
# 获取股票数据
|
||||
df = await self.data_provider.get_stock_data(stock_code, market_type)
|
||||
|
||||
# 计算技术指标
|
||||
df_with_indicators = self.indicator.calculate_indicators(df)
|
||||
|
||||
# 计算评分
|
||||
score = self.scorer.calculate_score(df_with_indicators)
|
||||
recommendation = self.scorer.get_recommendation(score)
|
||||
|
||||
# 生成基本分析结果
|
||||
basic_result = {
|
||||
"stock_code": stock_code,
|
||||
"score": score,
|
||||
"recommendation": recommendation,
|
||||
"data_point_count": len(df),
|
||||
"market_type": market_type
|
||||
}
|
||||
|
||||
# 输出基本分析结果
|
||||
logger.info(f"基本分析结果: {json.dumps(basic_result)}")
|
||||
yield json.dumps(basic_result)
|
||||
|
||||
# 使用AI进行深入分析
|
||||
async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df_with_indicators, stock_code, market_type, stream):
|
||||
yield analysis_chunk
|
||||
|
||||
logger.info(f"完成股票分析: {stock_code}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
yield json.dumps({"error": error_msg})
|
||||
|
||||
async def scan_stocks(self, stock_codes: List[str], market_type: str = 'A', min_score: int = 0, stream: bool = False) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
批量扫描股票
|
||||
|
||||
Args:
|
||||
stock_codes: 股票代码列表
|
||||
market_type: 市场类型
|
||||
min_score: 最低评分阈值
|
||||
stream: 是否使用流式响应
|
||||
|
||||
Returns:
|
||||
异步生成器,生成扫描结果的JSON字符串
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始批量扫描 {len(stock_codes)} 只股票, 市场: {market_type}")
|
||||
|
||||
# 输出初始状态
|
||||
yield json.dumps({
|
||||
"status": "scanning",
|
||||
"total_stocks": len(stock_codes),
|
||||
"market_type": market_type,
|
||||
"min_score": min_score
|
||||
})
|
||||
|
||||
# 批量获取股票数据
|
||||
stock_data_dict = await self.data_provider.get_multiple_stocks_data(stock_codes, market_type)
|
||||
|
||||
# 计算技术指标
|
||||
stock_with_indicators = {}
|
||||
for code, df in stock_data_dict.items():
|
||||
try:
|
||||
stock_with_indicators[code] = self.indicator.calculate_indicators(df)
|
||||
except Exception as e:
|
||||
logger.error(f"计算 {code} 技术指标时出错: {str(e)}")
|
||||
|
||||
# 评分股票
|
||||
results = self.scorer.batch_score_stocks(stock_with_indicators)
|
||||
|
||||
# 过滤低于最低评分的股票
|
||||
filtered_results = [r for r in results if r[1] >= min_score]
|
||||
|
||||
# 输出评分结果
|
||||
yield json.dumps({
|
||||
"scan_results": [
|
||||
{
|
||||
"stock_code": code,
|
||||
"score": score,
|
||||
"recommendation": rec
|
||||
} for code, score, rec in filtered_results
|
||||
],
|
||||
"total_matched": len(filtered_results),
|
||||
"total_scanned": len(results)
|
||||
})
|
||||
|
||||
# 如果需要进一步分析,对评分较高的股票进行AI分析
|
||||
if stream and filtered_results:
|
||||
top_stocks = filtered_results[:3] # 只分析前3只评分最高的股票
|
||||
|
||||
for stock_code, score, _ in top_stocks:
|
||||
df = stock_with_indicators.get(stock_code)
|
||||
if df is not None:
|
||||
# 输出正在分析的股票信息
|
||||
yield json.dumps({
|
||||
"analyzing": stock_code,
|
||||
"score": score
|
||||
})
|
||||
|
||||
# AI分析
|
||||
async for analysis_chunk in self.ai_analyzer.get_ai_analysis(df, stock_code, market_type, stream):
|
||||
yield analysis_chunk
|
||||
|
||||
# 输出扫描完成信息
|
||||
yield json.dumps({
|
||||
"status": "completed",
|
||||
"total_scanned": len(results),
|
||||
"total_matched": len(filtered_results)
|
||||
})
|
||||
|
||||
logger.info(f"完成批量扫描 {len(stock_codes)} 只股票, 符合条件: {len(filtered_results)}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"批量扫描股票时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
yield json.dumps({"error": error_msg})
|
||||
184
services/stock_data_provider.py
Normal file
184
services/stock_data_provider.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from logger import get_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class StockDataProvider:
|
||||
"""
|
||||
异步股票数据提供服务
|
||||
负责获取股票、基金等金融产品的历史数据
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化数据提供者服务"""
|
||||
logger.debug("初始化StockDataProvider")
|
||||
|
||||
async def get_stock_data(self, stock_code: str, market_type: str = 'A',
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None) -> pd.DataFrame:
|
||||
"""
|
||||
异步获取股票或基金数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
market_type: 市场类型,默认为'A'股
|
||||
start_date: 开始日期,格式YYYYMMDD,默认为一年前
|
||||
end_date: 结束日期,格式YYYYMMDD,默认为今天
|
||||
|
||||
Returns:
|
||||
包含历史数据的DataFrame
|
||||
"""
|
||||
# 使用线程池执行同步的akshare调用
|
||||
return await asyncio.to_thread(
|
||||
self._get_stock_data_sync,
|
||||
stock_code,
|
||||
market_type,
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
def _get_stock_data_sync(self, stock_code: str, market_type: str = 'A',
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None) -> pd.DataFrame:
|
||||
"""
|
||||
同步获取股票数据的实现
|
||||
将被异步方法调用
|
||||
"""
|
||||
import akshare as ak
|
||||
|
||||
if start_date is None:
|
||||
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
try:
|
||||
# 验证股票代码格式
|
||||
if market_type == 'A':
|
||||
# 上海证券交易所股票代码以6开头
|
||||
# 深圳证券交易所股票代码以0或3开头
|
||||
# 科创板股票代码以688开头
|
||||
# 北京证券交易所股票代码以8开头
|
||||
valid_prefixes = ['0', '3', '6', '688', '8']
|
||||
valid_format = False
|
||||
|
||||
for prefix in valid_prefixes:
|
||||
if stock_code.startswith(prefix):
|
||||
valid_format = True
|
||||
break
|
||||
|
||||
if not valid_format:
|
||||
error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头"
|
||||
logger.error(f"[股票代码格式错误] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.debug(f"获取A股数据: {stock_code}")
|
||||
df = ak.stock_zh_a_hist(
|
||||
symbol=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
adjust="qfq"
|
||||
)
|
||||
|
||||
elif market_type in ['HK']:
|
||||
logger.debug(f"获取港股数据: {stock_code}")
|
||||
df = ak.stock_hk_daily(
|
||||
symbol=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
adjust="qfq"
|
||||
)
|
||||
|
||||
elif market_type in ['US']:
|
||||
logger.debug(f"获取美股数据: {stock_code}")
|
||||
df = ak.stock_us_daily(
|
||||
symbol=stock_code,
|
||||
adjust="qfq"
|
||||
)
|
||||
# 过滤日期
|
||||
df = df[(df.index >= start_date) & (df.index <= end_date)]
|
||||
|
||||
elif market_type in ['ETF', 'LOF']:
|
||||
logger.debug(f"获取{market_type}基金数据: {stock_code}")
|
||||
df = ak.fund_etf_hist_sina(
|
||||
symbol=stock_code,
|
||||
start_date=start_date.replace('-', ''),
|
||||
end_date=end_date.replace('-', '')
|
||||
)
|
||||
|
||||
else:
|
||||
error_msg = f"不支持的市场类型: {market_type}"
|
||||
logger.error(f"[市场类型错误] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# 标准化列名
|
||||
if market_type == 'A':
|
||||
# 根据实际数据结构调整列名映射
|
||||
# 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
|
||||
df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover']
|
||||
elif market_type in ['HK', 'US']:
|
||||
# 根据实际情况调整
|
||||
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
|
||||
elif market_type in ['ETF', 'LOF']:
|
||||
# 基金数据可能有不同的列
|
||||
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
|
||||
|
||||
# 确保日期列是日期类型
|
||||
if 'Date' in df.columns:
|
||||
df['Date'] = pd.to_datetime(df['Date'])
|
||||
df.set_index('Date', inplace=True)
|
||||
|
||||
# 确保按日期升序排序
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
logger.info(f"成功获取{market_type}数据 {stock_code}, 数据点数: {len(df)}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def get_multiple_stocks_data(self, stock_codes: List[str],
|
||||
market_type: str = 'A',
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_concurrency: int = 5) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
异步批量获取多只股票数据
|
||||
|
||||
Args:
|
||||
stock_codes: 股票代码列表
|
||||
market_type: 市场类型,默认为'A'股
|
||||
start_date: 开始日期,格式YYYYMMDD
|
||||
end_date: 结束日期,格式YYYYMMDD
|
||||
max_concurrency: 最大并发数,默认为5
|
||||
|
||||
Returns:
|
||||
字典,键为股票代码,值为对应的DataFrame
|
||||
"""
|
||||
# 使用信号量控制并发数
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def get_with_semaphore(code):
|
||||
async with semaphore:
|
||||
try:
|
||||
return code, await self.get_stock_data(code, market_type, start_date, end_date)
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票 {code} 数据时出错: {str(e)}")
|
||||
return code, None
|
||||
|
||||
# 创建异步任务
|
||||
tasks = [get_with_semaphore(code) for code in stock_codes]
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# 构建结果字典,过滤掉失败的请求
|
||||
return {code: df for code, df in results if df is not None}
|
||||
128
services/stock_scorer.py
Normal file
128
services/stock_scorer.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, Optional, Any, List, Tuple
|
||||
from logger import get_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class StockScorer:
|
||||
"""
|
||||
股票评分服务
|
||||
负责根据技术指标计算股票的综合评分
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化股票评分服务"""
|
||||
logger.debug("初始化StockScorer")
|
||||
|
||||
def calculate_score(self, df: pd.DataFrame) -> int:
|
||||
"""
|
||||
计算股票评分(满分100分)
|
||||
|
||||
Args:
|
||||
df: 包含技术指标的DataFrame
|
||||
|
||||
Returns:
|
||||
股票评分(0-100的整数)
|
||||
"""
|
||||
try:
|
||||
# 使用最新的数据点进行评分
|
||||
latest = df.iloc[-1]
|
||||
|
||||
# 初始得分为0
|
||||
score = 0
|
||||
|
||||
# 移动平均线评分(25分)
|
||||
if latest['MA5'] > latest['MA20'] > latest['MA60']:
|
||||
# 短期、中期和长期均线呈多头排列
|
||||
score += 25
|
||||
elif latest['MA5'] > latest['MA20']:
|
||||
# 短期均线在中期均线之上
|
||||
score += 15
|
||||
elif latest['Close'] > latest['MA20']:
|
||||
# 股价在中期均线之上
|
||||
score += 10
|
||||
|
||||
# RSI评分(25分)
|
||||
rsi = latest['RSI']
|
||||
if 45 <= rsi <= 55:
|
||||
# RSI在中间区域,可能即将爆发
|
||||
score += 15
|
||||
elif 55 < rsi < 70:
|
||||
# RSI在强势区域但未超买
|
||||
score += 25
|
||||
elif 30 < rsi < 45:
|
||||
# RSI在弱势区域但未超卖
|
||||
score += 10
|
||||
elif rsi >= 70:
|
||||
# RSI超买
|
||||
score += 5
|
||||
elif rsi <= 30:
|
||||
# RSI超卖
|
||||
score += 15
|
||||
|
||||
# MACD得分(20分)
|
||||
if latest['MACD'] > latest['Signal']:
|
||||
score += 20
|
||||
|
||||
# 成交量得分(30分)
|
||||
if latest['Volume_Ratio'] > 1.5:
|
||||
score += 30
|
||||
elif latest['Volume_Ratio'] > 1:
|
||||
score += 15
|
||||
|
||||
return score
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算评分时出错: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise
|
||||
|
||||
def get_recommendation(self, score: int) -> str:
|
||||
"""
|
||||
根据评分获取投资建议
|
||||
|
||||
Args:
|
||||
score: 股票评分(0-100)
|
||||
|
||||
Returns:
|
||||
投资建议文本
|
||||
"""
|
||||
if score >= 80:
|
||||
return "强烈推荐"
|
||||
elif score >= 70:
|
||||
return "推荐"
|
||||
elif score >= 60:
|
||||
return "谨慎推荐"
|
||||
elif score >= 40:
|
||||
return "观望"
|
||||
elif score >= 20:
|
||||
return "不推荐"
|
||||
else:
|
||||
return "强烈不推荐"
|
||||
|
||||
def batch_score_stocks(self, stock_dfs: Dict[str, pd.DataFrame]) -> List[Tuple[str, int, str]]:
|
||||
"""
|
||||
批量评分多只股票
|
||||
|
||||
Args:
|
||||
stock_dfs: 字典,键为股票代码,值为DataFrame
|
||||
|
||||
Returns:
|
||||
评分结果列表,每项为(股票代码, 评分, 推荐)的三元组
|
||||
"""
|
||||
results = []
|
||||
|
||||
for stock_code, df in stock_dfs.items():
|
||||
try:
|
||||
score = self.calculate_score(df)
|
||||
recommendation = self.get_recommendation(score)
|
||||
results.append((stock_code, score, recommendation))
|
||||
except Exception as e:
|
||||
logger.error(f"评分股票 {stock_code} 时出错: {str(e)}")
|
||||
|
||||
# 按评分降序排序
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return results
|
||||
187
services/technical_indicator.py
Normal file
187
services/technical_indicator.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, Optional, Any
|
||||
from logger import get_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class TechnicalIndicator:
|
||||
"""
|
||||
技术指标计算服务
|
||||
负责计算常见的股票技术指标
|
||||
"""
|
||||
|
||||
def __init__(self, params: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
初始化技术指标计算服务
|
||||
|
||||
Args:
|
||||
params: 技术指标参数配置
|
||||
"""
|
||||
# 默认参数设置
|
||||
self.params = params or {
|
||||
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
|
||||
'rsi_period': 14,
|
||||
'bollinger_period': 20,
|
||||
'bollinger_std': 2,
|
||||
'volume_ma_period': 20,
|
||||
'atr_period': 14
|
||||
}
|
||||
|
||||
logger.debug(f"初始化TechnicalIndicator,参数: {self.params}")
|
||||
|
||||
def calculate_ema(self, series: pd.Series, period: int) -> pd.Series:
|
||||
"""
|
||||
计算指数移动平均线
|
||||
|
||||
Args:
|
||||
series: 价格序列
|
||||
period: 周期
|
||||
|
||||
Returns:
|
||||
EMA序列
|
||||
"""
|
||||
return series.ewm(span=period, adjust=False).mean()
|
||||
|
||||
def calculate_rsi(self, series: pd.Series, period: int) -> pd.Series:
|
||||
"""
|
||||
计算相对强弱指标(RSI)
|
||||
|
||||
Args:
|
||||
series: 价格序列
|
||||
period: 周期
|
||||
|
||||
Returns:
|
||||
RSI序列
|
||||
"""
|
||||
delta = series.diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = -delta.where(delta < 0, 0)
|
||||
|
||||
avg_gain = gain.rolling(window=period).mean()
|
||||
avg_loss = loss.rolling(window=period).mean()
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
return rsi
|
||||
|
||||
def calculate_macd(self, series: pd.Series) -> tuple:
|
||||
"""
|
||||
计算MACD指标
|
||||
|
||||
Args:
|
||||
series: 价格序列
|
||||
|
||||
Returns:
|
||||
(MACD线, 信号线, 柱状图)的元组
|
||||
"""
|
||||
ema12 = self.calculate_ema(series, 12)
|
||||
ema26 = self.calculate_ema(series, 26)
|
||||
|
||||
macd = ema12 - ema26
|
||||
signal = self.calculate_ema(macd, 9)
|
||||
histogram = macd - signal
|
||||
|
||||
return macd, signal, histogram
|
||||
|
||||
def calculate_bollinger_bands(self, series: pd.Series, period: int, std_dev: float) -> tuple:
|
||||
"""
|
||||
计算布林带
|
||||
|
||||
Args:
|
||||
series: 价格序列
|
||||
period: 周期
|
||||
std_dev: 标准差倍数
|
||||
|
||||
Returns:
|
||||
(中轨, 上轨, 下轨)的元组
|
||||
"""
|
||||
middle = series.rolling(window=period).mean()
|
||||
std = series.rolling(window=period).std()
|
||||
|
||||
upper = middle + std_dev * std
|
||||
lower = middle - std_dev * std
|
||||
|
||||
return middle, upper, lower
|
||||
|
||||
def calculate_atr(self, df: pd.DataFrame, period: int) -> pd.Series:
|
||||
"""
|
||||
计算平均真实波幅(ATR)
|
||||
|
||||
Args:
|
||||
df: 包含High, Low, Close列的DataFrame
|
||||
period: 周期
|
||||
|
||||
Returns:
|
||||
ATR序列
|
||||
"""
|
||||
high = df['High']
|
||||
low = df['Low']
|
||||
close = df['Close']
|
||||
|
||||
tr1 = high - low
|
||||
tr2 = abs(high - close.shift())
|
||||
tr3 = abs(low - close.shift())
|
||||
|
||||
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||||
atr = tr.rolling(window=period).mean()
|
||||
|
||||
return atr
|
||||
|
||||
def calculate_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算所有技术指标
|
||||
|
||||
Args:
|
||||
df: 原始价格数据,包含Open, High, Low, Close, Volume列
|
||||
|
||||
Returns:
|
||||
添加了技术指标的DataFrame
|
||||
"""
|
||||
try:
|
||||
# 复制数据框
|
||||
result_df = df.copy()
|
||||
|
||||
# 移动平均线
|
||||
for name, period in self.params['ma_periods'].items():
|
||||
result_df[f'MA{period}'] = result_df['Close'].rolling(window=period).mean()
|
||||
|
||||
# RSI
|
||||
result_df['RSI'] = self.calculate_rsi(result_df['Close'], self.params['rsi_period'])
|
||||
|
||||
# MACD
|
||||
macd, signal, histogram = self.calculate_macd(result_df['Close'])
|
||||
result_df['MACD'] = macd
|
||||
result_df['Signal'] = signal
|
||||
result_df['Histogram'] = histogram
|
||||
|
||||
# 布林带
|
||||
middle, upper, lower = self.calculate_bollinger_bands(
|
||||
result_df['Close'],
|
||||
self.params['bollinger_period'],
|
||||
self.params['bollinger_std']
|
||||
)
|
||||
result_df['BB_Middle'] = middle
|
||||
result_df['BB_Upper'] = upper
|
||||
result_df['BB_Lower'] = lower
|
||||
|
||||
# 成交量移动平均
|
||||
result_df['Volume_MA'] = result_df['Volume'].rolling(window=self.params['volume_ma_period']).mean()
|
||||
|
||||
# 成交量比率
|
||||
result_df['Volume_Ratio'] = result_df['Volume'] / result_df['Volume_MA']
|
||||
|
||||
# ATR
|
||||
result_df['ATR'] = self.calculate_atr(result_df, self.params['atr_period'])
|
||||
|
||||
# 波动率 (过去20天收盘价的标准差/均值)
|
||||
result_df['Volatility'] = result_df['Close'].rolling(window=20).std() / result_df['Close'].rolling(window=20).mean() * 100
|
||||
|
||||
return result_df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算技术指标时出错: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise
|
||||
151
services/us_stock_service_async.py
Normal file
151
services/us_stock_service_async.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any, Optional
|
||||
from logger import get_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class USStockServiceAsync:
|
||||
"""
|
||||
异步美股服务
|
||||
提供美股数据的异步搜索和获取功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化异步美股服务"""
|
||||
logger.debug("初始化USStockServiceAsync")
|
||||
|
||||
# 可选:添加缓存以减少频繁请求
|
||||
self._cache = None
|
||||
self._cache_timestamp = None
|
||||
|
||||
async def search_us_stocks(self, keyword: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
异步搜索美股代码
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
|
||||
Returns:
|
||||
匹配的股票列表
|
||||
"""
|
||||
try:
|
||||
logger.info(f"异步搜索美股: {keyword}")
|
||||
|
||||
# 使用线程池执行同步的akshare调用
|
||||
df = await asyncio.to_thread(self._get_us_stocks_data)
|
||||
|
||||
# 模糊匹配搜索
|
||||
mask = df['name'].str.contains(keyword, case=False, na=False)
|
||||
results = df[mask]
|
||||
|
||||
# 格式化返回结果并处理 NaN 值
|
||||
formatted_results = []
|
||||
for _, row in results.iterrows():
|
||||
formatted_results.append({
|
||||
'name': row['name'] if pd.notna(row['name']) else '',
|
||||
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
|
||||
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
|
||||
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0
|
||||
})
|
||||
|
||||
logger.info(f"美股搜索完成,找到 {len(formatted_results)} 个匹配项")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"搜索美股代码失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
|
||||
def _get_us_stocks_data(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取美股数据(同步方法,将被异步方法调用)
|
||||
|
||||
Returns:
|
||||
包含美股数据的DataFrame
|
||||
"""
|
||||
import akshare as ak
|
||||
|
||||
try:
|
||||
# 获取美股数据
|
||||
df = ak.stock_us_spot_em()
|
||||
|
||||
# 转换列名
|
||||
df = df.rename(columns={
|
||||
"序号": "index",
|
||||
"名称": "name",
|
||||
"最新价": "price",
|
||||
"涨跌额": "price_change",
|
||||
"涨跌幅": "price_change_percent",
|
||||
"开盘价": "open",
|
||||
"最高价": "high",
|
||||
"最低价": "low",
|
||||
"昨收价": "pre_close",
|
||||
"总市值": "market_value",
|
||||
"市盈率": "pe_ratio",
|
||||
"成交量": "volume",
|
||||
"成交额": "turnover",
|
||||
"振幅": "amplitude",
|
||||
"换手率": "turnover_rate",
|
||||
"代码": "symbol"
|
||||
})
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股数据失败: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise Exception(f"获取美股数据失败: {str(e)}")
|
||||
|
||||
async def get_us_stock_detail(self, symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
异步获取单个美股详细信息
|
||||
|
||||
Args:
|
||||
symbol: 股票代码
|
||||
|
||||
Returns:
|
||||
股票详细信息
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取美股详情: {symbol}")
|
||||
|
||||
# 使用线程池执行同步的akshare调用
|
||||
df = await asyncio.to_thread(self._get_us_stocks_data)
|
||||
|
||||
# 精确匹配股票代码
|
||||
result = df[df['symbol'] == symbol]
|
||||
|
||||
if len(result) == 0:
|
||||
raise Exception(f"未找到股票代码: {symbol}")
|
||||
|
||||
# 获取第一行数据
|
||||
row = result.iloc[0]
|
||||
|
||||
# 格式化为字典
|
||||
stock_detail = {
|
||||
'name': row['name'] if pd.notna(row['name']) else '',
|
||||
'symbol': str(row['symbol']) if pd.notna(row['symbol']) else '',
|
||||
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
|
||||
'price_change': float(row['price_change']) if pd.notna(row['price_change']) else 0.0,
|
||||
'price_change_percent': float(row['price_change_percent'].strip('%'))/100 if pd.notna(row['price_change_percent']) else 0.0,
|
||||
'open': float(row['open']) if pd.notna(row['open']) else 0.0,
|
||||
'high': float(row['high']) if pd.notna(row['high']) else 0.0,
|
||||
'low': float(row['low']) if pd.notna(row['low']) else 0.0,
|
||||
'pre_close': float(row['pre_close']) if pd.notna(row['pre_close']) else 0.0,
|
||||
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
|
||||
'pe_ratio': float(row['pe_ratio']) if pd.notna(row['pe_ratio']) else 0.0,
|
||||
'volume': float(row['volume']) if pd.notna(row['volume']) else 0.0,
|
||||
'turnover': float(row['turnover']) if pd.notna(row['turnover']) else 0.0
|
||||
}
|
||||
|
||||
logger.info(f"获取美股详情成功: {symbol}")
|
||||
return stock_detail
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取美股详情失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
Reference in New Issue
Block a user