Files
stock-scanner/services/stock_data_provider.py
2025-03-06 17:11:15 +08:00

184 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import asyncio
import os
from typing import Dict, List, Optional, Tuple, Any
from logger import get_logger
# 获取日志器
logger = get_logger()
class StockDataProvider:
"""
异步股票数据提供服务
负责获取股票、基金等金融产品的历史数据
"""
def __init__(self):
"""初始化数据提供者服务"""
logger.debug("初始化StockDataProvider")
async def get_stock_data(self, stock_code: str, market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
异步获取股票或基金数据
Args:
stock_code: 股票代码
market_type: 市场类型,默认为'A'
start_date: 开始日期格式YYYYMMDD默认为一年前
end_date: 结束日期格式YYYYMMDD默认为今天
Returns:
包含历史数据的DataFrame
"""
# 使用线程池执行同步的akshare调用
return await asyncio.to_thread(
self._get_stock_data_sync,
stock_code,
market_type,
start_date,
end_date
)
def _get_stock_data_sync(self, stock_code: str, market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
同步获取股票数据的实现
将被异步方法调用
"""
import akshare as ak
if start_date is None:
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
# 验证股票代码格式
if market_type == 'A':
# 上海证券交易所股票代码以6开头
# 深圳证券交易所股票代码以0或3开头
# 科创板股票代码以688开头
# 北京证券交易所股票代码以8开头
valid_prefixes = ['0', '3', '6', '688', '8']
valid_format = False
for prefix in valid_prefixes:
if stock_code.startswith(prefix):
valid_format = True
break
if not valid_format:
error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头"
logger.error(f"[股票代码格式错误] {error_msg}")
raise ValueError(error_msg)
logger.debug(f"获取A股数据: {stock_code}")
df = ak.stock_zh_a_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
elif market_type in ['HK']:
logger.debug(f"获取港股数据: {stock_code}")
df = ak.stock_hk_daily(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
elif market_type in ['US']:
logger.debug(f"获取美股数据: {stock_code}")
df = ak.stock_us_daily(
symbol=stock_code,
adjust="qfq"
)
# 过滤日期
df = df[(df.index >= start_date) & (df.index <= end_date)]
elif market_type in ['ETF', 'LOF']:
logger.debug(f"获取{market_type}基金数据: {stock_code}")
df = ak.fund_etf_hist_sina(
symbol=stock_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
else:
error_msg = f"不支持的市场类型: {market_type}"
logger.error(f"[市场类型错误] {error_msg}")
raise ValueError(error_msg)
# 标准化列名
if market_type == 'A':
# 根据实际数据结构调整列名映射
# 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover']
elif market_type in ['HK', 'US']:
# 根据实际情况调整
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
elif market_type in ['ETF', 'LOF']:
# 基金数据可能有不同的列
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
# 确保日期列是日期类型
if 'Date' in df.columns:
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)
# 确保按日期升序排序
df.sort_index(inplace=True)
logger.info(f"成功获取{market_type}数据 {stock_code}, 数据点数: {len(df)}")
return df
except Exception as e:
error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}"
logger.error(error_msg)
logger.exception(e)
raise Exception(error_msg)
async def get_multiple_stocks_data(self, stock_codes: List[str],
market_type: str = 'A',
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_concurrency: int = 5) -> Dict[str, pd.DataFrame]:
"""
异步批量获取多只股票数据
Args:
stock_codes: 股票代码列表
market_type: 市场类型,默认为'A'
start_date: 开始日期格式YYYYMMDD
end_date: 结束日期格式YYYYMMDD
max_concurrency: 最大并发数默认为5
Returns:
字典键为股票代码值为对应的DataFrame
"""
# 使用信号量控制并发数
semaphore = asyncio.Semaphore(max_concurrency)
async def get_with_semaphore(code):
async with semaphore:
try:
return code, await self.get_stock_data(code, market_type, start_date, end_date)
except Exception as e:
logger.error(f"获取股票 {code} 数据时出错: {str(e)}")
return code, None
# 创建异步任务
tasks = [get_with_semaphore(code) for code in stock_codes]
# 等待所有任务完成
results = await asyncio.gather(*tasks)
# 构建结果字典,过滤掉失败的请求
return {code: df for code, df in results if df is not None}