Added MAX_TOKENS and MAX_RESPONSE_SEGMENTS to .env.example and documentation. Updated constants to read from environment variables, allowing configuration of token and response segment limits.
413 lines
14 KiB
TypeScript
413 lines
14 KiB
TypeScript
import type { ActionFunctionArgs } from '@remix-run/node';
|
|
import {
|
|
consumeStream,
|
|
createUIMessageStream,
|
|
createUIMessageStreamResponse,
|
|
generateId,
|
|
type UIMessageStreamWriter,
|
|
} from 'ai';
|
|
import { upsertChat } from '~/lib/.server/chat';
|
|
import { ChatUsageStatus, recordUsage, updateUsageStatus } from '~/lib/.server/chatUsage';
|
|
import { chatStreamText } from '~/lib/.server/llm/chat-stream-text';
|
|
import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants';
|
|
import { createSummary } from '~/lib/.server/llm/create-summary';
|
|
import { selectContext } from '~/lib/.server/llm/select-context';
|
|
import { structuredPageSnapshot } from '~/lib/.server/llm/structured-page-snapshot';
|
|
import { createScopedLogger } from '~/lib/.server/logger';
|
|
import { getHistoryChatMessages, saveChatMessages, updateDiscardedMessage } from '~/lib/.server/message';
|
|
import { getPageByMessageId } from '~/lib/.server/page';
|
|
import { CONTINUE_PROMPT } from '~/lib/common/prompts/prompts';
|
|
import { DEFAULT_MODEL, DEFAULT_MODEL_DETAILS, getModel, MINOR_MODEL } from '~/lib/modules/constants';
|
|
import type { Page } from '~/types/actions';
|
|
import type { UPageUIMessage } from '~/types/message';
|
|
import { approximateUsageFromContent } from '~/utils/token';
|
|
|
|
const logger = createScopedLogger('api.chat.chat');
|
|
|
|
export type ElementInfo = {
|
|
tagName: string;
|
|
className?: string;
|
|
id?: string;
|
|
innerHTML?: string;
|
|
outerHTML?: string;
|
|
};
|
|
|
|
export type ChatActionParams = {
|
|
// 当前会话 ID
|
|
chatId: string;
|
|
// 回退到指定消息 ID
|
|
rewindTo: string;
|
|
// 最后一条消息,通常是用户消息。
|
|
message: UPageUIMessage;
|
|
// 如果用户指定编辑的元素,则需要传递该元素的信息。
|
|
elementInfo: ElementInfo;
|
|
};
|
|
|
|
export type ChatActionArgs = ActionFunctionArgs & {
|
|
userId: string;
|
|
};
|
|
|
|
export async function chatAction({ request, userId }: ChatActionArgs) {
|
|
const { rewindTo, chatId, message } = await request.json<ChatActionParams>();
|
|
const chat = await upsertChat({
|
|
id: chatId,
|
|
userId,
|
|
});
|
|
|
|
const elementInfo = message.metadata?.elementInfo;
|
|
const messageId = message.id;
|
|
const messageContent = message.parts.find((part) => part.type === 'text')?.text;
|
|
const initialUsageRecord = await recordUsage({
|
|
userId,
|
|
chatId: chat.id,
|
|
messageId,
|
|
status: ChatUsageStatus.PENDING,
|
|
prompt: messageContent || '',
|
|
modelName: DEFAULT_MODEL,
|
|
});
|
|
|
|
const minorModelInitialUsageRecord = await recordUsage({
|
|
userId,
|
|
chatId: chat.id,
|
|
messageId,
|
|
status: ChatUsageStatus.PENDING,
|
|
prompt: messageContent || '',
|
|
modelName: MINOR_MODEL,
|
|
});
|
|
|
|
let streamSwitches = 0;
|
|
let progressCounter: number = 1;
|
|
const cumulativeUsage = {
|
|
inputTokens: 0,
|
|
outputTokens: 0,
|
|
totalTokens: 0,
|
|
reasoningTokens: 0,
|
|
cachedInputTokens: 0,
|
|
};
|
|
const minorModelCumulativeUsage = {
|
|
inputTokens: 0,
|
|
outputTokens: 0,
|
|
totalTokens: 0,
|
|
reasoningTokens: 0,
|
|
cachedInputTokens: 0,
|
|
};
|
|
|
|
// 辅助函数:更新辅助模型使用量
|
|
const updateMinorModelUsage = (usage: {
|
|
inputTokens?: number;
|
|
outputTokens?: number;
|
|
totalTokens?: number;
|
|
reasoningTokens?: number;
|
|
cachedInputTokens?: number;
|
|
}) => {
|
|
minorModelCumulativeUsage.inputTokens += usage.inputTokens || 0;
|
|
minorModelCumulativeUsage.outputTokens += usage.outputTokens || 0;
|
|
minorModelCumulativeUsage.totalTokens += usage.totalTokens || 0;
|
|
minorModelCumulativeUsage.reasoningTokens += usage.reasoningTokens || 0;
|
|
minorModelCumulativeUsage.cachedInputTokens += usage.cachedInputTokens || 0;
|
|
};
|
|
|
|
// 计算用户 token 消耗
|
|
const calculateTokenUsage = async (status: ChatUsageStatus) => {
|
|
try {
|
|
await updateUsageStatus(initialUsageRecord.id, status, {
|
|
inputTokens: cumulativeUsage.inputTokens,
|
|
outputTokens: cumulativeUsage.outputTokens,
|
|
reasoningTokens: cumulativeUsage.reasoningTokens,
|
|
cachedTokens: cumulativeUsage.cachedInputTokens,
|
|
totalTokens: cumulativeUsage.totalTokens,
|
|
});
|
|
logger.debug(`用户 ${userId} 的聊天: ${chat.id} 总使用量为: ${JSON.stringify(cumulativeUsage)}`);
|
|
logger.debug(`用户 ${userId} 的聊天: ${chat.id} 使用状态已更新为 ${status}`);
|
|
} catch (error) {
|
|
logger.error(`更新用户 ${userId} 的使用状态时出错:`, error);
|
|
}
|
|
};
|
|
|
|
// 计算用户 token 消耗
|
|
const calculateMinorModelTokenUsage = async (status: ChatUsageStatus) => {
|
|
try {
|
|
await updateUsageStatus(minorModelInitialUsageRecord.id, status, {
|
|
inputTokens: minorModelCumulativeUsage.inputTokens,
|
|
outputTokens: minorModelCumulativeUsage.outputTokens,
|
|
reasoningTokens: minorModelCumulativeUsage.reasoningTokens,
|
|
cachedTokens: minorModelCumulativeUsage.cachedInputTokens,
|
|
totalTokens: minorModelCumulativeUsage.totalTokens,
|
|
});
|
|
logger.debug(`用户 ${userId} 的聊天: ${chat.id} 辅助模型使用状态已更新为 ${status}`);
|
|
} catch (error) {
|
|
logger.error(`更新用户 ${userId} 的辅助模型使用状态时出错:`, error);
|
|
// 记录错误但不中断流程
|
|
}
|
|
};
|
|
|
|
const progressId = generateId();
|
|
// 获取从第一条到当前消息之间的所有消息
|
|
const previousMessages = await getHistoryChatMessages({
|
|
chatId,
|
|
rewindTo,
|
|
});
|
|
const messages = [...previousMessages, message];
|
|
|
|
const streamExecutor = async ({ writer }: { writer: UIMessageStreamWriter<UPageUIMessage> }) => {
|
|
// 在消息的开头发送一个固定的消息,用于标识消息的开始。
|
|
writer.write({
|
|
type: 'start',
|
|
messageId: generateId(),
|
|
});
|
|
|
|
// 辅助 model 所获取的数据,用于后续的模型调用。
|
|
const minorModelData: { summary: string; context: Record<string, string[]>; pageSummary: string } = {
|
|
summary: '',
|
|
context: {},
|
|
pageSummary: '',
|
|
};
|
|
|
|
// 仅当有历史消息时,才调用辅助模型,首次调用无需调用。
|
|
if (previousMessages.length > 0) {
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'summary',
|
|
status: 'in-progress',
|
|
order: progressCounter++,
|
|
message: '正在分析请求...',
|
|
},
|
|
transient: true,
|
|
});
|
|
// 让 AI 分析用户消息摘要,明确用户下一步的意图。
|
|
const { text: summary, totalUsage: createSummaryUsage } = await createSummary({
|
|
messages,
|
|
model: getModel(MINOR_MODEL),
|
|
abortSignal: request.signal,
|
|
});
|
|
minorModelData.summary = summary;
|
|
updateMinorModelUsage(createSummaryUsage);
|
|
writer.write({
|
|
type: 'data-summary',
|
|
data: {
|
|
summary,
|
|
chatId: chat.id,
|
|
},
|
|
});
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'summary',
|
|
status: 'complete',
|
|
order: progressCounter++,
|
|
message: '分析完成',
|
|
},
|
|
transient: true,
|
|
});
|
|
|
|
// 获取最后一条历史消息所对应的 page
|
|
const lastMessage = previousMessages[previousMessages.length - 1];
|
|
const pageData = await getPageByMessageId(lastMessage.id);
|
|
if (pageData) {
|
|
const pages = pageData.pages as unknown as Page[];
|
|
// 根据用户摘要和所有的页面数据,让 AI 根据摘要、用户消息、页面数据,选择一部分待修改的页面和待修改的 section。
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'context',
|
|
status: 'in-progress',
|
|
order: progressCounter++,
|
|
message: '正在对页面进行分析...',
|
|
},
|
|
transient: true,
|
|
});
|
|
const { context, totalUsage: selectContextUsage } = await selectContext({
|
|
messages,
|
|
summary,
|
|
pages,
|
|
model: getModel(MINOR_MODEL),
|
|
abortSignal: request.signal,
|
|
});
|
|
minorModelData.context = context;
|
|
updateMinorModelUsage(selectContextUsage);
|
|
|
|
// 调用辅助 model 对 context 中的页面做摘要,如果没有,则对所有页面做摘要。
|
|
const selectPageNames = Object.keys(context);
|
|
const selectedPages = selectPageNames.length > 0 ? pages : pages.map((page) => page);
|
|
const { text: pageSummary, totalUsage: structuredPageSnapshotUsage } = await structuredPageSnapshot({
|
|
pages: selectedPages,
|
|
model: getModel(MINOR_MODEL),
|
|
abortSignal: request.signal,
|
|
});
|
|
minorModelData.pageSummary = pageSummary;
|
|
updateMinorModelUsage(structuredPageSnapshotUsage);
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'context',
|
|
status: 'complete',
|
|
order: progressCounter++,
|
|
message: '页面分析完成',
|
|
},
|
|
transient: true,
|
|
});
|
|
}
|
|
}
|
|
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'response',
|
|
status: 'in-progress',
|
|
order: progressCounter++,
|
|
message: '正在生成响应',
|
|
},
|
|
transient: true,
|
|
});
|
|
const executeStreamText = async (messages: UPageUIMessage[], isContinue: boolean = false) => {
|
|
const result = await chatStreamText({
|
|
messages,
|
|
elementInfo,
|
|
summary: minorModelData.summary,
|
|
pageSummary: minorModelData.pageSummary,
|
|
context: minorModelData.context,
|
|
maxTokens: DEFAULT_MODEL_DETAILS?.maxTokenAllowed,
|
|
model: getModel(DEFAULT_MODEL),
|
|
abortSignal: request.signal,
|
|
onFinish: async ({ totalUsage, finishReason, text }) => {
|
|
cumulativeUsage.inputTokens += totalUsage.inputTokens || 0;
|
|
cumulativeUsage.outputTokens += totalUsage.outputTokens || 0;
|
|
cumulativeUsage.totalTokens += totalUsage.totalTokens || 0;
|
|
cumulativeUsage.reasoningTokens += totalUsage.reasoningTokens || 0;
|
|
cumulativeUsage.cachedInputTokens += totalUsage.cachedInputTokens || 0;
|
|
|
|
if (finishReason === 'length') {
|
|
if (streamSwitches >= MAX_RESPONSE_SEGMENTS) {
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'response',
|
|
status: 'stopped',
|
|
order: progressCounter++,
|
|
message: '无法继续生成消息:已达到最大分段数',
|
|
},
|
|
transient: true,
|
|
});
|
|
writer.write({
|
|
type: 'finish',
|
|
});
|
|
return;
|
|
}
|
|
await continueMessage(text);
|
|
}
|
|
|
|
if (finishReason === 'stop') {
|
|
writer.write({
|
|
type: 'data-progress',
|
|
id: progressId,
|
|
data: {
|
|
label: 'response',
|
|
status: 'complete',
|
|
order: progressCounter++,
|
|
message: '响应生成完成',
|
|
},
|
|
transient: true,
|
|
});
|
|
writer.write({
|
|
type: 'finish',
|
|
});
|
|
}
|
|
},
|
|
onAbort: async ({ totalUsage }) => {
|
|
cumulativeUsage.inputTokens += totalUsage.inputTokens || 0;
|
|
cumulativeUsage.outputTokens += totalUsage.outputTokens || 0;
|
|
cumulativeUsage.totalTokens += totalUsage.totalTokens || 0;
|
|
cumulativeUsage.reasoningTokens += totalUsage.reasoningTokens || 0;
|
|
cumulativeUsage.cachedInputTokens += totalUsage.cachedInputTokens || 0;
|
|
},
|
|
});
|
|
|
|
const continueMessage = async (text: string) => {
|
|
logger.info(
|
|
`达到最大 token 限制 (${DEFAULT_MODEL_DETAILS?.maxTokenAllowed || MAX_TOKENS}): 继续消息, 还可以响应 (${MAX_RESPONSE_SEGMENTS - streamSwitches} 个分段)`,
|
|
);
|
|
messages.push({
|
|
id: generateId(),
|
|
role: 'assistant',
|
|
parts: [
|
|
{
|
|
type: 'text',
|
|
text,
|
|
},
|
|
],
|
|
});
|
|
messages.push({
|
|
id: generateId(),
|
|
role: 'user',
|
|
parts: [
|
|
{
|
|
type: 'text',
|
|
text: CONTINUE_PROMPT,
|
|
},
|
|
],
|
|
});
|
|
|
|
await executeStreamText(messages, true);
|
|
streamSwitches++;
|
|
};
|
|
|
|
writer.merge(
|
|
result.toUIMessageStream({
|
|
sendReasoning: !isContinue,
|
|
sendFinish: false,
|
|
sendStart: false,
|
|
}),
|
|
);
|
|
};
|
|
await executeStreamText([message], false);
|
|
};
|
|
|
|
const stream = createUIMessageStream<UPageUIMessage>({
|
|
execute: streamExecutor,
|
|
originalMessages: messages,
|
|
onFinish: async ({ messages, isAborted }) => {
|
|
if (isAborted) {
|
|
// 由于 AI SDK 没有提供在 onAbort 中计算 Token 消耗的方法。所以这里手动计算。
|
|
// https://github.com/vercel/ai/pull/8701
|
|
const lastAssistantMessage = messages.find((message) => message.role === 'assistant');
|
|
if (lastAssistantMessage) {
|
|
cumulativeUsage.outputTokens += approximateUsageFromContent(lastAssistantMessage.parts);
|
|
cumulativeUsage.totalTokens += approximateUsageFromContent(lastAssistantMessage.parts);
|
|
}
|
|
}
|
|
|
|
// 根据是否中止设置正确的状态
|
|
// TODO: 在错误情况下,现在还是会被设置为 SUCCESS。
|
|
const status = isAborted ? ChatUsageStatus.ABORTED : ChatUsageStatus.SUCCESS;
|
|
calculateTokenUsage(status);
|
|
calculateMinorModelTokenUsage(status);
|
|
|
|
if (isAborted) {
|
|
logger.info(`用户 ${userId} 的聊天: ${chatId} 中止处理完成`);
|
|
return;
|
|
}
|
|
|
|
// 保存消息到数据库
|
|
if (rewindTo) {
|
|
await updateDiscardedMessage(chatId, rewindTo);
|
|
}
|
|
saveChatMessages(chatId, messages);
|
|
},
|
|
onError: (error) => {
|
|
logger.error(`用户 ${userId} 的聊天: ${chatId} 处理过程中发生错误 ===> `, error);
|
|
calculateTokenUsage(ChatUsageStatus.FAILED);
|
|
calculateMinorModelTokenUsage(ChatUsageStatus.FAILED);
|
|
return '内部服务器错误,请稍后重试';
|
|
},
|
|
});
|
|
|
|
return createUIMessageStreamResponse({ stream, consumeSseStream: consumeStream });
|
|
}
|