refactor: 重构代码结构

This commit is contained in:
Cassianvale
2025-03-06 17:11:15 +08:00
parent bcf64f0041
commit 1e53d16b3a
9 changed files with 1380 additions and 38 deletions

View File

@@ -3,10 +3,11 @@ from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, Red
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Generator
from stock_analyzer import StockAnalyzer
from us_stock_service import USStockService
from fund_service import FundService
from typing import List, Optional, Dict, Any, Generator
from services.stock_analyzer_service import StockAnalyzerService
# 导入新的异步服务
from services.us_stock_service_async import USStockServiceAsync
from services.fund_service_async import FundServiceAsync
import asyncio
import threading
import os
@@ -25,7 +26,7 @@ logger = get_logger()
app = FastAPI(
title="Stock Scanner API",
description="股票分析API",
description="异步股票分析API",
version="1.0.0"
)
@@ -43,9 +44,10 @@ frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fronte
if os.path.exists(frontend_dist):
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
analyzer = StockAnalyzer()
us_stock_service = USStockService()
fund_service = FundService()
# 初始化异步服务
# StockAnalyzerService 不需要全局初始化,在 /analyze 接口中按需创建
us_stock_service = USStockServiceAsync()
fund_service = FundServiceAsync()
# 定义请求和响应模型
class AnalyzeRequest(BaseModel):
@@ -102,7 +104,7 @@ async def analyze(request: AnalyzeRequest):
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}")
# 创建新的分析器实例,使用自定义配置
custom_analyzer = StockAnalyzer(
custom_analyzer = StockAnalyzerService(
custom_api_url=custom_api_url,
custom_api_key=custom_api_key,
custom_api_model=custom_api_model,
@@ -126,17 +128,11 @@ async def analyze(request: AnalyzeRequest):
logger.debug(f"开始处理股票 {stock_code} 的流式响应")
chunk_count = 0
# 使用线程池执行同步分析
def run_analysis():
return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True))
# 在线程中执行同步操作
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, run_analysis)
for chunk in chunks:
# 使用异步生成器
async for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
chunk_count += 1
yield chunk + '\n'
logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块")
else:
# 批量分析流式处理
@@ -148,22 +144,16 @@ async def analyze(request: AnalyzeRequest):
logger.debug(f"开始处理批量股票的流式响应")
chunk_count = 0
# 使用线程池执行同步分析
def run_batch_analysis():
return list(custom_analyzer.scan_stocks(
[code.strip() for code in stock_codes],
min_score=0,
market_type=market_type,
stream=True
))
# 在线程中执行同步操作
loop = asyncio.get_event_loop()
chunks = await loop.run_in_executor(None, run_batch_analysis)
for chunk in chunks:
# 使用异步生成器
async for chunk in custom_analyzer.scan_stocks(
[code.strip() for code in stock_codes],
min_score=0,
market_type=market_type,
stream=True
):
chunk_count += 1
yield chunk + '\n'
logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块")
logger.info("成功创建流式响应生成器")
@@ -181,9 +171,8 @@ async def search_us_stocks(keyword: str = ""):
if not keyword:
raise HTTPException(status_code=400, detail="请输入搜索关键词")
# 在异步上下文中运行同步的搜索函数
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword)
# 直接使用异步服务的异步方法
results = await us_stock_service.search_us_stocks(keyword)
return {"results": results}
except Exception as e:
@@ -196,9 +185,8 @@ async def search_funds(keyword: str = "", market_type: str = ""):
if not keyword:
raise HTTPException(status_code=400, detail="请输入搜索关键词")
# 在异步上下文中运行同步的搜索函数
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type))
# 直接使用异步服务的异步方法
results = await fund_service.search_funds(keyword, market_type)
return {"results": results}
except Exception as e:
@@ -273,6 +261,36 @@ async def test_api_connection(request: TestAPIRequest):
content={"success": False, "message": f"API 测试连接时出错: {str(e)}"}
)
# 新增 API 端点:获取美股详情
@app.get("/us_stock_detail/{symbol}")
async def get_us_stock_detail(symbol: str):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供股票代码")
# 使用异步服务获取详情
detail = await us_stock_service.get_us_stock_detail(symbol)
return detail
except Exception as e:
logger.error(f"获取美股详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 新增 API 端点:获取基金详情
@app.get("/fund_detail/{symbol}")
async def get_fund_detail(symbol: str, market_type: str = "ETF"):
try:
if not symbol:
raise HTTPException(status_code=400, detail="请提供基金代码")
# 使用异步服务获取详情
detail = await fund_service.get_fund_detail(symbol, market_type)
return detail
except Exception as e:
logger.error(f"获取基金详情时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
logger.info("股票分析系统启动")
uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True)