feat: add ask user question

This commit is contained in:
voocel
2026-03-09 19:52:43 +08:00
parent 75bdda1fe3
commit ef55c89e9d
13 changed files with 868 additions and 83 deletions

154
tools/ask_user.go Normal file
View File

@@ -0,0 +1,154 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"unicode/utf8"
"github.com/voocel/agentcore/schema"
)
// AskUserResponse 用户回答结果。
type AskUserResponse struct {
Answers map[string]string // question text → 用户选择的答案
Notes map[string]string // question text → 自定义输入(选"其他"时)
}
// AskUserHandler 阻塞等待用户回答,由 CLI 或 TUI 注入具体实现。
type AskUserHandler func(ctx context.Context, questions []Question) (*AskUserResponse, error)
// Question 单个问题。
type Question struct {
Question string `json:"question"`
Header string `json:"header"`
Options []Option `json:"options"`
MultiSelect bool `json:"multiSelect"`
}
// Option 可选项。
type Option struct {
Label string `json:"label"`
Description string `json:"description"`
}
// AskUserTool 让 LLM 向用户提出结构化问题。
type AskUserTool struct {
mu sync.RWMutex
handler AskUserHandler
}
func NewAskUserTool() *AskUserTool {
return &AskUserTool{}
}
// SetHandler 注入 UI 回调CLI 和 TUI 各自实现。
func (t *AskUserTool) SetHandler(h AskUserHandler) {
t.mu.Lock()
t.handler = h
t.mu.Unlock()
}
func (t *AskUserTool) Name() string { return "ask_user" }
func (t *AskUserTool) Label() string { return "询问用户" }
func (t *AskUserTool) Description() string {
return "向用户提出结构化问题,用于需要用户确认方向、澄清需求或做出选择时。用户可以从预设选项中选择,也可以自由输入。"
}
func (t *AskUserTool) Schema() map[string]any {
option := schema.Object(
schema.Property("label", schema.String("选项显示文本1-5个词")).Required(),
schema.Property("description", schema.String("选项含义说明")).Required(),
)
question := schema.Object(
schema.Property("question", schema.String("完整的问题文本")).Required(),
schema.Property("header", schema.String("短标签最多12字符")).Required(),
schema.Property("options", schema.Array("2-4个可选项", option)).Required(),
schema.Property("multiSelect", schema.Bool("是否允许多选")),
)
return schema.Object(
schema.Property("questions", schema.Array("1-4个问题", question)).Required(),
)
}
type askUserArgs struct {
Questions []Question `json:"questions"`
}
func (t *AskUserTool) Execute(ctx context.Context, args json.RawMessage) (json.RawMessage, error) {
var a askUserArgs
if err := json.Unmarshal(args, &a); err != nil {
return nil, fmt.Errorf("invalid args: %w", err)
}
if err := validateQuestions(a.Questions); err != nil {
return json.Marshal(fmt.Sprintf("参数校验失败: %s", err))
}
t.mu.RLock()
h := t.handler
t.mu.RUnlock()
if h == nil {
return json.Marshal("当前环境不支持交互式询问,请根据你的判断自行决策并继续。")
}
resp, err := h(ctx, a.Questions)
if err != nil {
return json.Marshal(fmt.Sprintf("用户交互失败: %s。请根据你的判断自行决策并继续。", err))
}
return json.Marshal(formatAnswers(a.Questions, resp))
}
func validateQuestions(questions []Question) error {
if len(questions) == 0 {
return fmt.Errorf("至少需要一个问题")
}
if len(questions) > 4 {
return fmt.Errorf("最多4个问题当前 %d 个", len(questions))
}
for i, q := range questions {
if q.Question == "" {
return fmt.Errorf("问题 %d: 问题文本不能为空", i+1)
}
if q.Header == "" {
return fmt.Errorf("问题 %d: header 不能为空", i+1)
}
if utf8.RuneCountInString(q.Header) > 12 {
return fmt.Errorf("问题 %d: header %q 超过12字符", i+1, q.Header)
}
if len(q.Options) < 2 || len(q.Options) > 4 {
return fmt.Errorf("问题 %d: 需要2-4个选项当前 %d 个", i+1, len(q.Options))
}
for j, opt := range q.Options {
if opt.Label == "" {
return fmt.Errorf("问题 %d 选项 %d: label 不能为空", i+1, j+1)
}
if opt.Description == "" {
return fmt.Errorf("问题 %d 选项 %d: description 不能为空", i+1, j+1)
}
}
}
return nil
}
func formatAnswers(questions []Question, resp *AskUserResponse) string {
if resp == nil || len(resp.Answers) == 0 {
return "用户未提供回答,请根据你的判断自行决策并继续。"
}
var parts []string
for _, q := range questions {
answer, ok := resp.Answers[q.Question]
if !ok {
continue
}
entry := fmt.Sprintf("[%s] %s", q.Header, answer)
if note, hasNote := resp.Notes[q.Question]; hasNote {
entry += "(补充:" + note + ""
}
parts = append(parts, entry)
}
return fmt.Sprintf("用户回答:%s", strings.Join(parts, ""))
}

View File

@@ -58,6 +58,8 @@ func (t *SaveFoundationTool) Execute(_ context.Context, args json.RawMessage) (j
return nil, fmt.Errorf("save outline: %w", err)
}
_ = t.store.UpdatePhase(domain.PhaseOutline)
// 根据大纲长度自动设定总章节数
_ = t.store.SetTotalChapters(len(entries))
return json.Marshal(map[string]any{"saved": true, "type": "outline", "chapters": len(entries)})
case "characters":