Merge branch 'pr-1' into dev

# Conflicts:
#	stock_analyzer.py
This commit is contained in:
兰志宏
2025-03-04 16:13:50 +08:00
7 changed files with 1799 additions and 807 deletions

View File

@@ -1,4 +1,4 @@
# 使用 Python 3.9 作为基础镜像
# 使用 Python 3.10 作为基础镜像
FROM python:3.10-slim
# 设置工作目录
@@ -15,7 +15,6 @@ COPY . /app/
# 安装 Python 依赖
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install akshare --upgrade -i https://pypi.org/simple
# 设置环境变量
ENV PYTHONPATH=/app

58
logger.py Normal file
View File

@@ -0,0 +1,58 @@
from loguru import logger
import sys
import os
from datetime import datetime
# 获取当前时间作为日志文件名的一部分
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
# 创建日志目录
log_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
os.makedirs(log_dir, exist_ok=True)
# 配置日志
logger.remove() # 移除默认的处理器
# 添加标准输出处理器(控制台)
logger.add(
sys.stdout,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
level="DEBUG"
)
# 添加文件处理器debug级别
logger.add(
os.path.join(log_dir, f"debug_{current_time}.log"),
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
level="DEBUG",
rotation="100 MB",
retention="1 week"
)
# 添加文件处理器error级别
logger.add(
os.path.join(log_dir, f"error_{current_time}.log"),
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{line} - {message}",
level="ERROR",
rotation="100 MB",
retention="1 month"
)
# 添加流处理器(用于记录流式输出)
logger.add(
os.path.join(log_dir, f"stream_{current_time}.log"),
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {message}",
filter=lambda record: "STREAM" in record["extra"],
level="INFO"
)
# 创建专用于流式输出的日志器
stream_logger = logger.bind(STREAM=True)
def get_logger():
"""获取通用日志器"""
return logger
def get_stream_logger():
"""获取流式输出专用日志器"""
return stream_logger

View File

@@ -1,10 +1,12 @@
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
# 基础科学计算和数据处理库
numpy==2.1.2
pandas==2.2.2
scipy==1.15.1
# 数据获取和分析库
akshare
akshare==1.16.22
tqdm==4.67.1

View File

@@ -3,19 +3,28 @@ import numpy as np
from datetime import datetime, timedelta
import os
import requests
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Generator
from dotenv import load_dotenv
import json
from logger import get_logger, get_stream_logger
# 获取日志器
logger = get_logger()
stream_logger = get_stream_logger()
class StockAnalyzer:
def __init__(self, initial_cash=1000000):
def __init__(self, initial_cash=1000000, custom_api_url=None, custom_api_key=None, custom_api_model=None, custom_api_timeout=60):
# 加载环境变量
load_dotenv()
# 设置 API
self.API_URL = os.getenv('API_URL')
self.API_KEY = os.getenv('API_KEY')
self.API_TIMEOUT = int(os.getenv('API_TIMEOUT', '60'))
# 设置 API 配置,优先使用自定义配置,否则使用环境变量
self.API_URL = custom_api_url or os.getenv('API_URL')
self.API_KEY = custom_api_key or os.getenv('API_KEY')
self.API_TIMEOUT = custom_api_timeout or int(os.getenv('API_TIMEOUT', '60'))
self.API_MODEL = custom_api_model or os.getenv('API_MODEL', 'gpt-3.5-turbo')
logger.debug(f"初始化StockAnalyzer: API_URL={self.API_URL}, API_MODEL={self.API_MODEL}, API_KEY={'已提供' if self.API_KEY else '未提供'}")
# 配置参数
self.params = {
@@ -205,9 +214,10 @@ class StockAnalyzer:
print(f"计算评分时出错: {str(e)}")
raise
def get_ai_analysis(self, df, stock_code):
"""使用 Gemini 进行 AI 分析"""
def get_ai_analysis(self, df, stock_code, stream=False):
"""使用 OpenAI 进行 AI 分析"""
try:
logger.info(f"开始AI分析股票 {stock_code}, 流式模式: {stream}")
recent_data = df.tail(14).to_dict('records')
technical_summary = {
@@ -237,43 +247,207 @@ class StockAnalyzer:
请基于技术指标和市场动态进行分析,给出具体数据支持。
"""
logger.debug(f"生成的AI分析提示词: {prompt[:100]}...")
# 检查API配置
if not self.API_URL:
error_msg = "API URL未配置无法进行AI分析"
logger.error(error_msg)
return error_msg if not stream else (yield json.dumps({"error": error_msg}))
if not self.API_KEY:
error_msg = "API Key未配置无法进行AI分析"
logger.error(error_msg)
return error_msg if not stream else (yield json.dumps({"error": error_msg}))
# 标准化API URL
if self.API_URL.endswith('/'):
api_url = f"{self.API_URL}chat/completions"
else:
api_url = f"{self.API_URL}/v1/chat/completions"
# 标准化API URL
# api_url = self.API_URL
# if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
# if api_url.endswith('/v1'):
# api_url = f"{api_url}/chat/completions"
# elif api_url.endswith('/'):
# api_url = f"{api_url}v1/chat/completions"
# else:
# api_url = f"{api_url}/v1/chat/completions"
logger.debug(f"标准化后的API URL: {api_url}")
# 构建请求头和请求体
headers = {
"Authorization": f"Bearer {self.API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": os.getenv('API_MODEL'),
payload = {
"model": self.API_MODEL,
"messages": [{"role": "user", "content": prompt}]
}
if self.API_URL.endswith('/'):
api_url = f"{self.API_URL}chat/completions"
else:
api_url = f"{self.API_URL}/v1/chat/completions"
# 流式处理设置
if stream:
logger.debug(f"配置流式参数使用API URL: {api_url}")
payload["stream"] = True # 明确设置stream参数为True
try:
logger.debug(f"发起流式API请求: {api_url}")
logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
response = requests.post(
api_url,
headers=headers,
json=data,
timeout=self.API_TIMEOUT
json=payload,
timeout=60, # 增加超时时间
stream=True
)
print(api_url)
print(data)
print(response.json())
logger.debug(f"API流式响应状态码: {response.status_code}")
if response.status_code == 200:
return response.json()['choices'][0]['message']['content']
logger.info(f"成功获取API流式响应开始处理")
yield from self._process_ai_stream(response, stock_code)
else:
return "AI 分析暂时无法使用"
try:
error_response = response.json()
error_text = json.dumps(error_response, indent=2)
except:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
logger.error(error_msg)
yield json.dumps({"stock_code": stock_code, "error": error_msg})
except Exception as e:
print(f"AI 分析发生错误: {str(e)}")
return "AI 分析过程中发生错误"
error_msg = f"流式API请求异常: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield json.dumps({"stock_code": stock_code, "error": error_msg})
else:
# 非流式处理
logger.debug(f"发起非流式API请求: {api_url}")
try:
response = requests.post(
api_url,
headers=headers,
json=payload,
timeout=60
)
logger.debug(f"API非流式响应状态码: {response.status_code}")
if response.status_code == 200:
api_response = response.json()
content = api_response['choices'][0]['message']['content']
logger.info(f"成功获取AI分析结果长度: {len(content)}")
logger.debug(f"AI分析结果前100字符: {content[:100]}...")
return content
else:
try:
error_response = response.json()
error_text = json.dumps(error_response, indent=2)
except:
error_text = response.text[:500] if response.text else "无响应内容"
error_msg = f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"非流式API请求异常: {str(e)}"
logger.error(error_msg)
logger.exception(e)
return error_msg
except Exception as e:
error_msg = f"AI 分析过程中发生错误: {str(e)}"
logger.error(error_msg)
logger.exception(e)
if stream:
logger.debug("在流式模式下返回异常信息")
error_json = json.dumps({"stock_code": stock_code, "error": error_msg})
stream_logger.info(f"流式异常输出: {error_json}")
yield error_json
else:
return error_msg
def _process_ai_stream(self, response, stock_code) -> Generator[str, None, None]:
"""处理AI流式响应"""
logger.info(f"开始处理股票 {stock_code} 的AI流式响应")
buffer = ""
chunk_count = 0
try:
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
stream_logger.info(f"原始流式行: {line}")
# 跳过保持连接的空行
if line.strip() == '':
logger.debug("跳过空行")
continue
# 数据行通常以"data: "开头
if line.startswith('data: '):
data_content = line[6:] # 移除 "data: " 前缀
stream_logger.info(f"数据内容: {data_content}")
# 检查是否为流的结束
if data_content.strip() == '[DONE]':
logger.debug("收到流结束标记 [DONE]")
break
try:
json_data = json.loads(data_content)
logger.debug(f"解析的JSON数据: {json.dumps(json_data)[:100]}...")
if 'choices' in json_data:
delta = json_data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
chunk_count += 1
buffer += content
logger.debug(f"收到内容片段 #{chunk_count}: {content}")
stream_logger.info(f"发送内容片段: {content}")
# 创建包含AI分析片段的JSON
chunk_json = json.dumps({
"stock_code": stock_code,
"ai_analysis_chunk": content
})
stream_logger.info(f"流式输出JSON: {chunk_json}")
yield chunk_json
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}, 行内容: {data_content}")
# 忽略无法解析的JSON
pass
else:
logger.warning(f"收到非'data:'开头的行: {line}")
logger.info(f"AI流式处理完成共收到 {chunk_count} 个内容片段,总长度: {len(buffer)}")
# 如果buffer不为空最后一次发送完整内容
if buffer and not buffer.endswith('\n'):
logger.debug("发送换行符")
yield json.dumps({"stock_code": stock_code, "ai_analysis_chunk": "\n"})
except Exception as e:
error_msg = f"处理AI流式响应时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
yield json.dumps({"stock_code": stock_code, "error": error_msg})
def get_recommendation(self, score):
"""根据得分给出建议"""
logger.debug(f"根据评分 {score} 生成投资建议")
if score >= 80:
return '强烈推荐买入'
elif score >= 60:
@@ -285,17 +459,23 @@ class StockAnalyzer:
else:
return '强烈建议卖出'
def analyze_stock(self, stock_code, market_type='A'):
def analyze_stock(self, stock_code, market_type='A', stream=False):
"""分析单个股票"""
try:
logger.info(f"开始分析股票: {stock_code}, 市场: {market_type}, 流式模式: {stream}")
# 获取股票数据
logger.debug(f"获取股票 {stock_code} 数据")
df = self.get_stock_data(stock_code, market_type)
# 计算技术指标
logger.debug(f"计算股票 {stock_code} 技术指标")
df = self.calculate_indicators(df)
# 评分系统
logger.debug(f"计算股票 {stock_code} 评分")
score = self.calculate_score(df)
logger.info(f"股票 {stock_code} 评分结果: {score}")
# 获取最新数据
latest = df.iloc[-1]
@@ -312,29 +492,92 @@ class StockAnalyzer:
'rsi': latest['RSI'],
'macd_signal': 'BUY' if latest['MACD'] > latest['Signal'] else 'SELL',
'volume_status': 'HIGH' if latest['Volume_Ratio'] > 1.5 else 'NORMAL',
'recommendation': self.get_recommendation(score),
'ai_analysis': self.get_ai_analysis(df, stock_code)
'recommendation': self.get_recommendation(score)
}
logger.debug(f"生成股票 {stock_code} 基础报告: {json.dumps(report)[:100]}...")
if stream:
logger.info(f"以流式模式返回股票 {stock_code} 分析结果")
# 先返回基本报告结构
base_report = dict(report)
base_report['ai_analysis'] = ''
base_report_json = json.dumps(base_report)
logger.debug(f"基础报告JSON: {base_report_json[:100]}...")
stream_logger.info(f"发送基础报告: {base_report_json}")
yield base_report_json
# 然后流式返回AI分析部分
logger.debug(f"开始获取股票 {stock_code} 的流式AI分析")
ai_chunks_count = 0
for ai_chunk in self.get_ai_analysis(df, stock_code, stream=True):
ai_chunks_count += 1
stream_logger.info(f"股票 {stock_code} 流式块 #{ai_chunks_count}: {ai_chunk}")
yield ai_chunk
logger.info(f"股票 {stock_code} 流式AI分析完成共发送 {ai_chunks_count} 个块")
else:
logger.info(f"以非流式模式返回股票 {stock_code} 分析结果")
logger.debug(f"开始获取股票 {stock_code} 的AI分析")
report['ai_analysis'] = self.get_ai_analysis(df, stock_code)
logger.debug(f"AI分析结果长度: {len(report['ai_analysis'])}")
return report
except Exception as e:
print(f"分析股票时出错: {str(e)}")
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
if stream:
error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
stream_logger.info(f"流式错误输出: {error_json}")
yield error_json
else:
raise
def scan_market(self, stock_list, min_score=60, market_type='A'):
def scan_market(self, stock_list, min_score=60, market_type='A', stream=False):
"""扫描市场,寻找符合条件的股票"""
logger.info(f"开始扫描市场,股票数量: {len(stock_list)}, 最低分数: {min_score}, 市场: {market_type}, 流式模式: {stream}")
if not stream:
recommendations = []
for stock_code in stock_list:
try:
logger.debug(f"分析股票: {stock_code}")
report = self.analyze_stock(stock_code, market_type)
if report['score'] >= min_score:
logger.info(f"股票 {stock_code} 评分 {report['score']} >= {min_score},添加到推荐列表")
recommendations.append(report)
else:
logger.debug(f"股票 {stock_code} 评分 {report['score']} < {min_score},不添加到推荐列表")
except Exception as e:
print(f"分析股票 {stock_code} 时出错: {str(e)}")
logger.error(f"分析股票 {stock_code} 时出错: {str(e)}")
logger.exception(e)
continue
# 按得分排序
recommendations.sort(key=lambda x: x['score'], reverse=True)
logger.info(f"扫描完成,找到 {len(recommendations)} 个推荐股票")
return recommendations
else:
# 流式处理每个股票
logger.info(f"开始流式扫描 {len(stock_list)} 只股票")
stock_count = 0
for stock_code in stock_list:
stock_count += 1
logger.debug(f"流式分析股票 {stock_code} ({stock_count}/{len(stock_list)})")
try:
# 分析单只股票并获取流式结果
chunk_count = 0
for chunk in self.analyze_stock(stock_code, market_type, stream=True):
chunk_count += 1
stream_logger.info(f"股票 {stock_code} 流式块 #{chunk_count}: {chunk}")
yield chunk
logger.debug(f"股票 {stock_code} 流式分析完成,共 {chunk_count} 个块")
except Exception as e:
error_msg = f"分析股票 {stock_code} 时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
error_json = json.dumps({'stock_code': stock_code, 'error': error_msg})
stream_logger.info(f"流式错误输出: {error_json}")
yield error_json
logger.info(f"流式扫描完成,处理了 {stock_count} 只股票")

View File

@@ -30,11 +30,65 @@
</div>
{% endif %}
<div class="max-w-4xl mx-auto"> <!-- 将 max-w-2xl 改为 max-w-4xl -->
<div class="max-w-4xl mx-auto">
<!-- 批量分析 -->
<div class="bg-white p-6 rounded-lg shadow-md">
<h2 class="text-xl font-semibold mb-4">股票批量分析</h2>
<!-- API配置部分 -->
<div class="mb-6 border-b pb-6">
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-medium text-gray-700">API配置</h3>
<button id="toggleApiConfig" class="text-blue-600 hover:text-blue-800 text-sm flex items-center">
<span id="toggleApiConfigText">显示配置</span>
<svg id="toggleApiConfigIcon" class="w-4 h-4 ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7"></path>
</svg>
</button>
</div>
<div id="apiConfigPanel" class="hidden space-y-4">
<div class="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>
<label for="apiUrl" class="block text-sm font-medium text-gray-700 mb-1">API URL</label>
<input type="text" id="apiUrl"
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
placeholder="例如: https://api.openai.com"
value="{{ default_api_url }}">
</div>
<div>
<label for="apiModel" class="block text-sm font-medium text-gray-700 mb-1">API 模型</label>
<input type="text" id="apiModel"
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
placeholder="例如: gpt-3.5-turbo"
value="{{ default_api_model }}">
</div>
<div>
<label for="apiTimeout" class="block text-sm font-medium text-gray-700 mb-1">API 超时时间</label>
<input type="text" id="apiTimeout"
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
placeholder="例如: 60"
value="{{ default_api_timeout }}">
</div>
</div>
<div>
<label for="apiKey" class="block text-sm font-medium text-gray-700 mb-1">API Key</label>
<input type="password" id="apiKey"
class="w-full p-2 border rounded bg-white focus:ring-2 focus:ring-blue-500 focus:border-blue-500"
placeholder="输入您的API Key">
<p class="mt-1 text-sm text-gray-500">如不填写,将使用系统默认配置</p>
</div>
<div class="flex justify-end">
<button id="resetApiConfig" class="text-gray-600 hover:text-gray-800 text-sm mr-3">
重置为默认
</button>
<button id="testApiConfig" class="bg-blue-100 text-blue-700 px-3 py-1 rounded hover:bg-blue-200 text-sm">
测试连接
</button>
</div>
</div>
</div>
<!-- 添加市场类型选择 -->
<div class="mb-4">
<label for="marketType" class="block text-sm font-medium text-gray-700 mb-2">
@@ -281,6 +335,7 @@
</script>
<script>
let isAnalyzing = false;
let stockAnalysisData = {}; // 存储股票分析数据的对象
async function analyzeStocks() {
if (isAnalyzing) return; // 防止重复点击
@@ -289,6 +344,12 @@
const marketType = document.getElementById('marketType').value;
const analyzeBtn = document.getElementById('analyzeBtn');
const loadingSpinner = document.getElementById('loadingSpinner');
const resultContent = document.getElementById('resultContent');
// 获取API配置
const apiUrl = document.getElementById('apiUrl').value.trim();
const apiKey = document.getElementById('apiKey').value.trim();
const apiModel = document.getElementById('apiModel').value.trim();
if (!stockInput) {
alert('请输入代码');
@@ -305,6 +366,16 @@
loadingSpinner.classList.remove('hidden');
analyzeBtn.querySelector('span').textContent = '分析中...';
// 清空现有结果并初始化分析数据
resultContent.innerHTML = '';
stockAnalysisData = {};
// 创建结果容器
const resultsContainer = document.createElement('div');
resultsContainer.className = 'space-y-6';
resultContent.appendChild(resultsContainer);
// 使用fetch流式API
const response = await fetch('/analyze', {
method: 'POST',
headers: {
@@ -312,21 +383,61 @@
},
body: JSON.stringify({
stock_codes: stockCodes,
market_type: marketType // 添加市场类型参数
market_type: marketType,
api_url: apiUrl,
api_key: apiKey,
api_model: apiModel
api_timeout: apiTimeout
})
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.error || '分析失败');
const errorData = await response.json();
throw new Error(errorData.error || '分析失败');
}
// 设置流式处理
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
// 持续读取数据流
while (true) {
const { done, value } = await reader.read();
if (done) break;
// 解码接收到的数据并添加到缓冲区
buffer += decoder.decode(value, { stream: true });
// 处理完整的行
const lines = buffer.split('\n');
buffer = lines.pop() || ''; // 最后一行可能不完整,保留到下一次处理
for (const line of lines) {
if (line.trim() === '') continue;
try {
const chunk = JSON.parse(line);
handleStreamChunk(chunk, resultsContainer, marketType);
} catch (e) {
console.error('解析流数据出错:', e, line);
}
}
}
// 处理可能遗留在缓冲区的最后一行
if (buffer.trim()) {
try {
const chunk = JSON.parse(buffer);
handleStreamChunk(chunk, resultsContainer, marketType);
} catch (e) {
console.error('解析最后一行数据出错:', e, buffer);
}
}
const results = Array.isArray(data.results) ? data.results : [data];
displayResults(results);
} catch (error) {
alert('请求失败: ' + error.message);
document.getElementById('resultContent').innerHTML = `
console.error('请求失败:', error);
resultContent.innerHTML = `
<div class="p-4 bg-red-50 text-red-600 rounded">
分析出错:${error.message}
</div>
@@ -339,18 +450,72 @@
}
}
function displayResults(results) {
const resultContent = document.getElementById('resultContent');
if (!results || results.length === 0) {
resultContent.innerHTML = '<div class="p-6 bg-yellow-50 text-yellow-600 rounded-lg text-center">没有分析结果</div>';
// 处理流式数据的函数
function handleStreamChunk(chunk, container, marketType) {
// 处理初始化信息
if (chunk.stream_type) {
console.log('开始流式分析:', chunk);
return;
}
let html = '';
results.forEach(result => {
// 获取股票代码
const stockCode = chunk.stock_code;
// 如果是错误信息
if (chunk.error) {
// 添加或更新显示错误的卡片
let errorCard = document.getElementById(`error-${stockCode}`);
if (!errorCard) {
errorCard = document.createElement('div');
errorCard.id = `error-${stockCode}`;
errorCard.className = 'bg-red-50 p-4 rounded-lg text-red-600';
errorCard.innerHTML = `分析股票 ${stockCode} 出错: ${chunk.error}`;
container.appendChild(errorCard);
} else {
errorCard.innerHTML = `分析股票 ${stockCode} 出错: ${chunk.error}`;
}
return;
}
// 如果是基本报告结构
if (!chunk.ai_analysis_chunk) {
// 存储基本报告数据
stockAnalysisData[stockCode] = {
...chunk,
ai_analysis: ''
};
// 添加或更新股票卡片
createStockCard(stockCode, container, marketType);
return;
}
// 如果是AI分析片段
if (chunk.ai_analysis_chunk) {
// 确保该股票的数据存在
if (!stockAnalysisData[stockCode]) {
stockAnalysisData[stockCode] = {
stock_code: stockCode,
ai_analysis: ''
};
}
// 累加AI分析内容
stockAnalysisData[stockCode].ai_analysis += chunk.ai_analysis_chunk;
// 更新AI分析显示
updateAIAnalysisDisplay(stockCode);
}
}
// 创建股票卡片
function createStockCard(stockCode, container, marketType) {
const result = stockAnalysisData[stockCode];
if (!result) return;
// 根据市场类型设置货币符号
const currencySymbol = (() => {
switch(document.getElementById('marketType').value) {
switch(marketType) {
case 'US':
return '$';
case 'HK':
@@ -361,8 +526,17 @@
}
})();
html += `
<div class="bg-white rounded-lg shadow-lg overflow-hidden">
// 检查是否已存在该股票的卡片
let stockCard = document.getElementById(`stock-card-${stockCode}`);
if (!stockCard) {
stockCard = document.createElement('div');
stockCard.id = `stock-card-${stockCode}`;
stockCard.className = 'bg-white rounded-lg shadow-lg overflow-hidden';
container.appendChild(stockCard);
}
stockCard.innerHTML = `
<!-- 头部信息 -->
<div class="bg-gradient-to-r from-blue-600 to-blue-700 px-6 py-4">
<h3 class="text-xl font-bold text-white">
@@ -408,8 +582,18 @@
<!-- AI分析部分 -->
<div class="mt-6">
<h4 class="text-lg font-semibold text-gray-800 mb-3">AI分析</h4>
<div class="prose prose-blue max-w-none bg-gray-50 p-4 rounded-lg">
${marked.parse(result.ai_analysis)}
<div id="ai-analysis-${stockCode}" class="prose prose-blue max-w-none bg-gray-50 p-4 rounded-lg relative">
<!-- 加载动画 -->
<div class="ai-analysis-loading flex flex-col items-center justify-center py-8">
<div class="typing-animation mb-3">
<span></span>
<span></span>
<span></span>
</div>
<p class="text-gray-500 text-sm">AI正在思考分析中...</p>
</div>
<!-- 实际内容容器 -->
<div class="ai-analysis-content hidden"></div>
</div>
</div>
@@ -421,14 +605,189 @@
</div>
</div>
</div>
</div>
`;
}
// 更新AI分析显示
function updateAIAnalysisDisplay(stockCode) {
const analysisElement = document.getElementById(`ai-analysis-${stockCode}`);
if (analysisElement && stockAnalysisData[stockCode]) {
const loadingElement = analysisElement.querySelector('.ai-analysis-loading');
const contentElement = analysisElement.querySelector('.ai-analysis-content');
// 如果有AI分析内容
if (stockAnalysisData[stockCode].ai_analysis) {
// 解析Markdown
const parsedContent = marked.parse(stockAnalysisData[stockCode].ai_analysis);
// 检查是否是第一次添加内容
const isFirstUpdate = contentElement.classList.contains('hidden');
// 如果是第一次更新,显示内容区域并隐藏加载动画
if (isFirstUpdate) {
contentElement.innerHTML = parsedContent;
contentElement.classList.remove('hidden');
contentElement.classList.add('fade-in');
// 延迟隐藏加载动画,使过渡更平滑
setTimeout(() => {
loadingElement.style.display = 'none';
}, 300);
} else {
// 获取当前内容长度,用于确定新增内容
const currentLength = contentElement.textContent.length;
// 更新内容
contentElement.innerHTML = parsedContent;
// 尝试高亮新增的文本(通过比较长度)
const allTextNodes = getAllTextNodes(contentElement);
let totalLength = 0;
for (const node of allTextNodes) {
totalLength += node.textContent.length;
if (totalLength > currentLength) {
// 这个节点包含新内容将其包装在高亮span中
const newTextSpan = document.createElement('span');
newTextSpan.className = 'new-text';
node.parentNode.insertBefore(newTextSpan, node);
newTextSpan.appendChild(node);
break;
}
}
}
}
}
}
// 辅助函数:获取元素内的所有文本节点
function getAllTextNodes(element) {
const textNodes = [];
const walker = document.createTreeWalker(element, NodeFilter.SHOW_TEXT, null, false);
let node;
while (node = walker.nextNode()) {
textNodes.push(node);
}
return textNodes;
}
// 旧的displayResults函数保留用于兼容
function displayResults(results) {
const resultContent = document.getElementById('resultContent');
// 清空现有结果
resultContent.innerHTML = '';
stockAnalysisData = {};
// 创建结果容器
const resultsContainer = document.createElement('div');
resultsContainer.className = 'space-y-6';
resultContent.appendChild(resultsContainer);
if (!results || results.length === 0) {
resultsContainer.innerHTML = '<div class="p-6 bg-yellow-50 text-yellow-600 rounded-lg text-center">没有分析结果</div>';
return;
}
// 获取市场类型
const marketType = document.getElementById('marketType').value;
// 处理每个结果
results.forEach(result => {
stockAnalysisData[result.stock_code] = result;
createStockCard(result.stock_code, resultsContainer, marketType);
updateAIAnalysisDisplay(result.stock_code);
});
resultContent.innerHTML = html;
// 添加 Markdown 样式
addMarkdownStyles();
}
</script>
<!-- 添加 marked.js 用于解析 Markdown -->
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<!-- API配置相关脚本 -->
<script>
document.addEventListener('DOMContentLoaded', function() {
// API配置面板切换
const toggleBtn = document.getElementById('toggleApiConfig');
const configPanel = document.getElementById('apiConfigPanel');
const toggleText = document.getElementById('toggleApiConfigText');
const toggleIcon = document.getElementById('toggleApiConfigIcon');
toggleBtn.addEventListener('click', function() {
const isHidden = configPanel.classList.contains('hidden');
configPanel.classList.toggle('hidden', !isHidden);
toggleText.textContent = isHidden ? '隐藏配置' : '显示配置';
toggleIcon.style.transform = isHidden ? 'rotate(180deg)' : '';
});
// 重置API配置
document.getElementById('resetApiConfig').addEventListener('click', function() {
document.getElementById('apiUrl').value = '{{ default_api_url }}';
document.getElementById('apiModel').value = '{{ default_api_model }}';
document.getElementById('apiKey').value = '';
});
// 测试API连接
document.getElementById('testApiConfig').addEventListener('click', async function() {
const apiUrl = document.getElementById('apiUrl').value.trim();
const apiKey = document.getElementById('apiKey').value.trim();
const apiModel = document.getElementById('apiModel').value.trim();
if (!apiUrl) {
alert('请输入API URL');
return;
}
if (!apiKey) {
alert('请输入API Key');
return;
}
this.textContent = '测试中...';
this.disabled = true;
try {
// 使用后端代理进行API测试
const response = await fetch('/test_api_connection', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
api_url: apiUrl,
api_key: apiKey,
api_model: apiModel
})
});
const data = await response.json();
if (response.ok && data.success) {
alert('API连接成功');
} else {
alert(`API连接失败: ${data.message || '未知错误'}`);
}
} catch (error) {
alert(`API连接测试失败: ${error.message}`);
} finally {
this.textContent = '测试连接';
this.disabled = false;
}
});
});
</script>
<script>
// 添加 Markdown 样式
function addMarkdownStyles() {
// 检查是否已经添加了样式
if (!document.getElementById('markdown-styles')) {
const style = document.createElement('style');
style.id = 'markdown-styles';
style.textContent = `
.prose h1 { font-size: 1.5em; margin-top: 1em; margin-bottom: 0.5em; font-weight: bold; }
.prose h2 { font-size: 1.3em; margin-top: 1em; margin-bottom: 0.5em; font-weight: bold; }
@@ -441,11 +800,75 @@
.prose em { font-style: italic; }
.prose blockquote { border-left: 4px solid #e5e7eb; padding-left: 1em; margin: 1em 0; color: #4b5563; }
.prose code { background-color: #f3f4f6; padding: 0.2em 0.4em; border-radius: 0.25em; font-size: 0.9em; }
/* 打字机动画样式 */
.typing-animation {
display: flex;
align-items: center;
}
.typing-animation span {
height: 10px;
width: 10px;
margin: 0 2px;
background-color: #3b82f6;
border-radius: 50%;
display: inline-block;
animation: typing 1.5s infinite ease-in-out;
}
.typing-animation span:nth-child(1) {
animation-delay: 0s;
}
.typing-animation span:nth-child(2) {
animation-delay: 0.3s;
}
.typing-animation span:nth-child(3) {
animation-delay: 0.6s;
}
@keyframes typing {
0% { transform: scale(1); opacity: 0.7; }
50% { transform: scale(1.5); opacity: 1; }
100% { transform: scale(1); opacity: 0.7; }
}
/* 内容淡入效果 */
.ai-analysis-content {
transition: opacity 0.3s ease;
}
.ai-analysis-content.fade-in {
opacity: 0;
animation: fadeIn 0.5s forwards;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
/* 高亮新增文本效果 */
.new-text {
background-color: rgba(59, 130, 246, 0.1);
animation: highlightFade 2s forwards;
}
@keyframes highlightFade {
from { background-color: rgba(59, 130, 246, 0.1); }
to { background-color: transparent; }
}
`;
document.head.appendChild(style);
}
}
// 页面加载时添加样式
document.addEventListener('DOMContentLoaded', function() {
addMarkdownStyles();
});
</script>
<!-- 添加 marked.js 用于解析 Markdown -->
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
</body>
</html>

137
tests/test_stream.py Normal file
View File

@@ -0,0 +1,137 @@
import os
import requests
import json
from logger import get_logger, get_stream_logger
from dotenv import load_dotenv
# 获取日志器
logger = get_logger()
stream_logger = get_stream_logger()
def test_api_stream():
"""
测试API流式响应功能
"""
# 加载环境变量
load_dotenv()
# 获取API配置
api_url = os.getenv('API_URL')
api_key = os.getenv('API_KEY')
api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
logger.info(f"开始测试API流式响应API URL: {api_url}, MODEL: {api_model}")
# 检查API配置
if not api_url:
logger.error("API URL未配置无法进行测试")
return
if not api_key:
logger.error("API Key未配置无法进行测试")
return
# 标准化API URL
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
if api_url.endswith('/v1'):
api_url = f"{api_url}/chat/completions"
elif api_url.endswith('/'):
api_url = f"{api_url}v1/chat/completions"
else:
api_url = f"{api_url}/v1/chat/completions"
logger.debug(f"标准化后的API URL: {api_url}")
# 构建简单的测试提示
prompt = "这是一个API流式响应测试。请给出一个简短的股票分析样例。"
# 构建请求头和请求体
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": api_model,
"messages": [{"role": "user", "content": prompt}],
"stream": True # 明确设置stream参数为True
}
logger.debug(f"请求载荷: {json.dumps(payload, indent=2)}")
try:
logger.info(f"发起流式API请求: {api_url}")
response = requests.post(
api_url,
headers=headers,
json=payload,
timeout=60,
stream=True
)
logger.info(f"API流式响应状态码: {response.status_code}")
logger.debug(f"响应头: {response.headers}")
if response.status_code == 200:
logger.info("成功获取API流式响应开始处理")
buffer = ""
chunk_count = 0
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
logger.info(f"原始流式行: {line_str}")
# 跳过保持连接的空行
if line_str.strip() == '':
logger.debug("跳过空行")
continue
# 数据行通常以"data: "开头
if line_str.startswith('data: '):
data_content = line_str[6:].strip() # 移除 "data: " 前缀并去除前后空格
logger.info(f"数据内容: {data_content}")
# 检查是否为流的结束
if data_content == '[DONE]':
logger.info("收到流结束标记 [DONE]")
break
try:
# 解析JSON数据
json_data = json.loads(data_content)
logger.debug(f"JSON结构: {json.dumps(json_data, indent=2)}")
if 'choices' in json_data:
delta = json_data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
chunk_count += 1
buffer += content
logger.info(f"内容片段 #{chunk_count}: {content}")
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {e}, 内容: {data_content}")
else:
logger.warning(f"收到非'data:'开头的行: {line_str}")
logger.info(f"流式处理完成,共收到 {chunk_count} 个内容片段")
logger.info(f"完整内容:\n{buffer}")
else:
try:
error_response = response.json()
error_text = json.dumps(error_response, indent=2)
except:
error_text = response.text[:500] if response.text else "无响应内容"
logger.error(f"API请求失败: 状态码 {response.status_code}, 响应: {error_text}")
except Exception as e:
logger.error(f"测试过程中发生异常: {str(e)}")
logger.exception(e)
if __name__ == "__main__":
test_api_stream()

View File

@@ -1,8 +1,15 @@
from flask import Flask, render_template, request, jsonify
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
from stock_analyzer import StockAnalyzer
from us_stock_service import USStockService
import threading
import os
import traceback
import requests
from logger import get_logger, get_stream_logger
# 获取日志器
logger = get_logger()
stream_logger = get_stream_logger()
app = Flask(__name__)
analyzer = StockAnalyzer()
@@ -11,27 +18,86 @@ us_stock_service = USStockService()
@app.route('/')
def index():
announcement = os.getenv('ANNOUNCEMENT_TEXT') or None
return render_template('index.html', announcement=announcement)
# 获取默认API配置信息
default_api_url = os.getenv('API_URL', '')
default_api_model = os.getenv('API_MODEL', 'gpt-3.5-turbo')
# 不传递API_KEY到前端出于安全考虑
return render_template('index.html',
announcement=announcement,
default_api_url=default_api_url,
default_api_model=default_api_model)
@app.route('/analyze', methods=['POST'])
def analyze():
try:
logger.info("开始处理分析请求")
data = request.json
stock_codes = data.get('stock_codes', [])
market_type = data.get('market_type', 'A')
logger.debug(f"接收到分析请求: stock_codes={stock_codes}, market_type={market_type}")
# 获取自定义API配置
custom_api_url = data.get('api_url')
custom_api_key = data.get('api_key')
custom_api_model = data.get('api_model')
custom_api_timeout = data.get('api_timeout', 60)
logger.debug(f"自定义API配置: URL={custom_api_url}, 模型={custom_api_model}, API Key={'已提供' if custom_api_key else '未提供'}")
# 创建新的分析器实例,使用自定义配置
custom_analyzer = StockAnalyzer(
custom_api_url=custom_api_url,
custom_api_key=custom_api_key,
custom_api_model=custom_api_model,
custom_api_timeout= custom_api_timeout,
)
if not stock_codes:
logger.warning("未提供股票代码")
return jsonify({'error': '请输入代码'}), 400
results = []
for stock_code in stock_codes:
result = analyzer.analyze_stock(stock_code.strip(), market_type)
results.append(result)
# 使用流式响应
def generate():
if len(stock_codes) == 1:
# 单个股票分析流式处理
stock_code = stock_codes[0].strip()
logger.info(f"开始单股流式分析: {stock_code}")
stream_logger.info(f"初始化单股分析流: {stock_code}")
init_message = f'{{"stream_type": "single", "stock_code": "{stock_code}"}}\n'
stream_logger.info(f"发送初始化消息: {init_message}")
yield init_message
for chunk in custom_analyzer.analyze_stock(stock_code, market_type, stream=True):
stream_logger.info(f"流式输出块: {chunk}")
yield chunk + '\n'
else:
# 批量分析流式处理
logger.info(f"开始批量流式分析: {stock_codes}")
stream_logger.info(f"初始化批量分析流: {stock_codes}")
init_message = f'{{"stream_type": "batch", "stock_codes": {stock_codes}}}\n'
stream_logger.info(f"发送初始化消息: {init_message}")
yield init_message
for chunk in custom_analyzer.scan_market(
[code.strip() for code in stock_codes],
min_score=0,
market_type=market_type,
stream=True
):
stream_logger.info(f"流式输出块: {chunk}")
yield chunk + '\n'
logger.info("成功创建流式响应生成器")
return Response(stream_with_context(generate()), mimetype='application/json')
return jsonify({'results': results})
except Exception as e:
print(f"分析股票时出错: {str(e)}")
return jsonify({'error': str(e)}), 500
error_msg = f"分析股票时出错: {str(e)}"
logger.error(error_msg)
logger.exception(e)
return jsonify({'error': error_msg}), 500
@app.route('/search_us_stocks', methods=['GET'])
def search_us_stocks():
@@ -47,8 +113,72 @@ def search_us_stocks():
print(f"搜索美股代码时出错: {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/test_api_connection', methods=['POST'])
def test_api_connection():
"""测试API连接"""
try:
logger.info("开始测试API连接")
data = request.json
api_url = data.get('api_url')
api_key = data.get('api_key')
api_model = data.get('api_model')
logger.debug(f"测试API连接: URL={api_url}, 模型={api_model}, API Key={'已提供' if api_key else '未提供'}")
if not api_url:
logger.warning("未提供API URL")
return jsonify({'error': '请提供API URL'}), 400
if not api_key:
logger.warning("未提供API Key")
return jsonify({'error': '请提供API Key'}), 400
# 构建API URL
test_url = api_url
if not (api_url.endswith('/chat/completions') or api_url.endswith('/v1/chat/completions')):
if api_url.endswith('/v1'):
test_url = f"{api_url}/chat/completions"
elif api_url.endswith('/'):
test_url = f"{api_url}chat/completions"
else:
test_url = f"{api_url}/v1/chat/completions"
logger.debug(f"完整API测试URL: {test_url}")
# 发送测试请求
response = requests.post(
test_url,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": api_model or "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, this is a test message. Please respond with 'API connection successful'."}
],
"max_tokens": 20
},
timeout=10
)
# 检查响应
if response.status_code == 200:
logger.info(f"API连接测试成功: {response.status_code}")
return jsonify({'success': True, 'message': '连接成功'})
else:
error_message = response.json().get('error', {}).get('message', '未知错误')
logger.warning(f"API连接测试失败: {response.status_code} - {error_message}")
return jsonify({'success': False, 'message': f'连接失败: {error_message}', 'status_code': response.status_code}), 400
except requests.exceptions.RequestException as e:
logger.error(f"API连接请求错误: {str(e)}")
return jsonify({'success': False, 'message': f'请求错误: {str(e)}'}), 400
except Exception as e:
logger.error(f"测试API连接时出错: {str(e)}")
logger.exception(e)
return jsonify({'success': False, 'message': f'测试连接时出错: {str(e)}'}), 500
if __name__ == '__main__':
logger.info("股票分析系统启动")
app.run(host='0.0.0.0', port=8888, debug=True)