feat: 支持自定义API

This commit is contained in:
Cassianvale
2025-03-04 13:01:38 +08:00
parent 5ccab7ab43
commit 17ed403c3e
3 changed files with 1001 additions and 799 deletions

View File

@@ -1,343 +1,344 @@
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import os
import requests
from typing import Dict, List, Optional, Tuple
from dotenv import load_dotenv
class StockAnalyzer:
def __init__(self, initial_cash=1000000):
# 加载环境变量
load_dotenv()
# 设置 Gemini API
self.API_URL = os.getenv('API_URL')
self.API_KEY = os.getenv('API_KEY')
# 配置参数
self.params = {
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
'rsi_period': 14,
'bollinger_period': 20,
'bollinger_std': 2,
'volume_ma_period': 20,
'atr_period': 14
}
def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
"""获取股票数据"""
import akshare as ak
if start_date is None:
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
# 根据市场类型获取数据
if market_type == 'A':
df = ak.stock_zh_a_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
# A股数据列名映射
elif market_type == 'HK':
df = ak.stock_hk_daily(
symbol=stock_code,
adjust="qfq"
)
elif market_type == 'US':
df = ak.stock_us_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
# elif market_type == 'CRYPTO':
# df = ak.crypto_js_spot(
# symbol=stock_code
# )
else:
raise ValueError(f"不支持的市场类型: {market_type}")
# 重命名列名以匹配分析需求
df = df.rename(columns={
"日期": "date",
"开盘": "open",
"": "close",
"最高": "high",
"": "low",
"成交量": "volume"
})
# 确保日期格式正确
df['date'] = pd.to_datetime(df['date'])
# 数据类型转换
numeric_columns = ['open', 'close', 'high', 'low', 'volume']
df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
# 删除空值
df = df.dropna()
return df.sort_values('date')
except Exception as e:
raise Exception(f"获取股票数据失败: {str(e)}")
def calculate_ema(self, series, period):
"""计算指数移动平均线"""
return series.ewm(span=period, adjust=False).mean()
def calculate_rsi(self, series, period):
"""计算RSI指标"""
delta = series.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
def calculate_macd(self, series):
"""计算MACD指标"""
exp1 = series.ewm(span=12, adjust=False).mean()
exp2 = series.ewm(span=26, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=9, adjust=False).mean()
hist = macd - signal
return macd, signal, hist
def calculate_bollinger_bands(self, series, period, std_dev):
"""计算布林带"""
middle = series.rolling(window=period).mean()
std = series.rolling(window=period).std()
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return upper, middle, lower
def calculate_atr(self, df, period):
"""计算ATR指标"""
high = df['high']
low = df['low']
close = df['close'].shift(1)
tr1 = high - low
tr2 = abs(high - close)
tr3 = abs(low - close)
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
return tr.rolling(window=period).mean()
def calculate_indicators(self, df):
"""计算技术指标"""
try:
# 计算移动平均线
df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
# 计算RSI
df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
# 计算MACD
df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
# 计算布林带
df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
df['close'],
self.params['bollinger_period'],
self.params['bollinger_std']
)
# 成交量分析
df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
# 计算ATR和波动率
df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
df['Volatility'] = df['ATR'] / df['close'] * 100
# 动量指标
df['ROC'] = df['close'].pct_change(periods=10) * 100
return df
except Exception as e:
print(f"计算技术指标时出错: {str(e)}")
raise
def calculate_score(self, df):
"""计算股票评分"""
try:
score = 0
latest = df.iloc[-1]
# 趋势得分 (30分)
if latest['MA5'] > latest['MA20']:
score += 15
if latest['MA20'] > latest['MA60']:
score += 15
# RSI得分 (20分)
if 30 <= latest['RSI'] <= 70:
score += 20
elif latest['RSI'] < 30: # 超卖
score += 15
# MACD得分 (20分)
if latest['MACD'] > latest['Signal']:
score += 20
# 成交量得分 (30分)
if latest['Volume_Ratio'] > 1.5:
score += 30
elif latest['Volume_Ratio'] > 1:
score += 15
return score
except Exception as e:
print(f"计算评分时出错: {str(e)}")
raise
def get_ai_analysis(self, df, stock_code):
"""使用 OpenAI 进行 AI 分析"""
try:
recent_data = df.tail(14).to_dict('records')
technical_summary = {
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
'rsi_level': df.iloc[-1]['RSI']
}
prompt = f"""
分析股票 {stock_code}
技术指标概要:
{technical_summary}
近14日交易数据
{recent_data}
请提供:
1. 趋势分析(包含支撑位和压力位)
2. 成交量分析及其含义
3. 风险评估(包含波动率分析)
4. 短期和中期目标价位
5. 关键技术位分析
6. 具体交易建议(包含止损位)
请基于技术指标和市场动态进行分析,给出具体数据支持。
"""
# OpenAI API 调用
api_urls = [
f"{self.API_URL}/chat/completions",
f"{self.API_URL}/v1/chat/completions"
]
last_error = None
for api_url in api_urls:
try:
response = requests.post(
api_url,
headers={
"Authorization": f"Bearer {self.API_KEY}",
"Content-Type": "application/json"
},
json={
"model": os.getenv('API_MODEL', 'gpt-3.5-turbo'),
"messages": [{"role": "user", "content": prompt}]
},
timeout=30
)
if response.status_code == 200:
return response.json()['choices'][0]['message']['content']
else:
last_error = f"API 错误: {response.status_code} - {response.text}"
continue
except Exception as e:
last_error = str(e)
continue
print(f"AI 分析暂时无法使用: {last_error}")
return f"AI 分析暂时无法使用: {last_error}"
except Exception as e:
print(f"AI 分析发生错误: {str(e)}")
return f"AI 分析过程中发生错误: {str(e)}"
def get_recommendation(self, score):
"""根据得分给出建议"""
if score >= 80:
return '强烈推荐买入'
elif score >= 60:
return '建议买入'
elif score >= 40:
return '观望'
elif score >= 20:
return '建议卖出'
else:
return '强烈建议卖出'
def analyze_stock(self, stock_code, market_type='A'):
"""分析单个股票"""
try:
# 获取股票数据
df = self.get_stock_data(stock_code, market_type)
# 计算技术指标
df = self.calculate_indicators(df)
# 评分系统
score = self.calculate_score(df)
# 获取最新数据
latest = df.iloc[-1]
prev = df.iloc[-2]
# 生成报告(保持原有格式)
report = {
'stock_code': stock_code,
'analysis_date': datetime.now().strftime('%Y-%m-%d'),
'score': score,
'price': latest['close'],
'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
'rsi': latest['RSI'],
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
'recommendation': self.get_recommendation(score),
'ai_analysis': self.get_ai_analysis(df, stock_code)
}
return report
except Exception as e:
print(f"分析股票时出错: {str(e)}")
raise
def scan_market(self, stock_list, min_score=60, market_type='A'):
"""扫描市场,寻找符合条件的股票"""
recommendations = []
for stock_code in stock_list:
try:
report = self.analyze_stock(stock_code, market_type)
if report['score'] >= min_score:
recommendations.append(report)
except Exception as e:
print(f"分析股票 {stock_code} 时出错: {str(e)}")
continue
# 按得分排序
recommendations.sort(key=lambda x: x['score'], reverse=True)
return recommendations
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import os
import requests
from typing import Dict, List, Optional, Tuple
from dotenv import load_dotenv
class StockAnalyzer:
def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None):
# 加载环境变量
load_dotenv()
# 设置 API 配置,优先使用自定义配置,否则使用环境变量
self.API_URL = custom_api_url or os.getenv('API_URL')
self.API_KEY = custom_api_key or os.getenv('API_KEY')
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
# 配置参数
self.params = {
'ma_periods': {'short': 5, 'medium': 20, 'long': 60},
'rsi_period': 14,
'bollinger_period': 20,
'bollinger_std': 2,
'volume_ma_period': 20,
'atr_period': 14
}
def get_stock_data(self, stock_code, market_type='A', start_date=None, end_date=None, ):
"""获取股票数据"""
import akshare as ak
if start_date is None:
start_date = (datetime.now() - timedelta(days=365)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
# 根据市场类型获取数据
if market_type == 'A':
df = ak.stock_zh_a_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
# A股数据列名映射
elif market_type == 'HK':
df = ak.stock_hk_daily(
symbol=stock_code,
adjust="qfq"
)
elif market_type == 'US':
df = ak.stock_us_hist(
symbol=stock_code,
start_date=start_date,
end_date=end_date,
adjust="qfq"
)
# elif market_type == 'CRYPTO':
# df = ak.crypto_js_spot(
# symbol=stock_code
# )
else:
raise ValueError(f"不支持的市场类型: {market_type}")
# 重命名列名以匹配分析需求
df = df.rename(columns={
"日期": "date",
"": "open",
"收盘": "close",
"": "high",
"最低": "low",
"成交量": "volume"
})
# 确保日期格式正确
df['date'] = pd.to_datetime(df['date'])
# 数据类型转换
numeric_columns = ['open', 'close', 'high', 'low', 'volume']
df[numeric_columns] = df[numeric_columns].apply(pd.to_numeric, errors='coerce')
# 删除空值
df = df.dropna()
return df.sort_values('date')
except Exception as e:
raise Exception(f"获取股票数据失败: {str(e)}")
def calculate_ema(self, series, period):
"""计算指数移动平均线"""
return series.ewm(span=period, adjust=False).mean()
def calculate_rsi(self, series, period):
"""计算RSI指标"""
delta = series.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
def calculate_macd(self, series):
"""计算MACD指标"""
exp1 = series.ewm(span=12, adjust=False).mean()
exp2 = series.ewm(span=26, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=9, adjust=False).mean()
hist = macd - signal
return macd, signal, hist
def calculate_bollinger_bands(self, series, period, std_dev):
"""计算布林带"""
middle = series.rolling(window=period).mean()
std = series.rolling(window=period).std()
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return upper, middle, lower
def calculate_atr(self, df, period):
"""计算ATR指标"""
high = df['high']
low = df['low']
close = df['close'].shift(1)
tr1 = high - low
tr2 = abs(high - close)
tr3 = abs(low - close)
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
return tr.rolling(window=period).mean()
def calculate_indicators(self, df):
"""计算技术指标"""
try:
# 计算移动平均线
df['MA5'] = self.calculate_ema(df['close'], self.params['ma_periods']['short'])
df['MA20'] = self.calculate_ema(df['close'], self.params['ma_periods']['medium'])
df['MA60'] = self.calculate_ema(df['close'], self.params['ma_periods']['long'])
# 计算RSI
df['RSI'] = self.calculate_rsi(df['close'], self.params['rsi_period'])
# 计算MACD
df['MACD'], df['Signal'], df['MACD_hist'] = self.calculate_macd(df['close'])
# 计算布林带
df['BB_upper'], df['BB_middle'], df['BB_lower'] = self.calculate_bollinger_bands(
df['close'],
self.params['bollinger_period'],
self.params['bollinger_std']
)
# 成交量分析
df['Volume_MA'] = df['volume'].rolling(window=self.params['volume_ma_period']).mean()
df['Volume_Ratio'] = df['volume'] / df['Volume_MA']
# 计算ATR和波动率
df['ATR'] = self.calculate_atr(df, self.params['atr_period'])
df['Volatility'] = df['ATR'] / df['close'] * 100
# 动量指标
df['ROC'] = df['close'].pct_change(periods=10) * 100
return df
except Exception as e:
print(f"计算技术指标时出错: {str(e)}")
raise
def calculate_score(self, df):
"""计算股票评分"""
try:
score = 0
latest = df.iloc[-1]
# 趋势得分 (30分)
if latest['MA5'] > latest['MA20']:
score += 15
if latest['MA20'] > latest['MA60']:
score += 15
# RSI得分 (20分)
if 30 <= latest['RSI'] <= 70:
score += 20
elif latest['RSI'] < 30: # 超卖
score += 15
# MACD得分 (20分)
if latest['MACD'] > latest['Signal']:
score += 20
# 成交量得分 (30分)
if latest['Volume_Ratio'] > 1.5:
score += 30
elif latest['Volume_Ratio'] > 1:
score += 15
return score
except Exception as e:
print(f"计算评分时出错: {str(e)}")
raise
def get_ai_analysis(self, df, stock_code):
"""使用 OpenAI 进行 AI 分析"""
try:
recent_data = df.tail(14).to_dict('records')
technical_summary = {
'trend': 'upward' if df.iloc[-1]['MA5'] > df.iloc[-1]['MA20'] else 'downward',
'volatility': f"{df.iloc[-1]['Volatility']:.2f}%",
'volume_trend': 'increasing' if df.iloc[-1]['Volume_Ratio'] > 1 else 'decreasing',
'rsi_level': df.iloc[-1]['RSI']
}
prompt = f"""
分析股票 {stock_code}
技术指标概要:
{technical_summary}
近14日交易数据
{recent_data}
请提供:
1. 趋势分析(包含支撑位和压力位)
2. 成交量分析及其含义
3. 风险评估(包含波动率分析)
4. 短期和中期目标价位
5. 关键技术位分析
6. 具体交易建议(包含止损位)
请基于技术指标和市场动态进行分析,给出具体数据支持。
"""
# OpenAI API 调用
api_urls = [
f"{self.API_URL}/chat/completions",
f"{self.API_URL}/v1/chat/completions"
]
last_error = None
for api_url in api_urls:
try:
response = requests.post(
api_url,
headers={
"Authorization": f"Bearer {self.API_KEY}",
"Content-Type": "application/json"
},
json={
"model": self.API_MODEL,
"messages": [{"role": "user", "content": prompt}]
},
timeout=30
)
if response.status_code == 200:
return response.json()['choices'][0]['message']['content']
else:
last_error = f"API 错误: {response.status_code} - {response.text}"
continue
except Exception as e:
last_error = str(e)
continue
print(f"AI 分析暂时无法使用: {last_error}")
return f"AI 分析暂时无法使用: {last_error}"
except Exception as e:
print(f"AI 分析发生错误: {str(e)}")
return f"AI 分析过程中发生错误: {str(e)}"
def get_recommendation(self, score):
"""根据得分给出建议"""
if score >= 80:
return '强烈推荐买入'
elif score >= 60:
return '建议买入'
elif score >= 40:
return '观望'
elif score >= 20:
return '建议卖出'
else:
return '强烈建议卖出'
def analyze_stock(self, stock_code, market_type='A'):
"""分析单个股票"""
try:
# 获取股票数据
df = self.get_stock_data(stock_code, market_type)
# 计算技术指标
df = self.calculate_indicators(df)
# 评分系统
score = self.calculate_score(df)
# 获取最新数据
latest = df.iloc[-1]
prev = df.iloc[-2]
# 生成报告(保持原有格式)
report = {
'stock_code': stock_code,
'analysis_date': datetime.now().strftime('%Y-%m-%d'),
'score': score,
'price': latest['close'],
'price_change': (latest['close'] - prev['close']) / prev['close'] * 100,
'ma_trend': 'UP' if latest['MA5'] > latest['MA20'] else 'DOWN',
'rsi': latest['RSI'],
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
'recommendation': self.get_recommendation(score),
'ai_analysis': self.get_ai_analysis(df, stock_code)
}
return report
except Exception as e:
print(f"分析股票时出错: {str(e)}")
raise
def scan_market(self, stock_list, min_score=60, market_type='A'):
"""扫描市场,寻找符合条件的股票"""
recommendations = []
for stock_code in stock_list:
try:
report = self.analyze_stock(stock_code, market_type)
if report['score'] >= min_score:
recommendations.append(report)
except Exception as e:
print(f"分析股票 {stock_code} 时出错: {str(e)}")
continue
# 按得分排序
recommendations.sort(key=lambda x: x['score'], reverse=True)
return recommendations

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,8 @@ from stock_analyzer import StockAnalyzer
from us_stock_service import USStockService
import threading
import os
import traceback
import requests
app = Flask(__name__)
analyzer = StockAnalyzer()
@@ -11,7 +13,14 @@ us_stock_service = USStockService()
@app.route('/')
def index():
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
return render_template('index.html', announcement=announcement)
# 获取默认API配置信息
default_api_url = os.getenv('API_URL', '')
default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
# 不传递API_KEY到前端出于安全考虑
return render_template('index.html',
announcement=announcement,
default_api_url=default_api_url,
default_api_model=default_api_model)
@app.route('/analyze', methods=['POST'])
def analyze():
@@ -20,13 +29,26 @@ def analyze():
stock_codes = data.get('stock_codes', [])
market_type = data.get('market_type', 'A')
# 获取自定义API配置
custom_api_url = data.get('api_url')
custom_api_key = data.get('api_key')
custom_api_model = data.get('api_model')
# 创建新的分析器实例,使用自定义配置
custom_analyzer = StockAnalyzer(
custom_api_url=custom_api_url,
custom_api_key=custom_api_key,
custom_api_model=custom_api_model
)
if not stock_codes:
return jsonify({'error': '请输入代码'}), 400
results = []
for stock_code in stock_codes:
try:
result = analyzer.analyze_stock(stock_code.strip(), market_type)
# 使用自定义配置的分析器
result = custom_analyzer.analyze_stock(stock_code.strip(), market_type)
results.append(result)
except Exception as e:
print(f"分析股票 {stock_code} 失败: {str(e)}")
@@ -55,8 +77,59 @@ def search_us_stocks():
print(f"搜索美股代码时出错: {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/test_api_connection', methods=['POST'])
def test_api_connection():
"""测试API连接"""
try:
data = request.json
api_url = data.get('api_url')
api_key = data.get('api_key')
api_model = data.get('api_model')
if not api_url:
return jsonify({'error': '请提供API URL'}), 400
if not api_key:
return jsonify({'error': '请提供API Key'}), 400
# 构建API URL
test_url = api_url
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
if api_url.endswith('/v1'):
test_url = f"{api_url}/chat/completions"
elif api_url.endswith('/'):
test_url = f"{api_url}chat/completions"
else:
test_url = f"{api_url}/v1/chat/completions"
# 发送测试请求
response = requests.post(
test_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": api_model or "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
],
"max_tokens": 20
},
timeout=10
)
# 检查响应
if response.status_code == 200:
return jsonify({'success': True, 'message': '连接成功'})
else:
error_message = response.json().get('error', {}).get('message', '未知错误')
return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400
except requests.exceptions.RequestException as e:
return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
except Exception as e:
return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8888, debug=True)
app.run(host='0.0.0.0', port=8888, debug=True)