feat: 优化前端显示&修复若干bug
This commit is contained in:
239
web_server.py
239
web_server.py
@@ -2,28 +2,40 @@ from fastapi import FastAPI, Request, Response, Depends, HTTPException, Backgrou
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any, Generator
|
||||
from services.stock_analyzer_service import StockAnalyzerService
|
||||
# 导入新的异步服务
|
||||
from services.us_stock_service_async import USStockServiceAsync
|
||||
from services.fund_service_async import FundServiceAsync
|
||||
import asyncio
|
||||
import threading
|
||||
import os
|
||||
import traceback
|
||||
import httpx
|
||||
from logger import get_logger
|
||||
from utils.api_utils import APIUtils
|
||||
# 加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
from dotenv import load_dotenv, dotenv_values
|
||||
import uvicorn
|
||||
import json
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from jose import JWTError, jwt
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
# JWT相关配置
|
||||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32))
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 10080 # Token过期时间一周
|
||||
|
||||
LOGIN_PASSWORD = os.getenv("LOGIN_PASSWORD", "")
|
||||
print(LOGIN_PASSWORD)
|
||||
|
||||
# 是否需要登录
|
||||
REQUIRE_LOGIN = bool(LOGIN_PASSWORD.strip())
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Stock Scanner API",
|
||||
description="异步股票分析API",
|
||||
@@ -42,7 +54,7 @@ app.add_middleware(
|
||||
# 设置静态文件
|
||||
frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
|
||||
if os.path.exists(frontend_dist):
|
||||
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
|
||||
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dist, "assets")), name="assets")
|
||||
|
||||
# 初始化异步服务
|
||||
# StockAnalyzerService 不需要全局初始化,在 /analyze 接口中按需创建
|
||||
@@ -64,35 +76,121 @@ class TestAPIRequest(BaseModel):
|
||||
api_model: Optional[str] = None
|
||||
api_timeout: Optional[int] = 10
|
||||
|
||||
@app.get("/")
|
||||
async def index(request: Request):
|
||||
# 检查是否使用前端构建版本
|
||||
if os.path.exists(frontend_dist):
|
||||
index_file = os.path.join(frontend_dist, 'index.html')
|
||||
return FileResponse(index_file)
|
||||
else:
|
||||
# 不再使用模板渲染,而是重定向到API文档页面
|
||||
logger.warning("前端构建目录不存在,重定向到API文档页面")
|
||||
return RedirectResponse(url="/docs")
|
||||
class LoginRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
# 自定义依赖项,在REQUIRE_LOGIN=False时不要求token
|
||||
class OptionalOAuth2PasswordBearer(OAuth2PasswordBearer):
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
if not REQUIRE_LOGIN:
|
||||
return None
|
||||
try:
|
||||
return await super().__call__(request)
|
||||
except HTTPException:
|
||||
if not REQUIRE_LOGIN:
|
||||
return None
|
||||
raise
|
||||
|
||||
# 使用自定义的依赖项
|
||||
optional_oauth2_scheme = OptionalOAuth2PasswordBearer(tokenUrl="login")
|
||||
|
||||
# 创建访问令牌
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
# 验证令牌
|
||||
async def verify_token(token: Optional[str] = Depends(optional_oauth2_scheme)):
|
||||
# 如果未设置密码,则不需要验证
|
||||
if not REQUIRE_LOGIN:
|
||||
return "guest"
|
||||
|
||||
# 如果没有token且不需要登录,返回guest
|
||||
if token is None and not REQUIRE_LOGIN:
|
||||
return "guest"
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的认证凭据",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# 如果需要登录但没有token,抛出异常
|
||||
if token is None:
|
||||
raise credentials_exception
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
return username
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
|
||||
# 用户登录接口
|
||||
@app.post("/login")
|
||||
async def login(request: LoginRequest):
|
||||
"""用户登录接口"""
|
||||
# 如果未设置密码,表示不需要登录
|
||||
if not REQUIRE_LOGIN:
|
||||
access_token = create_access_token(data={"sub": "guest"})
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
if request.password != LOGIN_PASSWORD:
|
||||
logger.warning("登录失败:密码错误")
|
||||
raise HTTPException(status_code=401, detail="密码错误")
|
||||
|
||||
# 创建访问令牌
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": "user"}, expires_delta=access_token_expires
|
||||
)
|
||||
logger.info("用户登录成功")
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
# 检查用户认证状态
|
||||
@app.get("/check_auth")
|
||||
async def check_auth(username: str = Depends(verify_token)):
|
||||
"""检查用户认证状态"""
|
||||
return {"authenticated": True, "username": username}
|
||||
|
||||
# 获取系统配置
|
||||
@app.get("/config")
|
||||
async def get_config():
|
||||
"""返回系统配置信息"""
|
||||
config = {
|
||||
'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '',
|
||||
'default_api_url': os.getenv('API_URL', ''),
|
||||
'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'),
|
||||
'default_api_model': os.getenv('API_MODEL', ''),
|
||||
'default_api_timeout': os.getenv('API_TIMEOUT', '60')
|
||||
}
|
||||
return config
|
||||
|
||||
# AI分析股票
|
||||
@app.post("/analyze")
|
||||
async def analyze(request: AnalyzeRequest):
|
||||
async def analyze(request: AnalyzeRequest, username: str = Depends(verify_token)):
|
||||
try:
|
||||
logger.info("开始处理分析请求")
|
||||
stock_codes = request.stock_codes
|
||||
market_type = request.market_type
|
||||
|
||||
# 后端再次去重,确保安全
|
||||
original_count = len(stock_codes)
|
||||
stock_codes = list(dict.fromkeys(stock_codes)) # 保持原有顺序的去重方法
|
||||
if len(stock_codes) < original_count:
|
||||
logger.info(f"后端去重: 从{original_count}个代码中移除了{original_count - len(stock_codes)}个重复项")
|
||||
|
||||
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
|
||||
|
||||
# 获取自定义API配置
|
||||
@@ -122,7 +220,8 @@ async def analyze(request: AnalyzeRequest):
|
||||
stock_code = stock_codes[0].strip()
|
||||
logger.info(f"开始单股流式分析: {stock_code}")
|
||||
|
||||
init_message = f'{{"stream_type": "single", "stock_code": "{stock_code}"}}\n'
|
||||
stock_code_json = json.dumps(stock_code)
|
||||
init_message = f'{{"stream_type": "single", "stock_code": {stock_code_json}}}\n'
|
||||
yield init_message
|
||||
|
||||
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
|
||||
@@ -138,7 +237,8 @@ async def analyze(request: AnalyzeRequest):
|
||||
# 批量分析流式处理
|
||||
logger.info(f"开始批量流式分析: {stock_codes}")
|
||||
|
||||
init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes}}}\n'
|
||||
stock_codes_json = json.dumps(stock_codes)
|
||||
init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes_json}}}\n'
|
||||
yield init_message
|
||||
|
||||
logger.debug(f"开始处理批量股票的流式响应")
|
||||
@@ -165,8 +265,9 @@ async def analyze(request: AnalyzeRequest):
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 搜索美股代码
|
||||
@app.get("/search_us_stocks")
|
||||
async def search_us_stocks(keyword: str = ""):
|
||||
async def search_us_stocks(keyword: str = "", username: str = Depends(verify_token)):
|
||||
try:
|
||||
if not keyword:
|
||||
raise HTTPException(status_code=400, detail="请输入搜索关键词")
|
||||
@@ -179,8 +280,9 @@ async def search_us_stocks(keyword: str = ""):
|
||||
logger.error(f"搜索美股代码时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 搜索基金代码
|
||||
@app.get("/search_funds")
|
||||
async def search_funds(keyword: str = "", market_type: str = ""):
|
||||
async def search_funds(keyword: str = "", market_type: str = "", username: str = Depends(verify_token)):
|
||||
try:
|
||||
if not keyword:
|
||||
raise HTTPException(status_code=400, detail="请输入搜索关键词")
|
||||
@@ -193,8 +295,39 @@ async def search_funds(keyword: str = "", market_type: str = ""):
|
||||
logger.error(f"搜索基金代码时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 获取美股详情
|
||||
@app.get("/us_stock_detail/{symbol}")
|
||||
async def get_us_stock_detail(symbol: str, username: str = Depends(verify_token)):
|
||||
try:
|
||||
if not symbol:
|
||||
raise HTTPException(status_code=400, detail="请提供股票代码")
|
||||
|
||||
# 使用异步服务获取详情
|
||||
detail = await us_stock_service.get_us_stock_detail(symbol)
|
||||
return detail
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股详情时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 获取基金详情
|
||||
@app.get("/fund_detail/{symbol}")
|
||||
async def get_fund_detail(symbol: str, market_type: str = "ETF", username: str = Depends(verify_token)):
|
||||
try:
|
||||
if not symbol:
|
||||
raise HTTPException(status_code=400, detail="请提供基金代码")
|
||||
|
||||
# 使用异步服务获取详情
|
||||
detail = await fund_service.get_fund_detail(symbol, market_type)
|
||||
return detail
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取基金详情时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 测试API连接
|
||||
@app.post("/test_api_connection")
|
||||
async def test_api_connection(request: TestAPIRequest):
|
||||
async def test_api_connection(request: TestAPIRequest, username: str = Depends(verify_token)):
|
||||
"""测试API连接"""
|
||||
try:
|
||||
logger.info("开始测试API连接")
|
||||
@@ -226,7 +359,7 @@ async def test_api_connection(request: TestAPIRequest):
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": api_model or "gpt-3.5-turbo",
|
||||
"model": api_model or "",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
|
||||
],
|
||||
@@ -261,36 +394,34 @@ async def test_api_connection(request: TestAPIRequest):
|
||||
content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
|
||||
)
|
||||
|
||||
# 新增 API 端点:获取美股详情
|
||||
@app.get("/us_stock_detail/{symbol}")
|
||||
async def get_us_stock_detail(symbol: str):
|
||||
try:
|
||||
if not symbol:
|
||||
raise HTTPException(status_code=400, detail="请提供股票代码")
|
||||
|
||||
# 使用异步服务获取详情
|
||||
detail = await us_stock_service.get_us_stock_detail(symbol)
|
||||
return detail
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股详情时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# 检查是否需要登录
|
||||
@app.get("/need_login")
|
||||
async def need_login():
|
||||
"""检查是否需要登录"""
|
||||
return {"require_login": REQUIRE_LOGIN}
|
||||
|
||||
# 新增 API 端点:获取基金详情
|
||||
@app.get("/fund_detail/{symbol}")
|
||||
async def get_fund_detail(symbol: str, market_type: str = "ETF"):
|
||||
try:
|
||||
if not symbol:
|
||||
raise HTTPException(status_code=400, detail="请提供基金代码")
|
||||
|
||||
# 使用异步服务获取详情
|
||||
detail = await fund_service.get_fund_detail(symbol, market_type)
|
||||
return detail
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取基金详情时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
# 前端路由处理,必须放在所有API路由之后
|
||||
@app.get("/{full_path:path}")
|
||||
async def serve_frontend(full_path: str, request: Request):
|
||||
"""处理所有前端路由请求,返回index.html"""
|
||||
# 排除API路径和静态资源
|
||||
if full_path.startswith(("api/", "assets/", "docs", "openapi.json")) or \
|
||||
full_path in ["check_auth", "config", "analyze",
|
||||
"search_us_stocks", "search_funds",
|
||||
"test_api_connection", "us_stock_detail",
|
||||
"fund_detail"]:
|
||||
# 对于API路径,让FastAPI继续处理
|
||||
raise HTTPException(status_code=404, detail="API路径不存在")
|
||||
|
||||
# 检查是否使用前端构建版本
|
||||
if os.path.exists(frontend_dist):
|
||||
index_file = os.path.join(frontend_dist, 'index.html')
|
||||
return FileResponse(index_file)
|
||||
else:
|
||||
# 不再使用模板渲染,而是重定向到API文档页面
|
||||
logger.warning("前端构建目录不存在,重定向到API文档页面")
|
||||
return RedirectResponse(url="/docs")
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.info("股票分析系统启动")
|
||||
logger.info("股票AI分析系统启动")
|
||||
uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)
|
||||
Reference in New Issue
Block a user