Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cmd.CustomFileSystemFunc function #130

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions cmd/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import (
"context"
"errors"
"fmt"
"os"

"github.com/cloudspannerecosystem/wrench/internal/fs"
"github.com/cloudspannerecosystem/wrench/pkg/spanner"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -57,7 +56,7 @@ func apply(c *cobra.Command, _ []string) error {
return errors.New("cannot specify DDL and DML at same time")
}

ddl, err := os.ReadFile(ddlFile)
ddl, err := fs.ReadFile(ctx, ddlFile)
if err != nil {
return &Error{
err: err,
Expand All @@ -81,7 +80,7 @@ func apply(c *cobra.Command, _ []string) error {
}

// apply dml
dml, err := os.ReadFile(dmlFile)
dml, err := fs.ReadFile(ctx, dmlFile)
if err != nil {
return &Error{
err: err,
Expand Down
5 changes: 2 additions & 3 deletions cmd/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ package cmd

import (
"context"
"os"

"github.com/cloudspannerecosystem/wrench/internal/fs"
"github.com/spf13/cobra"
)

Expand All @@ -43,7 +42,7 @@ func create(c *cobra.Command, _ []string) error {
defer client.Close()

filename := schemaFilePath(c)
ddl, err := os.ReadFile(filename)
ddl, err := fs.ReadFile(ctx, filename)
if err != nil {
return &Error{
err: err,
Expand Down
14 changes: 14 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ package cmd

import (
"context"
"io/fs"
"os"
"runtime/debug"
"time"

wrenchfs "github.com/cloudspannerecosystem/wrench/internal/fs"
"github.com/spf13/cobra"
)

Expand All @@ -44,8 +46,20 @@ var (
timeout time.Duration
)

// CustomFileSystemFunc is a function that returns a custom fs.FS.
// This variable allows customizing what kind of fs.FS should be use in wrench CLI execution.
// e.g. embed.FS for use.
var CustomFileSystemFunc func() fs.FS

var rootCmd = &cobra.Command{
Use: "wrench",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
if CustomFileSystemFunc != nil {
ctx := cmd.Context()
ctx = wrenchfs.WithContext(ctx, CustomFileSystemFunc())
cmd.SetContext(ctx)
}
},
}

func Execute(ctx context.Context) error {
Expand Down
31 changes: 31 additions & 0 deletions internal/fs/fs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package fs

import (
"context"
"io/fs"
"os"
)

type contextFSKey struct{}

func WithContext(ctx context.Context, fsys fs.FS) context.Context {
return context.WithValue(ctx, contextFSKey{}, fsys)
}

func FromContext(ctx context.Context) fs.FS {
fsys, ok := ctx.Value(contextFSKey{}).(fs.FS)
if ok && fsys != nil {
return fsys
}
return os.DirFS(".")
}

func ReadFile(ctx context.Context, path string) ([]byte, error) {
fsys := FromContext(ctx)
return fs.ReadFile(fsys, path)
}

func ReadDir(ctx context.Context, path string) ([]fs.DirEntry, error) {
fsys := FromContext(ctx)
return fs.ReadDir(fsys, path)
}
2 changes: 1 addition & 1 deletion pkg/spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestExecuteMigrations(t *testing.T) {
t.Fatalf("failed to apply mutation: %v", err)
}

migrations, err := LoadMigrations("testdata/migrations")
migrations, err := ReadMigrations(ctx, "testdata/migrations")
if err != nil {
t.Fatalf("failed to load migrations: %v", err)
}
Expand Down
16 changes: 11 additions & 5 deletions pkg/spanner/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
package spanner

import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"

"github.com/cloudspannerecosystem/wrench/internal/fs"
)

var (
Expand Down Expand Up @@ -76,8 +78,8 @@ func (ms Migrations) Less(i, j int) bool {
return ms[i].Version < ms[j].Version
}

func LoadMigrations(dir string) (Migrations, error) {
files, err := os.ReadDir(dir)
func ReadMigrations(ctx context.Context, dir string) (Migrations, error) {
files, err := fs.ReadDir(ctx, dir)
if err != nil {
return nil, err
}
Expand All @@ -103,7 +105,7 @@ func LoadMigrations(dir string) (Migrations, error) {
continue
}

file, err := os.ReadFile(filepath.Join(dir, filename))
file, err := fs.ReadFile(ctx, filepath.Join(dir, filename))
if err != nil {
continue
}
Expand Down Expand Up @@ -138,6 +140,11 @@ func LoadMigrations(dir string) (Migrations, error) {
return migrations, nil
}

// Deprecated: use ReadMigrations instead.
func LoadMigrations(dir string) (Migrations, error) {
return ReadMigrations(context.Background(), dir)
}

func ddlToStatements(filename string, data []byte) ([]string, error) {
return toStatements(filename, data)
}
Expand Down Expand Up @@ -174,4 +181,3 @@ func inspectStatementsKind(statements []string) (statementKind, error) {
return "", errors.New("DDL, DML (INSERT), and partitioned DML (UPDATE or DELETE) must not be combined in the same migration file")
}
}

5 changes: 4 additions & 1 deletion pkg/spanner/migration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
package spanner_test

import (
"context"
"path/filepath"
"testing"

"github.com/cloudspannerecosystem/wrench/pkg/spanner"
)

func TestLoadMigrations(t *testing.T) {
ms, err := spanner.LoadMigrations(filepath.Join("testdata", "migrations"))
ctx := context.Background()

ms, err := spanner.ReadMigrations(ctx, filepath.Join("testdata", "migrations"))
if err != nil {
t.Fatal(err)
}
Expand Down
Loading