Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
github.com/stretchr/testify v1.11.1
github.com/urfave/cli/v2 v2.3.0
golang.org/x/crypto v0.46.0
golang.org/x/sys v0.42.0
golang.org/x/term v0.38.0
google.golang.org/grpc v1.74.2
google.golang.org/protobuf v1.36.11
Expand Down Expand Up @@ -76,7 +77,6 @@ require (
golang.org/x/net v0.47.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/time v0.14.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a // indirect
Expand Down
45 changes: 0 additions & 45 deletions internal/confirm.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,51 +39,6 @@ func FConfirmPrompt(msg string, stdin io.Reader, stdout io.Writer) bool {
}
}

// FConfirmPromptFromCh is the MCP-safe variant of FConfirmPrompt.
// It reads runes from a channel instead of stdin directly. Use this from
// inside interactiveRun when h.stdinCh is non-nil — the single stdin reader
// goroutine owns os.Stdin in that mode, so any direct read would race with
// it and lose bytes.
func FConfirmPromptFromCh(msg string, ch <-chan rune, stdout io.Writer) bool {
defer func() {
_, _ = fmt.Fprintln(stdout)
}()

// Drop any runes that may have been buffered from a previous prompt.
for {
select {
case <-ch:
default:
goto drained
}
}
drained:

for {
_, _ = fmt.Fprintf(stdout, "%s [y/n]: ", msg)

var sb strings.Builder
for {
r, ok := <-ch
if !ok {
return false
}
if r == '\n' || r == '\r' {
break
}
sb.WriteRune(r)
}

input := strings.ToLower(strings.TrimSpace(sb.String()))
if input == "y" || input == "yes" {
return true
}
if input == "n" || input == "no" {
return false
}
}
}

type Action struct {
Shortcut rune
ShortcutAliases []rune
Expand Down
15 changes: 15 additions & 0 deletions internal/terminal_bsd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//go:build darwin || freebsd || openbsd || netbsd

package internal

import "golang.org/x/sys/unix"

// FlushTerminalInput discards any bytes the kernel has buffered in the
// terminal's input queue. Called after term.MakeRaw at prompt entry so that
// keystrokes the user typed in cooked mode while the previous command was
// running don't carry over and trigger spurious actions.
func FlushTerminalInput(fd int) error {
// On BSD/Darwin, TIOCFLUSH takes a pointer to an int with bitmask
// FREAD (0x1) / FWRITE (0x2). TCIFLUSH happens to equal FREAD (0x1).
return unix.IoctlSetPointerInt(fd, unix.TIOCFLUSH, unix.TCIFLUSH)
}
13 changes: 13 additions & 0 deletions internal/terminal_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build linux

package internal

import "golang.org/x/sys/unix"

// FlushTerminalInput discards any bytes the kernel has buffered in the
// terminal's input queue. Called after term.MakeRaw at prompt entry so that
// keystrokes the user typed in cooked mode while the previous command was
// running don't carry over and trigger spurious actions.
func FlushTerminalInput(fd int) error {
return unix.IoctlSetInt(fd, unix.TCFLSH, unix.TCIFLUSH)
}
5 changes: 5 additions & 0 deletions internal/terminal_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//go:build windows

package internal

func FlushTerminalInput(fd int) error { return nil }
14 changes: 0 additions & 14 deletions trainings/files/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,11 @@ type Files struct {
stdin io.Reader
stdout io.Writer

// stdinCh, when set, takes precedence over stdin for confirm prompts.
// Callers running inside interactiveRun should pass h.stdinCh here so
// confirm prompts don't race with the MCP stdin reader goroutine for
// bytes on os.Stdin. See trainings/run.go interactiveRun for context.
stdinCh <-chan rune

deleteUnusedFiles bool
showFullDiff bool
forceOverwrite bool
}

// WithStdinCh returns a copy of f that reads confirm-prompt answers from ch
// instead of f.stdin. Use this when the caller is inside interactiveRun and
// MCP is active (h.stdinCh != nil).
func (f Files) WithStdinCh(ch <-chan rune) Files {
f.stdinCh = ch
return f
}

func NewFiles() Files {
return NewFilesWithStdOuts(os.Stdin, os.Stdout)
}
Expand Down
6 changes: 0 additions & 6 deletions trainings/files/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,7 @@ func (f Files) shouldWriteFile(fs afero.Fs, filePath string, file *genproto.File
}
}

// confirmPrompt routes to the MCP-safe channel reader when stdinCh is set,
// otherwise falls back to the plain stdin reader. This avoids races with
// the interactiveRun stdin goroutine that owns os.Stdin while MCP is active.
func (f Files) confirmPrompt(msg string) bool {
if f.stdinCh != nil {
return internal.FConfirmPromptFromCh(msg, f.stdinCh, f.stdout)
}
return internal.FConfirmPrompt(msg, f.stdin, f.stdout)
}

Expand Down
16 changes: 15 additions & 1 deletion trainings/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/spf13/afero"
"golang.org/x/term"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand All @@ -38,10 +39,14 @@ type Handlers struct {

loopState *mcppkg.LoopState // nil if MCP disabled
mcpPort int // 0 = MCP disabled
stdinCh <-chan rune // centralized stdin channel; non-nil only during interactiveRun with MCP

pendingMCPResultCh chan<- mcppkg.MCPResult // deferred result for blocking MCP commands (e.g. next exercise)

// sessionTermState is the cooked-mode terminal state captured during a
// prompt's call to enterPromptMode. Held only while a prompt is active so
// os.Exit paths can restore cooked mode before the process dies.
sessionTermState *term.State

// Fallback update-state for when MCP is disabled (loopState == nil).
// When loopState is non-nil all update accessors route through it so
// MCP tool handlers see the same values as the terminal prompt.
Expand Down Expand Up @@ -167,6 +172,15 @@ func (h *Handlers) newGitOps() *git.Ops {
return git.NewOps(trainingRoot, disabled)
}

// restoreTerminal restores the terminal to cooked mode if a prompt currently
// holds it in raw mode. Safe to call from os.Exit paths (no-op when no prompt
// is active).
func (h *Handlers) restoreTerminal() {
if h.sessionTermState != nil {
_ = term.Restore(0, h.sessionTermState)
}
}

func newTrainingRootFs(trainingRoot string) *afero.BasePathFs {
// Privacy of your files is our priority.
//
Expand Down
51 changes: 13 additions & 38 deletions trainings/next.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package trainings

import (
"bufio"
"context"
"fmt"
"os"
Expand All @@ -12,7 +11,6 @@ import (
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/spf13/afero"
"golang.org/x/term"

"github.com/ThreeDotsLabs/cli/internal"
"github.com/ThreeDotsLabs/cli/trainings/config"
Expand All @@ -22,46 +20,26 @@ import (
mcppkg "github.com/ThreeDotsLabs/cli/trainings/mcp"
)

// promptRune displays a prompt for the given actions and reads a single valid keypress.
// Uses h.stdinCh when MCP is active, otherwise reads os.Stdin directly.
// This is the MCP-safe replacement for internal.Prompt when called from within interactiveRun.
// promptRune displays a prompt for the given actions and reads a single valid
// keypress from stdin. Sets raw mode for the prompt only and spawns a scoped
// reader goroutine. Used both with and without MCP — does not select on MCP
// commands (callers that need MCP commands use waitForAction instead).
func (h *Handlers) promptRune(actions internal.Actions) rune {
defer fmt.Println()

printPrompt(actions)

termState, rawErr := term.MakeRaw(0)
if rawErr == nil {
defer term.Restore(0, termState)
}
defer h.enterPromptMode()()

if h.stdinCh == nil {
reader := bufio.NewReader(os.Stdin)
for {
ch, _, err := reader.ReadRune()
if err != nil {
return 'q'
}
if string(ch) == "\x03" {
if rawErr == nil {
term.Restore(0, termState)
}
os.Exit(0)
}
if key, ok := actions.ReadKeyFromInput(ch); ok {
return key
}
}
}

drainChannel(h.stdinCh)
done := make(chan struct{})
defer close(done)
runeCh := h.startScopedStdinReader(done)

for {
ch := <-h.stdinCh
if string(ch) == "\x03" {
if rawErr == nil {
term.Restore(0, termState)
}
ch, ok := <-runeCh
if !ok {
h.restoreTerminal()
logrus.Debug("stdin closed, exiting")
fmt.Println(color.HiBlackString("Input closed — exiting."))
os.Exit(0)
}
if key, ok := actions.ReadKeyFromInput(ch); ok {
Expand Down Expand Up @@ -168,9 +146,6 @@ func (h *Handlers) setExercise(ctx context.Context, fs *afero.BasePathFs, exerci
// Existing behavior (no git or text-only)
isEasy := exercise.TrainingDifficulty == genproto.TrainingDifficulty_EASY
f := files.NewFilesWithConfig(isEasy, isEasy)
if h.stdinCh != nil {
f = f.WithStdinCh(h.stdinCh)
}
if err := h.writeExerciseFiles(f, nextExerciseResponseToExerciseSolution(exercise), fs); err != nil {
return false, err
}
Expand Down
61 changes: 16 additions & 45 deletions trainings/run.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package trainings

import (
"bufio"
"context"
"encoding/json"
"fmt"
Expand All @@ -16,7 +15,6 @@ import (
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/spf13/afero"
"golang.org/x/term"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -147,24 +145,6 @@ func (h *Handlers) interactiveRun(ctx context.Context, trainingRootFs *afero.Bas
}
}

// Single stdin reader goroutine for the entire interactive session.
// Only needed when MCP is active (to select between stdin and MCP commands).
if h.loopState != nil {
ch := make(chan rune, 1)
go func() {
reader := bufio.NewReader(os.Stdin)
for {
r, _, err := reader.ReadRune()
if err != nil {
return
}
ch <- r
}
}()
h.stdinCh = ch
defer func() { h.stdinCh = nil }()
}

// Background poller: long `tr run` sessions (days/weeks) can outlive the
// one-shot update check that runs at CLI startup. This goroutine keeps
// checking periodically so a newer release is still surfaced mid-session.
Expand Down Expand Up @@ -345,6 +325,7 @@ func (h *Handlers) interactiveRun(ctx context.Context, trainingRootFs *afero.Bas
ctx = withMCPTriggered(ctx, fromMCP)

if chosenAction == loopActionQuit {
h.restoreTerminal()
os.Exit(0)
}
if chosenAction == loopActionUpdate {
Expand Down Expand Up @@ -743,12 +724,15 @@ func (h *Handlers) handleUpdateAction(ctx context.Context) {
}
fmt.Println()
fmt.Println("Please re-run your command.")
h.restoreTerminal()
os.Exit(0)
}

// waitForAction prints a prompt and waits for input from either stdin or the MCP command channel.
// When MCP is disabled (loopState == nil), it reads stdin synchronously (like internal.Prompt).
// When MCP is enabled, stdinCh must be a long-lived channel fed by a single goroutine in interactiveRun.
// waitForAction prints a prompt and waits for input from either stdin or the
// MCP command channel. When MCP is disabled (loopState == nil), it reads stdin
// synchronously via promptRune. When MCP is enabled, it sets raw mode for the
// prompt only and spawns a scoped reader goroutine — there is no long-lived
// stdin reader between prompts.
func (h *Handlers) waitForAction(
actions internal.Actions,
actionMap map[rune]loopAction,
Expand All @@ -763,24 +747,22 @@ func (h *Handlers) waitForAction(
return loopActionQuit, false
}

// MCP mode — need to select on both stdinCh and MCP commands.
defer fmt.Println()
printPrompt(actions)

termState, rawErr := term.MakeRaw(0)
if rawErr == nil {
defer term.Restore(0, termState)
}
defer h.enterPromptMode()()

drainChannel(h.stdinCh)
done := make(chan struct{})
defer close(done)
runeCh := h.startScopedStdinReader(done)

for {
select {
case ch := <-h.stdinCh:
if string(ch) == "\x03" {
if rawErr == nil {
term.Restore(0, termState)
}
case ch, ok := <-runeCh:
if !ok {
h.restoreTerminal()
logrus.Debug("stdin closed, exiting")
fmt.Println(color.HiBlackString("Input closed — exiting."))
os.Exit(0)
}
if key, ok := actions.ReadKeyFromInput(ch); ok {
Expand Down Expand Up @@ -818,17 +800,6 @@ func printPrompt(actions internal.Actions) {
fmt.Printf("%s", "Press "+formatActionsMessage(actionsStr)+" ")
}

// drainChannel discards any buffered values from a channel.
func drainChannel(ch <-chan rune) {
for {
select {
case <-ch:
default:
return
}
}
}

func (h *Handlers) sendPendingMCPResult(result mcppkg.MCPResult) {
if h.pendingMCPResultCh != nil {
h.pendingMCPResultCh <- result
Expand Down
Loading