Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete the inverted key in connection_states eBPF map on tcp termination #34586

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ static __always_inline void protocol_dispatcher_entrypoint(struct __sk_buff *skb

if (tcp_termination) {
bpf_map_delete_elem(&connection_states, &skb_tup);
// remove also the inverted key
conn_tuple_t flip_tup = {0};
bpf_memcpy(&flip_tup, &skb_tup, sizeof(conn_tuple_t));
flip_tuple(&flip_tup);
bpf_map_delete_elem(&connection_states, &flip_tup);
}

protocol_stack_t *stack = get_protocol_stack_if_exists(&skb_tup);
Expand Down
156 changes: 156 additions & 0 deletions pkg/network/usm/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@ import (
"math/rand"
"net"
nethttp "net/http"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"unsafe"

netebpf "github.com/DataDog/datadog-agent/pkg/network/ebpf"
"github.com/DataDog/datadog-agent/pkg/process/util"
manager "github.com/DataDog/ebpf-manager"
"github.com/cilium/ebpf"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -808,3 +813,154 @@ func TestCleanProtocolMaps(t *testing.T) {
})
}
}

const (
connectionStatesMapName = "connection_states"
serverHost = "127.0.0.1"
clientHost = "127.0.0.1"
)

// TestTCPConnectionStatesMap verifies that the map 'connection_states' is cleared after the TCP connection closes.
func TestTCPConnectionStatesMap(t *testing.T) {
if kv < usmconfig.MinimumKernelVersion {
t.Skipf("USM is not supported on %v", kv)
}
cfg := utils.NewUSMEmptyConfig()
cfg.EnableHTTPMonitoring = true
monitor := setupUSMTLSMonitor(t, cfg, useExistingConsumer)
require.NotNil(t, monitor)

statesMap, _, _ := monitor.ebpfProgram.Manager.Manager.GetMap(connectionStatesMapName)
require.NotNil(t, statesMap)
require.Equal(t, statesMap.KeySize(), uint32(unsafe.Sizeof(netebpf.ConnTuple{})), "wrong key size")

serverPort := uint16(4949)
runTCPServer(t, serverPort)

testTCPConnectionStatesFIN(t, monitor, statesMap, serverPort)
testTCPConnectionStatesRST(t, monitor, statesMap, serverPort)
}

// testTCPConnectionStatesFIN establishes a TCP connection, verifies the presence of connection entries in the map,
// then confirms the entries are removed after FIN segment.
func testTCPConnectionStatesFIN(t *testing.T, monitor *Monitor, statesMap *ebpf.Map, serverPort uint16) {
conn, err := net.Dial("tcp", net.JoinHostPort(serverHost, strconv.Itoa(int(serverPort))))
require.NoError(t, err)
defer conn.Close()

_, err = conn.Write([]byte("Hello"))
require.NoError(t, err)

clientPort := uint16(conn.LocalAddr().(*net.TCPAddr).Port)
tuples := makeKeys(t, clientHost, serverHost, clientPort, serverPort)
checkKeysInMap(t, monitor, statesMap, tuples, true)

tcpConn, ok := conn.(*net.TCPConn)
require.True(t, ok, "unexpected *net.TCPConn %T", conn)

// tells the TCP stack that the client will no longer send data, the TCP stack then sends a FIN segment.
err = tcpConn.CloseWrite()
require.NoError(t, err)

checkKeysInMap(t, monitor, statesMap, tuples, false)
}

// testTCPConnectionStatesRST establishes a TCP connection, verifies the presence of connection entries in the map,
// confirms that entries are removed after RST segment.
func testTCPConnectionStatesRST(t *testing.T, monitor *Monitor, statesMap *ebpf.Map, serverPort uint16) {
conn, err := net.Dial("tcp", net.JoinHostPort(serverHost, strconv.Itoa(int(serverPort))))
require.NoError(t, err)

_, err = conn.Write([]byte("Hello"))
require.NoError(t, err)

clientPort := uint16(conn.LocalAddr().(*net.TCPAddr).Port)
tuples := makeKeys(t, clientHost, serverHost, clientPort, serverPort)
checkKeysInMap(t, monitor, statesMap, tuples, true)

// explicitly closing an active connection triggers a RST segment
err = conn.Close()
require.NoError(t, err)

checkKeysInMap(t, monitor, statesMap, tuples, false)
}

// checkKeysInMap checks that all keys are present in the map ('exist'==true) or keys are missing from the map ('exist'==false).
func checkKeysInMap(t *testing.T, monitor *Monitor, m *ebpf.Map, tuples []netebpf.ConnTuple, exist bool) {
require.Equal(t, 2, len(tuples))
require.Eventually(t, func() bool {
if exist {
return findAllKeysInMap(m, tuples)
}
return !findAllKeysInMap(m, tuples)
}, 500*time.Millisecond, 50*time.Millisecond)
if t.Failed() {
t.Logf("failed search keys %v, %v exist: %t", tuples[0], tuples[1], exist)
ebpftest.DumpMapsTestHelper(t, monitor.DumpMaps, connectionStatesMapName)
t.FailNow()
}
}

// findAllKeysInMap returns true if all specified keys are present in the map.
func findAllKeysInMap(m *ebpf.Map, keys []netebpf.ConnTuple) bool {
set := make(map[netebpf.ConnTuple]struct{}, len(keys))

var key netebpf.ConnTuple
value := make([]byte, m.ValueSize())
iter := m.Iterate()
for iter.Next(unsafe.Pointer(&key), unsafe.Pointer(&value)) {
set[key] = struct{}{}
}
for _, k := range keys {
if _, exists := set[k]; !exists {
return false
}
}
return true
}

// makeKeys makes keys (connection tuples) for searching in the map.
func makeKeys(t *testing.T, src, dst string, srcPort, dstPort uint16) []netebpf.ConnTuple {
srcAddr, err := netip.ParseAddr(src)
require.NoError(t, err)
srcLow, srcHigh := util.ToLowHighIP(srcAddr)

dstAddr, err := netip.ParseAddr(dst)
require.NoError(t, err)
dstLow, dstHigh := util.ToLowHighIP(dstAddr)

tuples := []netebpf.ConnTuple{
{
Saddr_h: srcHigh,
Saddr_l: srcLow,
Daddr_h: dstHigh,
Daddr_l: dstLow,
Sport: srcPort,
Dport: dstPort,
Metadata: uint32(netebpf.TCP),
},
{
Saddr_h: dstHigh,
Saddr_l: dstLow,
Daddr_h: srcHigh,
Daddr_l: srcLow,
Sport: dstPort,
Dport: srcPort,
Metadata: uint32(netebpf.TCP),
},
}
return tuples
}

func runTCPServer(t *testing.T, serverPort uint16) {
serverPath := net.JoinHostPort(serverHost, strconv.Itoa(int(serverPort)))
server := testutil.NewTCPServer(serverPath, func(c net.Conn) {
_, _ = io.Copy(c, c)
}, false)
require.NotNil(t, server)
require.Equal(t, serverPath, server.Address())

done := make(chan struct{})
server.Run(done)
t.Cleanup(func() { close(done) })
}