Files
ainovel-clients/tools/ask_user.go
2026-03-13 01:15:00 +08:00

155 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"unicode/utf8"
"github.com/voocel/agentcore/schema"
)
// AskUserResponse 用户回答结果。
type AskUserResponse struct {
Answers map[string]string // question text → 用户选择的答案
Notes map[string]string // question text → 自定义输入(选"其他"时)
}
// AskUserHandler 阻塞等待用户回答,由 CLI 或 TUI 注入具体实现。
type AskUserHandler func(ctx context.Context, questions []Question) (*AskUserResponse, error)
// Question 单个问题。
type Question struct {
Question string `json:"question"`
Header string `json:"header"`
Options []Option `json:"options"`
MultiSelect bool `json:"multiSelect"`
}
// Option 可选项。
type Option struct {
Label string `json:"label"`
Description string `json:"description"`
}
// AskUserTool 让 LLM 向用户提出结构化问题。
type AskUserTool struct {
mu sync.RWMutex
handler AskUserHandler
}
func NewAskUserTool() *AskUserTool {
return &AskUserTool{}
}
// SetHandler 注入 UI 回调CLI 和 TUI 各自实现。
func (t *AskUserTool) SetHandler(h AskUserHandler) {
t.mu.Lock()
t.handler = h
t.mu.Unlock()
}
func (t *AskUserTool) Name() string { return "ask_user" }
func (t *AskUserTool) Label() string { return "询问用户" }
func (t *AskUserTool) Description() string {
return "当需求信息不足、且缺失信息会明显影响规划方向时,向用户提出 1-4 个结构化问题。每个问题必须包含 header、question 和 2-4 个选项;用户可选预设项,也可自由补充。返回结果是可直接阅读的中文摘要,格式类似:用户回答:[篇幅] 长篇;[重心] 剧情升级(补充:不要后宫)。只有在无法稳定判断篇幅、主线重心、关键约束或明确偏好时才使用;不要把能自行合理推断的问题都抛给用户。"
}
func (t *AskUserTool) Schema() map[string]any {
option := schema.Object(
schema.Property("label", schema.String("选项显示文本1-5个词")).Required(),
schema.Property("description", schema.String("选项含义说明")).Required(),
)
question := schema.Object(
schema.Property("question", schema.String("完整的问题文本")).Required(),
schema.Property("header", schema.String("短标签最多12字符")).Required(),
schema.Property("options", schema.Array("2-4个可选项", option)).Required(),
schema.Property("multiSelect", schema.Bool("是否允许多选")),
)
return schema.Object(
schema.Property("questions", schema.Array("1-4个问题", question)).Required(),
)
}
type askUserArgs struct {
Questions []Question `json:"questions"`
}
func (t *AskUserTool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) {
var a askUserArgs
if err := json.Unmarshal(args, &a); err != nil {
return nil, fmt.Errorf("invalid args: %w", err)
}
if err := validateQuestions(a.Questions); err != nil {
return json.Marshal(fmt.Sprintf("参数校验失败: %s", err))
}
t.mu.RLock()
h := t.handler
t.mu.RUnlock()
if h == nil {
return json.Marshal("当前环境不支持交互式询问,请根据你的判断自行决策并继续。")
}
resp, err := h(ctx, a.Questions)
if err != nil {
return json.Marshal(fmt.Sprintf("用户交互失败: %s。请根据你的判断自行决策并继续。", err))
}
return json.Marshal(formatAnswers(a.Questions, resp))
}
func validateQuestions(questions []Question) error {
if len(questions) == 0 {
return fmt.Errorf("至少需要一个问题")
}
if len(questions) > 4 {
return fmt.Errorf("最多4个问题当前 %d 个", len(questions))
}
for i, q := range questions {
if q.Question == "" {
return fmt.Errorf("问题 %d: 问题文本不能为空", i+1)
}
if q.Header == "" {
return fmt.Errorf("问题 %d: header 不能为空", i+1)
}
if utf8.RuneCountInString(q.Header) > 12 {
return fmt.Errorf("问题 %d: header %q 超过12字符", i+1, q.Header)
}
if len(q.Options) < 2 || len(q.Options) > 4 {
return fmt.Errorf("问题 %d: 需要2-4个选项当前 %d 个", i+1, len(q.Options))
}
for j, opt := range q.Options {
if opt.Label == "" {
return fmt.Errorf("问题 %d 选项 %d: label 不能为空", i+1, j+1)
}
if opt.Description == "" {
return fmt.Errorf("问题 %d 选项 %d: description 不能为空", i+1, j+1)
}
}
}
return nil
}
func formatAnswers(questions []Question, resp *AskUserResponse) string {
if resp == nil || len(resp.Answers) == 0 {
return "用户未提供回答,请根据你的判断自行决策并继续。"
}
var parts []string
for _, q := range questions {
answer, ok := resp.Answers[q.Question]
if !ok {
continue
}
entry := fmt.Sprintf("[%s] %s", q.Header, answer)
if note, hasNote := resp.Notes[q.Question]; hasNote {
entry += "(补充:" + note + ""
}
parts = append(parts, entry)
}
return fmt.Sprintf("用户回答:%s", strings.Join(parts, ""))
}