diff --git a/requirements.txt b/requirements.txt
index c77a82b..f758152 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,11 +10,14 @@ scipy==1.15.1
akshare==1.16.22
tqdm==4.67.1
+# Web框架与异步处理
+fastapi==0.115.11
+uvicorn[standard]==0.34.0
+pydantic==2.10.6
+httpx==0.28.1
-# 网络和API请求
-requests==2.32.3
+# 环境配置
python-dotenv==1.0.1
-flask==3.1.0
# 日志和系统工具
loguru==0.7.2
diff --git a/templates/index.html b/templates/index.html
deleted file mode 100644
index 0f159c7..0000000
--- a/templates/index.html
+++ /dev/null
@@ -1,1287 +0,0 @@
-
-
-
-
-
股票分析系统
-
-
-
-
-
-
-
-
-
-
股票批量分析
-
-
-
-
-
API配置
-
-
-
-
-
-
-
-
-
-
如不填写,将使用系统默认配置
-
-
-
-
-
请求超时时间,默认60秒
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
分析结果
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/utils/api_control.py b/utils/api_control.py
new file mode 100644
index 0000000..ba10342
--- /dev/null
+++ b/utils/api_control.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from typing import Optional, Any
+from fastapi import Response
+from pydantic import BaseModel
+
+
+class ResponseModel(BaseModel):
+ """
+ 统一返回模型
+ """
+ code: int = 200
+ msg: str = "Success"
+ data: Optional[Any] = None
+
+
+class ApiResponse:
+ @staticmethod
+ def __response(code: int, msg: str, data: Optional[Any] = None) -> ResponseModel:
+ return ResponseModel(code=code, msg=msg, data=data)
+
+ @classmethod
+ def success(cls, *, code: int = 200, msg: str = 'Success', data: Optional[Any] = None) -> Response:
+ response_model = cls.__response(code=code, msg=msg, data=data)
+ return cls(content=response_model.model_dump())
+
+ @classmethod
+ def fail(cls, *, code: int = 400, msg: str = 'Bad Request', data: Optional[Any] = None) -> Response:
+ response_model = cls.__response(code=code, msg=msg, data=data)
+ return cls(content=response_model.model_dump())
+
+response_api = ApiResponse()
+
+""" 示例
+@app.get("/example-success")
+async def example_success():
+ return response_api.success(data={"key": "value"})
+
+@app.get("/example-fail")
+async def example_fail():
+ return response_api.fail(msg="Something went wrong", data={"error": "details"})
+"""
diff --git a/web_server.py b/web_server.py
index a14ef1d..c761334 100644
--- a/web_server.py
+++ b/web_server.py
@@ -1,51 +1,80 @@
-from flask import Flask, render_template, request, jsonify, Response, stream_with_context, send_from_directory
+from fastapi import FastAPI, Request, Response, Depends, HTTPException, BackgroundTasks
+from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
+from fastapi.staticfiles import StaticFiles
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel, Field
+from typing import List, Optional, Generator
from stock_analyzer import StockAnalyzer
from us_stock_service import USStockService
-from fund_service import FundService # 新增导入
+from fund_service import FundService
+import asyncio
import threading
import os
import traceback
-import requests
+import httpx
from logger import get_logger
from utils.api_utils import APIUtils
# 加载环境变量
from dotenv import load_dotenv
+import uvicorn
load_dotenv()
# 获取日志器
logger = get_logger()
-app = Flask(__name__,
- static_folder='frontend/dist',
- static_url_path='/')
+app = FastAPI(
+ title="Stock Scanner API",
+ description="股票分析API",
+ version="1.0.0"
+)
+
+# 添加CORS中间件
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # 开发环境允许所有来源,生产环境应该限制
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+# 设置静态文件
+frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
+if os.path.exists(frontend_dist):
+ app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
analyzer = StockAnalyzer()
us_stock_service = USStockService()
-fund_service = FundService() # 新增服务实例
+fund_service = FundService()
-@app.route('/')
-def index():
+# 定义请求和响应模型
+class AnalyzeRequest(BaseModel):
+ stock_codes: List[str]
+ market_type: str = "A"
+ api_url: Optional[str] = None
+ api_key: Optional[str] = None
+ api_model: Optional[str] = None
+ api_timeout: Optional[str] = None
+
+class TestAPIRequest(BaseModel):
+ api_url: str
+ api_key: str
+ api_model: Optional[str] = None
+ api_timeout: Optional[int] = 10
+
+@app.get("/")
+async def index(request: Request):
# 检查是否使用前端构建版本
- frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
if os.path.exists(frontend_dist):
- return send_from_directory(frontend_dist, 'index.html')
+ index_file = os.path.join(frontend_dist, 'index.html')
+ return FileResponse(index_file)
else:
- # 传统模板渲染,用于兼容旧版本
- announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
- # 获取默认API配置信息
- default_api_url = os.getenv('API_URL', '')
- default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
- default_api_timeout = os.getenv('API_TIMEOUT', '60')
- # 不传递API_KEY到前端,出于安全考虑
- return render_template('index.html',
- announcement=announcement,
- default_api_url=default_api_url,
- default_api_model=default_api_model,
- default_api_timeout=default_api_timeout)
+ # 不再使用模板渲染,而是重定向到API文档页面
+ logger.warning("前端构建目录不存在,重定向到API文档页面")
+ return RedirectResponse(url="/docs")
-@app.route('/config')
-def get_config():
+@app.get("/config")
+async def get_config():
"""返回系统配置信息"""
config = {
'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '',
@@ -53,23 +82,22 @@ def get_config():
'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'),
'default_api_timeout': os.getenv('API_TIMEOUT', '60')
}
- return jsonify(config)
+ return config
-@app.route('/analyze', methods=['POST'])
-def analyze():
+@app.post("/analyze")
+async def analyze(request: AnalyzeRequest):
try:
logger.info("开始处理分析请求")
- data = request.json
- stock_codes = data.get('stock_codes', [])
- market_type = data.get('market_type', 'A')
+ stock_codes = request.stock_codes
+ market_type = request.market_type
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
# 获取自定义API配置
- custom_api_url = data.get('api_url')
- custom_api_key = data.get('api_key')
- custom_api_model = data.get('api_model')
- custom_api_timeout = data.get('api_timeout')
+ custom_api_url = request.api_url
+ custom_api_key = request.api_key
+ custom_api_model = request.api_model
+ custom_api_timeout = request.api_timeout
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}")
@@ -83,10 +111,10 @@ def analyze():
if not stock_codes:
logger.warning("未提供股票代码")
- return jsonify({'error': '请输入代码'}), 400
+ raise HTTPException(status_code=400, detail="请输入代码")
- # 使用流式响应
- def generate():
+ # 定义流式生成器
+ async def generate_stream():
if len(stock_codes) == 1:
# 单个股票分析流式处理
stock_code = stock_codes[0].strip()
@@ -97,7 +125,16 @@ def analyze():
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
chunk_count = 0
- for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
+
+ # 使用线程池执行同步分析
+ def run_analysis():
+ return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True))
+
+ # 在线程中执行同步操作
+ loop = asyncio.get_event_loop()
+ chunks = await loop.run_in_executor(None, run_analysis)
+
+ for chunk in chunks:
chunk_count += 1
yield chunk + '\n'
logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块")
@@ -110,114 +147,132 @@ def analyze():
logger.debug(f"开始处理批量股票的流式响应")
chunk_count = 0
- for chunk in custom_analyzer.scan_stocks(
- [code.strip() for code in stock_codes],
- min_score=0,
- market_type=market_type,
- stream=True
- ):
+
+ # 使用线程池执行同步分析
+ def run_batch_analysis():
+ return list(custom_analyzer.scan_stocks(
+ [code.strip() for code in stock_codes],
+ min_score=0,
+ market_type=market_type,
+ stream=True
+ ))
+
+ # 在线程中执行同步操作
+ loop = asyncio.get_event_loop()
+ chunks = await loop.run_in_executor(None, run_batch_analysis)
+
+ for chunk in chunks:
chunk_count += 1
yield chunk + '\n'
logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块")
logger.info("成功创建流式响应生成器")
- return Response(stream_with_context(generate()), mimetype='application/json')
+ return StreamingResponse(generate_stream(), media_type='application/json')
except Exception as e:
error_msg = f"分析时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
- return jsonify({'error': error_msg}), 500
+ raise HTTPException(status_code=500, detail=error_msg)
-@app.route('/search_us_stocks', methods=['GET'])
-def search_us_stocks():
+@app.get("/search_us_stocks")
+async def search_us_stocks(keyword: str = ""):
try:
- keyword = request.args.get('keyword', '')
if not keyword:
- return jsonify({'error': '请输入搜索关键词'}), 400
-
- results = us_stock_service.search_us_stocks(keyword)
- return jsonify({'results': results})
+ raise HTTPException(status_code=400, detail="请输入搜索关键词")
+
+ # 在异步上下文中运行同步的搜索函数
+ loop = asyncio.get_event_loop()
+ results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword)
+ return {"results": results}
except Exception as e:
- print(f"搜索美股代码时出错: {str(e)}")
- return jsonify({'error': str(e)}), 500
+ logger.error(f"搜索美股代码时出错: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
-# 添加基金搜索路由
-@app.route('/search_funds', methods=['GET'])
-def search_funds():
+@app.get("/search_funds")
+async def search_funds(keyword: str = "", market_type: str = ""):
try:
- keyword = request.args.get('keyword', '')
- market_type = request.args.get('market_type', '')
if not keyword:
- return jsonify({'error': '请输入搜索关键词'}), 400
-
- results = fund_service.search_funds(keyword, market_type)
- return jsonify({'results': results})
+ raise HTTPException(status_code=400, detail="请输入搜索关键词")
+
+ # 在异步上下文中运行同步的搜索函数
+ loop = asyncio.get_event_loop()
+ results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type))
+ return {"results": results}
except Exception as e:
logger.error(f"搜索基金代码时出错: {str(e)}")
- return jsonify({'error': str(e)}), 500
+ raise HTTPException(status_code=500, detail=str(e))
-@app.route('/test_api_connection', methods=['POST'])
-def test_api_connection():
+@app.post("/test_api_connection")
+async def test_api_connection(request: TestAPIRequest):
"""测试API连接"""
try:
logger.info("开始测试API连接")
- data = request.json
- api_url = data.get('api_url')
- api_key = data.get('api_key')
- api_model = data.get('api_model')
- api_timeout = data.get('api_timeout', 10) # 默认测试连接超时为10秒
+ api_url = request.api_url
+ api_key = request.api_key
+ api_model = request.api_model
+ api_timeout = request.api_timeout
logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}, Timeout={api_timeout}")
if not api_url:
logger.warning("未提供API URL")
- return jsonify({'error': '请提供API URL'}), 400
+ raise HTTPException(status_code=400, detail="请提供API URL")
if not api_key:
logger.warning("未提供API Key")
- return jsonify({'error': '请提供API Key'}), 400
+ raise HTTPException(status_code=400, detail="请提供API Key")
# 构建API URL
test_url = APIUtils.format_api_url(api_url)
logger.debug(f"完整API测试URL: {test_url}")
- # 发送测试请求
- response = requests.post(
- test_url,
- headers={
- "Authorization": f"Bearer {api_key}",
- "Content-Type": "application/json"
- },
- json={
- "model": api_model or "gpt-3.5-turbo",
- "messages": [
- {"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
- ],
- "max_tokens": 20
- },
- timeout=int(api_timeout)
- )
+ # 使用异步HTTP客户端发送测试请求
+ async with httpx.AsyncClient(timeout=float(api_timeout)) as client:
+ response = await client.post(
+ test_url,
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json"
+ },
+ json={
+ "model": api_model or "gpt-3.5-turbo",
+ "messages": [
+ {"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
+ ],
+ "max_tokens": 20
+ }
+ )
# 检查响应
if response.status_code == 200:
logger.info(f"API 连接测试成功: {response.status_code}")
- return jsonify({'success': True, 'message': 'API 连接测试成功'})
+ return {"success": True, "message": "API 连接测试成功"}
else:
- error_message = response.json().get('error', {}).get('message', '未知错误')
+ error_data = response.json()
+ error_message = error_data.get('error', {}).get('message', '未知错误')
logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
- return jsonify({'success': False, 'message': f'API 连接测试失败: {error_message}', 'status_code': response.status_code}), 400
+ return JSONResponse(
+ status_code=400,
+ content={"success": False, "message": f"API 连接测试失败: {error_message}", "status_code": response.status_code}
+ )
- except requests.exceptions.RequestException as e:
+ except httpx.RequestError as e:
logger.error(f"API 连接请求错误: {str(e)}")
- return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
+ return JSONResponse(
+ status_code=400,
+ content={"success": False, "message": f"请求错误: {str(e)}"}
+ )
except Exception as e:
logger.error(f"测试 API 连接时出错: {str(e)}")
logger.exception(e)
- return jsonify({'success': False, 'message': f'API 测试连接时出错: {str(e)}'}), 500
+ return JSONResponse(
+ status_code=500,
+ content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
+ )
if __name__ == '__main__':
logger.info("股票分析系统启动")
- app.run(host='127.0.0.1', port=8888, debug=True)
\ No newline at end of file
+ uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)
\ No newline at end of file