feat: 支持自定义API
This commit is contained in:
@@ -7,14 +7,15 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
class StockAnalyzer:
|
class StockAnalyzer:
|
||||||
def __init__(self, initial_cash=1000000):
|
def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None):
|
||||||
|
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# 设置 Gemini API
|
# 设置 API 配置,优先使用自定义配置,否则使用环境变量
|
||||||
self.API_URL = os.getenv('API_URL')
|
self.API_URL = custom_api_url or os.getenv('API_URL')
|
||||||
self.API_KEY = os.getenv('API_KEY')
|
self.API_KEY = custom_api_key or os.getenv('API_KEY')
|
||||||
|
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
self.params = {
|
self.params = {
|
||||||
@@ -252,7 +253,7 @@ class StockAnalyzer:
|
|||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
},
|
},
|
||||||
json={
|
json={
|
||||||
"model": os.getenv('API_MODEL', 'gpt-3.5-turbo'),
|
"model": self.API_MODEL,
|
||||||
"messages": [{"role": "user", "content": prompt}]
|
"messages": [{"role": "user", "content": prompt}]
|
||||||
},
|
},
|
||||||
timeout=30
|
timeout=30
|
||||||
|
|||||||
@@ -30,11 +30,58 @@
|
|||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
<div class="max-w-4xl mx-auto"> <!-- 将 max-w-2xl 改为 max-w-4xl -->
|
<div class="max-w-4xl mx-auto">
|
||||||
<!-- 批量分析 -->
|
<!-- 批量分析 -->
|
||||||
<div class="bg-white p-6 rounded-lg shadow-md">
|
<div class="bg-white p-6 rounded-lg shadow-md">
|
||||||
<h2 class="text-xl font-semibold mb-4">股票批量分析</h2>
|
<h2 class="text-xl font-semibold mb-4">股票批量分析</h2>
|
||||||
|
|
||||||
|
<!-- API配置部分 -->
|
||||||
|
<div class="mb-6 border-b pb-6">
|
||||||
|
<div class="flex items-center justify-between mb-4">
|
||||||
|
<h3 class="text-lg font-medium text-gray-700">API配置</h3>
|
||||||
|
<button id="toggleApiConfig" class="text-blue-600 hover:text-blue-800 text-sm flex items-center">
|
||||||
|
<span id="toggleApiConfigText">显示配置</span>
|
||||||
|
<svg id="toggleApiConfigIcon" class="w-4 h-4 ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||||
|
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7"></path>
|
||||||
|
</svg>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="apiConfigPanel" class="hidden space-y-4">
|
||||||
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||||
|
<div>
|
||||||
|
<label for="apiUrl" class="block text-sm font-medium text-gray-700 mb-1">API URL</label>
|
||||||
|
<input type="text" id="apiUrl"
|
||||||
|
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
||||||
|
placeholder="例如: https://api.openai.com"
|
||||||
|
value="{{ default_api_url }}">
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="apiModel" class="block text-sm font-medium text-gray-700 mb-1">API 模型</label>
|
||||||
|
<input type="text" id="apiModel"
|
||||||
|
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
||||||
|
placeholder="例如: gpt-3.5-turbo"
|
||||||
|
value="{{ default_api_model }}">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label for="apiKey" class="block text-sm font-medium text-gray-700 mb-1">API Key</label>
|
||||||
|
<input type="password" id="apiKey"
|
||||||
|
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
|
||||||
|
placeholder="输入您的API Key">
|
||||||
|
<p class="mt-1 text-sm text-gray-500">如不填写,将使用系统默认配置</p>
|
||||||
|
</div>
|
||||||
|
<div class="flex justify-end">
|
||||||
|
<button id="resetApiConfig" class="text-gray-600 hover:text-gray-800 text-sm mr-3">
|
||||||
|
重置为默认
|
||||||
|
</button>
|
||||||
|
<button id="testApiConfig" class="bg-blue-100 text-blue-700 px-3 py-1 rounded hover:bg-blue-200 text-sm">
|
||||||
|
测试连接
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 添加市场类型选择 -->
|
<!-- 添加市场类型选择 -->
|
||||||
<div class="mb-4">
|
<div class="mb-4">
|
||||||
<label for="marketType" class="block text-sm font-medium text-gray-700 mb-2">
|
<label for="marketType" class="block text-sm font-medium text-gray-700 mb-2">
|
||||||
@@ -290,6 +337,11 @@
|
|||||||
const analyzeBtn = document.getElementById('analyzeBtn');
|
const analyzeBtn = document.getElementById('analyzeBtn');
|
||||||
const loadingSpinner = document.getElementById('loadingSpinner');
|
const loadingSpinner = document.getElementById('loadingSpinner');
|
||||||
|
|
||||||
|
// 获取API配置
|
||||||
|
const apiUrl = document.getElementById('apiUrl').value.trim();
|
||||||
|
const apiKey = document.getElementById('apiKey').value.trim();
|
||||||
|
const apiModel = document.getElementById('apiModel').value.trim();
|
||||||
|
|
||||||
if (!stockInput) {
|
if (!stockInput) {
|
||||||
alert('请输入代码');
|
alert('请输入代码');
|
||||||
return;
|
return;
|
||||||
@@ -312,7 +364,10 @@
|
|||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
stock_codes: stockCodes,
|
stock_codes: stockCodes,
|
||||||
market_type: marketType // 添加市场类型参数
|
market_type: marketType, // 添加市场类型参数
|
||||||
|
api_url: apiUrl,
|
||||||
|
api_key: apiKey,
|
||||||
|
api_model: apiModel
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -447,5 +502,78 @@
|
|||||||
</script>
|
</script>
|
||||||
<!-- 添加 marked.js 用于解析 Markdown -->
|
<!-- 添加 marked.js 用于解析 Markdown -->
|
||||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||||
|
|
||||||
|
<!-- API配置相关脚本 -->
|
||||||
|
<script>
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
// API配置面板切换
|
||||||
|
const toggleBtn = document.getElementById('toggleApiConfig');
|
||||||
|
const configPanel = document.getElementById('apiConfigPanel');
|
||||||
|
const toggleText = document.getElementById('toggleApiConfigText');
|
||||||
|
const toggleIcon = document.getElementById('toggleApiConfigIcon');
|
||||||
|
|
||||||
|
toggleBtn.addEventListener('click', function() {
|
||||||
|
const isHidden = configPanel.classList.contains('hidden');
|
||||||
|
configPanel.classList.toggle('hidden', !isHidden);
|
||||||
|
toggleText.textContent = isHidden ? '隐藏配置' : '显示配置';
|
||||||
|
toggleIcon.style.transform = isHidden ? 'rotate(180deg)' : '';
|
||||||
|
});
|
||||||
|
|
||||||
|
// 重置API配置
|
||||||
|
document.getElementById('resetApiConfig').addEventListener('click', function() {
|
||||||
|
document.getElementById('apiUrl').value = '{{ default_api_url }}';
|
||||||
|
document.getElementById('apiModel').value = '{{ default_api_model }}';
|
||||||
|
document.getElementById('apiKey').value = '';
|
||||||
|
});
|
||||||
|
|
||||||
|
// 测试API连接
|
||||||
|
document.getElementById('testApiConfig').addEventListener('click', async function() {
|
||||||
|
const apiUrl = document.getElementById('apiUrl').value.trim();
|
||||||
|
const apiKey = document.getElementById('apiKey').value.trim();
|
||||||
|
const apiModel = document.getElementById('apiModel').value.trim();
|
||||||
|
|
||||||
|
if (!apiUrl) {
|
||||||
|
alert('请输入API URL');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!apiKey) {
|
||||||
|
alert('请输入API Key');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.textContent = '测试中...';
|
||||||
|
this.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 使用后端代理进行API测试
|
||||||
|
const response = await fetch('/test_api_connection', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
api_url: apiUrl,
|
||||||
|
api_key: apiKey,
|
||||||
|
api_model: apiModel
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (response.ok && data.success) {
|
||||||
|
alert('API连接成功!');
|
||||||
|
} else {
|
||||||
|
alert(`API连接失败: ${data.message || '未知错误'}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
alert(`API连接测试失败: ${error.message}`);
|
||||||
|
} finally {
|
||||||
|
this.textContent = '测试连接';
|
||||||
|
this.disabled = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
@@ -3,6 +3,8 @@ from stock_analyzer import StockAnalyzer
|
|||||||
from us_stock_service import USStockService
|
from us_stock_service import USStockService
|
||||||
import threading
|
import threading
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
|
import requests
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
analyzer = StockAnalyzer()
|
analyzer = StockAnalyzer()
|
||||||
@@ -11,7 +13,14 @@ us_stock_service = USStockService()
|
|||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
|
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
|
||||||
return render_template('index.html', announcement=announcement)
|
# 获取默认API配置信息
|
||||||
|
default_api_url = os.getenv('API_URL', '')
|
||||||
|
default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
||||||
|
# 不传递API_KEY到前端,出于安全考虑
|
||||||
|
return render_template('index.html',
|
||||||
|
announcement=announcement,
|
||||||
|
default_api_url=default_api_url,
|
||||||
|
default_api_model=default_api_model)
|
||||||
|
|
||||||
@app.route('/analyze', methods=['POST'])
|
@app.route('/analyze', methods=['POST'])
|
||||||
def analyze():
|
def analyze():
|
||||||
@@ -20,13 +29,26 @@ def analyze():
|
|||||||
stock_codes = data.get('stock_codes', [])
|
stock_codes = data.get('stock_codes', [])
|
||||||
market_type = data.get('market_type', 'A')
|
market_type = data.get('market_type', 'A')
|
||||||
|
|
||||||
|
# 获取自定义API配置
|
||||||
|
custom_api_url = data.get('api_url')
|
||||||
|
custom_api_key = data.get('api_key')
|
||||||
|
custom_api_model = data.get('api_model')
|
||||||
|
|
||||||
|
# 创建新的分析器实例,使用自定义配置
|
||||||
|
custom_analyzer = StockAnalyzer(
|
||||||
|
custom_api_url=custom_api_url,
|
||||||
|
custom_api_key=custom_api_key,
|
||||||
|
custom_api_model=custom_api_model
|
||||||
|
)
|
||||||
|
|
||||||
if not stock_codes:
|
if not stock_codes:
|
||||||
return jsonify({'error': '请输入代码'}), 400
|
return jsonify({'error': '请输入代码'}), 400
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for stock_code in stock_codes:
|
for stock_code in stock_codes:
|
||||||
try:
|
try:
|
||||||
result = analyzer.analyze_stock(stock_code.strip(), market_type)
|
# 使用自定义配置的分析器
|
||||||
|
result = custom_analyzer.analyze_stock(stock_code.strip(), market_type)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"分析股票 {stock_code} 失败: {str(e)}")
|
print(f"分析股票 {stock_code} 失败: {str(e)}")
|
||||||
@@ -55,8 +77,59 @@ def search_us_stocks():
|
|||||||
print(f"搜索美股代码时出错: {str(e)}")
|
print(f"搜索美股代码时出错: {str(e)}")
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/test_api_connection', methods=['POST'])
|
||||||
|
def test_api_connection():
|
||||||
|
"""测试API连接"""
|
||||||
|
try:
|
||||||
|
data = request.json
|
||||||
|
api_url = data.get('api_url')
|
||||||
|
api_key = data.get('api_key')
|
||||||
|
api_model = data.get('api_model')
|
||||||
|
|
||||||
|
if not api_url:
|
||||||
|
return jsonify({'error': '请提供API URL'}), 400
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return jsonify({'error': '请提供API Key'}), 400
|
||||||
|
|
||||||
|
# 构建API URL
|
||||||
|
test_url = api_url
|
||||||
|
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
|
||||||
|
if api_url.endswith('/v1'):
|
||||||
|
test_url = f"{api_url}/chat/completions"
|
||||||
|
elif api_url.endswith('/'):
|
||||||
|
test_url = f"{api_url}chat/completions"
|
||||||
|
else:
|
||||||
|
test_url = f"{api_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
# 发送测试请求
|
||||||
|
response = requests.post(
|
||||||
|
test_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": api_model or "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
|
||||||
|
],
|
||||||
|
"max_tokens": 20
|
||||||
|
},
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
return jsonify({'success': True, 'message': '连接成功'})
|
||||||
|
else:
|
||||||
|
error_message = response.json().get('error', {}).get('message', '未知错误')
|
||||||
|
return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(host='0.0.0.0', port=8888, debug=True)
|
app.run(host='0.0.0.0', port=8888, debug=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user