diff --git a/stock_analyzer.py b/stock_analyzer.py index 369faef..4b8185a 100644 --- a/stock_analyzer.py +++ b/stock_analyzer.py @@ -12,7 +12,7 @@ from logger import get_logger logger = get_logger() class StockAnalyzer: - def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None): + def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=None): # 加载环境变量 load_dotenv() @@ -21,8 +21,9 @@ class StockAnalyzer: self.API_URL = custom_api_url or os.getenv('API_URL') self.API_KEY = custom_api_key or os.getenv('API_KEY') self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo') + self.API_TIMEOUT = int(custom_api_timeout or os.getenv('API_TIMEOUT', 60)) - logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}") + logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}, API_TIMEOUT={self.API_TIMEOUT}") # 配置参数 self.params = { @@ -294,7 +295,7 @@ class StockAnalyzer: api_url, headers=headers, json=payload, - timeout=60, # 增加超时时间 + timeout=self.API_TIMEOUT, # 增加超时时间 stream=True ) @@ -328,7 +329,7 @@ class StockAnalyzer: api_url, headers=headers, json=payload, - timeout=60 + timeout=self.API_TIMEOUT ) logger.debug(f"API非流式响应状态码: {response.status_code}") diff --git a/templates/index.html b/templates/index.html index 387952d..979f6bc 100644 --- a/templates/index.html +++ b/templates/index.html @@ -64,12 +64,22 @@ value="{{ default_api_model }}"> -
- - -

如不填写,将使用系统默认配置

+
+
+ + +

如不填写,将使用系统默认配置

+
+
+ + +

请求超时时间,默认60秒

+
@@ -349,6 +359,7 @@ const apiUrl = document.getElementById('apiUrl').value.trim(); const apiKey = document.getElementById('apiKey').value.trim(); const apiModel = document.getElementById('apiModel').value.trim(); + const apiTimeout = document.getElementById('apiTimeout').value.trim(); if (!stockInput) { alert('请输入代码'); @@ -385,7 +396,8 @@ market_type: marketType, api_url: apiUrl, api_key: apiKey, - api_model: apiModel + api_model: apiModel, + api_timeout: apiTimeout }) }); @@ -718,6 +730,7 @@ const apiUrl = document.getElementById('apiUrl'); const apiKey = document.getElementById('apiKey'); const apiModel = document.getElementById('apiModel'); + const apiTimeout = document.getElementById('apiTimeout'); const saveApiConfig = document.getElementById('saveApiConfig'); const resetApiConfig = document.getElementById('resetApiConfig'); const testApiConfig = document.getElementById('testApiConfig'); @@ -743,6 +756,7 @@ apiUrl.value = '{{ default_api_url }}'; apiKey.value = ''; apiModel.value = '{{ default_api_model }}'; + apiTimeout.value = '{{ default_api_timeout }}'; saveApiConfig.checked = false; // 清除localStorage中的配置 @@ -756,6 +770,7 @@ const url = apiUrl.value.trim(); const key = apiKey.value.trim(); const model = apiModel.value.trim(); + const timeout = apiTimeout.value.trim(); if (!url) { alert('请输入API URL'); @@ -779,7 +794,8 @@ body: JSON.stringify({ api_url: url, api_key: key, - api_model: model + api_model: model, + api_timeout: timeout }) }); @@ -804,7 +820,7 @@ }); // 监听输入变化,自动保存配置 - [apiUrl, apiKey, apiModel].forEach(input => { + [apiUrl, apiKey, apiModel, apiTimeout].forEach(input => { input.addEventListener('change', function() { if (saveApiConfig.checked) { saveApiConfigToLocalStorage(); @@ -828,6 +844,7 @@ url: document.getElementById('apiUrl').value.trim(), model: document.getElementById('apiModel').value.trim(), key: document.getElementById('apiKey').value.trim(), + timeout: document.getElementById('apiTimeout').value.trim(), saveEnabled: true }; @@ -845,6 +862,7 @@ if (config.url) document.getElementById('apiUrl').value = config.url; if (config.model) document.getElementById('apiModel').value = config.model; if (config.key) document.getElementById('apiKey').value = config.key; + if (config.timeout) document.getElementById('apiTimeout').value = config.timeout; document.getElementById('saveApiConfig').checked = config.saveEnabled || false; } catch (error) { diff --git a/tests/test_stream.py b/tests/test_stream.py index 30cd508..7d01ada 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -85,7 +85,7 @@ def test_api_stream(): api_url, headers=headers, json=payload, - timeout=60, + timeout=int(os.getenv('API_TIMEOUT', 60)), stream=True ) diff --git a/web_server.py b/web_server.py index 36f0d06..b30aaf2 100644 --- a/web_server.py +++ b/web_server.py @@ -20,11 +20,13 @@ def index(): # 获取默认API配置信息 default_api_url = os.getenv('API_URL', '') default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo') + default_api_timeout = os.getenv('API_TIMEOUT', '60') # 不传递API_KEY到前端,出于安全考虑 return render_template('index.html', announcement=announcement, default_api_url=default_api_url, - default_api_model=default_api_model) + default_api_model=default_api_model, + default_api_timeout=default_api_timeout) @app.route('/analyze', methods=['POST']) def analyze(): @@ -40,14 +42,16 @@ def analyze(): custom_api_url = data.get('api_url') custom_api_key = data.get('api_key') custom_api_model = data.get('api_model') + custom_api_timeout = data.get('api_timeout') - logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}") + logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}") # 创建新的分析器实例,使用自定义配置 custom_analyzer = StockAnalyzer( custom_api_url=custom_api_url, custom_api_key=custom_api_key, - custom_api_model=custom_api_model + custom_api_model=custom_api_model, + custom_api_timeout=custom_api_timeout ) if not stock_codes: @@ -121,8 +125,9 @@ def test_api_connection(): api_url = data.get('api_url') api_key = data.get('api_key') api_model = data.get('api_model') + api_timeout = data.get('api_timeout', 10) # 默认测试连接超时为10秒 - logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}") + logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}, Timeout={api_timeout}") if not api_url: logger.warning("未提供API URL") @@ -158,7 +163,7 @@ def test_api_connection(): ], "max_tokens": 20 }, - timeout=10 + timeout=int(api_timeout) ) # 检查响应