Skip to content

Commit 91966a2

Browse files
authored
Merge pull request #7 from coder/completion
Add auto-completion
2 parents 6e88789 + c365495 commit 91966a2

13 files changed

+855
-93
lines changed

command.go

+101-5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ type Command struct {
5959
Middleware MiddlewareFunc
6060
Handler HandlerFunc
6161
HelpHandler HandlerFunc
62+
// CompletionHandler is called when the command is run in completion
63+
// mode. If nil, only the default completion handler is used.
64+
//
65+
// Flag and option parsing is best-effort in this mode, so even if an Option
66+
// is "required" it may not be set.
67+
CompletionHandler CompletionHandlerFunc
6268
}
6369

6470
// AddSubcommands adds the given subcommands, setting their
@@ -193,15 +199,22 @@ type Invocation struct {
193199
ctx context.Context
194200
Command *Command
195201
parsedFlags *pflag.FlagSet
196-
Args []string
202+
203+
// Args is reduced into the remaining arguments after parsing flags
204+
// during Run.
205+
Args []string
206+
197207
// Environ is a list of environment variables. Use EnvsWithPrefix to parse
198208
// os.Environ.
199209
Environ Environ
200210
Stdout io.Writer
201211
Stderr io.Writer
202212
Stdin io.Reader
203-
Logger slog.Logger
204-
Net Net
213+
214+
// Deprecated
215+
Logger slog.Logger
216+
// Deprecated
217+
Net Net
205218

206219
// testing
207220
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
@@ -282,6 +295,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
282295
return fs2
283296
}
284297

298+
func (inv *Invocation) CurWords() (prev string, cur string) {
299+
if len(inv.Args) == 1 {
300+
cur = inv.Args[0]
301+
prev = ""
302+
} else {
303+
cur = inv.Args[len(inv.Args)-1]
304+
prev = inv.Args[len(inv.Args)-2]
305+
}
306+
return
307+
}
308+
285309
// run recursively executes the command and its children.
286310
// allArgs is wired through the stack so that global flags can be accepted
287311
// anywhere in the command invocation.
@@ -378,8 +402,19 @@ func (inv *Invocation) run(state *runState) error {
378402
}
379403
}
380404

405+
// Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already.
406+
// We only look at the current word to figure out handler to run, or what directory to inspect.
407+
if inv.IsCompletionMode() {
408+
for _, e := range inv.complete() {
409+
fmt.Fprintln(inv.Stdout, e)
410+
}
411+
return nil
412+
}
413+
414+
ignoreFlagParseErrors := inv.Command.RawArgs
415+
381416
// Flag parse errors are irrelevant for raw args commands.
382-
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
417+
if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
383418
return xerrors.Errorf(
384419
"parsing flags (%v) for %q: %w",
385420
state.allArgs,
@@ -401,7 +436,7 @@ func (inv *Invocation) run(state *runState) error {
401436
}
402437
}
403438
// Don't error for missing flags if `--help` was supplied.
404-
if len(missing) > 0 && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
439+
if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
405440
return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", "))
406441
}
407442

@@ -558,6 +593,65 @@ func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
558593
return &i2
559594
}
560595

596+
func (inv *Invocation) complete() []string {
597+
prev, cur := inv.CurWords()
598+
599+
// If the current word is a flag
600+
if strings.HasPrefix(cur, "--") {
601+
flagParts := strings.Split(cur, "=")
602+
flagName := flagParts[0][2:]
603+
// If it's an equals flag
604+
if len(flagParts) == 2 {
605+
if out := inv.completeFlag(flagName); out != nil {
606+
for i, o := range out {
607+
out[i] = fmt.Sprintf("--%s=%s", flagName, o)
608+
}
609+
return out
610+
}
611+
} else if out := inv.Command.Options.ByFlag(flagName); out != nil {
612+
// If the current word is a valid flag, auto-complete it so the
613+
// shell moves the cursor
614+
return []string{cur}
615+
}
616+
}
617+
// If the previous word is a flag, then we're writing it's value
618+
// and we should check it's handler
619+
if strings.HasPrefix(prev, "--") {
620+
word := prev[2:]
621+
if out := inv.completeFlag(word); out != nil {
622+
return out
623+
}
624+
}
625+
// If the current word is the command, move the shell cursor
626+
if inv.Command.Name() == cur {
627+
return []string{inv.Command.Name()}
628+
}
629+
var completions []string
630+
631+
if inv.Command.CompletionHandler != nil {
632+
completions = append(completions, inv.Command.CompletionHandler(inv)...)
633+
}
634+
635+
completions = append(completions, DefaultCompletionHandler(inv)...)
636+
637+
return completions
638+
}
639+
640+
func (inv *Invocation) completeFlag(word string) []string {
641+
opt := inv.Command.Options.ByFlag(word)
642+
if opt == nil {
643+
return nil
644+
}
645+
if opt.CompletionHandler != nil {
646+
return opt.CompletionHandler(inv)
647+
}
648+
val, ok := opt.Value.(*Enum)
649+
if ok {
650+
return val.Choices
651+
}
652+
return nil
653+
}
654+
561655
// MiddlewareFunc returns the next handler in the chain,
562656
// or nil if there are no more.
563657
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
@@ -642,3 +736,5 @@ func RequireRangeArgs(start, end int) MiddlewareFunc {
642736

643737
// HandlerFunc handles an Invocation of a command.
644738
type HandlerFunc func(i *Invocation) error
739+
740+
type CompletionHandlerFunc func(i *Invocation) []string

command_test.go

+132-84
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"golang.org/x/xerrors"
1313

1414
serpent "github.com/coder/serpent"
15+
"github.com/coder/serpent/completion"
1516
)
1617

1718
// ioBufs is the standard input, output, and error for a command.
@@ -30,100 +31,147 @@ func fakeIO(i *serpent.Invocation) *ioBufs {
3031
return &b
3132
}
3233

33-
func TestCommand(t *testing.T) {
34-
t.Parallel()
35-
36-
cmd := func() *serpent.Command {
37-
var (
38-
verbose bool
39-
lower bool
40-
prefix string
41-
reqBool bool
42-
reqStr string
43-
)
44-
return &serpent.Command{
45-
Use: "root [subcommand]",
46-
Options: serpent.OptionSet{
47-
serpent.Option{
48-
Name: "verbose",
49-
Flag: "verbose",
50-
Value: serpent.BoolOf(&verbose),
51-
},
52-
serpent.Option{
53-
Name: "prefix",
54-
Flag: "prefix",
55-
Value: serpent.StringOf(&prefix),
56-
},
34+
func sampleCommand(t *testing.T) *serpent.Command {
35+
t.Helper()
36+
var (
37+
verbose bool
38+
lower bool
39+
prefix string
40+
reqBool bool
41+
reqStr string
42+
reqArr []string
43+
fileArr []string
44+
enumStr string
45+
)
46+
enumChoices := []string{"foo", "bar", "qux"}
47+
return &serpent.Command{
48+
Use: "root [subcommand]",
49+
Options: serpent.OptionSet{
50+
serpent.Option{
51+
Name: "verbose",
52+
Flag: "verbose",
53+
Value: serpent.BoolOf(&verbose),
5754
},
58-
Children: []*serpent.Command{
59-
{
60-
Use: "required-flag --req-bool=true --req-string=foo",
61-
Short: "Example with required flags",
62-
Options: serpent.OptionSet{
63-
serpent.Option{
64-
Name: "req-bool",
65-
Flag: "req-bool",
66-
Value: serpent.BoolOf(&reqBool),
67-
Required: true,
68-
},
69-
serpent.Option{
70-
Name: "req-string",
71-
Flag: "req-string",
72-
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
73-
ok := strings.Contains(value.String(), " ")
74-
if !ok {
75-
return xerrors.Errorf("string must contain a space")
76-
}
77-
return nil
78-
}),
79-
Required: true,
80-
},
55+
serpent.Option{
56+
Name: "prefix",
57+
Flag: "prefix",
58+
Value: serpent.StringOf(&prefix),
59+
},
60+
},
61+
Children: []*serpent.Command{
62+
{
63+
Use: "required-flag --req-bool=true --req-string=foo",
64+
Short: "Example with required flags",
65+
Options: serpent.OptionSet{
66+
serpent.Option{
67+
Name: "req-bool",
68+
Flag: "req-bool",
69+
FlagShorthand: "b",
70+
Value: serpent.BoolOf(&reqBool),
71+
Required: true,
8172
},
82-
HelpHandler: func(i *serpent.Invocation) error {
83-
_, _ = i.Stdout.Write([]byte("help text.png"))
84-
return nil
73+
serpent.Option{
74+
Name: "req-string",
75+
Flag: "req-string",
76+
FlagShorthand: "s",
77+
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
78+
ok := strings.Contains(value.String(), " ")
79+
if !ok {
80+
return xerrors.Errorf("string must contain a space")
81+
}
82+
return nil
83+
}),
84+
Required: true,
8585
},
86-
Handler: func(i *serpent.Invocation) error {
87-
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
88-
return nil
86+
serpent.Option{
87+
Name: "req-enum",
88+
Flag: "req-enum",
89+
Value: serpent.EnumOf(&enumStr, enumChoices...),
90+
},
91+
serpent.Option{
92+
Name: "req-array",
93+
Flag: "req-array",
94+
FlagShorthand: "a",
95+
Value: serpent.StringArrayOf(&reqArr),
8996
},
9097
},
91-
{
92-
Use: "toupper [word]",
93-
Short: "Converts a word to upper case",
94-
Middleware: serpent.Chain(
95-
serpent.RequireNArgs(1),
96-
),
97-
Aliases: []string{"up"},
98-
Options: serpent.OptionSet{
99-
serpent.Option{
100-
Name: "lower",
101-
Flag: "lower",
102-
Value: serpent.BoolOf(&lower),
103-
},
98+
HelpHandler: func(i *serpent.Invocation) error {
99+
_, _ = i.Stdout.Write([]byte("help text.png"))
100+
return nil
101+
},
102+
Handler: func(i *serpent.Invocation) error {
103+
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
104+
return nil
105+
},
106+
},
107+
{
108+
Use: "toupper [word]",
109+
Short: "Converts a word to upper case",
110+
Middleware: serpent.Chain(
111+
serpent.RequireNArgs(1),
112+
),
113+
Aliases: []string{"up"},
114+
Options: serpent.OptionSet{
115+
serpent.Option{
116+
Name: "lower",
117+
Flag: "lower",
118+
Value: serpent.BoolOf(&lower),
104119
},
105-
Handler: func(i *serpent.Invocation) error {
106-
_, _ = i.Stdout.Write([]byte(prefix))
107-
w := i.Args[0]
108-
if lower {
109-
w = strings.ToLower(w)
110-
} else {
111-
w = strings.ToUpper(w)
112-
}
113-
_, _ = i.Stdout.Write(
114-
[]byte(
115-
w,
116-
),
117-
)
118-
if verbose {
119-
_, _ = i.Stdout.Write([]byte("!!!"))
120-
}
121-
return nil
120+
},
121+
Handler: func(i *serpent.Invocation) error {
122+
_, _ = i.Stdout.Write([]byte(prefix))
123+
w := i.Args[0]
124+
if lower {
125+
w = strings.ToLower(w)
126+
} else {
127+
w = strings.ToUpper(w)
128+
}
129+
_, _ = i.Stdout.Write(
130+
[]byte(
131+
w,
132+
),
133+
)
134+
if verbose {
135+
_, _ = i.Stdout.Write([]byte("!!!"))
136+
}
137+
return nil
138+
},
139+
},
140+
{
141+
Use: "file <file>",
142+
Handler: func(inv *serpent.Invocation) error {
143+
return nil
144+
},
145+
CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool {
146+
return true
147+
}),
148+
Middleware: serpent.RequireNArgs(1),
149+
},
150+
{
151+
Use: "altfile",
152+
Handler: func(inv *serpent.Invocation) error {
153+
return nil
154+
},
155+
Options: serpent.OptionSet{
156+
{
157+
Name: "extra",
158+
Flag: "extra",
159+
Description: "Extra files.",
160+
Value: serpent.StringArrayOf(&fileArr),
122161
},
123162
},
163+
CompletionHandler: func(i *serpent.Invocation) []string {
164+
return []string{"doesntexist.go"}
165+
},
124166
},
125-
}
167+
},
126168
}
169+
}
170+
171+
func TestCommand(t *testing.T) {
172+
t.Parallel()
173+
174+
cmd := func() *serpent.Command { return sampleCommand(t) }
127175

128176
t.Run("SimpleOK", func(t *testing.T) {
129177
t.Parallel()

0 commit comments

Comments
 (0)