Skip to content

Commit e2bc46c

Browse files
Improve polling updater.stop() call when long polling (#152)
1 parent 954c160 commit e2bc46c

File tree

5 files changed

+53
-24
lines changed

5 files changed

+53
-24
lines changed

ext/botmapping.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,15 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
163163
w.WriteHeader(http.StatusNotFound)
164164
return
165165
}
166+
166167
b.updateWriterControl.Add(1)
167168
defer b.updateWriterControl.Done()
168169

170+
if b.shouldStopUpdates() {
171+
w.WriteHeader(http.StatusServiceUnavailable)
172+
return
173+
}
174+
169175
headerSecret := r.Header.Get("X-Telegram-Bot-Api-Secret-Token")
170176
if b.webhookSecret != "" && b.webhookSecret != headerSecret {
171177
// Drop any updates from invalid secret tokens.
@@ -184,10 +190,6 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
184190
return
185191
}
186192

187-
if b.isUpdateChannelStopped() {
188-
return
189-
}
190-
191193
b.updateChan <- bytes
192194
}
193195
}
@@ -213,7 +215,7 @@ func (b *botData) stop() {
213215
close(b.updateChan)
214216
}
215217

216-
func (b *botData) isUpdateChannelStopped() bool {
218+
func (b *botData) shouldStopUpdates() bool {
217219
select {
218220
case <-b.stopUpdates:
219221
// if anything comes in on the closing channel, we know the channel is closed.

ext/botmapping_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ func Test_botData_isUpdateChannelStopped(t *testing.T) {
8989
t.Errorf("bot with token %s should not have failed to be added", b.Token)
9090
return
9191
}
92-
if bData.isUpdateChannelStopped() {
92+
if bData.shouldStopUpdates() {
9393
t.Errorf("bot with token %s should not be stopped yet", b.Token)
9494
return
9595
}
9696

9797
bData.stop()
98-
if !bData.isUpdateChannelStopped() {
98+
if !bData.shouldStopUpdates() {
9999
t.Errorf("bot with token %s should be stopped", b.Token)
100100
return
101101
}

ext/updater.go

+8-4
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s
173173
defer bData.updateWriterControl.Done()
174174

175175
for {
176+
// Check if updater loop has been terminated.
177+
if bData.shouldStopUpdates() {
178+
return
179+
}
180+
176181
// Manually craft the getUpdate calls to improve memory management, reduce json parsing overheads, and
177182
// unnecessary reallocation of url.Values in the polling loop.
178183
r, err := bData.bot.Request("getUpdates", v, nil, opts)
@@ -219,10 +224,6 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s
219224

220225
v["offset"] = strconv.FormatInt(lastUpdate.UpdateId+1, 10)
221226

222-
if bData.isUpdateChannelStopped() {
223-
return
224-
}
225-
226227
for _, updData := range rawUpdates {
227228
temp := updData // use new mem address to avoid loop conflicts
228229
bData.updateChan <- temp
@@ -240,6 +241,9 @@ func (u *Updater) Idle() {
240241
}
241242

242243
// Stop stops the current updater and dispatcher instances.
244+
//
245+
// When using long polling, Stop() will wait for the getUpdates call to return, which may cause a delay due to the
246+
// request timeout.
243247
func (u *Updater) Stop() error {
244248
// Stop any running servers.
245249
if u.webhookServer != nil {

ext/updater_test.go

+24-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strconv"
1111
"strings"
1212
"sync"
13+
"sync/atomic"
1314
"testing"
1415
"time"
1516

@@ -98,8 +99,14 @@ func concurrentTest(t *testing.T) {
9899
t.Parallel()
99100

100101
delay := time.Second
101-
server := basicTestServer(t, map[string]testEndpoint{
102-
"getUpdates": {delay: delay, reply: `{"ok": true, "result": [{"message": {"text": "stop"}}]}`},
102+
server := basicTestServer(t, map[string]*testEndpoint{
103+
"getUpdates": {
104+
delay: delay,
105+
replies: []string{
106+
`{"ok": true, "result": [{"message": {"text": "stop"}}]}`,
107+
},
108+
reply: `{"ok": true, "result": []}`,
109+
},
103110
"deleteWebhook": {reply: `{"ok": true, "result": true}`},
104111
})
105112
defer server.Close()
@@ -290,7 +297,7 @@ func TestUpdater_GetHandlerFunc(t *testing.T) {
290297
}
291298

292299
func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
293-
server := basicTestServer(t, map[string]testEndpoint{
300+
server := basicTestServer(t, map[string]*testEndpoint{
294301
"getUpdates": {reply: `{"ok": true}`},
295302
"deleteWebhook": {reply: `{"ok": true, "result": true}`},
296303
})
@@ -329,7 +336,7 @@ func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
329336
}
330337

331338
func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
332-
server := basicTestServer(t, map[string]testEndpoint{
339+
server := basicTestServer(t, map[string]*testEndpoint{
333340
"getUpdates": {reply: `{"ok": true, "result": []}`},
334341
})
335342
defer server.Close()
@@ -384,7 +391,7 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
384391
}
385392

386393
func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
387-
server := basicTestServer(t, map[string]testEndpoint{
394+
server := basicTestServer(t, map[string]*testEndpoint{
388395
"getUpdates": {reply: `{"ok": true, "result": []}`},
389396
})
390397
defer server.Close()
@@ -432,7 +439,7 @@ func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
432439
}
433440

434441
func TestUpdaterSupportsLongPollReAdding(t *testing.T) {
435-
server := basicTestServer(t, map[string]testEndpoint{
442+
server := basicTestServer(t, map[string]*testEndpoint{
436443
"getUpdates": {reply: `{"ok": true, "result": []}`},
437444
})
438445
defer server.Close()
@@ -484,10 +491,14 @@ func TestUpdaterSupportsLongPollReAdding(t *testing.T) {
484491

485492
type testEndpoint struct {
486493
delay time.Duration
494+
// Will reply these until we run out of replies, at which point we repeat "reply"
495+
replies []string
496+
idx atomic.Int32
497+
// default reply
487498
reply string
488499
}
489500

490-
func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Server {
501+
func basicTestServer(t *testing.T, methods map[string]*testEndpoint) *httptest.Server {
491502
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
492503
pathItems := strings.Split(r.URL.Path, "/")
493504
lastItem := pathItems[len(pathItems)-1]
@@ -498,7 +509,12 @@ func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Se
498509
if out.delay != 0 {
499510
time.Sleep(out.delay)
500511
}
501-
fmt.Fprint(w, out.reply)
512+
count := int(out.idx.Add(1) - 1)
513+
if len(out.replies) != 0 && len(out.replies) > count {
514+
fmt.Fprint(w, out.replies[count])
515+
} else {
516+
fmt.Fprint(w, out.reply)
517+
}
502518
return
503519
}
504520

samples/echoMultiBot/main.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ func main() {
8686

8787
// If we get here, the updater.Idle() has ended.
8888
// This means that updater.Stop() has been called, stopping all bots gracefully.
89-
log.Println("Updater is no longer idling; all bots have been stopped gracefully.")
89+
log.Println("Updater is no longer idling; all bots have been stopped gracefully. Exiting in 1s.")
90+
91+
// We sleep one last second to allow for the "stopall" goroutine to send the shutdown message.
92+
time.Sleep(time.Second)
9093
}
9194

9295
// startLongPollingBots demonstrates how to start multiple bots with long-polling.
@@ -159,11 +162,14 @@ func stop(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error {
159162
return fmt.Errorf("failed to echo message: %w", err)
160163
}
161164

162-
if !updater.StopBot(b.Token) {
163-
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil)
164-
return nil
165-
}
165+
go func() {
166+
if !updater.StopBot(b.Token) {
167+
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil)
168+
return
169+
}
166170

171+
ctx.EffectiveMessage.Reply(b, "Stopped @"+b.Username, nil)
172+
}()
167173
return nil
168174
}
169175

@@ -181,6 +187,7 @@ func stopAll(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error {
181187
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Failed to stop updater: %s", err.Error()), nil)
182188
return
183189
}
190+
ctx.EffectiveMessage.Reply(b, "All bots have been stopped.", nil)
184191
}()
185192

186193
return nil

0 commit comments

Comments
 (0)