diff --git a/requirements.txt b/requirements.txt index c77a82b..f758152 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,11 +10,14 @@ scipy==1.15.1 akshare==1.16.22 tqdm==4.67.1 +# Web框架与异步处理 +fastapi==0.115.11 +uvicorn[standard]==0.34.0 +pydantic==2.10.6 +httpx==0.28.1 -# 网络和API请求 -requests==2.32.3 +# 环境配置 python-dotenv==1.0.1 -flask==3.1.0 # 日志和系统工具 loguru==0.7.2 diff --git a/templates/index.html b/templates/index.html deleted file mode 100644 index 0f159c7..0000000 --- a/templates/index.html +++ /dev/null @@ -1,1287 +0,0 @@ - - - - - - 股票分析系统 - - - - - - {% if announcement %} -
-
-
-
- - - -
-
-

-

-
- -
-
-
- - - {% endif %} - -
-

股票分析系统

- - -
-
-
- -
-

当前时间

-

-
- - -
-

A股市场

-

-

-
- - -
-

港股市场

-

-

-
- - -
-

美股市场

-

-

-
-
-
-
- - - -
- -
-

股票批量分析

- - -
-
-

API配置

- -
- - -
- - -
- - -
- - - -
- - - -
- - -
- - - - -
- - -
-
-

分析结果

- -
-
-
- - - - - - - - - - - - - \ No newline at end of file diff --git a/utils/api_control.py b/utils/api_control.py new file mode 100644 index 0000000..ba10342 --- /dev/null +++ b/utils/api_control.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Optional, Any +from fastapi import Response +from pydantic import BaseModel + + +class ResponseModel(BaseModel): + """ + 统一返回模型 + """ + code: int = 200 + msg: str = "Success" + data: Optional[Any] = None + + +class ApiResponse: + @staticmethod + def __response(code: int, msg: str, data: Optional[Any] = None) -> ResponseModel: + return ResponseModel(code=code, msg=msg, data=data) + + @classmethod + def success(cls, *, code: int = 200, msg: str = 'Success', data: Optional[Any] = None) -> Response: + response_model = cls.__response(code=code, msg=msg, data=data) + return cls(content=response_model.model_dump()) + + @classmethod + def fail(cls, *, code: int = 400, msg: str = 'Bad Request', data: Optional[Any] = None) -> Response: + response_model = cls.__response(code=code, msg=msg, data=data) + return cls(content=response_model.model_dump()) + +response_api = ApiResponse() + +""" 示例 +@app.get("/example-success") +async def example_success(): + return response_api.success(data={"key": "value"}) + +@app.get("/example-fail") +async def example_fail(): + return response_api.fail(msg="Something went wrong", data={"error": "details"}) +""" diff --git a/web_server.py b/web_server.py index a14ef1d..c761334 100644 --- a/web_server.py +++ b/web_server.py @@ -1,51 +1,80 @@ -from flask import Flask, render_template, request, jsonify, Response, stream_with_context, send_from_directory +from fastapi import FastAPI, Request, Response, Depends, HTTPException, BackgroundTasks +from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from typing import List, Optional, Generator from stock_analyzer import StockAnalyzer from us_stock_service import USStockService -from fund_service import FundService # 新增导入 +from fund_service import FundService +import asyncio import threading import os import traceback -import requests +import httpx from logger import get_logger from utils.api_utils import APIUtils # 加载环境变量 from dotenv import load_dotenv +import uvicorn load_dotenv() # 获取日志器 logger = get_logger() -app = Flask(__name__, - static_folder='frontend/dist', - static_url_path='/') +app = FastAPI( + title="Stock Scanner API", + description="股票分析API", + version="1.0.0" +) + +# 添加CORS中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 开发环境允许所有来源,生产环境应该限制 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 设置静态文件 +frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist') +if os.path.exists(frontend_dist): + app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend") analyzer = StockAnalyzer() us_stock_service = USStockService() -fund_service = FundService() # 新增服务实例 +fund_service = FundService() -@app.route('/') -def index(): +# 定义请求和响应模型 +class AnalyzeRequest(BaseModel): + stock_codes: List[str] + market_type: str = "A" + api_url: Optional[str] = None + api_key: Optional[str] = None + api_model: Optional[str] = None + api_timeout: Optional[str] = None + +class TestAPIRequest(BaseModel): + api_url: str + api_key: str + api_model: Optional[str] = None + api_timeout: Optional[int] = 10 + +@app.get("/") +async def index(request: Request): # 检查是否使用前端构建版本 - frontend_dist = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'frontend', 'dist') if os.path.exists(frontend_dist): - return send_from_directory(frontend_dist, 'index.html') + index_file = os.path.join(frontend_dist, 'index.html') + return FileResponse(index_file) else: - # 传统模板渲染,用于兼容旧版本 - announcement = os.getenv('ANNOUNCEMENT_TEXT') or None - # 获取默认API配置信息 - default_api_url = os.getenv('API_URL', '') - default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo') - default_api_timeout = os.getenv('API_TIMEOUT', '60') - # 不传递API_KEY到前端,出于安全考虑 - return render_template('index.html', - announcement=announcement, - default_api_url=default_api_url, - default_api_model=default_api_model, - default_api_timeout=default_api_timeout) + # 不再使用模板渲染,而是重定向到API文档页面 + logger.warning("前端构建目录不存在,重定向到API文档页面") + return RedirectResponse(url="/docs") -@app.route('/config') -def get_config(): +@app.get("/config") +async def get_config(): """返回系统配置信息""" config = { 'announcement': os.getenv('ANNOUNCEMENT_TEXT') or '', @@ -53,23 +82,22 @@ def get_config(): 'default_api_model': os.getenv('API_MODEL', 'gpt-3.5-turbo'), 'default_api_timeout': os.getenv('API_TIMEOUT', '60') } - return jsonify(config) + return config -@app.route('/analyze', methods=['POST']) -def analyze(): +@app.post("/analyze") +async def analyze(request: AnalyzeRequest): try: logger.info("开始处理分析请求") - data = request.json - stock_codes = data.get('stock_codes', []) - market_type = data.get('market_type', 'A') + stock_codes = request.stock_codes + market_type = request.market_type logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}") # 获取自定义API配置 - custom_api_url = data.get('api_url') - custom_api_key = data.get('api_key') - custom_api_model = data.get('api_model') - custom_api_timeout = data.get('api_timeout') + custom_api_url = request.api_url + custom_api_key = request.api_key + custom_api_model = request.api_model + custom_api_timeout = request.api_timeout logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}, Timeout={custom_api_timeout}") @@ -83,10 +111,10 @@ def analyze(): if not stock_codes: logger.warning("未提供股票代码") - return jsonify({'error': '请输入代码'}), 400 + raise HTTPException(status_code=400, detail="请输入代码") - # 使用流式响应 - def generate(): + # 定义流式生成器 + async def generate_stream(): if len(stock_codes) == 1: # 单个股票分析流式处理 stock_code = stock_codes[0].strip() @@ -97,7 +125,16 @@ def analyze(): logger.debug(f"开始处理股票 {stock_code} 的流式响应") chunk_count = 0 - for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True): + + # 使用线程池执行同步分析 + def run_analysis(): + return list(custom_analyzer.analyze_stock(stock_code, market_type, stream=True)) + + # 在线程中执行同步操作 + loop = asyncio.get_event_loop() + chunks = await loop.run_in_executor(None, run_analysis) + + for chunk in chunks: chunk_count += 1 yield chunk + '\n' logger.info(f"股票 {stock_code} 流式分析完成,共发送 {chunk_count} 个块") @@ -110,114 +147,132 @@ def analyze(): logger.debug(f"开始处理批量股票的流式响应") chunk_count = 0 - for chunk in custom_analyzer.scan_stocks( - [code.strip() for code in stock_codes], - min_score=0, - market_type=market_type, - stream=True - ): + + # 使用线程池执行同步分析 + def run_batch_analysis(): + return list(custom_analyzer.scan_stocks( + [code.strip() for code in stock_codes], + min_score=0, + market_type=market_type, + stream=True + )) + + # 在线程中执行同步操作 + loop = asyncio.get_event_loop() + chunks = await loop.run_in_executor(None, run_batch_analysis) + + for chunk in chunks: chunk_count += 1 yield chunk + '\n' logger.info(f"批量流式分析完成,共发送 {chunk_count} 个块") logger.info("成功创建流式响应生成器") - return Response(stream_with_context(generate()), mimetype='application/json') + return StreamingResponse(generate_stream(), media_type='application/json') except Exception as e: error_msg = f"分析时出错: {str(e)}" logger.error(error_msg) logger.exception(e) - return jsonify({'error': error_msg}), 500 + raise HTTPException(status_code=500, detail=error_msg) -@app.route('/search_us_stocks', methods=['GET']) -def search_us_stocks(): +@app.get("/search_us_stocks") +async def search_us_stocks(keyword: str = ""): try: - keyword = request.args.get('keyword', '') if not keyword: - return jsonify({'error': '请输入搜索关键词'}), 400 - - results = us_stock_service.search_us_stocks(keyword) - return jsonify({'results': results}) + raise HTTPException(status_code=400, detail="请输入搜索关键词") + + # 在异步上下文中运行同步的搜索函数 + loop = asyncio.get_event_loop() + results = await loop.run_in_executor(None, us_stock_service.search_us_stocks, keyword) + return {"results": results} except Exception as e: - print(f"搜索美股代码时出错: {str(e)}") - return jsonify({'error': str(e)}), 500 + logger.error(f"搜索美股代码时出错: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) -# 添加基金搜索路由 -@app.route('/search_funds', methods=['GET']) -def search_funds(): +@app.get("/search_funds") +async def search_funds(keyword: str = "", market_type: str = ""): try: - keyword = request.args.get('keyword', '') - market_type = request.args.get('market_type', '') if not keyword: - return jsonify({'error': '请输入搜索关键词'}), 400 - - results = fund_service.search_funds(keyword, market_type) - return jsonify({'results': results}) + raise HTTPException(status_code=400, detail="请输入搜索关键词") + + # 在异步上下文中运行同步的搜索函数 + loop = asyncio.get_event_loop() + results = await loop.run_in_executor(None, lambda: fund_service.search_funds(keyword, market_type)) + return {"results": results} except Exception as e: logger.error(f"搜索基金代码时出错: {str(e)}") - return jsonify({'error': str(e)}), 500 + raise HTTPException(status_code=500, detail=str(e)) -@app.route('/test_api_connection', methods=['POST']) -def test_api_connection(): +@app.post("/test_api_connection") +async def test_api_connection(request: TestAPIRequest): """测试API连接""" try: logger.info("开始测试API连接") - data = request.json - api_url = data.get('api_url') - api_key = data.get('api_key') - api_model = data.get('api_model') - api_timeout = data.get('api_timeout', 10) # 默认测试连接超时为10秒 + api_url = request.api_url + api_key = request.api_key + api_model = request.api_model + api_timeout = request.api_timeout logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}, Timeout={api_timeout}") if not api_url: logger.warning("未提供API URL") - return jsonify({'error': '请提供API URL'}), 400 + raise HTTPException(status_code=400, detail="请提供API URL") if not api_key: logger.warning("未提供API Key") - return jsonify({'error': '请提供API Key'}), 400 + raise HTTPException(status_code=400, detail="请提供API Key") # 构建API URL test_url = APIUtils.format_api_url(api_url) logger.debug(f"完整API测试URL: {test_url}") - # 发送测试请求 - response = requests.post( - test_url, - headers={ - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json" - }, - json={ - "model": api_model or "gpt-3.5-turbo", - "messages": [ - {"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."} - ], - "max_tokens": 20 - }, - timeout=int(api_timeout) - ) + # 使用异步HTTP客户端发送测试请求 + async with httpx.AsyncClient(timeout=float(api_timeout)) as client: + response = await client.post( + test_url, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + }, + json={ + "model": api_model or "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."} + ], + "max_tokens": 20 + } + ) # 检查响应 if response.status_code == 200: logger.info(f"API 连接测试成功: {response.status_code}") - return jsonify({'success': True, 'message': 'API 连接测试成功'}) + return {"success": True, "message": "API 连接测试成功"} else: - error_message = response.json().get('error', {}).get('message', '未知错误') + error_data = response.json() + error_message = error_data.get('error', {}).get('message', '未知错误') logger.warning(f"API连接测试失败: {response.status_code} - {error_message}") - return jsonify({'success': False, 'message': f'API 连接测试失败: {error_message}', 'status_code': response.status_code}), 400 + return JSONResponse( + status_code=400, + content={"success": False, "message": f"API 连接测试失败: {error_message}", "status_code": response.status_code} + ) - except requests.exceptions.RequestException as e: + except httpx.RequestError as e: logger.error(f"API 连接请求错误: {str(e)}") - return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400 + return JSONResponse( + status_code=400, + content={"success": False, "message": f"请求错误: {str(e)}"} + ) except Exception as e: logger.error(f"测试 API 连接时出错: {str(e)}") logger.exception(e) - return jsonify({'success': False, 'message': f'API 测试连接时出错: {str(e)}'}), 500 + return JSONResponse( + status_code=500, + content={"success": False, "message": f"API 测试连接时出错: {str(e)}"} + ) if __name__ == '__main__': logger.info("股票分析系统启动") - app.run(host='127.0.0.1', port=8888, debug=True) \ No newline at end of file + uvicorn.run("web_server:app", host="127.0.0.1", port=8888, reload=True) \ No newline at end of file