Skip to content

Commit

Permalink
Block until stop completes
Browse files Browse the repository at this point in the history
Currently the disconnect/stop bridge call will complete before all the
loops have returned. This switches them all to use a shared cancelable
context and wait group to block on stop until all loops exit.
  • Loading branch information
Fizzadar committed Feb 21, 2025
1 parent 79985bf commit 3867aa6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
5 changes: 4 additions & 1 deletion pkg/signalmeow/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ type Client struct {

AuthedWS *web.SignalWebsocket
UnauthedWS *web.SignalWebsocket
WSCancel context.CancelFunc
lastConnectionStatus SignalConnectionStatus

ctx context.Context
cancel context.CancelFunc
wg *sync.WaitGroup

EventHandler func(events.SignalEvent)

storageAuthLock sync.Mutex
Expand Down
16 changes: 10 additions & 6 deletions pkg/signalmeow/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,13 @@ func (cli *Client) CheckAndUploadNewPreKeys(ctx context.Context, pks store.PreKe
return nil
}

func (cli *Client) StartKeyCheckLoop(ctx context.Context) {
log := zerolog.Ctx(ctx).With().Str("action", "start key check loop").Logger()
func (cli *Client) StartKeyCheckLoop() {
log := zerolog.Ctx(cli.ctx).With().Str("action", "start key check loop").Logger()

cli.wg.Add(1)
go func() {
defer cli.wg.Done()

// Do the initial check in 5-10 minutes after starting the loop
window_start := 0
window_size := 1
Expand All @@ -611,22 +615,22 @@ func (cli *Client) StartKeyCheckLoop(ctx context.Context) {
log.Debug().Dur("check_time", check_time).Msg("Waiting to check for new prekeys")

select {
case <-ctx.Done():
case <-cli.ctx.Done():
return
case <-time.After(check_time):
err := cli.CheckAndUploadNewPreKeys(ctx, cli.Store.ACIPreKeyStore)
err := cli.CheckAndUploadNewPreKeys(cli.ctx, cli.Store.ACIPreKeyStore)
if err != nil {
log.Err(err).Msg("Error checking and uploading new prekeys for ACI identity")
// Retry within half an hour
window_start = 5
window_size = 25
continue
}
err = cli.CheckAndUploadNewPreKeys(ctx, cli.Store.PNIPreKeyStore)
err = cli.CheckAndUploadNewPreKeys(cli.ctx, cli.Store.PNIPreKeyStore)
if err != nil {
if errors.Is(err, errPrekeyUpload422) {
log.Err(err).Msg("Got 422 error while uploading PNI prekeys, deleting session")
disconnectErr := cli.ClearKeysAndDisconnect(ctx)
disconnectErr := cli.ClearKeysAndDisconnect(cli.ctx)
if disconnectErr != nil {
log.Err(disconnectErr).Msg("ClearKeysAndDisconnect error")
}
Expand Down
40 changes: 21 additions & 19 deletions pkg/signalmeow/receiving.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,26 @@ type SignalConnectionStatus struct {
}

func (cli *Client) StartWebsockets(ctx context.Context) (authChan, unauthChan chan web.SignalWebsocketConnectionStatus, err error) {
authChan, unauthChan, _, _, err = cli.startWebsocketsInternal(ctx)
authChan, unauthChan, err = cli.startWebsocketsInternal(ctx)
return
}

func (cli *Client) startWebsocketsInternal(
ctx context.Context,
) (
authChan, unauthChan chan web.SignalWebsocketConnectionStatus,
cancelCtx context.Context,
cancelFunc context.CancelFunc,
err error,
) {
cancelCtx, cancelFunc = context.WithCancel(ctx)
cli.WSCancel = cancelFunc
unauthChan, err = cli.connectUnauthedWS(cancelCtx)
cli.ctx, cli.cancel = context.WithCancel(ctx)
unauthChan, err = cli.connectUnauthedWS(cli.ctx)
if err != nil {
cancelFunc()
cli.cancel()
return
}
zerolog.Ctx(ctx).Info().Msg("Unauthed websocket connecting")
authChan, err = cli.connectAuthedWS(cancelCtx, cli.incomingRequestHandler)
authChan, err = cli.connectAuthedWS(cli.ctx, cli.incomingRequestHandler)
if err != nil {
cancelFunc()
cli.cancel()
return
}
zerolog.Ctx(ctx).Info().Msg("Authed websocket connecting")
Expand All @@ -100,8 +97,8 @@ func (cli *Client) startWebsocketsInternal(

func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnectionStatus, error) {
log := zerolog.Ctx(ctx).With().Str("action", "start receive loops").Logger()
ctx = log.WithContext(ctx)
authChan, unauthChan, ctx, cancel, err := cli.startWebsocketsInternal(log.WithContext(ctx))

authChan, unauthChan, err := cli.startWebsocketsInternal(log.WithContext(ctx))
if err != nil {
return nil, err
}
Expand All @@ -110,13 +107,15 @@ func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnection
initialConnectChan := make(chan struct{})

// Combine both websocket status channels into a single, more generic "Signal" connection status channel
cli.wg.Add(1)
go func() {
defer cli.wg.Done()
defer close(statusChan)
defer cancel()
defer cli.cancel()
var currentStatus, lastAuthStatus, lastUnauthStatus web.SignalWebsocketConnectionStatus
for {
select {
case <-ctx.Done():
case <-cli.ctx.Done():
log.Info().Msg("Context done, exiting websocket status loop")
return
case status := <-authChan:
Expand Down Expand Up @@ -201,27 +200,29 @@ func (cli *Client) StartReceiveLoops(ctx context.Context) (chan SignalConnection
}()

// Send sync message once both websockets are connected
cli.wg.Add(1)
go func() {
defer cli.wg.Done()
for {
select {
case <-ctx.Done():
case <-cli.ctx.Done():
return
case <-initialConnectChan:
log.Info().Msg("Both websockets connected, sending contacts sync request")
// TODO hacky
if cli.SyncContactsOnConnect {
cli.SendContactSyncRequest(ctx)
cli.SendContactSyncRequest(cli.ctx)
}
if cli.Store.MasterKey == nil {
cli.SendStorageMasterKeyRequest(ctx)
cli.SendStorageMasterKeyRequest(cli.ctx)
}
return
}
}
}()

// Start loop to check for and upload more prekeys
cli.StartKeyCheckLoop(ctx)
cli.StartKeyCheckLoop()

return statusChan, nil
}
Expand All @@ -233,8 +234,9 @@ func (cli *Client) StopReceiveLoops() error {
}()
authErr := cli.AuthedWS.Close()
unauthErr := cli.UnauthedWS.Close()
if cli.WSCancel != nil {
cli.WSCancel()
if cli.cancel != nil {
cli.cancel()
cli.wg.Wait()
}
if authErr != nil {
return authErr
Expand Down

0 comments on commit 3867aa6

Please sign in to comment.