346 lines
12 KiB
Python
346 lines
12 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
import os
|
||
import requests
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dotenv import load_dotenv
|
||
import logging
|
||
|
||
class StockAnalyzer:
|
||
def __init__(self, initial_cash=1000000):
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||
self.logger = logging.getLogger(__name__)
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# 设置 Gemini API
|
||
self.gemini_api_url = os.getenv('GEMINI_API_URL')
|
||
self.gemini_api_key = os.getenv('GEMINI_API_KEY')
|
||
|
||
# 配置参数
|
||
self.params = {
|
||
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
|
||
'rsi_period': 14,
|
||
'bollinger_period': 20,
|
||
'bollinger_std': 2,
|
||
'volume_ma_period': 20,
|
||
'atr_period': 14
|
||
}
|
||
|
||
# 添加市场类型枚举
|
||
self.MARKET_TYPES = {
|
||
'A': 'A股',
|
||
'HK': '港股',
|
||
'CRYPTO': '加密货币'
|
||
}
|
||
|
||
def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
|
||
"""获取股票数据"""
|
||
import akshare as ak
|
||
|
||
if start_date is None:
|
||
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
try:
|
||
# 根据市场类型获取数据
|
||
if market_type == 'A':
|
||
df = ak.stock_zh_a_hist(
|
||
symbol=stock_code,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
adjust="qfq"
|
||
)
|
||
# A股数据列名映射
|
||
elif market_type == 'HK':
|
||
df = ak.stock_hk_daily(
|
||
symbol=stock_code,
|
||
adjust="qfq"
|
||
)
|
||
elif market_type == 'US':
|
||
df = ak.stock_us_hist(
|
||
symbol=stock_code,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
adjust="qfq"
|
||
)
|
||
elif market_type == 'CRYPTO':
|
||
df = ak.crypto_js_spot(
|
||
symbol=stock_code
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的市场类型: {market_type}")
|
||
|
||
# 重命名列名以匹配分析需求
|
||
df = df.rename(columns={
|
||
"日期": "date",
|
||
"开盘": "open",
|
||
"收盘": "close",
|
||
"最高": "high",
|
||
"最低": "low",
|
||
"成交量": "volume"
|
||
})
|
||
|
||
# 确保日期格式正确
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
|
||
# 数据类型转换
|
||
numeric_columns = ['open', 'close', 'high', 'low', 'volume']
|
||
df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
|
||
|
||
# 删除空值
|
||
df = df.dropna()
|
||
|
||
return df.sort_values('date')
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"获取股票数据失败: {str(e)}")
|
||
raise Exception(f"获取股票数据失败: {str(e)}")
|
||
|
||
def calculate_ema(self, series, period):
|
||
"""计算指数移动平均线"""
|
||
return series.ewm(span=period, adjust=False).mean()
|
||
|
||
def calculate_rsi(self, series, period):
|
||
"""计算RSI指标"""
|
||
delta = series.diff()
|
||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||
rs = gain / loss
|
||
return 100 - (100 / (1 + rs))
|
||
|
||
def calculate_macd(self, series):
|
||
"""计算MACD指标"""
|
||
exp1 = series.ewm(span=12, adjust=False).mean()
|
||
exp2 = series.ewm(span=26, adjust=False).mean()
|
||
macd = exp1 - exp2
|
||
signal = macd.ewm(span=9, adjust=False).mean()
|
||
hist = macd - signal
|
||
return macd, signal, hist
|
||
|
||
def calculate_bollinger_bands(self, series, period, std_dev):
|
||
"""计算布林带"""
|
||
middle = series.rolling(window=period).mean()
|
||
std = series.rolling(window=period).std()
|
||
upper = middle + (std * std_dev)
|
||
lower = middle - (std * std_dev)
|
||
return upper, middle, lower
|
||
|
||
def calculate_atr(self, df, period):
|
||
"""计算ATR指标"""
|
||
high = df['high']
|
||
low = df['low']
|
||
close = df['close'].shift(1)
|
||
|
||
tr1 = high - low
|
||
tr2 = abs(high - close)
|
||
tr3 = abs(low - close)
|
||
|
||
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||
return tr.rolling(window=period).mean()
|
||
|
||
def calculate_indicators(self, df):
|
||
"""计算技术指标"""
|
||
try:
|
||
# 计算移动平均线
|
||
df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
|
||
df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
|
||
df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
|
||
|
||
# 计算RSI
|
||
df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
|
||
|
||
# 计算MACD
|
||
df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
|
||
|
||
# 计算布林带
|
||
df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
|
||
df['close'],
|
||
self.params['bollinger_period'],
|
||
self.params['bollinger_std']
|
||
)
|
||
|
||
# 成交量分析
|
||
df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
|
||
df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
|
||
|
||
# 计算ATR和波动率
|
||
df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
|
||
df['Volatility'] = df['ATR'] / df['close'] * 100
|
||
|
||
# 动量指标
|
||
df['ROC'] = df['close'].pct_change(periods=10) * 100
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"计算技术指标时出错: {str(e)}")
|
||
raise
|
||
|
||
def calculate_score(self, df):
|
||
"""计算股票评分"""
|
||
try:
|
||
score = 0
|
||
latest = df.iloc[-1]
|
||
|
||
# 趋势得分 (30分)
|
||
if latest['MA5'] > latest['MA20']:
|
||
score += 15
|
||
if latest['MA20'] > latest['MA60']:
|
||
score += 15
|
||
|
||
# RSI得分 (20分)
|
||
if 30 <= latest['RSI'] <= 70:
|
||
score += 20
|
||
elif latest['RSI'] < 30: # 超卖
|
||
score += 15
|
||
|
||
# MACD得分 (20分)
|
||
if latest['MACD'] > latest['Signal']:
|
||
score += 20
|
||
|
||
# 成交量得分 (30分)
|
||
if latest['Volume_Ratio'] > 1.5:
|
||
score += 30
|
||
elif latest['Volume_Ratio'] > 1:
|
||
score += 15
|
||
|
||
return score
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"计算评分时出错: {str(e)}")
|
||
raise
|
||
|
||
def get_ai_analysis(self, df, stock_code):
|
||
"""使用 Gemini 进行 AI 分析"""
|
||
try:
|
||
recent_data = df.tail(14).to_dict('records')
|
||
|
||
technical_summary = {
|
||
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
|
||
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
|
||
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
|
||
'rsi_level': df.iloc[-1]['RSI']
|
||
}
|
||
|
||
prompt = f"""
|
||
分析股票 {stock_code}:
|
||
|
||
技术指标概要:
|
||
{technical_summary}
|
||
|
||
近14日交易数据:
|
||
{recent_data}
|
||
|
||
请提供:
|
||
1. 趋势分析(包含支撑位和压力位)
|
||
2. 成交量分析及其含义
|
||
3. 风险评估(包含波动率分析)
|
||
4. 短期和中期目标价位
|
||
5. 关键技术位分析
|
||
6. 具体交易建议(包含止损位)
|
||
|
||
请基于技术指标和市场动态进行分析,给出具体数据支持。
|
||
"""
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {self.gemini_api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
data = {
|
||
"model": os.getenv('GEMINI_API_MODEL'),
|
||
"messages": [{"role": "user", "content": prompt}]
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{self.gemini_api_url}/v1/chat/completions",
|
||
headers=headers,
|
||
json=data,
|
||
timeout=30
|
||
)
|
||
print(headers)
|
||
print(data)
|
||
print(response.json())
|
||
|
||
if response.status_code == 200:
|
||
return response.json()['choices'][0]['message']['content']
|
||
else:
|
||
return "AI 分析暂时无法使用"
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"AI 分析发生错误: {str(e)}")
|
||
return "AI 分析过程中发生错误"
|
||
|
||
def get_recommendation(self, score):
|
||
"""根据得分给出建议"""
|
||
if score >= 80:
|
||
return '强烈推荐买入'
|
||
elif score >= 60:
|
||
return '建议买入'
|
||
elif score >= 40:
|
||
return '观望'
|
||
elif score >= 20:
|
||
return '建议卖出'
|
||
else:
|
||
return '强烈建议卖出'
|
||
|
||
def analyze_stock(self, stock_code, market_type='A'):
|
||
"""分析单个股票"""
|
||
try:
|
||
# 获取股票数据
|
||
df = self.get_stock_data(stock_code, market_type)
|
||
|
||
# 计算技术指标
|
||
df = self.calculate_indicators(df)
|
||
|
||
# 评分系统
|
||
score = self.calculate_score(df)
|
||
|
||
# 获取最新数据
|
||
latest = df.iloc[-1]
|
||
prev = df.iloc[-2]
|
||
|
||
# 生成报告(保持原有格式)
|
||
report = {
|
||
'stock_code': stock_code,
|
||
'analysis_date': datetime.now().strftime('%Y-%m-%d'),
|
||
'score': score,
|
||
'price': latest['close'],
|
||
'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
|
||
'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
|
||
'rsi': latest['RSI'],
|
||
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
|
||
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
|
||
'recommendation': self.get_recommendation(score),
|
||
'ai_analysis': self.get_ai_analysis(df, stock_code)
|
||
}
|
||
|
||
return report
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"分析股票时出错: {str(e)}")
|
||
raise
|
||
|
||
def scan_market(self, stock_list, min_score=60, market_type='A'):
|
||
"""扫描市场,寻找符合条件的股票"""
|
||
recommendations = []
|
||
|
||
for stock_code in stock_list:
|
||
try:
|
||
report = self.analyze_stock(stock_code, market_type)
|
||
if report['score'] >= min_score:
|
||
recommendations.append(report)
|
||
except Exception as e:
|
||
self.logger.error(f"分析股票 {stock_code} 时出错: {str(e)}")
|
||
continue
|
||
|
||
# 按得分排序
|
||
recommendations.sort(key=lambda x: x['score'], reverse=True)
|
||
return recommendations
|