Skip to content

Commit

Permalink
Startup script: implement retries for connection-related operations (#…
Browse files Browse the repository at this point in the history
…249)

* Startup script: implement retries for connection-related operations

* assert.Equal → assert.Contains

* Wait for at least 1,000 lines of logs

* Join slice of strings before calling assert.Contains()

* TestHostDirs: use require.Contains() instead of require.EqualValues()

* TestHostDirs: wait for at least 4 log lines
  • Loading branch information
edigaryev authored Feb 12, 2025
1 parent 4794f2a commit ee3c0f9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 32 deletions.
10 changes: 5 additions & 5 deletions internal/tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func TestSingleVM(t *testing.T) {
if err != nil {
t.Fatal(err)
}
return len(logLines) > 0
return len(logLines) >= 1000
}), "failed to wait for logs to become available")
logLines, err := devClient.VMs().Logs(context.Background(), "test-vm")
if err != nil {
Expand All @@ -78,7 +78,7 @@ func TestSingleVM(t *testing.T) {
for i := 1; i <= 1000; i++ {
expectedLogs = append(expectedLogs, strconv.Itoa(i))
}
assert.Equal(t, expectedLogs, logLines)
assert.Contains(t, strings.Join(logLines, "\n"), strings.Join(expectedLogs, "\n"))

// Ensure that the VM exists on disk before deleting it
require.True(t, hasVMByPredicate(t, func(info tart.VMInfo) bool {
Expand Down Expand Up @@ -416,17 +416,17 @@ func TestHostDirs(t *testing.T) {
logLines, err = devClient.VMs().Logs(context.Background(), vmName)
require.NoError(t, err)

return len(logLines) > 0
return len(logLines) >= 4
}), "failed to wait for logs to become available")

fmt.Println(logLines)

require.EqualValues(t, []string{
require.Contains(t, strings.Join(logLines, "\n"), strings.Join([]string{
"Read-write mount exists",
"Read-only mount exists",
"Failed to create a file in read-only mount",
"Successfully created a file in read-write mount",
}, logLines)
}, "\n"))
require.FileExists(t, filepath.Join(dirToMount, "test-rw.txt"))
require.NoFileExists(t, filepath.Join(dirToMount, "test-ro.txt"))
}
Expand Down
63 changes: 36 additions & 27 deletions internal/worker/vmmanager/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,31 +316,15 @@ func (vm *VM) shell(
env map[string]string,
consumeLine func(line string),
) error {
ip, err := vm.IP(ctx)
if err != nil {
return fmt.Errorf("%w to get IP", ErrVMFailed)
}

var netConn net.Conn

addr := ip + ":22"
var sess *ssh.Session

if err := retry.Do(func() error {
dialer := net.Dialer{}

netConn, err = dialer.DialContext(ctx, "tcp", addr)

return err
}, retry.Context(ctx)); err != nil {
return fmt.Errorf("%w to dial: %v", ErrVMFailed, err)
}

// set default user and password if not provided
// Set default user and password if not provided
if sshUser == "" && sshPassword == "" {
sshUser = "admin"
sshPassword = "admin"
}

// Configure SSH client
sshConfig := &ssh.ClientConfig{
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
Expand All @@ -351,15 +335,40 @@ func (vm *VM) shell(
},
}

sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, sshConfig)
if err != nil {
return fmt.Errorf("%w to connect via SSH: %v", ErrVMFailed, err)
}
cli := ssh.NewClient(sshConn, chans, reqs)
if err := retry.Do(func() error {
ip, err := vm.IP(ctx)
if err != nil {
return fmt.Errorf("failed to get VM's IP: %w", err)
}

sess, err := cli.NewSession()
if err != nil {
return fmt.Errorf("%w: failed to open SSH session: %v", ErrVMFailed, err)
addr := ip + ":22"

dialer := net.Dialer{
Timeout: 5 * time.Second,
}

netConn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("failed to dial %s: %w", addr, err)
}

sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, sshConfig)
if err != nil {
return fmt.Errorf("SSH handshake with %s failed: %w", addr, err)
}

sshClient := ssh.NewClient(sshConn, chans, reqs)

sess, err = sshClient.NewSession()
if err != nil {
return fmt.Errorf("failed to open an SSH session on %s: %w", addr, err)
}

return nil
}, retry.Context(ctx), retry.OnRetry(func(n uint, err error) {
consumeLine(fmt.Sprintf("attempt %d to establish SSH connection failed: %v", n, err))
})); err != nil {
return fmt.Errorf("failed to establish SSH connection: %w", err)
}

// Log output from the virtual machine
Expand Down

0 comments on commit ee3c0f9

Please sign in to comment.