refactor: 重构代码结构
This commit is contained in:
184
services/stock_data_provider.py
Normal file
184
services/stock_data_provider.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from logger import get_logger
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
class StockDataProvider:
|
||||
"""
|
||||
异步股票数据提供服务
|
||||
负责获取股票、基金等金融产品的历史数据
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化数据提供者服务"""
|
||||
logger.debug("初始化StockDataProvider")
|
||||
|
||||
async def get_stock_data(self, stock_code: str, market_type: str = 'A',
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None) -> pd.DataFrame:
|
||||
"""
|
||||
异步获取股票或基金数据
|
||||
|
||||
Args:
|
||||
stock_code: 股票代码
|
||||
market_type: 市场类型,默认为'A'股
|
||||
start_date: 开始日期,格式YYYYMMDD,默认为一年前
|
||||
end_date: 结束日期,格式YYYYMMDD,默认为今天
|
||||
|
||||
Returns:
|
||||
包含历史数据的DataFrame
|
||||
"""
|
||||
# 使用线程池执行同步的akshare调用
|
||||
return await asyncio.to_thread(
|
||||
self._get_stock_data_sync,
|
||||
stock_code,
|
||||
market_type,
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
|
||||
def _get_stock_data_sync(self, stock_code: str, market_type: str = 'A',
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None) -> pd.DataFrame:
|
||||
"""
|
||||
同步获取股票数据的实现
|
||||
将被异步方法调用
|
||||
"""
|
||||
import akshare as ak
|
||||
|
||||
if start_date is None:
|
||||
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
try:
|
||||
# 验证股票代码格式
|
||||
if market_type == 'A':
|
||||
# 上海证券交易所股票代码以6开头
|
||||
# 深圳证券交易所股票代码以0或3开头
|
||||
# 科创板股票代码以688开头
|
||||
# 北京证券交易所股票代码以8开头
|
||||
valid_prefixes = ['0', '3', '6', '688', '8']
|
||||
valid_format = False
|
||||
|
||||
for prefix in valid_prefixes:
|
||||
if stock_code.startswith(prefix):
|
||||
valid_format = True
|
||||
break
|
||||
|
||||
if not valid_format:
|
||||
error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头"
|
||||
logger.error(f"[股票代码格式错误] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.debug(f"获取A股数据: {stock_code}")
|
||||
df = ak.stock_zh_a_hist(
|
||||
symbol=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
adjust="qfq"
|
||||
)
|
||||
|
||||
elif market_type in ['HK']:
|
||||
logger.debug(f"获取港股数据: {stock_code}")
|
||||
df = ak.stock_hk_daily(
|
||||
symbol=stock_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
adjust="qfq"
|
||||
)
|
||||
|
||||
elif market_type in ['US']:
|
||||
logger.debug(f"获取美股数据: {stock_code}")
|
||||
df = ak.stock_us_daily(
|
||||
symbol=stock_code,
|
||||
adjust="qfq"
|
||||
)
|
||||
# 过滤日期
|
||||
df = df[(df.index >= start_date) & (df.index <= end_date)]
|
||||
|
||||
elif market_type in ['ETF', 'LOF']:
|
||||
logger.debug(f"获取{market_type}基金数据: {stock_code}")
|
||||
df = ak.fund_etf_hist_sina(
|
||||
symbol=stock_code,
|
||||
start_date=start_date.replace('-', ''),
|
||||
end_date=end_date.replace('-', '')
|
||||
)
|
||||
|
||||
else:
|
||||
error_msg = f"不支持的市场类型: {market_type}"
|
||||
logger.error(f"[市场类型错误] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# 标准化列名
|
||||
if market_type == 'A':
|
||||
# 根据实际数据结构调整列名映射
|
||||
# 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
|
||||
df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover']
|
||||
elif market_type in ['HK', 'US']:
|
||||
# 根据实际情况调整
|
||||
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
|
||||
elif market_type in ['ETF', 'LOF']:
|
||||
# 基金数据可能有不同的列
|
||||
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
|
||||
|
||||
# 确保日期列是日期类型
|
||||
if 'Date' in df.columns:
|
||||
df['Date'] = pd.to_datetime(df['Date'])
|
||||
df.set_index('Date', inplace=True)
|
||||
|
||||
# 确保按日期升序排序
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
logger.info(f"成功获取{market_type}数据 {stock_code}, 数据点数: {len(df)}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
|
||||
async def get_multiple_stocks_data(self, stock_codes: List[str],
|
||||
market_type: str = 'A',
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_concurrency: int = 5) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
异步批量获取多只股票数据
|
||||
|
||||
Args:
|
||||
stock_codes: 股票代码列表
|
||||
market_type: 市场类型,默认为'A'股
|
||||
start_date: 开始日期,格式YYYYMMDD
|
||||
end_date: 结束日期,格式YYYYMMDD
|
||||
max_concurrency: 最大并发数,默认为5
|
||||
|
||||
Returns:
|
||||
字典,键为股票代码,值为对应的DataFrame
|
||||
"""
|
||||
# 使用信号量控制并发数
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def get_with_semaphore(code):
|
||||
async with semaphore:
|
||||
try:
|
||||
return code, await self.get_stock_data(code, market_type, start_date, end_date)
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票 {code} 数据时出错: {str(e)}")
|
||||
return code, None
|
||||
|
||||
# 创建异步任务
|
||||
tasks = [get_with_semaphore(code) for code in stock_codes]
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# 构建结果字典,过滤掉失败的请求
|
||||
return {code: df for code, df in results if df is not None}
|
||||
Reference in New Issue
Block a user