66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""OpenAI 兼容型 LLM 工具封装,保持与旧项目一致的接口体验。"""
|
|
|
|
import os
|
|
from dataclasses import asdict, dataclass
|
|
from typing import AsyncGenerator, Dict, List, Optional
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
|
|
@dataclass
|
|
class ChatMessage:
|
|
role: str
|
|
content: str
|
|
|
|
def to_dict(self) -> Dict[str, str]:
|
|
return asdict(self)
|
|
|
|
|
|
class LLMClient:
|
|
"""异步流式调用封装,兼容 OpenAI SDK。"""
|
|
|
|
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
|
key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
if not key:
|
|
raise ValueError("缺少 OPENAI_API_KEY 配置,请在数据库或环境变量中补全。")
|
|
|
|
self._client = AsyncOpenAI(api_key=key, base_url=base_url or os.environ.get("OPENAI_API_BASE"))
|
|
|
|
async def stream_chat(
|
|
self,
|
|
messages: List[ChatMessage],
|
|
model: Optional[str] = None,
|
|
response_format: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
timeout: int = 120,
|
|
**kwargs,
|
|
) -> AsyncGenerator[Dict[str, str], None]:
|
|
payload = {
|
|
"model": model or os.environ.get("MODEL", "gpt-3.5-turbo"),
|
|
"messages": [msg.to_dict() for msg in messages],
|
|
"stream": True,
|
|
"timeout": timeout,
|
|
**kwargs,
|
|
}
|
|
if response_format:
|
|
payload["response_format"] = {"type": response_format}
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
if top_p is not None:
|
|
payload["top_p"] = top_p
|
|
if max_tokens is not None:
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
stream = await self._client.chat.completions.create(**payload)
|
|
async for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
choice = chunk.choices[0]
|
|
yield {
|
|
"content": choice.delta.content,
|
|
"finish_reason": choice.finish_reason,
|
|
}
|