diff --git a/internal/impl/io/output_http_server.go b/internal/impl/io/output_http_server.go index 8639bea51..f4e65b61d 100644 --- a/internal/impl/io/output_http_server.go +++ b/internal/impl/io/output_http_server.go @@ -42,6 +42,9 @@ const ( hsoFieldCORS = "cors" hsoFieldCORSEnabled = "enabled" hsoFieldCORSAllowedOrigins = "allowed_origins" + hsoFieldWriteWait = "write_wait" + hsoFieldPongWait = "pong_wait" + hsoFieldPingPeriod = "ping_period" ) type hsoConfig struct { @@ -54,6 +57,9 @@ type hsoConfig struct { CertFile string KeyFile string CORS httpserver.CORSConfig + WriteWait time.Duration + PongWait time.Duration + PingPeriod time.Duration } func hsoConfigFromParsed(pConf *service.ParsedConfig) (conf hsoConfig, err error) { @@ -95,6 +101,15 @@ func hsoConfigFromParsed(pConf *service.ParsedConfig) (conf hsoConfig, err error if conf.CORS, err = corsConfigFromParsed(pConf.Namespace(hsoFieldCORS)); err != nil { return } + if conf.WriteWait, err = pConf.FieldDuration(hsoFieldWriteWait); err != nil { + return + } + if conf.PongWait, err = pConf.FieldDuration(hsoFieldPongWait); err != nil { + return + } + if conf.PingPeriod, err = pConf.FieldDuration(hsoFieldPingPeriod); err != nil { + return + } return } @@ -145,6 +160,18 @@ Please note, messages are considered delivered as soon as the data is written to Advanced(). Default(""), service.NewInternalField(corsSpec), + service.NewDurationField(hsoFieldWriteWait). + Description("The time allowed to write a message to the websocket."). + Default("10s"). + Advanced(), + service.NewDurationField(hsoFieldPongWait). + Description("The time allowed to read the next pong message from the client."). + Default("60s"). + Advanced(), + service.NewDurationField(hsoFieldPingPeriod). + Description("Send pings to client with this period. Must be less than pong wait."). + Default("54s"). + Advanced(), ) } @@ -393,50 +420,93 @@ func (h *httpServerOutput) wsHandler(w http.ResponseWriter, r *http.Request) { defer func() { if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) - h.log.Warn("Websocket request failed: %v\n", err) + h.log.Warn("WebSocket request failed: %v", err) return } }() upgrader := websocket.Upgrader{} - var ws *websocket.Conn - if ws, err = upgrader.Upgrade(w, r, nil); err != nil { + // Upgrade the HTTP connection to a WebSocket connection + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.log.Warn("WebSocket upgrade failed: %v", err) return } defer ws.Close() - ctx, done := h.shutSig.SoftStopCtx(r.Context()) - defer done() + ws.SetReadLimit(512) + if err := ws.SetReadDeadline(time.Now().Add(h.conf.PongWait)); err != nil { + h.log.Warn("Failed to set read deadline: %v", err) + return + } - for !h.shutSig.IsSoftStopSignalled() { - var ts message.Transaction - var open bool + ws.SetPongHandler(func(string) error { + return ws.SetReadDeadline(time.Now().Add(h.conf.PongWait)) + }) + // Start a goroutine to read messages (to process control frames) + done := make(chan struct{}) + go func() { + defer close(done) + for { + _, _, err := ws.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + h.log.Warn("WebSocket read error: %v", err) + } + break + } + } + }() + + // Start ticker to send ping messages to the client periodically + ticker := time.NewTicker(h.conf.PingPeriod) + defer ticker.Stop() + + ctx, doneCtx := h.shutSig.SoftStopCtx(r.Context()) + defer doneCtx() + + for !h.shutSig.IsSoftStopSignalled() { select { - case ts, open = <-h.transactions: + case ts, open := <-h.transactions: if !open { + // If the transactions channel is closed, trigger server shutdown go h.TriggerCloseNow() return } - case <-r.Context().Done(): + // Write messages to the client + var writeErr error + for _, msg := range message.GetAllBytes(ts.Payload) { + _ = ws.SetWriteDeadline(time.Now().Add(h.conf.WriteWait)) + if writeErr = ws.WriteMessage(websocket.BinaryMessage, msg); writeErr != nil { + break + } + h.mWSBatchSent.Incr(1) + h.mWSSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload))) + } + if writeErr != nil { + h.mWSError.Incr(1) + _ = ts.Ack(ctx, writeErr) + return // Exit the loop on write error + } + _ = ts.Ack(ctx, nil) + case <-ticker.C: + // Send a ping message to the client + //nolint:errcheck // this function does not actually return an error + ws.SetWriteDeadline(time.Now().Add(h.conf.WriteWait)) + if err := ws.WriteMessage(websocket.PingMessage, nil); err != nil { + h.log.Warn("WebSocket ping error: %v", err) + return + } + case <-done: + // The read goroutine has exited, indicating the client has disconnected + h.log.Debug("WebSocket client disconnected") return - case <-h.shutSig.SoftStopChan(): + case <-ctx.Done(): + // The context has been canceled (e.g., server is shutting down) return } - - var werr error - for _, msg := range message.GetAllBytes(ts.Payload) { - if werr = ws.WriteMessage(websocket.BinaryMessage, msg); werr != nil { - break - } - h.mWSBatchSent.Incr(1) - h.mWSSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload))) - } - if werr != nil { - h.mWSError.Incr(1) - } - _ = ts.Ack(ctx, werr) } } diff --git a/website/docs/components/outputs/http_server.md b/website/docs/components/outputs/http_server.md index 1150c1704..24840bbcb 100644 --- a/website/docs/components/outputs/http_server.md +++ b/website/docs/components/outputs/http_server.md @@ -58,6 +58,9 @@ output: cors: enabled: false allowed_origins: [] + write_wait: 10s + pong_wait: 60s + ping_period: 54s ``` @@ -172,4 +175,28 @@ An explicit list of origins that are allowed for CORS requests. Type: `array` Default: `[]` +### `write_wait` + +The time allowed to write a message to the websocket. + + +Type: `string` +Default: `"10s"` + +### `pong_wait` + +The time allowed to read the next pong message from the client. + + +Type: `string` +Default: `"60s"` + +### `ping_period` + +Send pings to client with this period. Must be less than pong wait. + + +Type: `string` +Default: `"54s"` +