🎉 first commit
This commit is contained in:
412
app/routes/api.chat/chat.server.ts
Normal file
412
app/routes/api.chat/chat.server.ts
Normal file
@@ -0,0 +1,412 @@
|
||||
import type { ActionFunctionArgs } from '@remix-run/node';
|
||||
import {
|
||||
consumeStream,
|
||||
createUIMessageStream,
|
||||
createUIMessageStreamResponse,
|
||||
generateId,
|
||||
type UIMessageStreamWriter,
|
||||
} from 'ai';
|
||||
import { upsertChat } from '~/lib/.server/chat';
|
||||
import { ChatUsageStatus, recordUsage, updateUsageStatus } from '~/lib/.server/chatUsage';
|
||||
import { chatStreamText } from '~/lib/.server/llm/chat-stream-text';
|
||||
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
|
||||
import { createSummary } from '~/lib/.server/llm/create-summary';
|
||||
import { selectContext } from '~/lib/.server/llm/select-context';
|
||||
import { structuredPageSnapshot } from '~/lib/.server/llm/structured-page-snapshot';
|
||||
import { createScopedLogger } from '~/lib/.server/logger';
|
||||
import { getHistoryChatMessages, saveChatMessages, updateDiscardedMessage } from '~/lib/.server/message';
|
||||
import { getPageByMessageId } from '~/lib/.server/page';
|
||||
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts';
|
||||
import type { Page } from '~/types/actions';
|
||||
import type { UPageUIMessage } from '~/types/message';
|
||||
import { DEFAULT_MODEL, DEFAULT_MODEL_DETAILS, getModel, MINOR_MODEL } from '~/utils/constants';
|
||||
import { approximateUsageFromContent } from '~/utils/token';
|
||||
|
||||
const logger = createScopedLogger('api.chat.chat');
|
||||
|
||||
export type ElementInfo = {
|
||||
tagName: string;
|
||||
className?: string;
|
||||
id?: string;
|
||||
innerHTML?: string;
|
||||
outerHTML?: string;
|
||||
};
|
||||
|
||||
export type ChatActionParams = {
|
||||
// 当前会话 ID
|
||||
chatId: string;
|
||||
// 回退到指定消息 ID
|
||||
rewindTo: string;
|
||||
// 最后一条消息,通常是用户消息。
|
||||
message: UPageUIMessage;
|
||||
// 如果用户指定编辑的元素,则需要传递该元素的信息。
|
||||
elementInfo: ElementInfo;
|
||||
};
|
||||
|
||||
export type ChatActionArgs = ActionFunctionArgs & {
|
||||
userId: string;
|
||||
};
|
||||
|
||||
export async function chatAction({ request, userId }: ChatActionArgs) {
|
||||
const { rewindTo, chatId, message } = await request.json<ChatActionParams>();
|
||||
const chat = await upsertChat({
|
||||
id: chatId,
|
||||
userId,
|
||||
});
|
||||
|
||||
const elementInfo = message.metadata?.elementInfo;
|
||||
const messageId = message.id;
|
||||
const messageContent = message.parts.find((part) => part.type === 'text')?.text;
|
||||
const initialUsageRecord = await recordUsage({
|
||||
userId,
|
||||
chatId: chat.id,
|
||||
messageId,
|
||||
status: ChatUsageStatus.PENDING,
|
||||
prompt: messageContent || '',
|
||||
modelName: DEFAULT_MODEL,
|
||||
});
|
||||
|
||||
const minorModelInitialUsageRecord = await recordUsage({
|
||||
userId,
|
||||
chatId: chat.id,
|
||||
messageId,
|
||||
status: ChatUsageStatus.PENDING,
|
||||
prompt: messageContent || '',
|
||||
modelName: MINOR_MODEL,
|
||||
});
|
||||
|
||||
let streamSwitches = 0;
|
||||
let progressCounter: number = 1;
|
||||
const cumulativeUsage = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0,
|
||||
};
|
||||
const minorModelCumulativeUsage = {
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalTokens: 0,
|
||||
reasoningTokens: 0,
|
||||
cachedInputTokens: 0,
|
||||
};
|
||||
|
||||
// 辅助函数:更新辅助模型使用量
|
||||
const updateMinorModelUsage = (usage: {
|
||||
inputTokens?: number;
|
||||
outputTokens?: number;
|
||||
totalTokens?: number;
|
||||
reasoningTokens?: number;
|
||||
cachedInputTokens?: number;
|
||||
}) => {
|
||||
minorModelCumulativeUsage.inputTokens += usage.inputTokens || 0;
|
||||
minorModelCumulativeUsage.outputTokens += usage.outputTokens || 0;
|
||||
minorModelCumulativeUsage.totalTokens += usage.totalTokens || 0;
|
||||
minorModelCumulativeUsage.reasoningTokens += usage.reasoningTokens || 0;
|
||||
minorModelCumulativeUsage.cachedInputTokens += usage.cachedInputTokens || 0;
|
||||
};
|
||||
|
||||
// 计算用户 token 消耗
|
||||
const calculateTokenUsage = async (status: ChatUsageStatus) => {
|
||||
try {
|
||||
await updateUsageStatus(initialUsageRecord.id, status, {
|
||||
inputTokens: cumulativeUsage.inputTokens,
|
||||
outputTokens: cumulativeUsage.outputTokens,
|
||||
reasoningTokens: cumulativeUsage.reasoningTokens,
|
||||
cachedTokens: cumulativeUsage.cachedInputTokens,
|
||||
totalTokens: cumulativeUsage.totalTokens,
|
||||
});
|
||||
logger.debug(`用户 ${userId} 的聊天: ${chat.id} 总使用量为: ${JSON.stringify(cumulativeUsage)}`);
|
||||
logger.debug(`用户 ${userId} 的聊天: ${chat.id} 使用状态已更新为 ${status}`);
|
||||
} catch (error) {
|
||||
logger.error(`更新用户 ${userId} 的使用状态时出错:`, error);
|
||||
}
|
||||
};
|
||||
|
||||
// 计算用户 token 消耗
|
||||
const calculateMinorModelTokenUsage = async (status: ChatUsageStatus) => {
|
||||
try {
|
||||
await updateUsageStatus(minorModelInitialUsageRecord.id, status, {
|
||||
inputTokens: minorModelCumulativeUsage.inputTokens,
|
||||
outputTokens: minorModelCumulativeUsage.outputTokens,
|
||||
reasoningTokens: minorModelCumulativeUsage.reasoningTokens,
|
||||
cachedTokens: minorModelCumulativeUsage.cachedInputTokens,
|
||||
totalTokens: minorModelCumulativeUsage.totalTokens,
|
||||
});
|
||||
logger.debug(`用户 ${userId} 的聊天: ${chat.id} 辅助模型使用状态已更新为 ${status}`);
|
||||
} catch (error) {
|
||||
logger.error(`更新用户 ${userId} 的辅助模型使用状态时出错:`, error);
|
||||
// 记录错误但不中断流程
|
||||
}
|
||||
};
|
||||
|
||||
const progressId = generateId();
|
||||
// 获取从第一条到当前消息之间的所有消息
|
||||
const previousMessages = await getHistoryChatMessages({
|
||||
chatId,
|
||||
rewindTo,
|
||||
});
|
||||
const messages = [...previousMessages, message];
|
||||
|
||||
const streamExecutor = async ({ writer }: { writer: UIMessageStreamWriter<UPageUIMessage> }) => {
|
||||
// 在消息的开头发送一个固定的消息,用于标识消息的开始。
|
||||
writer.write({
|
||||
type: 'start',
|
||||
messageId: generateId(),
|
||||
});
|
||||
|
||||
// 辅助 model 所获取的数据,用于后续的模型调用。
|
||||
const minorModelData: { summary: string; context: Record<string, string[]>; pageSummary: string } = {
|
||||
summary: '',
|
||||
context: {},
|
||||
pageSummary: '',
|
||||
};
|
||||
|
||||
// 仅当有历史消息时,才调用辅助模型,首次调用无需调用。
|
||||
if (previousMessages.length > 0) {
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'summary',
|
||||
status: 'in-progress',
|
||||
order: progressCounter++,
|
||||
message: '正在分析请求...',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
// 让 AI 分析用户消息摘要,明确用户下一步的意图。
|
||||
const { text: summary, totalUsage: createSummaryUsage } = await createSummary({
|
||||
messages,
|
||||
model: getModel(MINOR_MODEL),
|
||||
abortSignal: request.signal,
|
||||
});
|
||||
minorModelData.summary = summary;
|
||||
updateMinorModelUsage(createSummaryUsage);
|
||||
writer.write({
|
||||
type: 'data-summary',
|
||||
data: {
|
||||
summary,
|
||||
chatId: chat.id,
|
||||
},
|
||||
});
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'summary',
|
||||
status: 'complete',
|
||||
order: progressCounter++,
|
||||
message: '分析完成',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
|
||||
// 获取最后一条历史消息所对应的 page
|
||||
const lastMessage = previousMessages[previousMessages.length - 1];
|
||||
const pageData = await getPageByMessageId(lastMessage.id);
|
||||
if (pageData) {
|
||||
const pages = pageData.pages as unknown as Page[];
|
||||
// 根据用户摘要和所有的页面数据,让 AI 根据摘要、用户消息、页面数据,选择一部分待修改的页面和待修改的 section。
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'context',
|
||||
status: 'in-progress',
|
||||
order: progressCounter++,
|
||||
message: '正在对页面进行分析...',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
const { context, totalUsage: selectContextUsage } = await selectContext({
|
||||
messages,
|
||||
summary,
|
||||
pages,
|
||||
model: getModel(MINOR_MODEL),
|
||||
abortSignal: request.signal,
|
||||
});
|
||||
minorModelData.context = context;
|
||||
updateMinorModelUsage(selectContextUsage);
|
||||
|
||||
// 调用辅助 model 对 context 中的页面做摘要,如果没有,则对所有页面做摘要。
|
||||
const selectPageNames = Object.keys(context);
|
||||
const selectedPages = selectPageNames.length > 0 ? pages : pages.map((page) => page);
|
||||
const { text: pageSummary, totalUsage: structuredPageSnapshotUsage } = await structuredPageSnapshot({
|
||||
pages: selectedPages,
|
||||
model: getModel(MINOR_MODEL),
|
||||
abortSignal: request.signal,
|
||||
});
|
||||
minorModelData.pageSummary = pageSummary;
|
||||
updateMinorModelUsage(structuredPageSnapshotUsage);
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'context',
|
||||
status: 'complete',
|
||||
order: progressCounter++,
|
||||
message: '页面分析完成',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'response',
|
||||
status: 'in-progress',
|
||||
order: progressCounter++,
|
||||
message: '正在生成响应',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
const executeStreamText = async (messages: UPageUIMessage[], isContinue: boolean = false) => {
|
||||
const result = await chatStreamText({
|
||||
messages,
|
||||
elementInfo,
|
||||
summary: minorModelData.summary,
|
||||
pageSummary: minorModelData.pageSummary,
|
||||
context: minorModelData.context,
|
||||
maxTokens: DEFAULT_MODEL_DETAILS?.maxTokenAllowed,
|
||||
model: getModel(DEFAULT_MODEL),
|
||||
abortSignal: request.signal,
|
||||
onFinish: async ({ totalUsage, finishReason, text }) => {
|
||||
cumulativeUsage.inputTokens += totalUsage.inputTokens || 0;
|
||||
cumulativeUsage.outputTokens += totalUsage.outputTokens || 0;
|
||||
cumulativeUsage.totalTokens += totalUsage.totalTokens || 0;
|
||||
cumulativeUsage.reasoningTokens += totalUsage.reasoningTokens || 0;
|
||||
cumulativeUsage.cachedInputTokens += totalUsage.cachedInputTokens || 0;
|
||||
|
||||
if (finishReason === 'length') {
|
||||
if (streamSwitches >= MAX_RESPONSE_SEGMENTS) {
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'response',
|
||||
status: 'stopped',
|
||||
order: progressCounter++,
|
||||
message: '无法继续生成消息:已达到最大分段数',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
writer.write({
|
||||
type: 'finish',
|
||||
});
|
||||
return;
|
||||
}
|
||||
await continueMessage(text);
|
||||
}
|
||||
|
||||
if (finishReason === 'stop') {
|
||||
writer.write({
|
||||
type: 'data-progress',
|
||||
id: progressId,
|
||||
data: {
|
||||
label: 'response',
|
||||
status: 'complete',
|
||||
order: progressCounter++,
|
||||
message: '响应生成完成',
|
||||
},
|
||||
transient: true,
|
||||
});
|
||||
writer.write({
|
||||
type: 'finish',
|
||||
});
|
||||
}
|
||||
},
|
||||
onAbort: async ({ totalUsage }) => {
|
||||
cumulativeUsage.inputTokens += totalUsage.inputTokens || 0;
|
||||
cumulativeUsage.outputTokens += totalUsage.outputTokens || 0;
|
||||
cumulativeUsage.totalTokens += totalUsage.totalTokens || 0;
|
||||
cumulativeUsage.reasoningTokens += totalUsage.reasoningTokens || 0;
|
||||
cumulativeUsage.cachedInputTokens += totalUsage.cachedInputTokens || 0;
|
||||
},
|
||||
});
|
||||
|
||||
const continueMessage = async (text: string) => {
|
||||
logger.info(
|
||||
`达到最大 token 限制 (${MAX_TOKENS}): 继续消息, 还可以响应 (${MAX_RESPONSE_SEGMENTS - streamSwitches} 个分段)`,
|
||||
);
|
||||
messages.push({
|
||||
id: generateId(),
|
||||
role: 'assistant',
|
||||
parts: [
|
||||
{
|
||||
type: 'text',
|
||||
text,
|
||||
},
|
||||
],
|
||||
});
|
||||
messages.push({
|
||||
id: generateId(),
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
type: 'text',
|
||||
text: CONTINUE_PROMPT,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
await executeStreamText(messages, true);
|
||||
streamSwitches++;
|
||||
};
|
||||
|
||||
writer.merge(
|
||||
result.toUIMessageStream({
|
||||
sendReasoning: !isContinue,
|
||||
sendFinish: false,
|
||||
sendStart: false,
|
||||
}),
|
||||
);
|
||||
};
|
||||
await executeStreamText([message], false);
|
||||
};
|
||||
|
||||
const stream = createUIMessageStream<UPageUIMessage>({
|
||||
execute: streamExecutor,
|
||||
originalMessages: messages,
|
||||
onFinish: async ({ messages, isAborted }) => {
|
||||
if (isAborted) {
|
||||
// 由于 AI SDK 没有提供在 onAbort 中计算 Token 消耗的方法。所以这里手动计算。
|
||||
// https://github.com/vercel/ai/pull/8701
|
||||
const lastAssistantMessage = messages.find((message) => message.role === 'assistant');
|
||||
if (lastAssistantMessage) {
|
||||
cumulativeUsage.outputTokens += approximateUsageFromContent(lastAssistantMessage.parts);
|
||||
cumulativeUsage.totalTokens += approximateUsageFromContent(lastAssistantMessage.parts);
|
||||
}
|
||||
}
|
||||
|
||||
// 根据是否中止设置正确的状态
|
||||
// TODO: 在错误情况下,现在还是会被设置为 SUCCESS。
|
||||
const status = isAborted ? ChatUsageStatus.ABORTED : ChatUsageStatus.SUCCESS;
|
||||
calculateTokenUsage(status);
|
||||
calculateMinorModelTokenUsage(status);
|
||||
|
||||
if (isAborted) {
|
||||
logger.info(`用户 ${userId} 的聊天: ${chatId} 中止处理完成`);
|
||||
return;
|
||||
}
|
||||
|
||||
// 保存消息到数据库
|
||||
if (rewindTo) {
|
||||
await updateDiscardedMessage(chatId, rewindTo);
|
||||
}
|
||||
saveChatMessages(chatId, messages);
|
||||
},
|
||||
onError: (error) => {
|
||||
logger.error(`用户 ${userId} 的聊天: ${chatId} 处理过程中发生错误 ===> `, error);
|
||||
calculateTokenUsage(ChatUsageStatus.FAILED);
|
||||
calculateMinorModelTokenUsage(ChatUsageStatus.FAILED);
|
||||
return '内部服务器错误,请稍后重试';
|
||||
},
|
||||
});
|
||||
|
||||
return createUIMessageStreamResponse({ stream, consumeSseStream: consumeStream });
|
||||
}
|
||||
61
app/routes/api.chat/mock-chat.server.ts
Normal file
61
app/routes/api.chat/mock-chat.server.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import type { ActionFunctionArgs } from '@remix-run/node';
|
||||
import { generateId } from 'ai';
|
||||
import { readFile } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { createScopedLogger } from '~/lib/.server/logger';
|
||||
|
||||
const logger = createScopedLogger('api.chat.mock-chat');
|
||||
|
||||
/**
|
||||
* 处理 mock 数据的流式输出,通常用于开发环境下使用。
|
||||
*/
|
||||
export async function mockChat(_args: ActionFunctionArgs, filePath: string = 'mock_stream_text_1.txt') {
|
||||
try {
|
||||
const id = generateId();
|
||||
// 读取 mock 数据文件
|
||||
const mockFilePath = join(process.cwd(), 'mock', filePath);
|
||||
const fileContent = await readFile(mockFilePath, 'utf-8');
|
||||
const lines = fileContent.split('\n').map((line) => {
|
||||
// 替换 messageId 为生成 id,data: {"type":"start","messageId":"uoLyIATGAm28y7rP"}
|
||||
if (line.includes('messageId')) {
|
||||
const startData = JSON.parse(line.replace('data: ', ''));
|
||||
startData.messageId = id;
|
||||
return `data: ${JSON.stringify(startData)}`;
|
||||
}
|
||||
return line;
|
||||
});
|
||||
|
||||
// 创建一个 ReadableStream 来按行输出内容
|
||||
const encoder = new TextEncoder();
|
||||
const stream = new ReadableStream({
|
||||
async start(controller) {
|
||||
// 按行输出内容,每行之间添加延迟
|
||||
for (const line of lines) {
|
||||
if (line.trim() !== '') {
|
||||
controller.enqueue(encoder.encode(`${line}\n\n`));
|
||||
// 添加小延迟,模拟真实的流式输出
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
}
|
||||
}
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
|
||||
// 返回 Response 对象
|
||||
return new Response(stream, {
|
||||
status: 200,
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream; charset=utf-8',
|
||||
Connection: 'keep-alive',
|
||||
'Cache-Control': 'no-cache, no-transform',
|
||||
'X-Accel-Buffering': 'no',
|
||||
},
|
||||
});
|
||||
} catch (error: any) {
|
||||
logger.error(`Mock 数据流式输出错误: ${error}`);
|
||||
throw new Response(`Mock 数据流式输出错误: ${error.message}`, {
|
||||
status: 500,
|
||||
statusText: 'Internal Server Error',
|
||||
});
|
||||
}
|
||||
}
|
||||
25
app/routes/api.chat/route.tsx
Normal file
25
app/routes/api.chat/route.tsx
Normal file
@@ -0,0 +1,25 @@
|
||||
import { type ActionFunctionArgs } from '@remix-run/node';
|
||||
import { requireAuth } from '~/lib/.server/auth';
|
||||
import { errorResponse } from '~/utils/api-response';
|
||||
import { chatAction } from './chat.server';
|
||||
import { mockChat } from './mock-chat.server';
|
||||
|
||||
export async function action(args: ActionFunctionArgs) {
|
||||
const authResult = await requireAuth(args.request, { isApi: true });
|
||||
|
||||
if (authResult instanceof Response) {
|
||||
return authResult;
|
||||
}
|
||||
|
||||
const userId = authResult.userInfo?.sub;
|
||||
if (!userId) {
|
||||
return errorResponse(401, '用户未登录');
|
||||
}
|
||||
|
||||
const useMock = false;
|
||||
if (useMock) {
|
||||
return mockChat(args, 'mock_stream_text_1.txt');
|
||||
}
|
||||
|
||||
return chatAction({ ...args, userId });
|
||||
}
|
||||
Reference in New Issue
Block a user