feat: 优化前端显示&修复若干bug
This commit is contained in:
@@ -231,35 +231,83 @@ class AIAnalyzer:
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
if chunk:
|
||||
chunk_str = chunk.strip()
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[6:] # 去除"data: "前缀
|
||||
|
||||
if chunk_str == "[DONE]":
|
||||
logger.debug("收到流结束标记 [DONE]")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 解析数据块
|
||||
chunk_data = json.loads(chunk_str)
|
||||
delta = chunk_data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
if content:
|
||||
chunk_count += 1
|
||||
buffer += content
|
||||
collected_messages.append(content)
|
||||
# 分割多行响应(处理某些API可能在一个chunk中返回多行)
|
||||
lines = chunk.strip().split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 直接发送每个内容片段,不累积
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"ai_analysis_chunk": content,
|
||||
"status": "analyzing"
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
# 忽略无法解析的块
|
||||
logger.error(f"JSON解析错误,块内容: {chunk_str[:100]}...")
|
||||
continue
|
||||
# 处理以data:开头的行
|
||||
if line.startswith("data: "):
|
||||
line = line[6:] # 去除"data: "前缀
|
||||
|
||||
if line == "[DONE]":
|
||||
logger.debug("收到流结束标记 [DONE]")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 处理特殊错误情况
|
||||
if "error" in line.lower():
|
||||
error_msg = line
|
||||
try:
|
||||
error_data = json.loads(line)
|
||||
error_msg = error_data.get("error", line)
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(f"流式响应中收到错误: {error_msg}")
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"error": f"流式响应错误: {error_msg}",
|
||||
"status": "error"
|
||||
})
|
||||
continue
|
||||
|
||||
# 尝试解析JSON
|
||||
chunk_data = json.loads(line)
|
||||
|
||||
# 检查是否有finish_reason
|
||||
finish_reason = chunk_data.get("choices", [{}])[0].get("finish_reason")
|
||||
if finish_reason == "stop":
|
||||
logger.debug("收到finish_reason=stop,流结束")
|
||||
continue
|
||||
|
||||
# 获取delta内容
|
||||
delta = chunk_data.get("choices", [{}])[0].get("delta", {})
|
||||
|
||||
# 检查delta是否为空对象
|
||||
if not delta or delta == {}:
|
||||
logger.debug("收到空的delta对象,跳过")
|
||||
continue
|
||||
|
||||
content = delta.get("content", "")
|
||||
|
||||
if content:
|
||||
chunk_count += 1
|
||||
buffer += content
|
||||
collected_messages.append(content)
|
||||
|
||||
# 直接发送每个内容片段,不累积
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"ai_analysis_chunk": content,
|
||||
"status": "analyzing"
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
# 记录解析错误并尝试恢复
|
||||
logger.error(f"JSON解析错误,块内容: {line}")
|
||||
|
||||
# 如果是特定错误模式,处理它
|
||||
if "streaming failed after retries" in line.lower():
|
||||
logger.error("检测到流式传输失败")
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"error": "流式传输失败,请稍后重试",
|
||||
"status": "error"
|
||||
})
|
||||
return
|
||||
continue
|
||||
|
||||
logger.info(f"AI流式处理完成,共收到 {chunk_count} 个内容片段,总长度: {len(buffer)}")
|
||||
|
||||
|
||||
@@ -56,8 +56,11 @@ class FundServiceAsync:
|
||||
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0,
|
||||
'total_value': float(row['total_value']) if pd.notna(row['total_value']) else 0.0,
|
||||
})
|
||||
# 限制只返回前10个结果
|
||||
if len(formatted_results) >= 10:
|
||||
break
|
||||
|
||||
logger.info(f"基金搜索完成,找到 {len(formatted_results)} 个匹配项")
|
||||
logger.info(f"基金搜索完成,找到 {len(formatted_results)} 个匹配项(限制显示前10个)")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -60,6 +60,30 @@ class StockAnalyzerService:
|
||||
# 获取股票数据
|
||||
df = await self.data_provider.get_stock_data(stock_code, market_type)
|
||||
|
||||
# 检查是否有错误
|
||||
if hasattr(df, 'error'):
|
||||
error_msg = df.error
|
||||
logger.error(f"获取股票数据时出错: {error_msg}")
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"market_type": market_type,
|
||||
"error": error_msg,
|
||||
"status": "error"
|
||||
})
|
||||
return
|
||||
|
||||
# 检查数据是否为空
|
||||
if df.empty:
|
||||
error_msg = f"获取到的股票 {stock_code} 数据为空"
|
||||
logger.error(error_msg)
|
||||
yield json.dumps({
|
||||
"stock_code": stock_code,
|
||||
"market_type": market_type,
|
||||
"error": error_msg,
|
||||
"status": "error"
|
||||
})
|
||||
return
|
||||
|
||||
# 计算技术指标
|
||||
df_with_indicators = self.indicator.calculate_indicators(df)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from logger import get_logger
|
||||
import re
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
@@ -57,27 +58,16 @@ class StockDataProvider:
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
# 确保日期格式统一(移除可能的'-'符号)
|
||||
if isinstance(start_date, str) and '-' in start_date:
|
||||
start_date = start_date.replace('-', '')
|
||||
if isinstance(end_date, str) and '-' in end_date:
|
||||
end_date = end_date.replace('-', '')
|
||||
|
||||
try:
|
||||
# 验证股票代码格式
|
||||
if market_type == 'A':
|
||||
# 上海证券交易所股票代码以6开头
|
||||
# 深圳证券交易所股票代码以0或3开头
|
||||
# 科创板股票代码以688开头
|
||||
# 北京证券交易所股票代码以8开头
|
||||
valid_prefixes = ['0', '3', '6', '688', '8']
|
||||
valid_format = False
|
||||
|
||||
for prefix in valid_prefixes:
|
||||
if stock_code.startswith(prefix):
|
||||
valid_format = True
|
||||
break
|
||||
|
||||
if not valid_format:
|
||||
error_msg = f"无效的A股股票代码格式: {stock_code}。A股代码应以0、3、6、688或8开头"
|
||||
logger.error(f"[股票代码格式错误] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.debug(f"获取A股数据: {stock_code}")
|
||||
|
||||
df = ak.stock_zh_a_hist(
|
||||
symbol=stock_code,
|
||||
start_date=start_date,
|
||||
@@ -96,13 +86,72 @@ class StockDataProvider:
|
||||
|
||||
elif market_type in ['US']:
|
||||
logger.debug(f"获取美股数据: {stock_code}")
|
||||
df = ak.stock_us_daily(
|
||||
symbol=stock_code,
|
||||
adjust="qfq"
|
||||
)
|
||||
# 过滤日期
|
||||
df = df[(df.index >= start_date) & (df.index <= end_date)]
|
||||
try:
|
||||
df = ak.stock_us_daily(
|
||||
symbol=stock_code,
|
||||
adjust="qfq"
|
||||
)
|
||||
logger.debug(f"美股数据原始列: {df.columns.tolist()}")
|
||||
logger.debug(f"美股数据形状: {df.shape}")
|
||||
|
||||
# 确保索引是日期时间类型
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
# 如果存在命名为'date'的列,将其设为索引
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df.set_index('date', inplace=True)
|
||||
logger.debug("已将'date'列设置为索引")
|
||||
else:
|
||||
# 否则将当前索引转换为日期类型
|
||||
df.index = pd.to_datetime(df.index)
|
||||
logger.debug("已将索引转换为DatetimeIndex")
|
||||
|
||||
# 计算美股的成交额(Amount)= 成交量(Volume)× 收盘价(Close)
|
||||
volume_col = next((col for col in df.columns if col.lower() == 'volume'), None)
|
||||
close_col = next((col for col in df.columns if col.lower() == 'close'), None)
|
||||
|
||||
if volume_col and close_col:
|
||||
df['amount'] = df[volume_col] * df[close_col]
|
||||
logger.debug("已为美股数据计算成交额(amount)字段")
|
||||
else:
|
||||
logger.warning(f"美股数据缺少volume或close列,无法计算amount。当前列: {df.columns.tolist()}")
|
||||
# 添加空的amount列,避免后续处理错误
|
||||
df['amount'] = 0.0
|
||||
|
||||
# 将所有列名转为小写以进行统一处理
|
||||
df.columns = [col.lower() for col in df.columns]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取美股数据失败 {stock_code}: {str(e)}")
|
||||
raise ValueError(f"获取美股数据失败 {stock_code}: {str(e)}")
|
||||
|
||||
# 将字符串日期转换为日期时间对象进行比较
|
||||
try:
|
||||
# 尝试多种格式解析日期
|
||||
# 如果日期是数字格式(20220101),使用适当的格式
|
||||
if start_date.isdigit() and len(start_date) == 8:
|
||||
start_date_dt = pd.to_datetime(start_date, format='%Y%m%d')
|
||||
else:
|
||||
# 否则让pandas自动推断格式
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
|
||||
if end_date.isdigit() and len(end_date) == 8:
|
||||
end_date_dt = pd.to_datetime(end_date, format='%Y%m%d')
|
||||
else:
|
||||
end_date_dt = pd.to_datetime(end_date)
|
||||
except Exception as e:
|
||||
logger.warning(f"日期转换出错: {str(e)},使用默认值")
|
||||
# 如果转换失败,使用合理的默认值
|
||||
start_date_dt = pd.to_datetime('20000101', format='%Y%m%d')
|
||||
end_date_dt = pd.to_datetime(datetime.now().strftime('%Y%m%d'), format='%Y%m%d')
|
||||
|
||||
# 过滤日期
|
||||
try:
|
||||
df = df[(df.index >= start_date_dt) & (df.index <= end_date_dt)]
|
||||
logger.debug(f"日期过滤后数据点数: {len(df)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"日期过滤出错: {str(e)},返回原始数据")
|
||||
|
||||
elif market_type in ['ETF', 'LOF']:
|
||||
logger.debug(f"获取{market_type}基金数据: {stock_code}")
|
||||
df = ak.fund_etf_hist_sina(
|
||||
@@ -122,8 +171,31 @@ class StockDataProvider:
|
||||
# 实际数据列:['日期', '股票代码', '开盘', '收盘', '最高', '最低', '成交量', '成交额', '振幅', '涨跌幅', '涨跌额', '换手率']
|
||||
df.columns = ['Date', 'Code', 'Open', 'Close', 'High', 'Low', 'Volume', 'Amount', 'Amplitude', 'Change_pct', 'Change', 'Turnover']
|
||||
elif market_type in ['HK', 'US']:
|
||||
# 根据实际情况调整
|
||||
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
|
||||
# 美股数据列可能不同,需要通过映射处理
|
||||
columns_mapping = {
|
||||
'open': 'Open',
|
||||
'high': 'High',
|
||||
'low': 'Low',
|
||||
'close': 'Close',
|
||||
'volume': 'Volume',
|
||||
'amount': 'Amount'
|
||||
}
|
||||
|
||||
# 创建新的DataFrame以确保列顺序和存在性
|
||||
new_df = pd.DataFrame(index=df.index)
|
||||
|
||||
# 遍历映射,填充新DataFrame
|
||||
for orig_col, new_col in columns_mapping.items():
|
||||
if orig_col in df.columns:
|
||||
new_df[new_col] = df[orig_col]
|
||||
else:
|
||||
# 如果原始列不存在,创建一个填充0的列
|
||||
logger.warning(f"数据中缺少{orig_col}列,使用0值填充")
|
||||
new_df[new_col] = 0.0
|
||||
|
||||
# 替换原始df
|
||||
df = new_df
|
||||
|
||||
elif market_type in ['ETF', 'LOF']:
|
||||
# 基金数据可能有不同的列
|
||||
df.columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume', 'Amount']
|
||||
@@ -143,7 +215,11 @@ class StockDataProvider:
|
||||
error_msg = f"获取{market_type}数据失败 {stock_code}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.exception(e)
|
||||
raise Exception(error_msg)
|
||||
# 使用空的DataFrame并添加错误信息,而不是抛出异常
|
||||
# 这样上层调用者可以检查是否有错误并适当处理
|
||||
df = pd.DataFrame()
|
||||
df.error = error_msg # 添加错误属性
|
||||
return df
|
||||
|
||||
async def get_multiple_stocks_data(self, stock_codes: List[str],
|
||||
market_type: str = 'A',
|
||||
@@ -181,4 +257,4 @@ class StockDataProvider:
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# 构建结果字典,过滤掉失败的请求
|
||||
return {code: df for code, df in results if df is not None}
|
||||
return {code: df for code, df in results if df is not None}
|
||||
@@ -49,8 +49,11 @@ class USStockServiceAsync:
|
||||
'price': float(row['price']) if pd.notna(row['price']) else 0.0,
|
||||
'market_value': float(row['market_value']) if pd.notna(row['market_value']) else 0.0
|
||||
})
|
||||
# 限制只返回前10个结果
|
||||
if len(formatted_results) >= 10:
|
||||
break
|
||||
|
||||
logger.info(f"美股搜索完成,找到 {len(formatted_results)} 个匹配项")
|
||||
logger.info(f"美股搜索完成,找到 {len(formatted_results)} 个匹配项(限制显示前10个)")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user