feat: 优化前端显示&修复若干bug

This commit is contained in:
CaasianVale
2025-03-07 03:33:18 +08:00
parent ff5b820a57
commit 4c115cf325
29 changed files with 3726 additions and 1209 deletions

View File

@@ -2,28 +2,40 @@ from fastapi import FastAPI, Request, Response, Depends, HTTPException, Backgrou
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Generator
from services.stock_analyzer_service import StockAnalyzerService
# 导入新的异步服务
from services.us_stock_service_async import USStockServiceAsync
from services.fund_service_async import FundServiceAsync
import asyncio
import threading
import os
import traceback
import httpx
from logger import get_logger
from utils.api_utils import APIUtils
# 加载环境变量
from dotenv import load_dotenv
from dotenv import load_dotenv, dotenv_values
import uvicorn
import json
import secrets
from datetime import datetime, timedelta
from jose import JWTError, jwt
load_dotenv()
# 获取日志器
logger = get_logger()
# JWT相关配置
SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_hex(32))
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 10080 # Token过期时间一周
LOGIN_PASSWORD = os.getenv("LOGIN_PASSWORD", "")
print(LOGIN_PASSWORD)
# 是否需要登录
REQUIRE_LOGIN = bool(LOGIN_PASSWORD.strip())
app = FastAPI(
title="Stock Scanner API",
description="异步股票分析API",
@@ -42,7 +54,7 @@ app.add_middleware(
# 设置静态文件
frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist')
if os.path.exists(frontend_dist):
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dist, "assets")), name="assets")
# 初始化异步服务
# StockAnalyzerService 不需要全局初始化,在 /analyze 接口中按需创建
@@ -64,35 +76,121 @@ class TestAPIRequest(BaseModel):
api_model: Optional[str] = None
api_timeout: Optional[int] = 10
@app.get("/")
async def index(request: Request):
# 检查是否使用前端构建版本
if os.path.exists(frontend_dist):
index_file = os.path.join(frontend_dist, 'index.html')
return FileResponse(index_file)
else:
# 不再使用模板渲染而是重定向到API文档页面
logger.warning("前端构建目录不存在重定向到API文档页面")
return RedirectResponse(url="/docs")
class LoginRequest(BaseModel):
password: str
class Token(BaseModel):
access_token: str
token_type: str
# 自定义依赖项在REQUIRE_LOGIN=False时不要求token
class OptionalOAuth2PasswordBearer(OAuth2PasswordBearer):
async def __call__(self, request: Request) -> Optional[str]:
if not REQUIRE_LOGIN:
return None
try:
return await super().__call__(request)
except HTTPException:
if not REQUIRE_LOGIN:
return None
raise
# 使用自定义的依赖项
optional_oauth2_scheme = OptionalOAuth2PasswordBearer(tokenUrl="login")
# 创建访问令牌
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# 验证令牌
async def verify_token(token: Optional[str] = Depends(optional_oauth2_scheme)):
# 如果未设置密码,则不需要验证
if not REQUIRE_LOGIN:
return "guest"
# 如果没有token且不需要登录返回guest
if token is None and not REQUIRE_LOGIN:
return "guest"
credentials_exception = HTTPException(
status_code=401,
detail="无效的认证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
# 如果需要登录但没有token抛出异常
if token is None:
raise credentials_exception
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
return username
except JWTError:
raise credentials_exception
# 用户登录接口
@app.post("/login")
async def login(request: LoginRequest):
"""用户登录接口"""
# 如果未设置密码,表示不需要登录
if not REQUIRE_LOGIN:
access_token = create_access_token(data={"sub": "guest"})
return {"access_token": access_token, "token_type": "bearer"}
if request.password != LOGIN_PASSWORD:
logger.warning("登录失败:密码错误")
raise HTTPException(status_code=401, detail="密码错误")
# 创建访问令牌
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": "user"}, expires_delta=access_token_expires
)
logger.info("用户登录成功")
return {"access_token": access_token, "token_type": "bearer"}
# 检查用户认证状态
@app.get("/check_auth")
async def check_auth(username: str = Depends(verify_token)):
"""检查用户认证状态"""
return {"authenticated": True, "username": username}
# 获取系统配置
@app.get("/config")
async def get_config():
"""返回系统配置信息"""
config = {
'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '',
'default_api_url': os.getenv('API_URL', ''),
'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'),
'default_api_model': os.getenv('API_MODEL', ''),
'default_api_timeout': os.getenv('API_TIMEOUT', '60')
}
return config
# AI分析股票
@app.post("/analyze")
async def analyze(request: AnalyzeRequest):
async def analyze(request: AnalyzeRequest, username: str = Depends(verify_token)):
try:
logger.info("开始处理分析请求")
stock_codes = request.stock_codes
market_type = request.market_type
# 后端再次去重,确保安全
original_count = len(stock_codes)
stock_codes = list(dict.fromkeys(stock_codes)) # 保持原有顺序的去重方法
if len(stock_codes) < original_count:
logger.info(f"后端去重: 从{original_count}个代码中移除了{original_count - len(stock_codes)}个重复项")
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
# 获取自定义API配置
@@ -122,7 +220,8 @@ async def analyze(request: AnalyzeRequest):
stock_code = stock_codes[0].strip()
logger.info(f"开始单股流式分析: {stock_code}")
init_message = f'{{"stream_type": "single", "stock_code": "{stock_code}"}}\n'
stock_code_json = json.dumps(stock_code)
init_message = f'{{"stream_type": "single", "stock_code": {stock_code_json}}}\n'
yield init_message
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
@@ -138,7 +237,8 @@ async def analyze(request: AnalyzeRequest):
# 批量分析流式处理
logger.info(f"开始批量流式分析: {stock_codes}")
init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes}}}\n'
stock_codes_json = json.dumps(stock_codes)
init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes_json}}}\n'
yield init_message
logger.debug(f"开始处理批量股票的流式响应")
@@ -165,8 +265,9 @@ async def analyze(request: AnalyzeRequest):
logger.exception(e)
raise HTTPException(status_code=500, detail=error_msg)
# 搜索美股代码
@app.get("/search_us_stocks")
async def search_us_stocks(keyword: str = ""):
async def search_us_stocks(keyword: str = "", username: str = Depends(verify_token)):
try:
if not keyword:
raise HTTPException(status_code=400, detail="请输入搜索关键词")
@@ -179,8 +280,9 @@ async def search_us_stocks(keyword: str = ""):
logger.error(f"搜索美股代码时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 搜索基金代码
@app.get("/search_funds")
async def search_funds(keyword: str = "", market_type: str = ""):
async def search_funds(keyword: str = "", market_type: str = "", username: str = Depends(verify_token)):
try:
if not keyword:
raise HTTPException(status_code=400, detail="请输入搜索关键词")
@@ -193,8 +295,39 @@ async def search_funds(keyword: str = "", market_type: str = ""):
logger.error(f"搜索基金代码时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 获取美股详情
@app.get("/us_stock_detail/{symbol}")
async def get_us_stock_detail(symbol: str, username: str = Depends(verify_token)):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供股票代码")
# 使用异步服务获取详情
detail = await us_stock_service.get_us_stock_detail(symbol)
return detail
except Exception as e:
logger.error(f"获取美股详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 获取基金详情
@app.get("/fund_detail/{symbol}")
async def get_fund_detail(symbol: str, market_type: str = "ETF", username: str = Depends(verify_token)):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供基金代码")
# 使用异步服务获取详情
detail = await fund_service.get_fund_detail(symbol, market_type)
return detail
except Exception as e:
logger.error(f"获取基金详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 测试API连接
@app.post("/test_api_connection")
async def test_api_connection(request: TestAPIRequest):
async def test_api_connection(request: TestAPIRequest, username: str = Depends(verify_token)):
"""测试API连接"""
try:
logger.info("开始测试API连接")
@@ -226,7 +359,7 @@ async def test_api_connection(request: TestAPIRequest):
"Content-Type": "application/json"
},
json={
"model": api_model or "gpt-3.5-turbo",
"model": api_model or "",
"messages": [
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
],
@@ -261,36 +394,34 @@ async def test_api_connection(request: TestAPIRequest):
content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
)
# 新增 API 端点:获取美股详情
@app.get("/us_stock_detail/{symbol}")
async def get_us_stock_detail(symbol: str):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供股票代码")
# 使用异步服务获取详情
detail = await us_stock_service.get_us_stock_detail(symbol)
return detail
except Exception as e:
logger.error(f"获取美股详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 检查是否需要登录
@app.get("/need_login")
async def need_login():
"""检查是否需要登录"""
return {"require_login": REQUIRE_LOGIN}
# 新增 API 端点:获取基金详情
@app.get("/fund_detail/{symbol}")
async def get_fund_detail(symbol: str, market_type: str = "ETF"):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供基金代码")
# 使用异步服务获取详情
detail = await fund_service.get_fund_detail(symbol, market_type)
return detail
except Exception as e:
logger.error(f"获取基金详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 前端路由处理必须放在所有API路由之后
@app.get("/{full_path:path}")
async def serve_frontend(full_path: str, request: Request):
"""处理所有前端路由请求返回index.html"""
# 排除API路径和静态资源
if full_path.startswith(("api/", "assets/", "docs", "openapi.json")) or \
full_path in ["check_auth", "config", "analyze",
"search_us_stocks", "search_funds",
"test_api_connection", "us_stock_detail",
"fund_detail"]:
# 对于API路径让FastAPI继续处理
raise HTTPException(status_code=404, detail="API路径不存在")
# 检查是否使用前端构建版本
if os.path.exists(frontend_dist):
index_file = os.path.join(frontend_dist, 'index.html')
return FileResponse(index_file)
else:
# 不再使用模板渲染而是重定向到API文档页面
logger.warning("前端构建目录不存在重定向到API文档页面")
return RedirectResponse(url="/docs")
if __name__ == '__main__':
logger.info("股票分析系统启动")
logger.info("股票AI分析系统启动")
uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)