diff --git a/cmd/system-probe/modules/network_tracer.go b/cmd/system-probe/modules/network_tracer.go index 2fbbfeb2349c9f..11d8ac0cce838b 100644 --- a/cmd/system-probe/modules/network_tracer.go +++ b/cmd/system-probe/modules/network_tracer.go @@ -88,12 +88,13 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { httpMux.HandleFunc("/connections", utils.WithConcurrencyLimit(utils.DefaultMaxConcurrentRequests, func(w http.ResponseWriter, req *http.Request) { start := time.Now() id := getClientID(req) - cs, err := nt.tracer.GetActiveConnections(id) + cs, cleanup, err := nt.tracer.GetActiveConnections(id) if err != nil { log.Errorf("unable to retrieve connections: %s", err) w.WriteHeader(500) return } + defer cleanup() contentType := req.Header.Get("Accept") marshaler := marshal.GetMarshaler(contentType) writeConnections(w, marshaler, cs) @@ -157,12 +158,13 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { return } id := getClientID(req) - cs, err := nt.tracer.GetActiveConnections(id) + cs, cleanup, err := nt.tracer.GetActiveConnections(id) if err != nil { log.Errorf("unable to retrieve connections: %s", err) w.WriteHeader(500) return } + defer cleanup() utils.WriteAsJSON(w, httpdebugging.HTTP(cs.HTTP, cs.DNS)) }) @@ -173,12 +175,13 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { return } id := getClientID(req) - cs, err := nt.tracer.GetActiveConnections(id) + cs, cleanup, err := nt.tracer.GetActiveConnections(id) if err != nil { log.Errorf("unable to retrieve connections: %s", err) w.WriteHeader(500) return } + defer cleanup() utils.WriteAsJSON(w, kafkadebugging.Kafka(cs.Kafka)) }) @@ -189,12 +192,13 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { return } id := getClientID(req) - cs, err := nt.tracer.GetActiveConnections(id) + cs, cleanup, err := nt.tracer.GetActiveConnections(id) if err != nil { log.Errorf("unable to retrieve connections: %s", err) w.WriteHeader(500) return } + defer cleanup() utils.WriteAsJSON(w, postgresdebugging.Postgres(cs.Postgres)) }) @@ -205,12 +209,13 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { return } id := getClientID(req) - cs, err := nt.tracer.GetActiveConnections(id) + cs, cleanup, err := nt.tracer.GetActiveConnections(id) if err != nil { log.Errorf("unable to retrieve connections: %s", err) w.WriteHeader(500) return } + defer cleanup() utils.WriteAsJSON(w, redisdebugging.Redis(cs.Redis)) }) @@ -221,12 +226,13 @@ func (nt *networkTracer) Register(httpMux *module.Router) error { return } id := getClientID(req) - cs, err := nt.tracer.GetActiveConnections(id) + cs, cleanup, err := nt.tracer.GetActiveConnections(id) if err != nil { log.Errorf("unable to retrieve connections: %s", err) w.WriteHeader(500) return } + defer cleanup() utils.WriteAsJSON(w, httpdebugging.HTTP(cs.HTTP2, cs.DNS)) }) diff --git a/pkg/network/nettop/main.go b/pkg/network/nettop/main.go index 66f4fb589cea8b..a386bf59536669 100644 --- a/pkg/network/nettop/main.go +++ b/pkg/network/nettop/main.go @@ -51,10 +51,11 @@ func main() { printConns := func(now time.Time) { fmt.Printf("-- %s --\n", now) - cs, err := t.GetActiveConnections(fmt.Sprintf("%d", os.Getpid())) + cs, cleanup, err := t.GetActiveConnections(fmt.Sprintf("%d", os.Getpid())) if err != nil { fmt.Println(err) } + defer cleanup() for _, c := range cs.Conns { fmt.Println(network.ConnectionSummary(&c, cs.DNS)) } diff --git a/pkg/network/protocols/http2/testutils.go b/pkg/network/protocols/http2/testutils.go new file mode 100644 index 00000000000000..de4ac013dff5e4 --- /dev/null +++ b/pkg/network/protocols/http2/testutils.go @@ -0,0 +1,64 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2025-present Datadog, Inc. + +//go:build linux && test + +package http2 + +import ( + "context" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + + "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" +) + +// StartH2CServer starts a new HTTP/2 server with the given address and returns a function to stop it. +func StartH2CServer(t *testing.T, address string, isTLS bool) func() { + srv := &http.Server{ + Addr: address, + Handler: h2c.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + statusCode := testutil.StatusFromPath(r.URL.Path) + if statusCode == 0 { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(int(statusCode)) + } + defer func() { _ = r.Body.Close() }() + _, _ = io.Copy(w, r.Body) + }), &http2.Server{}), + IdleTimeout: 2 * time.Second, + } + + require.NoError(t, http2.ConfigureServer(srv, nil), "could not configure server") + + l, err := net.Listen("tcp", address) + require.NoError(t, err, "could not listen") + + if isTLS { + cert, key, err := testutil.GetCertsPaths() + require.NoError(t, err, "could not get certs paths") + go func() { + if err := srv.ServeTLS(l, cert, key); err != http.ErrServerClosed { + require.NoError(t, err, "could not serve TLS") + } + }() + } else { + go func() { + if err := srv.Serve(l); err != http.ErrServerClosed { + require.NoError(t, err, "could not serve") + } + }() + } + + return func() { _ = srv.Shutdown(context.Background()) } +} diff --git a/pkg/network/tracer/tracer.go b/pkg/network/tracer/tracer.go index 37885ffe9354ec..129367fb2924d1 100644 --- a/pkg/network/tracer/tracer.go +++ b/pkg/network/tracer/tracer.go @@ -419,7 +419,7 @@ func (t *Tracer) Stop() { } // GetActiveConnections returns the delta for connection info from the last time it was called with the same clientID -func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, error) { +func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, func(), error) { t.bufferLock.Lock() defer t.bufferLock.Unlock() if log.ShouldLog(log.TraceLvl) { @@ -430,11 +430,10 @@ func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, er buffer := network.ClientPool.Get(clientID) latestTime, active, err := t.getConnections(buffer.ConnectionBuffer) if err != nil { - return nil, fmt.Errorf("error retrieving connections: %s", err) + return nil, nil, fmt.Errorf("error retrieving connections: %s", err) } - usmStats, cleaners := t.usmMonitor.GetProtocolStats() - defer cleaners() + usmStats, cleanup := t.usmMonitor.GetProtocolStats() delta := t.state.GetDelta(clientID, latestTime, active, t.reverseDNS.GetDNSStats(), usmStats) ips := make(map[util.Address]struct{}, len(delta.Conns)/2) @@ -469,7 +468,7 @@ func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, er conns.PrebuiltAssets = netebpf.GetModulesInUse() t.lastCheck.Store(time.Now().Unix()) - return conns, nil + return conns, cleanup, nil } // RegisterClient registers a clientID with the tracer diff --git a/pkg/network/tracer/tracer_linux_test.go b/pkg/network/tracer/tracer_linux_test.go index a751c20c2fdd20..eccebb50ecbbbb 100644 --- a/pkg/network/tracer/tracer_linux_test.go +++ b/pkg/network/tracer/tracer_linux_test.go @@ -113,7 +113,9 @@ func (s *TracerSuite) TestTCPRemoveEntries() { defer c2.Close() assert.EventuallyWithT(t, func(ct *assert.CollectT) { - conn, ok := findConnection(c2.LocalAddr(), c2.RemoteAddr(), getConnections(ct, tr)) + conns, cleanup := getConnections(ct, tr) + defer cleanup() + conn, ok := findConnection(c2.LocalAddr(), c2.RemoteAddr(), conns) if !assert.True(ct, ok) { return } @@ -128,7 +130,9 @@ func (s *TracerSuite) TestTCPRemoveEntries() { // Make sure the first connection got cleaned up assert.EventuallyWithT(t, func(ct *assert.CollectT) { - _, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), getConnections(ct, tr)) + connections, cleanup := getConnections(ct, tr) + defer cleanup() + _, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.False(ct, ok) }, 5*time.Second, 100*time.Millisecond) @@ -178,7 +182,8 @@ func (s *TracerSuite) TestTCPRetransmit() { var conn *network.ConnectionStats require.EventuallyWithT(t, func(ct *assert.CollectT) { // Iterate through active connections until we find connection created above, and confirm send + recv counts - connections := getConnections(ct, tr) + connections, cleanup := getConnections(ct, tr) + defer cleanup() conn, _ = findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.NotNil(ct, conn) @@ -248,7 +253,8 @@ func (s *TracerSuite) TestTCPRetransmitSharedSocket() { t.Logf("local=%s remote=%s", c.LocalAddr(), c.RemoteAddr()) // Fetch all connections matching source and target address - allConnections := getConnections(t, tr) + allConnections, cleanup := getConnections(t, tr) + defer cleanup() conns := network.FilterConnections(allConnections, network.ByTuple(c.LocalAddr(), c.RemoteAddr())) require.Len(t, conns, numProcesses) @@ -309,7 +315,8 @@ func (s *TracerSuite) TestTCPRTT() { require.EventuallyWithT(t, func(ct *assert.CollectT) { // Fetch connection matching source and target address - allConnections := getConnections(ct, tr) + allConnections, cleanup := getConnections(ct, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), allConnections) if !assert.True(ct, ok) { return @@ -375,7 +382,9 @@ func (s *TracerSuite) TestTCPMiscount() { server.Shutdown() - conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), getConnections(t, tr)) + allConnections, cleanup := getConnections(t, tr) + defer cleanup() + conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), allConnections) if assert.True(t, ok) { // TODO this should not happen but is expected for now // we don't have the correct count since retries happened @@ -409,7 +418,8 @@ func (s *TracerSuite) TestConnectionExpirationRegression() { // Fetch connection matching source and target address // This will make sure to populate the state for this particular client - allConnections := getConnections(t, tr) + allConnections, cleanup := getConnections(t, tr) + defer cleanup() connectionStats, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), allConnections) require.True(t, ok) assert.Equal(t, uint64(len(payload)), connectionStats.Last.SentBytes) @@ -422,13 +432,15 @@ func (s *TracerSuite) TestConnectionExpirationRegression() { tr.ebpfTracer.Remove(connectionStats) // Since no bytes were send or received after we obtained the connectionStats, we should have 0 LastBytesSent - allConnections = getConnections(t, tr) + allConnections, cleanup2 := getConnections(t, tr) + defer cleanup2() connectionStats, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), allConnections) require.True(t, ok) assert.Equal(t, uint64(0), connectionStats.Last.SentBytes) // Finally, this connection should have been expired from the state - allConnections = getConnections(t, tr) + allConnections, cleanup3 := getConnections(t, tr) + defer cleanup3() _, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), allConnections) require.False(t, ok) } @@ -479,7 +491,8 @@ func (s *TracerSuite) TestConntrackExpiration() { return } - connections := getConnections(collect, tr) + connections, cleanup := getConnections(collect, tr) + defer cleanup() t.Log(connections) // for debugging failures var ok bool conn, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), connections) @@ -493,7 +506,8 @@ func (s *TracerSuite) TestConntrackExpiration() { // conntrack should still have the connection information since the connection is still // alive tr.config.TCPConnTimeout = time.Duration(-1) - _ = getConnections(t, tr) + _, cleanup1 := getConnections(t, tr) + defer cleanup1() assert.NotNil(t, tr.conntracker.GetTranslationForConn(&conn.ConnectionTuple), "translation should not have been deleted") @@ -501,7 +515,8 @@ func (s *TracerSuite) TestConntrackExpiration() { cmd := exec.Command("conntrack", "-D", "-s", c.LocalAddr().(*net.TCPAddr).IP.String(), "-d", c.RemoteAddr().(*net.TCPAddr).IP.String(), "-p", "tcp") out, err := cmd.CombinedOutput() require.NoError(t, err, "conntrack delete failed, output: %s", out) - _ = getConnections(t, tr) + _, cleanup2 := getConnections(t, tr) + defer cleanup2() assert.Nil(t, tr.conntracker.GetTranslationForConn(&conn.ConnectionTuple), "translation should have been deleted") @@ -544,7 +559,8 @@ func (s *TracerSuite) TestConntrackDelays() { require.NoError(t, err) require.EventuallyWithT(t, func(collect *assert.CollectT) { - connections := getConnections(collect, tr) + connections, cleanup := getConnections(collect, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.True(collect, ok) require.NotNil(collect, tr.conntracker.GetTranslationForConn(&conn.ConnectionTuple)) @@ -600,7 +616,8 @@ func (s *TracerSuite) TestTranslationBindingRegression() { }, 3*time.Second, 100*time.Millisecond, "timed out waiting for conntrack update") // Assert that the connection to 2.2.2.2 has an IPTranslation object bound to it - connections := getConnections(t, tr) + connections, cleanup := getConnections(t, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.True(t, ok) require.NotNil(t, conn.IPTranslation, "missing translation for connection") @@ -631,7 +648,8 @@ func (s *TracerSuite) TestUnconnectedUDPSendIPv6() { require.NoError(t, err) require.EventuallyWithT(t, func(ct *assert.CollectT) { - connections := getConnections(ct, tr) + connections, cleanup := getConnections(ct, tr) + defer cleanup() outgoing := network.FilterConnections(connections, func(cs network.ConnectionStats) bool { if cs.Type != network.UDP { return false @@ -738,7 +756,9 @@ func (s *TracerSuite) TestGatewayLookupEnabled() { var conn *network.ConnectionStats require.EventuallyWithT(t, func(ct *assert.CollectT) { var ok bool - conn, ok = findConnection(dnsClientAddr, dnsServerAddr, getConnections(ct, tr)) + connections, cleanup := getConnections(ct, tr) + defer cleanup() + conn, ok = findConnection(dnsClientAddr, dnsServerAddr, connections) require.True(ct, ok, "connection not found") }, 3*time.Second, 100*time.Millisecond) @@ -794,7 +814,9 @@ func (s *TracerSuite) TestGatewayLookupSubnetLookupError() { var c *network.ConnectionStats require.EventuallyWithT(t, func(ct *assert.CollectT) { var ok bool - c, ok = findConnection(dnsClientAddr, dnsServerAddr, getConnections(ct, tr)) + connections, cleanup := getConnections(ct, tr) + defer cleanup() + c, ok = findConnection(dnsClientAddr, dnsServerAddr, connections) require.True(ct, ok, "connection not found") }, 3*time.Second, 100*time.Millisecond, "connection not found") require.Nil(t, c.Via) @@ -807,7 +829,9 @@ func (s *TracerSuite) TestGatewayLookupSubnetLookupError() { dnsClientAddr = &net.UDPAddr{IP: net.ParseIP(clientIP), Port: clientPort} require.EventuallyWithT(t, func(ct *assert.CollectT) { var ok bool - c, ok = findConnection(dnsClientAddr, dnsServerAddr, getConnections(ct, tr)) + connections, cleanup := getConnections(ct, tr) + defer cleanup() + c, ok = findConnection(dnsClientAddr, dnsServerAddr, connections) require.True(ct, ok, "connection not found") }, 3*time.Second, 100*time.Millisecond, "connection not found") require.Nil(t, c.Via) @@ -910,7 +934,8 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() { require.EventuallyWithT(t, func(collect *assert.CollectT) { var ok bool - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() t.Log(conns) conn, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), conns) require.True(collect, ok) @@ -944,7 +969,8 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() { var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { var ok bool - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() t.Log(conns) conn, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), conns) require.True(collect, ok) @@ -975,7 +1001,9 @@ func (s *TracerSuite) TestGatewayLookupCrossNamespace() { require.EventuallyWithT(t, func(collect *assert.CollectT) { var ok bool - conn, ok = findConnection(dnsClientAddr, dnsServerAddr, getConnections(collect, tr)) + conns, cleanup := getConnections(collect, tr) + defer cleanup() + conn, ok = findConnection(dnsClientAddr, dnsServerAddr, conns) require.True(collect, ok) require.Equal(collect, network.OUTGOING, conn.Direction) }, 3*time.Second, 100*time.Millisecond) @@ -1018,7 +1046,8 @@ func (s *TracerSuite) TestConnectionAssured() { } require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), conns) require.True(collect, ok) require.Positive(collect, conn.Monotonic.SentBytes) @@ -1052,7 +1081,8 @@ func (s *TracerSuite) TestConnectionNotAssured() { require.NoError(t, err) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), conns) require.True(collect, ok) require.Positive(collect, conn.Monotonic.SentBytes) @@ -1122,7 +1152,8 @@ func (s *TracerSuite) TestDNATIntraHostIntegration() { _, err = conn.Read(bs) require.NoError(collect, err) - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() t.Log(conns) outgoing, _ = findConnection(conn.LocalAddr(), conn.RemoteAddr(), conns) @@ -1177,7 +1208,9 @@ func (s *TracerSuite) TestSelfConnect() { t.Logf("port is %d", port) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := network.FilterConnections(getConnections(collect, tr), func(cs network.ConnectionStats) bool { + allConnections, cleanup := getConnections(collect, tr) + defer cleanup() + conns := network.FilterConnections(allConnections, func(cs network.ConnectionStats) bool { return cs.SPort == uint16(port) && cs.DPort == uint16(port) && cs.Source.IsLoopback() && cs.Dest.IsLoopback() }) @@ -1279,7 +1312,8 @@ func testUDPPeekCount(t *testing.T, udpnet, ip string) { var incoming *network.ConnectionStats var outgoing *network.ConnectionStats require.EventuallyWithTf(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() newOutgoing, _ := findConnection(c.LocalAddr(), c.RemoteAddr(), conns) if newOutgoing != nil { outgoing = newOutgoing @@ -1347,7 +1381,8 @@ func testUDPPacketSumming(t *testing.T, udpnet, ip string) { var incoming *network.ConnectionStats var outgoing *network.ConnectionStats require.EventuallyWithTf(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() newOutgoing, _ := findConnection(c.LocalAddr(), c.RemoteAddr(), conns) if newOutgoing != nil { outgoing = newOutgoing @@ -1414,7 +1449,9 @@ func (s *TracerSuite) TestUDPPythonReusePort() { conns := map[network.ConnectionTuple]network.ConnectionStats{} require.EventuallyWithT(t, func(collect *assert.CollectT) { - _conns := network.FilterConnections(getConnections(collect, tr), func(cs network.ConnectionStats) bool { + allConnections, cleanup := getConnections(collect, tr) + defer cleanup() + _conns := network.FilterConnections(allConnections, func(cs network.ConnectionStats) bool { return cs.Type == network.UDP && cs.Source.IsLoopback() && cs.Dest.IsLoopback() && @@ -1533,7 +1570,8 @@ func testUDPReusePort(t *testing.T, udpnet string, ip string) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { // use t instead of ct because getConnections uses require (not assert), and we get a better error message that way - connections := getConnections(ct, tr) + connections, cleanup := getConnections(ct, tr) + defer cleanup() incoming, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), connections) if assert.True(ct, ok, "unable to find incoming connection") { @@ -1556,7 +1594,8 @@ func testUDPReusePort(t *testing.T, udpnet string, ip string) { }, 3*time.Second, 100*time.Millisecond) // log the connections at the end in case the test failed - connections := getConnections(t, tr) + connections, cleanup := getConnections(t, tr) + defer cleanup() for _, c := range connections.Conns { t.Log(c) } @@ -1665,7 +1704,8 @@ func (s *TracerSuite) TestSendfileRegression() { t.Logf("looking for connections %+v <-> %+v", c.LocalAddr(), c.RemoteAddr()) var outConn, inConn *network.ConnectionStats assert.EventuallyWithT(t, func(ct *assert.CollectT) { - conns := getConnections(ct, tr) + conns, cleanup := getConnections(ct, tr) + defer cleanup() t.Log(conns) newOutConn := network.FirstConnection(conns, network.ByType(connType), network.ByFamily(family), network.ByTuple(c.LocalAddr(), c.RemoteAddr())) if newOutConn != nil { @@ -1792,7 +1832,8 @@ func (s *TracerSuite) TestSendfileError() { c.Close() require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), conns) require.True(collect, ok) require.Equalf(collect, int64(0), int64(conn.Monotonic.SentBytes), "sendfile data wasn't properly traced") @@ -1902,7 +1943,8 @@ func (s *TracerSuite) TestShortWrite() { unix.Close(sk) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), conns) require.True(collect, ok) @@ -1991,7 +2033,9 @@ func (s *TracerSuite) TestBlockingReadCounts() { }, 10*time.Second, 100*time.Millisecond, "failed to get required bytes") require.EventuallyWithT(t, func(collect *assert.CollectT) { - conn, found := findConnection(c.(*net.TCPConn).LocalAddr(), c.(*net.TCPConn).RemoteAddr(), getConnections(collect, tr)) + connections, cleanup := getConnections(collect, tr) + defer cleanup() + conn, found := findConnection(c.(*net.TCPConn).LocalAddr(), c.(*net.TCPConn).RemoteAddr(), connections) require.True(collect, found) require.Equal(collect, uint64(read), conn.Monotonic.RecvBytes) }, 3*time.Second, 100*time.Millisecond) @@ -2036,7 +2080,8 @@ func (s *TracerSuite) TestPreexistingConnectionDirection() { var incoming, outgoing *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { - connections := getConnections(collect, tr) + connections, cleanup := getConnections(collect, tr) + defer cleanup() newOutgoing, _ := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) if newOutgoing != nil { outgoing = newOutgoing @@ -2130,7 +2175,8 @@ func testPreexistingEmptyIncomingConnectionDirection(t *testing.T, config *confi time.Sleep(250 * time.Millisecond) - conns := getConnections(t, tr) + conns, cleanup := getConnections(t, tr) + defer cleanup() _, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), conns) require.False(t, ok, "expected connection to not be found") } @@ -2172,7 +2218,8 @@ func (s *TracerSuite) TestUDPIncomingDirectionFix() { require.NoError(t, err) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() conn, _ := findConnection(net.UDPAddrFromAddrPort(ap), raddr, conns) require.NotNil(collect, conn) require.Equal(collect, network.OUTGOING, conn.Direction) @@ -2450,7 +2497,8 @@ LOOP: // not be in the closed state, so duration will the // timestamp of when it was created require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() conn, found := findConnection(c.LocalAddr(), srv.Addr(), conns) require.True(collect, found, "could not find connection") // all we can do is verify it is > 0 @@ -2459,7 +2507,9 @@ LOOP: require.NoError(t, c.Close(), "error closing client connection") require.EventuallyWithT(t, func(collect *assert.CollectT) { - conn, found := findConnection(c.LocalAddr(), srv.Addr(), getConnections(collect, tr)) + conns, cleanup := getConnections(collect, tr) + defer cleanup() + conn, found := findConnection(c.LocalAddr(), srv.Addr(), conns) require.True(collect, found, "could not find connection") require.True(collect, conn.IsClosed, "connection should be closed") // after closing the client connection, the duration should be @@ -2533,7 +2583,8 @@ func (s *TracerSuite) TestTCPFailureConnectionTimeout() { // Check if the connection was recorded as failed due to timeout require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() // 110 is the errno for ETIMEDOUT conn := findFailedConnection(t, localAddr, srvAddr, conns, 110) require.NotNil(collect, conn) @@ -2589,9 +2640,11 @@ func (s *TracerSuite) TestTCPFailureConnectionResetWithDNAT() { // Check if the connection was recorded as reset var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { + conns, cleanup := getConnections(collect, tr) + defer cleanup() // 104 is the errno for ECONNRESET // findFailedConnection gets `t` as it needs to log, it does not assert so no conversion is needed. - conn = findFailedConnection(t, c.LocalAddr().String(), serverAddr, getConnections(collect, tr), 104) + conn = findFailedConnection(t, c.LocalAddr().String(), serverAddr, conns, 104) require.NotNil(collect, conn) }, 3*time.Second, 100*time.Millisecond, "Failed connection not recorded properly") @@ -2730,7 +2783,8 @@ func (s *TracerSuite) TestTLSClassification() { validation: func(t *testing.T, tr *Tracer, port uint16, _ uint16) { // Verify that no TLS tags are set for this connection require.EventuallyWithT(t, func(ct *assert.CollectT) { - payload := getConnections(ct, tr) + payload, cleanup := getConnections(ct, tr) + defer cleanup() for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { t.Log("Unexpected TLS protocol detected for invalid handshake") @@ -2761,7 +2815,8 @@ func (s *TracerSuite) TestTLSClassification() { } func validateTLSTags(t *assert.CollectT, tr *Tracer, port uint16, scenario uint16) bool { - payload := getConnections(t, tr) + payload, cleanup := getConnections(t, tr) + defer cleanup() for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) && !c.TLSTags.IsEmpty() { tlsTags := c.TLSTags.GetDynamicTags() @@ -2840,7 +2895,8 @@ func waitForTracer(t *testing.T, tr *Tracer, srvAddr string) { require.NoError(collect, err) defer client.Close() - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() _, found := findConnection(client.LocalAddr(), client.RemoteAddr(), conns) require.True(collect, found) }, time.Second*15, time.Second) @@ -2889,7 +2945,8 @@ func (s *TracerSuite) TestRawTLSClient() { sendMessage(t, conn, handshake) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() c, found := findConnection(conn.LocalAddr(), conn.RemoteAddr(), conns) require.True(collect, found) assert.True(collect, c.ProtocolStack.Contains(protocols.TLS), "expected TLS protocol") @@ -2899,7 +2956,8 @@ func (s *TracerSuite) TestRawTLSClient() { sendMessage(t, conn, []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() c, found := findConnection(conn.LocalAddr(), conn.RemoteAddr(), conns) require.True(collect, found) assert.True(collect, c.ProtocolStack.Contains(protocols.TLS), "expected TLS protocol") @@ -2919,7 +2977,8 @@ func (s *TracerSuite) TestRawTLSClient() { sendMessage(t, conn, []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() c, found := findConnection(conn.LocalAddr(), conn.RemoteAddr(), conns) require.True(collect, found) assert.False(collect, c.ProtocolStack.Contains(protocols.TLS), "not expected TLS protocol") @@ -2929,7 +2988,8 @@ func (s *TracerSuite) TestRawTLSClient() { sendMessage(t, conn, handshake) require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() c, found := findConnection(conn.LocalAddr(), conn.RemoteAddr(), conns) require.True(collect, found) assert.False(collect, c.ProtocolStack.Contains(protocols.TLS), "not expected TLS protocol") diff --git a/pkg/network/tracer/tracer_test.go b/pkg/network/tracer/tracer_test.go index 099221bad97ce3..8e1d5d809c4eff 100644 --- a/pkg/network/tracer/tracer_test.go +++ b/pkg/network/tracer/tracer_test.go @@ -207,7 +207,8 @@ func (s *TracerSuite) TestTCPSendAndReceive() { var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { // Iterate through active connections until we find connection created above, and confirm send + recv counts - connections := getConnections(collect, tr) + connections, cleanup := getConnections(collect, tr) + defer cleanup() var ok bool conn, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.True(collect, ok) @@ -256,7 +257,9 @@ func (s *TracerSuite) TestTCPShortLived() { var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { var ok bool - conn, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), getConnections(collect, tr)) + connections, cleanup := getConnections(collect, tr) + defer cleanup() + conn, ok = findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.True(collect, ok) }, 3*time.Second, 100*time.Millisecond, "connection not found") @@ -275,7 +278,9 @@ func (s *TracerSuite) TestTCPShortLived() { assert.Equal(t, uint16(1), m.TCPEstablished) assert.Equal(t, uint16(1), m.TCPClosed) - _, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), getConnections(t, tr)) + connections, cleanup := getConnections(t, tr) + defer cleanup() + _, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) assert.False(t, ok) } @@ -320,7 +325,8 @@ func (s *TracerSuite) TestTCPOverIPv6() { r := bufio.NewReader(c) r.ReadBytes(byte('\n')) - connections := getConnections(t, tr) + connections, cleanup := getConnections(t, tr) + defer cleanup() conn, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.True(t, ok) @@ -374,7 +380,8 @@ func (s *TracerSuite) TestTCPCollectionDisabled() { r := bufio.NewReader(c) r.ReadBytes(byte('\n')) - connections := getConnections(t, tr) + connections, cleanup := getConnections(t, tr) + defer cleanup() // Confirm that we could not find connection created above _, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) @@ -409,7 +416,8 @@ func (s *TracerSuite) TestTCPConnsReported() { // for ebpfless, it takes time for the packet capture to arrive, so poll require.EventuallyWithT(t, func(collect *assert.CollectT) { // Test - connections := getConnections(collect, tr) + connections, cleanup := getConnections(collect, tr) + defer cleanup() // Server-side newForward, _ := findConnection(c.RemoteAddr(), c.LocalAddr(), connections) @@ -489,7 +497,8 @@ func testUDPSendAndReceive(t *testing.T, tr *Tracer, addr string) { // Iterate through active connections until we find connection created above, and confirm send + recv counts require.EventuallyWithT(t, func(ct *assert.CollectT) { // use t instead of ct because getConnections uses require (not assert), and we get a better error message - connections := getConnections(ct, tr) + connections, cleanup := getConnections(ct, tr) + defer cleanup() incoming, ok := findConnection(c.RemoteAddr(), c.LocalAddr(), connections) if assert.True(ct, ok, "unable to find incoming connection") { assert.Equal(ct, network.INCOMING, incoming.Direction) @@ -546,7 +555,8 @@ func (s *TracerSuite) TestUDPDisabled() { c.Read(make([]byte, serverMessageSize)) // Iterate through active connections until we find connection created above, and confirm send + recv counts - connections := getConnections(t, tr) + connections, cleanup := getConnections(t, tr) + defer cleanup() _, ok := findConnection(c.LocalAddr(), c.RemoteAddr(), connections) require.False(t, ok) @@ -572,7 +582,9 @@ func (s *TracerSuite) TestLocalDNSCollectionDisabled() { assert.NoError(t, err) // Iterate through active connections making sure there are no local DNS calls - for _, c := range getConnections(t, tr).Conns { + connections, cleanup := getConnections(t, tr) + defer cleanup() + for _, c := range connections.Conns { assert.False(t, isLocalDNS(c)) } } @@ -599,7 +611,9 @@ func (s *TracerSuite) TestLocalDNSCollectionEnabled() { // Iterate through active connections making sure theres at least one connection require.EventuallyWithT(t, func(collect *assert.CollectT) { - for _, c := range getConnections(collect, tr).Conns { + connections, cleanup := getConnections(collect, tr) + defer cleanup() + for _, c := range connections.Conns { if isLocalDNS(c) { return } @@ -636,7 +650,8 @@ func (s *TracerSuite) TestShouldSkipExcludedConnection() { require.EventuallyWithT(t, func(collect *assert.CollectT) { // Make sure we're not picking up 127.0.0.1:80 - cxs := getConnections(collect, tr) + cxs, cleanup := getConnections(collect, tr) + defer cleanup() for _, c := range cxs.Conns { assert.False(collect, c.Source.String() == "127.0.0.1" && c.SPort == 80, "connection %s should be excluded", c) assert.False(collect, c.Dest.String() == "127.0.0.1" && c.DPort == 80 && c.Type == network.TCP, "connection %s should be excluded", c) @@ -674,7 +689,8 @@ func (s *TracerSuite) TestShouldExcludeEmptyStatsConnection() { var zeroConn network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { - cxs := getConnections(collect, tr) + cxs, cleanup := getConnections(collect, tr) + defer cleanup() for _, c := range cxs.Conns { if c.Dest.String() == "127.0.0.1" && c.DPort == 80 { zeroConn = c @@ -685,7 +701,8 @@ func (s *TracerSuite) TestShouldExcludeEmptyStatsConnection() { }, 2*time.Second, 100*time.Millisecond) // next call should not have the same connection - cxs := getConnections(t, tr) + cxs, cleanup := getConnections(t, tr) + defer cleanup() found := false for _, c := range cxs.Conns { if c.Source == zeroConn.Source && c.SPort == zeroConn.SPort && @@ -987,11 +1004,11 @@ func initTracerState(t testing.TB, tr *Tracer) { require.NoError(t, err) } -func getConnections(t require.TestingT, tr *Tracer) *network.Connections { +func getConnections(t require.TestingT, tr *Tracer) (*network.Connections, func()) { // Iterate through active connections until we find connection created above, and confirm send + recv counts - connections, err := tr.GetActiveConnections(clientID) + connections, cleanup, err := tr.GetActiveConnections(clientID) require.NoError(t, err) - return connections + return connections, cleanup } func testDNSStats(t *testing.T, tr *Tracer, domain string, success, failure, timeout int, serverIP string) { @@ -1029,7 +1046,8 @@ func testDNSStats(t *testing.T, tr *Tracer, domain string, success, failure, tim } // Iterate through active connections until we find connection created above, and confirm send + recv counts - connections := getConnections(c, tr) + connections, cleanup := getConnections(c, tr) + defer cleanup() conn, ok := findConnection(dnsClientAddr, dnsServerAddr, connections) if passed := assert.True(c, ok); !passed { return @@ -1116,7 +1134,9 @@ func (s *TracerSuite) TestTCPEstablished() { // for ebpfless, wait for the packet capture to appear require.EventuallyWithT(t, func(collect *assert.CollectT) { - conn, ok = findConnection(laddr, raddr, getConnections(collect, tr)) + connections, cleanup := getConnections(collect, tr) + defer cleanup() + conn, ok = findConnection(laddr, raddr, connections) require.True(collect, ok) }, 3*time.Second, 100*time.Millisecond, "couldn't find connection") @@ -1129,7 +1149,9 @@ func (s *TracerSuite) TestTCPEstablished() { // Wait for the connection to be sent from the perf buffer require.EventuallyWithT(t, func(collect *assert.CollectT) { var ok bool - conn, ok = findConnection(laddr, raddr, getConnections(collect, tr)) + connections, cleanup := getConnections(collect, tr) + defer cleanup() + conn, ok = findConnection(laddr, raddr, connections) require.True(collect, ok) }, 3*time.Second, 100*time.Millisecond, "couldn't find connection") @@ -1164,7 +1186,9 @@ func (s *TracerSuite) TestTCPEstablishedPreExistingConn() { var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { var ok bool - conn, ok = findConnection(laddr, raddr, getConnections(collect, tr)) + connections, cleanup := getConnections(collect, tr) + defer cleanup() + conn, ok = findConnection(laddr, raddr, connections) require.True(collect, ok) }, 3*time.Second, 100*time.Millisecond, "couldn't find connection") @@ -1189,7 +1213,8 @@ func (s *TracerSuite) TestUnconnectedUDPSendIPv4() { require.NoError(t, err) require.EventuallyWithT(t, func(ct *assert.CollectT) { - connections := getConnections(ct, tr) + connections, cleanup := getConnections(ct, tr) + defer cleanup() outgoing := network.FilterConnections(connections, func(cs network.ConnectionStats) bool { return cs.DPort == uint16(remotePort) }) @@ -1220,7 +1245,8 @@ func (s *TracerSuite) TestConnectedUDPSendIPv6() { var outgoing []network.ConnectionStats require.EventuallyWithT(t, func(ct *assert.CollectT) { - connections := getConnections(ct, tr) + connections, cleanup := getConnections(ct, tr) + defer cleanup() outgoing = network.FilterConnections(connections, func(cs network.ConnectionStats) bool { return cs.DPort == uint16(remotePort) }) @@ -1270,7 +1296,8 @@ func (s *TracerSuite) TestTCPDirection() { var outgoingConns []network.ConnectionStats var incomingConns []network.ConnectionStats require.EventuallyWithTf(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() if len(outgoingConns) == 0 { outgoingConns = network.FilterConnections(conns, func(cs network.ConnectionStats) bool { return fmt.Sprintf("%s:%d", cs.Dest, cs.DPort) == serverAddr @@ -1314,7 +1341,8 @@ func (s *TracerSuite) TestTCPFailureConnectionRefused() { // Check if the connection was recorded as refused var foundConn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() // Check for the refusal record foundConn = findFailedConnectionByRemoteAddr(srvAddr, conns, 111) require.NotNil(collect, foundConn) @@ -1366,9 +1394,11 @@ func (s *TracerSuite) TestTCPFailureConnectionResetWithData() { // Check if the connection was recorded as reset var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { + connections, cleanup := getConnections(collect, tr) + defer cleanup() // 104 is the errno for ECONNRESET // findFailedConnection needs `t` for logging, hence no need to pass `collect`. - conn = findFailedConnection(t, c.LocalAddr().String(), serverAddr, getConnections(collect, tr), 104) + conn = findFailedConnection(t, c.LocalAddr().String(), serverAddr, connections, 104) require.NotNil(collect, conn) }, 3*time.Second, 100*time.Millisecond, "Failed connection not recorded properly") @@ -1416,7 +1446,8 @@ func (s *TracerSuite) TestTCPFailureConnectionResetNoData() { // Check if the connection was recorded as reset var conn *network.ConnectionStats require.EventuallyWithT(t, func(collect *assert.CollectT) { - conns := getConnections(collect, tr) + conns, cleanup := getConnections(collect, tr) + defer cleanup() // 104 is the errno for ECONNRESET // findFailedConnection needs `t` for logging, hence no need to pass `collect`. conn = findFailedConnection(t, c.LocalAddr().String(), serverAddr, conns, 104) @@ -1480,7 +1511,7 @@ func BenchmarkGetActiveConnections(b *testing.B) { require.NoError(b, err) laddr, raddr := c.LocalAddr(), c.RemoteAddr() c.Write([]byte("hello")) - connections := getConnections(b, tr) + connections, _ := getConnections(b, tr) conn, ok := findConnection(laddr, raddr, connections) require.True(b, ok) @@ -1491,7 +1522,8 @@ func BenchmarkGetActiveConnections(b *testing.B) { // Wait for the connection to be sent from the perf buffer require.Eventually(b, func() bool { var ok bool - conn, ok = findConnection(laddr, raddr, getConnections(b, tr)) + connections, _ := getConnections(b, tr) + conn, ok = findConnection(laddr, raddr, connections) return ok }, 3*time.Second, 10*time.Millisecond, "couldn't find connection") diff --git a/pkg/network/tracer/tracer_unsupported.go b/pkg/network/tracer/tracer_unsupported.go index a2f68f0b69c72d..f13ae3570fb09a 100644 --- a/pkg/network/tracer/tracer_unsupported.go +++ b/pkg/network/tracer/tracer_unsupported.go @@ -32,8 +32,8 @@ func NewTracer(_ *config.Config, _ telemetry.Component, _ statsd.ClientInterface func (t *Tracer) Stop() {} // GetActiveConnections is not implemented on this OS for Tracer -func (t *Tracer) GetActiveConnections(_ string) (*network.Connections, error) { - return nil, ebpf.ErrNotImplemented +func (t *Tracer) GetActiveConnections(_ string) (*network.Connections, func(), error) { + return nil, nil, ebpf.ErrNotImplemented } // GetNetworkID is not implemented on this OS for Tracer diff --git a/pkg/network/tracer/tracer_windows.go b/pkg/network/tracer/tracer_windows.go index 2f6695fb3fb7a3..2260b1d2183577 100644 --- a/pkg/network/tracer/tracer_windows.go +++ b/pkg/network/tracer/tracer_windows.go @@ -197,7 +197,7 @@ func (t *Tracer) Stop() { } // GetActiveConnections returns all active connections -func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, error) { +func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, func(), error) { t.connLock.Lock() defer t.connLock.Unlock() @@ -210,13 +210,13 @@ func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, er return !t.shouldSkipConnection(c) }) if err != nil { - return nil, fmt.Errorf("error retrieving open connections from driver: %w", err) + return nil, nil, fmt.Errorf("error retrieving open connections from driver: %w", err) } _, err = t.driverInterface.GetClosedConnectionStats(t.closedBuffer, func(c *network.ConnectionStats) bool { return !t.shouldSkipConnection(c) }) if err != nil { - return nil, fmt.Errorf("error retrieving closed connections from driver: %w", err) + return nil, nil, fmt.Errorf("error retrieving closed connections from driver: %w", err) } activeConnStats := buffer.Connections() closedConnStats := t.closedBuffer.Connections() @@ -250,7 +250,7 @@ func (t *Tracer) GetActiveConnections(clientID string) (*network.Connections, er conns.DNS = t.reverseDNS.Resolve(ips) conns.ConnTelemetry = t.state.GetTelemetryDelta(clientID, t.getConnTelemetry()) conns.HTTP = delta.HTTP - return conns, nil + return conns, func() {}, nil } // RegisterClient registers the client diff --git a/pkg/network/usm/tests/tracer_classification_test.go b/pkg/network/usm/tests/tracer_classification_test.go index 627759a943e1a2..0e023134232235 100644 --- a/pkg/network/usm/tests/tracer_classification_test.go +++ b/pkg/network/usm/tests/tracer_classification_test.go @@ -66,11 +66,11 @@ func setupTracer(t testing.TB, cfg *config.Config) *tracer.Tracer { return tr } -func getConnections(t require.TestingT, tr *tracer.Tracer) *network.Connections { +func getConnections(t require.TestingT, tr *tracer.Tracer) (*network.Connections, func()) { // Iterate through active connections until we find connection created above, and confirm send + recv counts - connections, err := tr.GetActiveConnections(clientID) + connections, cleanup, err := tr.GetActiveConnections(clientID) require.NoError(t, err) - return connections + return connections, cleanup } // testContext shares the context of a given test. @@ -339,7 +339,8 @@ func waitForConnectionsWithProtocol(t *testing.T, tr *tracer.Tracer, targetAddr, t.Logf("looking for server addr %s", serverAddr) var outgoing, incoming *network.ConnectionStats failed := !assert.Eventually(t, func() bool { - conns := getConnections(t, tr) + conns, cleanup := getConnections(t, tr) + defer cleanup() if outgoing == nil { for _, c := range network.FilterConnections(conns, func(cs network.ConnectionStats) bool { return cs.Direction == network.OUTGOING && cs.Type == network.TCP && fmt.Sprintf("%s:%d", cs.Dest, cs.DPort) == targetAddr diff --git a/pkg/network/usm/tests/tracer_usm_linux_test.go b/pkg/network/usm/tests/tracer_usm_linux_test.go index 5c3bc133c2ec0a..1b1aa0d24bc359 100644 --- a/pkg/network/usm/tests/tracer_usm_linux_test.go +++ b/pkg/network/usm/tests/tracer_usm_linux_test.go @@ -16,6 +16,7 @@ import ( "net" nethttp "net/http" "net/netip" + neturl "net/url" "os" "os/exec" "strconv" @@ -47,6 +48,7 @@ import ( netlink "github.com/DataDog/datadog-agent/pkg/network/netlink/testutil" "github.com/DataDog/datadog-agent/pkg/network/protocols" "github.com/DataDog/datadog-agent/pkg/network/protocols/amqp" + "github.com/DataDog/datadog-agent/pkg/network/protocols/http" "github.com/DataDog/datadog-agent/pkg/network/protocols/http/testutil" usmhttp2 "github.com/DataDog/datadog-agent/pkg/network/protocols/http2" "github.com/DataDog/datadog-agent/pkg/network/protocols/kafka" @@ -431,7 +433,8 @@ func (s *USMSuite) TestIgnoreTLSClassificationIfApplicationProtocolWasDetected() // Perform the TLS handshake require.NoError(t, tlsConn.Handshake()) require.EventuallyWithT(t, func(collect *assert.CollectT) { - payload := getConnections(collect, tr) + payload, cleanup := getConnections(collect, tr) + defer cleanup() for _, c := range payload.Conns { if c.DPort == srvPortU16 || c.SPort == srvPortU16 { require.Equal(collect, c.ProtocolStack.Contains(protocols.TLS), tt.shouldBeTLS) @@ -504,7 +507,8 @@ func (s *USMSuite) TestTLSClassification() { validation: func(t *testing.T, tr *tracer.Tracer) { // Iterate through active connections until we find connection created above require.EventuallyWithTf(t, func(collect *assert.CollectT) { - payload := getConnections(collect, tr) + payload, cleanup := getConnections(collect, tr) + defer cleanup() for _, c := range payload.Conns { if c.DPort == port && c.ProtocolStack.Contains(protocols.TLS) { return @@ -576,7 +580,8 @@ func (s *USMSuite) TestTLSClassificationAlreadyRunning() { // Iterate through active connections until we find connection created above var foundIncoming, foundOutgoing bool require.EventuallyWithTf(t, func(collect *assert.CollectT) { - payload := getConnections(collect, tr) + payload, cleanup := getConnections(collect, tr) + defer cleanup() for _, c := range payload.Conns { if !foundIncoming && c.DPort == uint16(portAsValue) && c.ProtocolStack.Contains(protocols.TLS) { @@ -2538,3 +2543,303 @@ func goTLSDetachPID(t *testing.T, pid int) { return !utils.IsProgramTraced(consts.USMModuleName, usm.GoTLSAttacherName, pid) }, 5*time.Second, 100*time.Millisecond, "process %v is still traced by Go-TLS after detaching", pid) } + +func testHTTPLikeSketches(t *testing.T, tr *tracer.Tracer, client *nethttp.Client, url string, isHTTP2 bool) { + parsedURL, err := neturl.Parse(url) + require.NoError(t, err) + + getReq, err := nethttp.NewRequest("GET", url, nil) + require.NoError(t, err) + + getResp, err := client.Do(getReq) + require.NoError(t, err) + defer getResp.Body.Close() + + postReq1, err := nethttp.NewRequest("POST", url, nil) + require.NoError(t, err) + + postResp1, err := client.Do(postReq1) + require.NoError(t, err) + defer postResp1.Body.Close() + + postReq2, err := nethttp.NewRequest("POST", url, nil) + require.NoError(t, err) + + postResp2, err := client.Do(postReq2) + require.NoError(t, err) + defer postResp2.Body.Close() + + var getRequestStats, postRequestsStats *http.RequestStats + require.EventuallyWithT(t, func(ct *assert.CollectT) { + conns, cleanup := getConnections(ct, tr) + defer cleanup() + + requests := conns.HTTP + if isHTTP2 { + requests = conns.HTTP2 + } + if getRequestStats == nil || postRequestsStats == nil { + require.True(ct, len(requests) > 0, "no requests") + } + + for key, stats := range requests { + if getRequestStats != nil && postRequestsStats != nil { + break + } + if key.Path.Content.Get() != parsedURL.Path { + continue + } + if key.Method.String() == "GET" { + getRequestStats = stats + continue + } + if key.Method.String() == "POST" { + postRequestsStats = stats + continue + } + } + + require.NotNil(ct, getRequestStats) + require.Len(ct, getRequestStats.Data, 1) + require.NotNil(ct, getRequestStats.Data[nethttp.StatusOK]) + require.Equal(ct, 1, getRequestStats.Data[nethttp.StatusOK].Count) + require.Nil(ct, getRequestStats.Data[nethttp.StatusOK].Latencies) + require.NotZero(ct, getRequestStats.Data[nethttp.StatusOK].FirstLatencySample) + + require.NotNil(ct, postRequestsStats) + require.Len(ct, postRequestsStats.Data, 1) + require.NotNil(ct, postRequestsStats.Data[nethttp.StatusOK]) + require.Equal(ct, 2, postRequestsStats.Data[nethttp.StatusOK].Count) + require.NotNil(ct, postRequestsStats.Data[nethttp.StatusOK].Latencies) + require.NotZero(ct, postRequestsStats.Data[nethttp.StatusOK].FirstLatencySample) + require.Equal(ct, float64(2), postRequestsStats.Data[nethttp.StatusOK].Latencies.GetCount()) + }, 10*time.Second, 1*time.Second) +} + +const ( + httpServerAddr = "127.0.0.1:8080" +) + +var ( + httpURL = "http://" + httpServerAddr + "/200/request-0" +) + +func skipIfKernelIsNotSupported(t *testing.T, minimalKernelVersion kernel.Version) { + if kv < minimalKernelVersion { + t.Skipf("skipping test, kernel version %s is not supported", kv) + } +} + +func testHTTPSketches(t *testing.T, tr *tracer.Tracer) { + srvDoneFn := testutil.HTTPServer(t, httpServerAddr, testutil.Options{ + EnableKeepAlive: true, + }) + t.Cleanup(srvDoneFn) + + client := new(nethttp.Client) + transport := nethttp.DefaultTransport.(*nethttp.Transport).Clone() + transport.ForceAttemptHTTP2 = false + transport.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) nethttp.RoundTripper) + + client.Transport = transport + + testHTTPLikeSketches(t, tr, client, httpURL, false) +} + +func testHTTP2Sketches(t *testing.T, tr *tracer.Tracer) { + skipIfKernelIsNotSupported(t, usmhttp2.MinimumKernelVersion) + srvDoneFn := usmhttp2.StartH2CServer(t, httpServerAddr, false) + t.Cleanup(srvDoneFn) + + client := &nethttp.Client{ + Transport: &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(_ context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + return net.Dial(network, addr) + }, + }, + } + + testHTTPLikeSketches(t, tr, client, httpURL, true) +} + +const ( + localhost = "127.0.0.1" +) + +func testKafkaSketches(t *testing.T, tr *tracer.Tracer) { + serverAddress := net.JoinHostPort(localhost, kafkaPort) + require.NoError(t, kafka.RunServer(t, localhost, kafkaPort)) + + topicName1 := fmt.Sprintf("test-topic-1-%d", time.Now().UnixNano()) + topicName2 := fmt.Sprintf("test-topic-2-%d", time.Now().UnixNano()) + + version := kversion.V3_4_0() + version.SetMaxKeyVersion(produceAPIKey, 10) + version.SetMaxKeyVersion(fetchAPIKey, 10) + client, err := kafka.NewClient(kafka.Options{ + ServerAddress: serverAddress, + CustomOptions: []kgo.Opt{kgo.MaxVersions(version)}, + }) + require.NoError(t, err) + + defer client.Client.Close() + + require.NoError(t, client.CreateTopic(topicName1)) + require.NoError(t, client.CreateTopic(topicName2)) + + record1 := &kgo.Record{Topic: topicName1, Value: []byte("Hello Kafka!")} + ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + require.NoError(t, client.Client.ProduceSync(ctxTimeout, record1).FirstErr(), "record had a produce error while synchronously producing") + require.NoError(t, client.Client.ProduceSync(ctxTimeout, record1).FirstErr(), "record had a produce error while synchronously producing") + + record2 := &kgo.Record{Topic: topicName2, Value: []byte("Hello Kafka!")} + ctxTimeout, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + require.NoError(t, client.Client.ProduceSync(ctxTimeout, record2).FirstErr(), "record had a produce error while synchronously producing") + + client.Client.AddConsumeTopics(topicName2) + fetches := client.Client.PollFetches(context.Background()) + require.Empty(t, fetches.Errors()) + require.Len(t, fetches.Records(), 1) + + var fetchRequestStats, produceRequestsStats *kafka.RequestStats + require.EventuallyWithT(t, func(ct *assert.CollectT) { + conns, cleanup := getConnections(ct, tr) + defer cleanup() + + requests := conns.Kafka + if fetchRequestStats == nil || produceRequestsStats == nil { + require.True(ct, len(requests) > 0, "no requests") + } + + for key, stats := range requests { + if fetchRequestStats != nil && produceRequestsStats != nil { + break + } + + if key.TopicName.Get() == topicName2 && key.RequestAPIKey == kafka.FetchAPIKey { + fetchRequestStats = stats + continue + } + if key.TopicName.Get() == topicName1 && key.RequestAPIKey == kafka.ProduceAPIKey { + produceRequestsStats = stats + continue + } + } + + require.NotNil(ct, fetchRequestStats) + require.Len(ct, fetchRequestStats.ErrorCodeToStat, 1) + require.NotNil(ct, fetchRequestStats.ErrorCodeToStat[0]) + require.Equal(ct, 1, fetchRequestStats.ErrorCodeToStat[0].Count) + require.Nil(ct, fetchRequestStats.ErrorCodeToStat[0].Latencies) + require.NotZero(ct, fetchRequestStats.ErrorCodeToStat[0].FirstLatencySample) + + require.NotNil(ct, produceRequestsStats) + require.Len(ct, produceRequestsStats.ErrorCodeToStat, 1) + require.NotNil(ct, produceRequestsStats.ErrorCodeToStat[0]) + require.Equal(ct, 2, produceRequestsStats.ErrorCodeToStat[0].Count) + require.NotNil(ct, produceRequestsStats.ErrorCodeToStat[0].Latencies) + require.NotZero(ct, produceRequestsStats.ErrorCodeToStat[0].FirstLatencySample) + require.Equal(ct, float64(2), produceRequestsStats.ErrorCodeToStat[0].Latencies.GetCount()) + }, 10*time.Second, 1*time.Second) +} + +func testPostgresSketches(t *testing.T, tr *tracer.Tracer) { + serverAddress := net.JoinHostPort(localhost, postgresPort) + require.NoError(t, pgutils.RunServer(t, localhost, postgresPort, false)) + // Verifies that the postgres server is up and running. + // It tries to connect to the server until it succeeds or the timeout is reached. + // We need that function (and cannot relay on the RunServer method) as the target regex is being logged a couple os + // milliseconds before the server is actually ready to accept connections. + waitForPostgresServer(t, serverAddress, false) + + pg := pgutils.NewPGClient(pgutils.ConnectionOptions{ + ServerAddress: serverAddress, + }) + require.NoError(t, pg.RunCreateQuery()) + require.NoError(t, pg.RunInsertQuery(1)) + require.NoError(t, pg.RunInsertQuery(2)) + require.NoError(t, pg.RunSelectQuery()) + + var insertRequestStats, selectRequestsStats *pgutils.RequestStat + require.EventuallyWithT(t, func(ct *assert.CollectT) { + conns, cleanup := getConnections(ct, tr) + defer cleanup() + + requests := conns.Postgres + if insertRequestStats == nil || selectRequestsStats == nil { + require.True(ct, len(requests) > 0, "no requests") + } + + for key, stats := range requests { + if selectRequestsStats != nil && insertRequestStats != nil { + break + } + + if key.Parameters == "dummy" && key.Operation == pgutils.SelectOP { + selectRequestsStats = stats + continue + } + if key.Parameters == "dummy" && key.Operation == pgutils.InsertOP { + insertRequestStats = stats + continue + } + } + + require.NotNil(ct, selectRequestsStats) + require.Equal(ct, 1, selectRequestsStats.Count) + require.Nil(ct, selectRequestsStats.Latencies) + require.NotZero(ct, selectRequestsStats.FirstLatencySample) + + require.NotNil(ct, insertRequestStats) + require.Equal(ct, 2, insertRequestStats.Count) + require.NotNil(ct, insertRequestStats.Latencies) + require.NotZero(ct, insertRequestStats.FirstLatencySample) + require.Equal(ct, float64(2), insertRequestStats.Latencies.GetCount()) + }, 10*time.Second, 1*time.Second) +} + +func (s *USMSuite) TestVerifySketches() { + t := s.T() + skipIfKernelIsNotSupported(t, usmconfig.MinimumKernelVersion) + + cfg := utils.NewUSMEmptyConfig() + cfg.EnableHTTPMonitoring = true + cfg.EnableHTTP2Monitoring = kv >= usmhttp2.MinimumKernelVersion + cfg.EnableKafkaMonitoring = true + cfg.EnablePostgresMonitoring = true + + tr, err := tracer.NewTracer(cfg, nil, nil) + require.NoError(t, err) + t.Cleanup(tr.Stop) + require.NoError(t, tr.RegisterClient(clientID)) + + tests := []struct { + name string + testFunc func(t *testing.T, tr *tracer.Tracer) + }{ + { + name: "http", + testFunc: testHTTPSketches, + }, + { + name: "http2", + testFunc: testHTTP2Sketches, + }, + { + name: "kafka", + testFunc: testKafkaSketches, + }, + { + name: "postgres", + testFunc: testPostgresSketches, + }, + } + for _, tt := range tests { + s.Run(tt.name, func() { + tt.testFunc(s.T(), tr) + }) + } +} diff --git a/pkg/network/usm/usm_http2_monitor_test.go b/pkg/network/usm/usm_http2_monitor_test.go index 2c449f93233c7c..b073cb8e0883e3 100644 --- a/pkg/network/usm/usm_http2_monitor_test.go +++ b/pkg/network/usm/usm_http2_monitor_test.go @@ -31,7 +31,6 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "golang.org/x/net/http2/hpack" ddebpf "github.com/DataDog/datadog-agent/pkg/ebpf" @@ -139,7 +138,7 @@ func (s *usmHTTP2Suite) TestHTTP2DynamicTableCleanup() { cfg.HTTP2DynamicTableMapCleanerInterval = 5 * time.Second // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -201,7 +200,7 @@ func (s *usmHTTP2Suite) TestSimpleHTTP2() { cfg := s.getCfg() // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -396,7 +395,7 @@ func (s *usmHTTP2Suite) TestHTTP2KernelTelemetry() { cfg := s.getCfg() // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -526,7 +525,7 @@ func (s *usmHTTP2Suite) TestHTTP2ManyDifferentPaths() { cfg := s.getCfg() // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -593,7 +592,7 @@ func (s *usmHTTP2Suite) TestRawTraffic() { usmMonitor := setupUSMTLSMonitor(t, cfg, useExistingConsumer) // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -1355,7 +1354,7 @@ func (s *usmHTTP2Suite) TestDynamicTable() { cfg := s.getCfg() // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -1436,7 +1435,7 @@ func (s *usmHTTP2Suite) TestIncompleteFrameTable() { cfg := s.getCfg() // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -1510,7 +1509,7 @@ func (s *usmHTTP2Suite) TestRawHuffmanEncoding() { cfg := s.getCfg() // Start local server and register its cleanup. - t.Cleanup(startH2CServer(t, authority, s.isTLS)) + t.Cleanup(usmhttp2.StartH2CServer(t, authority, s.isTLS)) // Start the proxy server. proxyProcess, cancel := proxy.NewExternalUnixTransparentProxyServer(t, unixPath, authority, s.isTLS) @@ -1895,46 +1894,6 @@ func getExpectedOutcomeForPathWithRepeatedChars() map[usmhttp.Key]captureRange { return expected } -func startH2CServer(t *testing.T, address string, isTLS bool) func() { - srv := &http.Server{ - Addr: authority, - Handler: h2c.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - statusCode := testutil.StatusFromPath(r.URL.Path) - if statusCode == 0 { - w.WriteHeader(http.StatusOK) - } else { - w.WriteHeader(int(statusCode)) - } - defer func() { _ = r.Body.Close() }() - _, _ = io.Copy(w, r.Body) - }), &http2.Server{}), - IdleTimeout: 2 * time.Second, - } - - require.NoError(t, http2.ConfigureServer(srv, nil), "could not configure server") - - l, err := net.Listen("tcp", address) - require.NoError(t, err, "could not listen") - - if isTLS { - cert, key, err := testutil.GetCertsPaths() - require.NoError(t, err, "could not get certs paths") - go func() { - if err := srv.ServeTLS(l, cert, key); err != http.ErrServerClosed { - require.NoError(t, err, "could not serve TLS") - } - }() - } else { - go func() { - if err := srv.Serve(l); err != http.ErrServerClosed { - require.NoError(t, err, "could not serve") - } - }() - } - - return func() { _ = srv.Shutdown(context.Background()) } -} - func getClientsIndex(index, totalCount int) int { return index % totalCount }