refactor: flask后端改为异步fastapi
This commit is contained in:
@@ -10,11 +10,14 @@ scipy==1.15.1
|
|||||||
akshare==1.16.22
|
akshare==1.16.22
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
|
|
||||||
|
# Web框架与异步处理
|
||||||
|
fastapi==0.115.11
|
||||||
|
uvicorn[standard]==0.34.0
|
||||||
|
pydantic==2.10.6
|
||||||
|
httpx==0.28.1
|
||||||
|
|
||||||
# 网络和API请求
|
# 环境配置
|
||||||
requests==2.32.3
|
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
flask==3.1.0
|
|
||||||
|
|
||||||
# 日志和系统工具
|
# 日志和系统工具
|
||||||
loguru==0.7.2
|
loguru==0.7.2
|
||||||
|
|||||||
1287
templates/index.html
1287
templates/index.html
File diff suppressed because it is too large
Load Diff
43
utils/api_control.py
Normal file
43
utils/api_control.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from typing import Optional, Any
|
||||||
|
from fastapi import Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseModel(BaseModel):
|
||||||
|
"""
|
||||||
|
统一返回模型
|
||||||
|
"""
|
||||||
|
code: int = 200
|
||||||
|
msg: str = "Success"
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApiResponse:
|
||||||
|
@staticmethod
|
||||||
|
def __response(code: int, msg: str, data: Optional[Any] = None) -> ResponseModel:
|
||||||
|
return ResponseModel(code=code, msg=msg, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(cls, *, code: int = 200, msg: str = 'Success', data: Optional[Any] = None) -> Response:
|
||||||
|
response_model = cls.__response(code=code, msg=msg, data=data)
|
||||||
|
return cls(content=response_model.model_dump())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fail(cls, *, code: int = 400, msg: str = 'Bad Request', data: Optional[Any] = None) -> Response:
|
||||||
|
response_model = cls.__response(code=code, msg=msg, data=data)
|
||||||
|
return cls(content=response_model.model_dump())
|
||||||
|
|
||||||
|
response_api = ApiResponse()
|
||||||
|
|
||||||
|
""" 示例
|
||||||
|
@app.get("/example-success")
|
||||||
|
async def example_success():
|
||||||
|
return response_api.success(data={"key": "value"})
|
||||||
|
|
||||||
|
@app.get("/example-fail")
|
||||||
|
async def example_fail():
|
||||||
|
return response_api.fail(msg="Something went wrong", data={"error": "details"})
|
||||||
|
"""
|
||||||
251
web_server.py
251
web_server.py
@@ -1,51 +1,80 @@
|
|||||||
from flask import Flask, render_template, request, jsonify, Response, stream_with_context, send_from_directory
|
from fastapi import FastAPI, Request, Response, Depends, HTTPException, BackgroundTasks
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional, Generator
|
||||||
from stock_analyzer import StockAnalyzer
|
from stock_analyzer import StockAnalyzer
|
||||||
from us_stock_service import USStockService
|
from us_stock_service import USStockService
|
||||||
from fund_service import FundService # 新增导入
|
from fund_service import FundService
|
||||||
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import httpx
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from utils.api_utils import APIUtils
|
from utils.api_utils import APIUtils
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# 获取日志器
|
# 获取日志器
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
app = Flask(__name__,
|
app = FastAPI(
|
||||||
static_folder='frontend/dist',
|
title="Stock Scanner API",
|
||||||
static_url_path='/')
|
description="股票分析API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加CORS中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 开发环境允许所有来源,生产环境应该限制
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置静态文件
|
||||||
|
frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
|
||||||
|
if os.path.exists(frontend_dist):
|
||||||
|
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
|
||||||
|
|
||||||
analyzer = StockAnalyzer()
|
analyzer = StockAnalyzer()
|
||||||
us_stock_service = USStockService()
|
us_stock_service = USStockService()
|
||||||
fund_service = FundService() # 新增服务实例
|
fund_service = FundService()
|
||||||
|
|
||||||
@app.route('/')
|
# 定义请求和响应模型
|
||||||
def index():
|
class AnalyzeRequest(BaseModel):
|
||||||
|
stock_codes: List[str]
|
||||||
|
market_type: str = "A"
|
||||||
|
api_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
api_model: Optional[str] = None
|
||||||
|
api_timeout: Optional[str] = None
|
||||||
|
|
||||||
|
class TestAPIRequest(BaseModel):
|
||||||
|
api_url: str
|
||||||
|
api_key: str
|
||||||
|
api_model: Optional[str] = None
|
||||||
|
api_timeout: Optional[int] = 10
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def index(request: Request):
|
||||||
# 检查是否使用前端构建版本
|
# 检查是否使用前端构建版本
|
||||||
frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
|
|
||||||
if os.path.exists(frontend_dist):
|
if os.path.exists(frontend_dist):
|
||||||
return send_from_directory(frontend_dist, 'index.html')
|
index_file = os.path.join(frontend_dist, 'index.html')
|
||||||
|
return FileResponse(index_file)
|
||||||
else:
|
else:
|
||||||
# 传统模板渲染,用于兼容旧版本
|
# 不再使用模板渲染,而是重定向到API文档页面
|
||||||
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
|
logger.warning("前端构建目录不存在,重定向到API文档页面")
|
||||||
# 获取默认API配置信息
|
return RedirectResponse(url="/docs")
|
||||||
default_api_url = os.getenv('API_URL', '')
|
|
||||||
default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
|
|
||||||
default_api_timeout = os.getenv('API_TIMEOUT', '60')
|
|
||||||
# 不传递API_KEY到前端,出于安全考虑
|
|
||||||
return render_template('index.html',
|
|
||||||
announcement=announcement,
|
|
||||||
default_api_url=default_api_url,
|
|
||||||
default_api_model=default_api_model,
|
|
||||||
default_api_timeout=default_api_timeout)
|
|
||||||
|
|
||||||
@app.route('/config')
|
@app.get("/config")
|
||||||
def get_config():
|
async def get_config():
|
||||||
"""返回系统配置信息"""
|
"""返回系统配置信息"""
|
||||||
config = {
|
config = {
|
||||||
'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '',
|
'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '',
|
||||||
@@ -53,23 +82,22 @@ def get_config():
|
|||||||
'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'),
|
'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'),
|
||||||
'default_api_timeout': os.getenv('API_TIMEOUT', '60')
|
'default_api_timeout': os.getenv('API_TIMEOUT', '60')
|
||||||
}
|
}
|
||||||
return jsonify(config)
|
return config
|
||||||
|
|
||||||
@app.route('/analyze', methods=['POST'])
|
@app.post("/analyze")
|
||||||
def analyze():
|
async def analyze(request: AnalyzeRequest):
|
||||||
try:
|
try:
|
||||||
logger.info("开始处理分析请求")
|
logger.info("开始处理分析请求")
|
||||||
data = request.json
|
stock_codes = request.stock_codes
|
||||||
stock_codes = data.get('stock_codes', [])
|
market_type = request.market_type
|
||||||
market_type = data.get('market_type', 'A')
|
|
||||||
|
|
||||||
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
|
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
|
||||||
|
|
||||||
# 获取自定义API配置
|
# 获取自定义API配置
|
||||||
custom_api_url = data.get('api_url')
|
custom_api_url = request.api_url
|
||||||
custom_api_key = data.get('api_key')
|
custom_api_key = request.api_key
|
||||||
custom_api_model = data.get('api_model')
|
custom_api_model = request.api_model
|
||||||
custom_api_timeout = data.get('api_timeout')
|
custom_api_timeout = request.api_timeout
|
||||||
|
|
||||||
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}")
|
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}")
|
||||||
|
|
||||||
@@ -83,10 +111,10 @@ def analyze():
|
|||||||
|
|
||||||
if not stock_codes:
|
if not stock_codes:
|
||||||
logger.warning("未提供股票代码")
|
logger.warning("未提供股票代码")
|
||||||
return jsonify({'error': '请输入代码'}), 400
|
raise HTTPException(status_code=400, detail="请输入代码")
|
||||||
|
|
||||||
# 使用流式响应
|
# 定义流式生成器
|
||||||
def generate():
|
async def generate_stream():
|
||||||
if len(stock_codes) == 1:
|
if len(stock_codes) == 1:
|
||||||
# 单个股票分析流式处理
|
# 单个股票分析流式处理
|
||||||
stock_code = stock_codes[0].strip()
|
stock_code = stock_codes[0].strip()
|
||||||
@@ -97,7 +125,16 @@ def analyze():
|
|||||||
|
|
||||||
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
|
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
|
|
||||||
|
# 使用线程池执行同步分析
|
||||||
|
def run_analysis():
|
||||||
|
return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True))
|
||||||
|
|
||||||
|
# 在线程中执行同步操作
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
chunks = await loop.run_in_executor(None, run_analysis)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
yield chunk + '\n'
|
yield chunk + '\n'
|
||||||
logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块")
|
logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块")
|
||||||
@@ -110,114 +147,132 @@ def analyze():
|
|||||||
|
|
||||||
logger.debug(f"开始处理批量股票的流式响应")
|
logger.debug(f"开始处理批量股票的流式响应")
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
for chunk in custom_analyzer.scan_stocks(
|
|
||||||
[code.strip() for code in stock_codes],
|
# 使用线程池执行同步分析
|
||||||
min_score=0,
|
def run_batch_analysis():
|
||||||
market_type=market_type,
|
return list(custom_analyzer.scan_stocks(
|
||||||
stream=True
|
[code.strip() for code in stock_codes],
|
||||||
):
|
min_score=0,
|
||||||
|
market_type=market_type,
|
||||||
|
stream=True
|
||||||
|
))
|
||||||
|
|
||||||
|
# 在线程中执行同步操作
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
chunks = await loop.run_in_executor(None, run_batch_analysis)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
yield chunk + '\n'
|
yield chunk + '\n'
|
||||||
logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块")
|
logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块")
|
||||||
|
|
||||||
logger.info("成功创建流式响应生成器")
|
logger.info("成功创建流式响应生成器")
|
||||||
return Response(stream_with_context(generate()), mimetype='application/json')
|
return StreamingResponse(generate_stream(), media_type='application/json')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"分析时出错: {str(e)}"
|
error_msg = f"分析时出错: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return jsonify({'error': error_msg}), 500
|
raise HTTPException(status_code=500, detail=error_msg)
|
||||||
|
|
||||||
@app.route('/search_us_stocks', methods=['GET'])
|
@app.get("/search_us_stocks")
|
||||||
def search_us_stocks():
|
async def search_us_stocks(keyword: str = ""):
|
||||||
try:
|
try:
|
||||||
keyword = request.args.get('keyword', '')
|
|
||||||
if not keyword:
|
if not keyword:
|
||||||
return jsonify({'error': '请输入搜索关键词'}), 400
|
raise HTTPException(status_code=400, detail="请输入搜索关键词")
|
||||||
|
|
||||||
results = us_stock_service.search_us_stocks(keyword)
|
# 在异步上下文中运行同步的搜索函数
|
||||||
return jsonify({'results': results})
|
loop = asyncio.get_event_loop()
|
||||||
|
results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword)
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"搜索美股代码时出错: {str(e)}")
|
logger.error(f"搜索美股代码时出错: {str(e)}")
|
||||||
return jsonify({'error': str(e)}), 500
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
# 添加基金搜索路由
|
@app.get("/search_funds")
|
||||||
@app.route('/search_funds', methods=['GET'])
|
async def search_funds(keyword: str = "", market_type: str = ""):
|
||||||
def search_funds():
|
|
||||||
try:
|
try:
|
||||||
keyword = request.args.get('keyword', '')
|
|
||||||
market_type = request.args.get('market_type', '')
|
|
||||||
if not keyword:
|
if not keyword:
|
||||||
return jsonify({'error': '请输入搜索关键词'}), 400
|
raise HTTPException(status_code=400, detail="请输入搜索关键词")
|
||||||
|
|
||||||
results = fund_service.search_funds(keyword, market_type)
|
# 在异步上下文中运行同步的搜索函数
|
||||||
return jsonify({'results': results})
|
loop = asyncio.get_event_loop()
|
||||||
|
results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type))
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"搜索基金代码时出错: {str(e)}")
|
logger.error(f"搜索基金代码时出错: {str(e)}")
|
||||||
return jsonify({'error': str(e)}), 500
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.route('/test_api_connection', methods=['POST'])
|
@app.post("/test_api_connection")
|
||||||
def test_api_connection():
|
async def test_api_connection(request: TestAPIRequest):
|
||||||
"""测试API连接"""
|
"""测试API连接"""
|
||||||
try:
|
try:
|
||||||
logger.info("开始测试API连接")
|
logger.info("开始测试API连接")
|
||||||
data = request.json
|
api_url = request.api_url
|
||||||
api_url = data.get('api_url')
|
api_key = request.api_key
|
||||||
api_key = data.get('api_key')
|
api_model = request.api_model
|
||||||
api_model = data.get('api_model')
|
api_timeout = request.api_timeout
|
||||||
api_timeout = data.get('api_timeout', 10) # 默认测试连接超时为10秒
|
|
||||||
|
|
||||||
logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}, Timeout={api_timeout}")
|
logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}, Timeout={api_timeout}")
|
||||||
|
|
||||||
if not api_url:
|
if not api_url:
|
||||||
logger.warning("未提供API URL")
|
logger.warning("未提供API URL")
|
||||||
return jsonify({'error': '请提供API URL'}), 400
|
raise HTTPException(status_code=400, detail="请提供API URL")
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
logger.warning("未提供API Key")
|
logger.warning("未提供API Key")
|
||||||
return jsonify({'error': '请提供API Key'}), 400
|
raise HTTPException(status_code=400, detail="请提供API Key")
|
||||||
|
|
||||||
# 构建API URL
|
# 构建API URL
|
||||||
test_url = APIUtils.format_api_url(api_url)
|
test_url = APIUtils.format_api_url(api_url)
|
||||||
logger.debug(f"完整API测试URL: {test_url}")
|
logger.debug(f"完整API测试URL: {test_url}")
|
||||||
|
|
||||||
# 发送测试请求
|
# 使用异步HTTP客户端发送测试请求
|
||||||
response = requests.post(
|
async with httpx.AsyncClient(timeout=float(api_timeout)) as client:
|
||||||
test_url,
|
response = await client.post(
|
||||||
headers={
|
test_url,
|
||||||
"Authorization": f"Bearer {api_key}",
|
headers={
|
||||||
"Content-Type": "application/json"
|
"Authorization": f"Bearer {api_key}",
|
||||||
},
|
"Content-Type": "application/json"
|
||||||
json={
|
},
|
||||||
"model": api_model or "gpt-3.5-turbo",
|
json={
|
||||||
"messages": [
|
"model": api_model or "gpt-3.5-turbo",
|
||||||
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
|
"messages": [
|
||||||
],
|
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
|
||||||
"max_tokens": 20
|
],
|
||||||
},
|
"max_tokens": 20
|
||||||
timeout=int(api_timeout)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查响应
|
# 检查响应
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
logger.info(f"API 连接测试成功: {response.status_code}")
|
logger.info(f"API 连接测试成功: {response.status_code}")
|
||||||
return jsonify({'success': True, 'message': 'API 连接测试成功'})
|
return {"success": True, "message": "API 连接测试成功"}
|
||||||
else:
|
else:
|
||||||
error_message = response.json().get('error', {}).get('message', '未知错误')
|
error_data = response.json()
|
||||||
|
error_message = error_data.get('error', {}).get('message', '未知错误')
|
||||||
logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
|
logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
|
||||||
return jsonify({'success': False, 'message': f'API 连接测试失败: {error_message}', 'status_code': response.status_code}), 400
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"success": False, "message": f"API 连接测试失败: {error_message}", "status_code": response.status_code}
|
||||||
|
)
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except httpx.RequestError as e:
|
||||||
logger.error(f"API 连接请求错误: {str(e)}")
|
logger.error(f"API 连接请求错误: {str(e)}")
|
||||||
return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"success": False, "message": f"请求错误: {str(e)}"}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"测试 API 连接时出错: {str(e)}")
|
logger.error(f"测试 API 连接时出错: {str(e)}")
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
return jsonify({'success': False, 'message': f'API 测试连接时出错: {str(e)}'}), 500
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
logger.info("股票分析系统启动")
|
logger.info("股票分析系统启动")
|
||||||
app.run(host='127.0.0.1', port=8888, debug=True)
|
uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)
|
||||||
Reference in New Issue
Block a user