refactor: flask后端改为异步fastapi

This commit is contained in:
Cassianvale
2025-03-06 15:40:26 +08:00
parent 8781eebdfa
commit bcf64f0041
4 changed files with 202 additions and 1388 deletions

View File

@@ -10,11 +10,14 @@ scipy==1.15.1
akshare==1.16.22
tqdm==4.67.1
# Web框架与异步处理
fastapi==0.115.11
uvicorn[standard]==0.34.0
pydantic==2.10.6
httpx==0.28.1
# 网络和API请求
requests==2.32.3
# 环境配置
python-dotenv==1.0.1
flask==3.1.0
# 日志和系统工具
loguru==0.7.2

File diff suppressed because it is too large Load Diff

43
utils/api_control.py Normal file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Any
from fastapi import Response
from pydantic import BaseModel
class ResponseModel(BaseModel):
"""
统一返回模型
"""
code: int = 200
msg: str = "Success"
data: Optional[Any] = None
class ApiResponse:
@staticmethod
def __response(code: int, msg: str, data: Optional[Any] = None) -> ResponseModel:
return ResponseModel(code=code, msg=msg, data=data)
@classmethod
def success(cls, *, code: int = 200, msg: str = 'Success', data: Optional[Any] = None) -> Response:
response_model = cls.__response(code=code, msg=msg, data=data)
return cls(content=response_model.model_dump())
@classmethod
def fail(cls, *, code: int = 400, msg: str = 'Bad Request', data: Optional[Any] = None) -> Response:
response_model = cls.__response(code=code, msg=msg, data=data)
return cls(content=response_model.model_dump())
response_api = ApiResponse()
""" 示例
@app.get("/example-success")
async def example_success():
return response_api.success(data={"key": "value"})
@app.get("/example-fail")
async def example_fail():
return response_api.fail(msg="Something went wrong", data={"error": "details"})
"""

View File

@@ -1,51 +1,80 @@
from flask import Flask, render_template, request, jsonify, Response, stream_with_context, send_from_directory
from fastapi import FastAPI, Request, Response, Depends, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Generator
from stock_analyzer import StockAnalyzer
from us_stock_service import USStockService
from fund_service import FundService # 新增导入
from fund_service import FundService
import asyncio
import threading
import os
import traceback
import requests
import httpx
from logger import get_logger
from utils.api_utils import APIUtils
# 加载环境变量
from dotenv import load_dotenv
import uvicorn
load_dotenv()
# 获取日志器
logger = get_logger()
app = Flask(__name__,
static_folder='frontend/dist',
static_url_path='/')
app = FastAPI(
title="Stock Scanner API",
description="股票分析API",
version="1.0.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 开发环境允许所有来源,生产环境应该限制
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 设置静态文件
frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
if os.path.exists(frontend_dist):
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
analyzer = StockAnalyzer()
us_stock_service = USStockService()
fund_service = FundService() # 新增服务实例
fund_service = FundService()
@app.route('/')
def index():
# 定义请求和响应模型
class AnalyzeRequest(BaseModel):
stock_codes: List[str]
market_type: str = "A"
api_url: Optional[str] = None
api_key: Optional[str] = None
api_model: Optional[str] = None
api_timeout: Optional[str] = None
class TestAPIRequest(BaseModel):
api_url: str
api_key: str
api_model: Optional[str] = None
api_timeout: Optional[int] = 10
@app.get("/")
async def index(request: Request):
# 检查是否使用前端构建版本
frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
if os.path.exists(frontend_dist):
return send_from_directory(frontend_dist, 'index.html')
index_file = os.path.join(frontend_dist, 'index.html')
return FileResponse(index_file)
else:
# 传统模板渲染,用于兼容旧版本
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
# 获取默认API配置信息
default_api_url = os.getenv('API_URL', '')
default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
default_api_timeout = os.getenv('API_TIMEOUT', '60')
# 不传递API_KEY到前端出于安全考虑
return render_template('index.html',
announcement=announcement,
default_api_url=default_api_url,
default_api_model=default_api_model,
default_api_timeout=default_api_timeout)
# 不再使用模板渲染,而是重定向到API文档页面
logger.warning("前端构建目录不存在重定向到API文档页面")
return RedirectResponse(url="/docs")
@app.route('/config')
def get_config():
@app.get("/config")
async def get_config():
"""返回系统配置信息"""
config = {
'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '',
@@ -53,23 +82,22 @@ def get_config():
'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'),
'default_api_timeout': os.getenv('API_TIMEOUT', '60')
}
return jsonify(config)
return config
@app.route('/analyze', methods=['POST'])
def analyze():
@app.post("/analyze")
async def analyze(request: AnalyzeRequest):
try:
logger.info("开始处理分析请求")
data = request.json
stock_codes = data.get('stock_codes', [])
market_type = data.get('market_type', 'A')
stock_codes = request.stock_codes
market_type = request.market_type
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
# 获取自定义API配置
custom_api_url = data.get('api_url')
custom_api_key = data.get('api_key')
custom_api_model = data.get('api_model')
custom_api_timeout = data.get('api_timeout')
custom_api_url = request.api_url
custom_api_key = request.api_key
custom_api_model = request.api_model
custom_api_timeout = request.api_timeout
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}")
@@ -83,10 +111,10 @@ def analyze():
if not stock_codes:
logger.warning("未提供股票代码")
return jsonify({'error': '请输入代码'}), 400
raise HTTPException(status_code=400, detail="请输入代码")
# 使用流式响应
def generate():
# 定义流式生成器
async def generate_stream():
if len(stock_codes) == 1:
# 单个股票分析流式处理
stock_code = stock_codes[0].strip()
@@ -97,7 +125,16 @@ def analyze():
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
chunk_count = 0
for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
# 使用线程池执行同步分析
def run_analysis():
return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True))
# 在线程中执行同步操作
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, run_analysis)
for chunk in chunks:
chunk_count += 1
yield chunk + '\n'
logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块")
@@ -110,114 +147,132 @@ def analyze():
logger.debug(f"开始处理批量股票的流式响应")
chunk_count = 0
for chunk in custom_analyzer.scan_stocks(
[code.strip() for code in stock_codes],
min_score=0,
market_type=market_type,
stream=True
):
# 使用线程池执行同步分析
def run_batch_analysis():
return list(custom_analyzer.scan_stocks(
[code.strip() for code in stock_codes],
min_score=0,
market_type=market_type,
stream=True
))
# 在线程中执行同步操作
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, run_batch_analysis)
for chunk in chunks:
chunk_count += 1
yield chunk + '\n'
logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块")
logger.info("成功创建流式响应生成器")
return Response(stream_with_context(generate()), mimetype='application/json')
return StreamingResponse(generate_stream(), media_type='application/json')
except Exception as e:
error_msg = f"分析时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
return jsonify({'error': error_msg}), 500
raise HTTPException(status_code=500, detail=error_msg)
@app.route('/search_us_stocks', methods=['GET'])
def search_us_stocks():
@app.get("/search_us_stocks")
async def search_us_stocks(keyword: str = ""):
try:
keyword = request.args.get('keyword', '')
if not keyword:
return jsonify({'error': '请输入搜索关键词'}), 400
results = us_stock_service.search_us_stocks(keyword)
return jsonify({'results': results})
raise HTTPException(status_code=400, detail="请输入搜索关键词")
# 在异步上下文中运行同步的搜索函数
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword)
return {"results": results}
except Exception as e:
print(f"搜索美股代码时出错: {str(e)}")
return jsonify({'error': str(e)}), 500
logger.error(f"搜索美股代码时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 添加基金搜索路由
@app.route('/search_funds', methods=['GET'])
def search_funds():
@app.get("/search_funds")
async def search_funds(keyword: str = "", market_type: str = ""):
try:
keyword = request.args.get('keyword', '')
market_type = request.args.get('market_type', '')
if not keyword:
return jsonify({'error': '请输入搜索关键词'}), 400
results = fund_service.search_funds(keyword, market_type)
return jsonify({'results': results})
raise HTTPException(status_code=400, detail="请输入搜索关键词")
# 在异步上下文中运行同步的搜索函数
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type))
return {"results": results}
except Exception as e:
logger.error(f"搜索基金代码时出错: {str(e)}")
return jsonify({'error': str(e)}), 500
raise HTTPException(status_code=500, detail=str(e))
@app.route('/test_api_connection', methods=['POST'])
def test_api_connection():
@app.post("/test_api_connection")
async def test_api_connection(request: TestAPIRequest):
"""测试API连接"""
try:
logger.info("开始测试API连接")
data = request.json
api_url = data.get('api_url')
api_key = data.get('api_key')
api_model = data.get('api_model')
api_timeout = data.get('api_timeout', 10) # 默认测试连接超时为10秒
api_url = request.api_url
api_key = request.api_key
api_model = request.api_model
api_timeout = request.api_timeout
logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}, Timeout={api_timeout}")
if not api_url:
logger.warning("未提供API URL")
return jsonify({'error': '请提供API URL'}), 400
raise HTTPException(status_code=400, detail="请提供API URL")
if not api_key:
logger.warning("未提供API Key")
return jsonify({'error': '请提供API Key'}), 400
raise HTTPException(status_code=400, detail="请提供API Key")
# 构建API URL
test_url = APIUtils.format_api_url(api_url)
logger.debug(f"完整API测试URL: {test_url}")
# 发送测试请求
response = requests.post(
test_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": api_model or "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
],
"max_tokens": 20
},
timeout=int(api_timeout)
)
# 使用异步HTTP客户端发送测试请求
async with httpx.AsyncClient(timeout=float(api_timeout)) as client:
response = await client.post(
test_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": api_model or "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
],
"max_tokens": 20
}
)
# 检查响应
if response.status_code == 200:
logger.info(f"API 连接测试成功: {response.status_code}")
return jsonify({'success': True, 'message': 'API 连接测试成功'})
return {"success": True, "message": "API 连接测试成功"}
else:
error_message = response.json().get('error', {}).get('message', '未知错误')
error_data = response.json()
error_message = error_data.get('error', {}).get('message', '未知错误')
logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
return jsonify({'success': False, 'message': f'API 连接测试失败: {error_message}', 'status_code': response.status_code}), 400
return JSONResponse(
status_code=400,
content={"success": False, "message": f"API 连接测试失败: {error_message}", "status_code": response.status_code}
)
except requests.exceptions.RequestException as e:
except httpx.RequestError as e:
logger.error(f"API 连接请求错误: {str(e)}")
return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
return JSONResponse(
status_code=400,
content={"success": False, "message": f"请求错误: {str(e)}"}
)
except Exception as e:
logger.error(f"测试 API 连接时出错: {str(e)}")
logger.exception(e)
return jsonify({'success': False, 'message': f'API 测试连接时出错: {str(e)}'}), 500
return JSONResponse(
status_code=500,
content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
)
if __name__ == '__main__':
logger.info("股票分析系统启动")
app.run(host='127.0.0.1', port=8888, debug=True)
uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)