Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto-completion #7

Merged
merged 15 commits into from
Aug 1, 2024
106 changes: 101 additions & 5 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ type Command struct {
Middleware MiddlewareFunc
Handler HandlerFunc
HelpHandler HandlerFunc
// CompletionHandler is called when the command is run in completion
// mode. If nil, only the default completion handler is used.
//
// Flag and option parsing is best-effort in this mode, so even if an Option
// is "required" it may not be set.
CompletionHandler CompletionHandlerFunc
}

// AddSubcommands adds the given subcommands, setting their
Expand Down Expand Up @@ -193,15 +199,22 @@ type Invocation struct {
ctx context.Context
Command *Command
parsedFlags *pflag.FlagSet
Args []string

// Args is reduced into the remaining arguments after parsing flags
// during Run.
Args []string

// Environ is a list of environment variables. Use EnvsWithPrefix to parse
// os.Environ.
Environ Environ
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
Logger slog.Logger
Net Net

// Deprecated
Logger slog.Logger
// Deprecated
Net Net

// testing
signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc)
Expand Down Expand Up @@ -282,6 +295,17 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet {
return fs2
}

func (inv *Invocation) CurWords() (prev string, cur string) {
if len(inv.Args) == 1 {
cur = inv.Args[0]
prev = ""
} else {
cur = inv.Args[len(inv.Args)-1]
prev = inv.Args[len(inv.Args)-2]
}
return
}

// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
Expand Down Expand Up @@ -378,8 +402,19 @@ func (inv *Invocation) run(state *runState) error {
}
}

// Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already.
// We only look at the current word to figure out handler to run, or what directory to inspect.
if inv.IsCompletionMode() {
for _, e := range inv.complete() {
fmt.Fprintln(inv.Stdout, e)
}
return nil
}

ignoreFlagParseErrors := inv.Command.RawArgs

// Flag parse errors are irrelevant for raw args commands.
if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
Expand All @@ -396,7 +431,7 @@ func (inv *Invocation) run(state *runState) error {
}
}
// Don't error for missing flags if `--help` was supplied.
if len(missing) > 0 && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", "))
}

Expand Down Expand Up @@ -553,6 +588,65 @@ func (inv *Invocation) with(fn func(*Invocation)) *Invocation {
return &i2
}

func (inv *Invocation) complete() []string {
prev, cur := inv.CurWords()

// If the current word is a flag
if strings.HasPrefix(cur, "--") {
flagParts := strings.Split(cur, "=")
flagName := flagParts[0][2:]
// If it's an equals flag
if len(flagParts) == 2 {
if out := inv.completeFlag(flagName); out != nil {
for i, o := range out {
out[i] = fmt.Sprintf("--%s=%s", flagName, o)
}
return out
}
} else if out := inv.Command.Options.ByFlag(flagName); out != nil {
// If the current word is a valid flag, auto-complete it so the
// shell moves the cursor
return []string{cur}
}
}
// If the previous word is a flag, then we're writing it's value
// and we should check it's handler
if strings.HasPrefix(prev, "--") {
word := prev[2:]
if out := inv.completeFlag(word); out != nil {
return out
}
}
// If the current word is the command, move the shell cursor
if inv.Command.Name() == cur {
return []string{inv.Command.Name()}
}
var completions []string

if inv.Command.CompletionHandler != nil {
completions = append(completions, inv.Command.CompletionHandler(inv)...)
}

completions = append(completions, DefaultCompletionHandler(inv)...)

return completions
}

func (inv *Invocation) completeFlag(word string) []string {
opt := inv.Command.Options.ByFlag(word)
if opt == nil {
return nil
}
if opt.CompletionHandler != nil {
return opt.CompletionHandler(inv)
}
val, ok := opt.Value.(*Enum)
if ok {
return val.Choices
}
return nil
}

// MiddlewareFunc returns the next handler in the chain,
// or nil if there are no more.
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
Expand Down Expand Up @@ -637,3 +731,5 @@ func RequireRangeArgs(start, end int) MiddlewareFunc {

// HandlerFunc handles an Invocation of a command.
type HandlerFunc func(i *Invocation) error

type CompletionHandlerFunc func(i *Invocation) []string
216 changes: 132 additions & 84 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/xerrors"

serpent "github.com/coder/serpent"
"github.com/coder/serpent/completion"
)

// ioBufs is the standard input, output, and error for a command.
Expand All @@ -30,100 +31,147 @@ func fakeIO(i *serpent.Invocation) *ioBufs {
return &b
}

func TestCommand(t *testing.T) {
t.Parallel()

cmd := func() *serpent.Command {
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
)
return &serpent.Command{
Use: "root [subcommand]",
Options: serpent.OptionSet{
serpent.Option{
Name: "verbose",
Flag: "verbose",
Value: serpent.BoolOf(&verbose),
},
serpent.Option{
Name: "prefix",
Flag: "prefix",
Value: serpent.StringOf(&prefix),
},
func sampleCommand(t *testing.T) *serpent.Command {
t.Helper()
var (
verbose bool
lower bool
prefix string
reqBool bool
reqStr string
reqArr []string
fileArr []string
enumStr string
)
enumChoices := []string{"foo", "bar", "qux"}
return &serpent.Command{
Use: "root [subcommand]",
Options: serpent.OptionSet{
serpent.Option{
Name: "verbose",
Flag: "verbose",
Value: serpent.BoolOf(&verbose),
},
Children: []*serpent.Command{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: serpent.OptionSet{
serpent.Option{
Name: "req-bool",
Flag: "req-bool",
Value: serpent.BoolOf(&reqBool),
Required: true,
},
serpent.Option{
Name: "req-string",
Flag: "req-string",
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
serpent.Option{
Name: "prefix",
Flag: "prefix",
Value: serpent.StringOf(&prefix),
},
},
Children: []*serpent.Command{
{
Use: "required-flag --req-bool=true --req-string=foo",
Short: "Example with required flags",
Options: serpent.OptionSet{
serpent.Option{
Name: "req-bool",
Flag: "req-bool",
FlagShorthand: "b",
Value: serpent.BoolOf(&reqBool),
Required: true,
},
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
return nil
serpent.Option{
Name: "req-string",
Flag: "req-string",
FlagShorthand: "s",
Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error {
ok := strings.Contains(value.String(), " ")
if !ok {
return xerrors.Errorf("string must contain a space")
}
return nil
}),
Required: true,
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
serpent.Option{
Name: "req-enum",
Flag: "req-enum",
Value: serpent.EnumOf(&enumStr, enumChoices...),
},
serpent.Option{
Name: "req-array",
Flag: "req-array",
FlagShorthand: "a",
Value: serpent.StringArrayOf(&reqArr),
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: serpent.OptionSet{
serpent.Option{
Name: "lower",
Flag: "lower",
Value: serpent.BoolOf(&lower),
},
HelpHandler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte("help text.png"))
return nil
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool)))
return nil
},
},
{
Use: "toupper [word]",
Short: "Converts a word to upper case",
Middleware: serpent.Chain(
serpent.RequireNArgs(1),
),
Aliases: []string{"up"},
Options: serpent.OptionSet{
serpent.Option{
Name: "lower",
Flag: "lower",
Value: serpent.BoolOf(&lower),
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
_, _ = i.Stdout.Write([]byte("!!!"))
}
return nil
},
Handler: func(i *serpent.Invocation) error {
_, _ = i.Stdout.Write([]byte(prefix))
w := i.Args[0]
if lower {
w = strings.ToLower(w)
} else {
w = strings.ToUpper(w)
}
_, _ = i.Stdout.Write(
[]byte(
w,
),
)
if verbose {
_, _ = i.Stdout.Write([]byte("!!!"))
}
return nil
},
},
{
Use: "file <file>",
Handler: func(inv *serpent.Invocation) error {
return nil
},
CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool {
return true
}),
Middleware: serpent.RequireNArgs(1),
},
{
Use: "altfile",
Handler: func(inv *serpent.Invocation) error {
return nil
},
Options: serpent.OptionSet{
{
Name: "extra",
Flag: "extra",
Description: "Extra files.",
Value: serpent.StringArrayOf(&fileArr),
},
},
CompletionHandler: func(i *serpent.Invocation) []string {
return []string{"doesntexist.go"}
},
},
}
},
}
}

func TestCommand(t *testing.T) {
t.Parallel()

cmd := func() *serpent.Command { return sampleCommand(t) }

t.Run("SimpleOK", func(t *testing.T) {
t.Parallel()
Expand Down
Loading
Loading