perf: ask user

This commit is contained in:
voocel
2026-03-13 01:15:00 +08:00
parent 7488198461
commit 25e219e934
10 changed files with 677 additions and 28 deletions

View File

@@ -17,11 +17,13 @@ func Run(cfg app.Config, refs tools.References, prompts app.Prompts, styles map[
if err != nil {
return err
}
bridge := newAskUserBridge()
rt.AskUser().SetHandler(bridge.handler)
restoreLog := redirectLogger(rt.Dir())
defer restoreLog()
defer rt.Close()
m := NewModel(rt)
m := NewModel(rt, bridge)
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
_, err = p.Run()
return err

278
tui/ask_user.go Normal file
View File

@@ -0,0 +1,278 @@
package tui
import (
"context"
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
"github.com/voocel/ainovel-cli/tools"
)
type askUserRequest struct {
questions []tools.Question
resultCh chan askUserResult
}
type askUserResult struct {
resp *tools.AskUserResponse
err error
}
type askUserBridge struct {
requests chan askUserRequest
}
func newAskUserBridge() *askUserBridge {
return &askUserBridge{
requests: make(chan askUserRequest),
}
}
func (b *askUserBridge) handler(ctx context.Context, questions []tools.Question) (*tools.AskUserResponse, error) {
req := askUserRequest{
questions: questions,
resultCh: make(chan askUserResult, 1),
}
select {
case b.requests <- req:
case <-ctx.Done():
return nil, ctx.Err()
}
select {
case result := <-req.resultCh:
return result.resp, result.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
type askUserState struct {
request askUserRequest
index int
cursor int
typing bool
input string
selected map[int]bool
answers map[string]string
notes map[string]string
}
func newAskUserState(req askUserRequest) *askUserState {
return &askUserState{
request: req,
selected: make(map[int]bool),
answers: make(map[string]string),
notes: make(map[string]string),
}
}
func (s *askUserState) currentQuestion() tools.Question {
return s.request.questions[s.index]
}
func (s *askUserState) optionCount() int {
return len(s.currentQuestion().Options) + 1
}
func (s *askUserState) choiceLabel(idx int) string {
q := s.currentQuestion()
if idx < len(q.Options) {
return q.Options[idx].Label
}
return "自由输入"
}
func (s *askUserState) choiceDescription(idx int) string {
q := s.currentQuestion()
if idx < len(q.Options) {
return q.Options[idx].Description
}
return "以上都不合适,自己补充"
}
func (s *askUserState) moveCursor(delta int) {
total := s.optionCount()
if total == 0 {
s.cursor = 0
return
}
s.cursor = (s.cursor + delta + total) % total
}
func (s *askUserState) toggleSelection() {
if s.selected[s.cursor] {
delete(s.selected, s.cursor)
return
}
s.selected[s.cursor] = true
}
func (s *askUserState) finishCurrentAnswer() bool {
q := s.currentQuestion()
if s.typing {
text := strings.TrimSpace(s.input)
if text == "" {
return false
}
s.answers[q.Question] = text
s.notes[q.Question] = text
return s.advance()
}
if q.MultiSelect {
var values []string
var custom string
for idx := 0; idx < s.optionCount(); idx++ {
if !s.selected[idx] {
continue
}
if idx < len(q.Options) {
values = append(values, q.Options[idx].Label)
continue
}
custom = strings.TrimSpace(s.input)
}
if custom != "" {
values = append(values, custom)
s.notes[q.Question] = custom
}
if len(values) == 0 {
return false
}
s.answers[q.Question] = strings.Join(values, "、")
return s.advance()
}
if s.cursor >= len(q.Options) {
s.typing = true
s.input = ""
return false
}
s.answers[q.Question] = q.Options[s.cursor].Label
return s.advance()
}
func (s *askUserState) advance() bool {
s.index++
if s.index >= len(s.request.questions) {
return true
}
s.cursor = 0
s.typing = false
s.input = ""
s.selected = make(map[int]bool)
return false
}
func (s *askUserState) submit() {
s.request.resultCh <- askUserResult{
resp: &tools.AskUserResponse{
Answers: s.answers,
Notes: s.notes,
},
}
}
func (s *askUserState) cancelCurrentTyping() {
s.typing = false
s.input = ""
}
func renderAskUserModal(width, height int, state *askUserState) string {
if state == nil {
return ""
}
q := state.currentQuestion()
boxW := minInt(maxInt(width*60/100, 52), width-4)
boxH := minInt(maxInt(height*60/100, 16), height-4)
if boxW < 40 {
boxW = maxInt(width-2, 20)
}
if boxH < 10 {
boxH = maxInt(height-2, 8)
}
var b strings.Builder
title := fmt.Sprintf("需要补充信息 %d/%d", state.index+1, len(state.request.questions))
b.WriteString(lipgloss.NewStyle().Foreground(colorAccent).Bold(true).Render(title))
b.WriteString("\n\n")
if q.Header != "" {
b.WriteString(highlightValueStyle.Render(q.Header))
b.WriteString("\n")
}
b.WriteString(cardContentStyle.Render(q.Question))
b.WriteString("\n\n")
for idx := 0; idx < state.optionCount(); idx++ {
prefix := " "
if state.cursor == idx {
prefix = lipgloss.NewStyle().Foreground(colorAccent).Bold(true).Render(" ")
}
label := state.choiceLabel(idx)
if q.MultiSelect {
marker := "[ ]"
if state.selected[idx] {
marker = "[x]"
}
label = marker + " " + label
}
b.WriteString(prefix + cardContentStyle.Render(label))
b.WriteString("\n")
b.WriteString(" " + lipgloss.NewStyle().Foreground(colorDim).Render(state.choiceDescription(idx)))
b.WriteString("\n")
}
if state.typing || (q.MultiSelect && state.selected[len(q.Options)]) {
b.WriteString("\n")
b.WriteString(panelTitleStyle.Render("补充内容"))
b.WriteString("\n")
content := state.input
if content == "" {
content = "请输入..."
}
style := lipgloss.NewStyle().
Width(boxW-8).
Border(baseBorder).
BorderForeground(colorDim).
Padding(0, 1)
b.WriteString(style.Render(content))
b.WriteString("\n")
}
hint := "↑↓ 选择 · Enter 确认"
if q.MultiSelect {
hint = "↑↓ 选择 · Space 勾选 · Enter 提交"
}
if state.typing {
hint = "输入补充内容 · Enter 确认 · Esc 返回选项"
}
b.WriteString("\n")
b.WriteString(lipgloss.NewStyle().Foreground(colorDim).Render(hint))
box := lipgloss.NewStyle().
Width(boxW).
Height(boxH).
Border(baseBorder).
BorderForeground(colorAccent).
Padding(1, 2).
Background(lipgloss.Color("#1b1712")).
Render(b.String())
return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, box)
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}

View File

@@ -12,6 +12,7 @@ type (
eventMsg app.UIEvent
snapshotMsg app.UISnapshot
doneMsg struct{}
askUserMsg askUserRequest
startResultMsg struct{ err error }
steerResultMsg struct{}
spinnerTickMsg time.Time
@@ -102,3 +103,13 @@ func listenStreamClear(rt *app.Runtime) tea.Cmd {
return streamClearMsg{}
}
}
func listenAskUser(bridge *askUserBridge) tea.Cmd {
return func() tea.Msg {
req, ok := <-bridge.requests
if !ok {
return nil
}
return askUserMsg(req)
}
}

View File

@@ -3,6 +3,7 @@ package tui
import (
"strings"
"time"
"unicode/utf8"
"github.com/charmbracelet/bubbles/textarea"
"github.com/charmbracelet/bubbles/viewport"
@@ -35,12 +36,15 @@ var spinnerFrames = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
// Model 是 TUI 的顶层状态。
type Model struct {
runtime *app.Runtime
askBridge *askUserBridge
askState *askUserState
snapshot app.UISnapshot
events []app.UIEvent
viewport viewport.Model // 事件流 viewport
streamVP viewport.Model // 流式输出 viewport
detailVP viewport.Model // 右侧详情 viewport
streamBuf *strings.Builder // 流式文本累积缓冲
streamRounds []string
textarea textarea.Model
width int
height int
@@ -56,7 +60,7 @@ type Model struct {
}
// NewModel 创建 TUI Model。
func NewModel(rt *app.Runtime) Model {
func NewModel(rt *app.Runtime, bridge *askUserBridge) Model {
ta := textarea.New()
ta.Placeholder = "输入小说需求例如写一部12章都市悬疑小说"
ta.CharLimit = 500
@@ -79,6 +83,7 @@ func NewModel(rt *app.Runtime) Model {
return Model{
runtime: rt,
askBridge: bridge,
autoScroll: true,
streamScroll: true,
mode: modeNew,
@@ -94,6 +99,7 @@ func (m Model) Init() tea.Cmd {
return tea.Batch(
textarea.Blink,
listenEvents(m.runtime),
listenAskUser(m.askBridge),
listenDone(m.runtime),
listenStream(m.runtime),
listenStreamClear(m.runtime),
@@ -116,6 +122,9 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
case tea.KeyMsg:
if m.askState != nil {
return m.handleAskUserKey(msg)
}
switch msg.Type {
case tea.KeyCtrlC:
return m, tea.Quit
@@ -127,6 +136,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.viewport.SetContent("")
m.viewport.GotoTop()
m.streamBuf.Reset()
m.streamRounds = nil
m.streamVP.SetContent("")
m.streamVP.GotoTop()
m.streamRound = 0
@@ -233,6 +243,15 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.refreshEventViewport()
return m, listenEvents(m.runtime)
case askUserMsg:
m.askState = newAskUserState(askUserRequest(msg))
m.textarea.Blur()
m.events = append(m.events, app.UIEvent{
Time: time.Now(), Category: "SYSTEM", Summary: "等待用户补充关键信息", Level: "info",
})
m.refreshEventViewport()
return m, listenAskUser(m.askBridge)
case snapshotMsg:
m.snapshot = app.UISnapshot(msg)
m.refreshDetailViewport()
@@ -270,22 +289,25 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tickSpinner()
case streamDeltaMsg:
m.streamBuf.WriteString(string(msg))
m.streamVP.SetContent(m.streamBuf.String())
if len(m.streamRounds) == 0 {
m.streamRounds = append(m.streamRounds, "")
}
m.streamRounds[len(m.streamRounds)-1] += string(msg)
m.streamVP.SetContent(renderStreamContent(m.streamRounds, m.streamVP.Width))
if m.streamScroll {
m.streamVP.GotoBottom()
}
return m, listenStream(m.runtime)
case streamClearMsg:
// 新一轮输出:保留历史内容,用分隔线标记新段落
m.streamRound++
if m.streamBuf.Len() > 0 {
m.streamBuf.WriteString("\n")
m.streamBuf.WriteString(renderStreamSeparator(m.streamRound, m.streamVP.Width))
m.streamBuf.WriteString("\n")
// 新一轮输出:按轮次分块显示,避免长文本和分隔线直接拼接导致错乱。
if len(m.streamRounds) == 0 {
m.streamRounds = append(m.streamRounds, "")
} else if strings.TrimSpace(m.streamRounds[len(m.streamRounds)-1]) != "" {
m.streamRounds = append(m.streamRounds, "")
}
m.streamVP.SetContent(m.streamBuf.String())
m.streamRound = len(m.streamRounds)
m.streamVP.SetContent(renderStreamContent(m.streamRounds, m.streamVP.Width))
if m.streamScroll {
m.streamVP.GotoBottom()
}
@@ -486,5 +508,80 @@ func (m Model) View() string {
body = lipgloss.JoinHorizontal(lipgloss.Top, left, center, right)
}
return lipgloss.JoinVertical(lipgloss.Left, topBar, body, inputBox)
view := lipgloss.JoinVertical(lipgloss.Left, topBar, body, inputBox)
if m.askState != nil {
return renderAskUserModal(m.width, m.height, m.askState)
}
return view
}
func (m Model) handleAskUserKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
if m.askState == nil {
return m, nil
}
state := m.askState
q := state.currentQuestion()
if state.typing {
switch msg.Type {
case tea.KeyEsc:
state.cancelCurrentTyping()
return m, nil
case tea.KeyEnter:
if state.finishCurrentAnswer() {
state.submit()
m.askState = nil
if m.mode != modeDone {
m.textarea.Focus()
}
}
return m, nil
case tea.KeyBackspace, tea.KeyCtrlH:
if state.input != "" {
_, size := utf8.DecodeLastRuneInString(state.input)
state.input = state.input[:len(state.input)-size]
}
return m, nil
default:
if msg.Type == tea.KeyRunes {
state.input += string(msg.Runes)
}
return m, nil
}
}
switch msg.Type {
case tea.KeyUp:
state.moveCursor(-1)
case tea.KeyDown:
state.moveCursor(1)
case tea.KeySpace:
if q.MultiSelect {
state.toggleSelection()
if state.cursor == len(q.Options) && !state.selected[state.cursor] {
state.input = ""
}
}
case tea.KeyEnter:
if q.MultiSelect {
if state.cursor == len(q.Options) {
state.toggleSelection()
if state.selected[state.cursor] {
state.typing = true
}
return m, nil
}
if len(state.selected) == 0 {
state.toggleSelection()
}
}
if state.finishCurrentAnswer() {
state.submit()
m.askState = nil
if m.mode != modeDone {
m.textarea.Focus()
}
}
}
return m, nil
}

View File

@@ -1,6 +1,7 @@
package tui
import (
"encoding/json"
"fmt"
"strings"
@@ -228,17 +229,183 @@ func renderStreamPanel(vp viewport.Model, width, height int, focused bool) strin
return header + "\n" + vpStyle.Render(vp.View())
}
// renderStreamSeparator 渲染流式面板中的轮次分隔线
func renderStreamSeparator(round, width int) string {
label := fmt.Sprintf(" #%d ", round)
lineW := (width - lipgloss.Width(label)) / 2
if lineW < 1 {
lineW = 1
// renderStreamContent 将流式输出按轮次渲染为分块内容,避免长段直接拼接导致错乱
func renderStreamContent(rounds []string, width int) string {
if width < 24 {
width = 24
}
line := strings.Repeat("─", lineW)
dimLine := lipgloss.NewStyle().Foreground(colorDim).Render(line)
dimLabel := lipgloss.NewStyle().Foreground(colorDim).Render(label)
return dimLine + dimLabel + dimLine
var blocks []string
displayIndex := 0
for i, round := range rounds {
text := strings.TrimSpace(round)
if text == "" {
continue
}
displayIndex++
blocks = append(blocks, renderStreamBlock(displayIndex, text, width, i == len(rounds)-1))
}
return strings.Join(blocks, "\n\n")
}
func renderStreamBlock(index int, text string, width int, active bool) string {
headerStyle := lipgloss.NewStyle().Foreground(colorDim)
bodyStyle := lipgloss.NewStyle().Foreground(colorText)
dividerColor := colorDim
if active {
headerStyle = lipgloss.NewStyle().Foreground(colorAccent).Bold(true)
dividerColor = colorAccent
}
header := headerStyle.Render(fmt.Sprintf("◆ 第 %d 段", index))
divider := lipgloss.NewStyle().Foreground(dividerColor).Render(strings.Repeat("─", max(8, width)))
lines := wrapStreamText(text, max(16, width-4))
var b strings.Builder
b.WriteString(header)
b.WriteString("\n")
b.WriteString(divider)
b.WriteString("\n")
for i, line := range lines {
if i > 0 {
b.WriteString("\n")
}
b.WriteString(bodyStyle.Render(line))
}
return b.String()
}
func wrapStreamText(text string, width int) []string {
if width < 8 {
return []string{text}
}
var out []string
for _, raw := range strings.Split(strings.ReplaceAll(text, "\r\n", "\n"), "\n") {
if strings.TrimSpace(raw) == "" {
out = append(out, "")
continue
}
if compact, ok := compactJSONLine(raw, width); ok {
out = append(out, compact)
continue
}
prefix, rest, nextPrefix := parseWrapPrefix(raw)
wrapped := wrapRunes(rest, max(4, width-lipgloss.Width(prefix)))
for i, line := range wrapped {
if i == 0 {
out = append(out, prefix+line)
continue
}
out = append(out, nextPrefix+line)
}
}
return out
}
func compactJSONLine(line string, width int) (string, bool) {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
return "", false
}
if !(strings.HasPrefix(trimmed, "{") || strings.HasPrefix(trimmed, "[")) {
return "", false
}
var value any
if err := json.Unmarshal([]byte(trimmed), &value); err != nil {
return "", false
}
compact, err := json.Marshal(value)
if err != nil {
return "", false
}
text := string(compact)
limit := max(24, width-2)
if lipgloss.Width(text) > limit {
text = truncate(text, limit-1)
}
return lipgloss.NewStyle().Foreground(colorDim).Render("JSON: ") +
lipgloss.NewStyle().Foreground(lipgloss.Color("#8fb7c9")).Render(text), true
}
func parseWrapPrefix(line string) (prefix, content, nextPrefix string) {
indent := line[:len(line)-len(strings.TrimLeft(line, " \t"))]
trimmed := strings.TrimSpace(line)
switch {
case strings.HasPrefix(trimmed, "- "), strings.HasPrefix(trimmed, "* "), strings.HasPrefix(trimmed, "• "):
prefix = indent + trimmed[:2]
content = strings.TrimSpace(trimmed[2:])
nextPrefix = indent + " "
return prefix, content, nextPrefix
case orderedListPrefix(trimmed) != "":
marker := orderedListPrefix(trimmed)
prefix = indent + marker
content = strings.TrimSpace(strings.TrimPrefix(trimmed, marker))
nextPrefix = indent + strings.Repeat(" ", lipgloss.Width(marker))
return prefix, content, nextPrefix
case strings.HasPrefix(trimmed, "```"):
return indent, trimmed, indent
default:
return indent, trimmed, indent
}
}
func orderedListPrefix(line string) string {
end := strings.Index(line, ". ")
if end <= 0 {
return ""
}
for _, r := range line[:end] {
if r < '0' || r > '9' {
return ""
}
}
return line[:end+2]
}
func wrapRunes(text string, width int) []string {
if text == "" {
return []string{""}
}
if width < 2 {
return []string{text}
}
var lines []string
var current strings.Builder
currentWidth := 0
for _, r := range text {
rw := lipgloss.Width(string(r))
if currentWidth > 0 && currentWidth+rw > width {
lines = append(lines, strings.TrimRight(current.String(), " "))
current.Reset()
currentWidth = 0
if r == ' ' {
continue
}
}
current.WriteRune(r)
currentWidth += rw
}
if current.Len() > 0 {
lines = append(lines, strings.TrimRight(current.String(), " "))
}
if len(lines) == 0 {
return []string{""}
}
return lines
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// renderDetailContent 构建右侧详情面板内容。