Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support --filter-mark mark/[/mask] #471

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bpf/kprobe_pwru.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct {
struct config {
u32 netns;
u32 mark;
u32 mask;
u32 ifindex;
u8 output_meta: 1;
u8 output_tuple: 1;
Expand Down Expand Up @@ -218,7 +219,7 @@ filter_meta(struct sk_buff *skb) {
if (cfg->netns && get_netns(skb) != cfg->netns) {
return false;
}
if (cfg->mark && BPF_CORE_READ(skb, mark) != cfg->mark) {
if (cfg->mark && cfg->mask && (BPF_CORE_READ(skb, mark) & cfg->mask) != cfg->mark) {
return false;
}
if (cfg->ifindex != 0 && BPF_CORE_READ(skb, dev, ifindex) != cfg->ifindex) {
Expand Down
10 changes: 6 additions & 4 deletions internal/pwru/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ const (
var Version string = "version unknown"

type FilterCfg struct {
FilterNetns uint32
FilterMark uint32
FilterIfindex uint32
FilterNetns uint32
FilterMark uint32
FilterMarkMask uint32
FilterIfindex uint32

OutputFlags uint8
FilterFlags uint8
Expand All @@ -49,7 +50,8 @@ type FilterCfg struct {

func GetConfig(flags *Flags) (cfg FilterCfg, err error) {
cfg = FilterCfg{
FilterMark: flags.FilterMark,
FilterMark: flags.FilterMark,
FilterMarkMask: flags.FilterMarkMask,
}
cfg.FilterFlags |= IsSetMask
if flags.OutputSkb {
Expand Down
59 changes: 58 additions & 1 deletion internal/pwru/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package pwru
import (
"fmt"
"os"
"strconv"
"strings"

flag "github.com/spf13/pflag"
Expand All @@ -27,6 +28,7 @@ type Flags struct {

FilterNetns string
FilterMark uint32
FilterMarkMask uint32
FilterFunc string
FilterNonSkbFuncs []string
FilterTrackSkb bool
Expand Down Expand Up @@ -67,7 +69,7 @@ func (f *Flags) SetFlags() {
flag.StringVar(&f.FilterFunc, "filter-func", "", "filter kernel functions to be probed by name (exact match, supports RE2 regular expression)")
flag.StringSliceVar(&f.FilterNonSkbFuncs, "filter-non-skb-funcs", nil, "filter non-skb kernel functions to be probed (--filter-track-skb-by-stackid will be enabled)")
flag.StringVar(&f.FilterNetns, "filter-netns", "", "filter netns (\"/proc/<pid>/ns/net\", \"inode:<inode>\")")
flag.Uint32Var(&f.FilterMark, "filter-mark", 0, "filter skb mark")
flag.Var(newMarkFlagValue(&f.FilterMark, &f.FilterMarkMask), "filter-mark", "filter skb mark (format: mark[/mask], e.g., 0xa00/0xf00)")
flag.BoolVar(&f.FilterTrackSkb, "filter-track-skb", false, "trace a packet even if it does not match given filters (e.g., after NAT or tunnel decapsulation)")
flag.BoolVar(&f.FilterTrackSkbByStackid, "filter-track-skb-by-stackid", false, "trace a packet even after it is kfreed (e.g., traffic going through bridge)")
flag.BoolVar(&f.FilterTraceTc, "filter-trace-tc", false, "trace TC bpf progs")
Expand Down Expand Up @@ -155,3 +157,58 @@ type Event struct {
ParamThird uint64
CPU uint32
}

type markFlagValue struct {
mark *uint32
mask *uint32
}

func newMarkFlagValue(mark, mask *uint32) *markFlagValue {
return &markFlagValue{mark: mark, mask: mask}
}

func (f *markFlagValue) String() string {
if *f.mask == 0 {
return fmt.Sprintf("0x%x", *f.mark)
}
return fmt.Sprintf("0x%x/0x%x", *f.mark, *f.mask)
}

func (f *markFlagValue) Set(value string) error {
parts := strings.Split(value, "/")

mark, err := parseUint32HexOrDecimal(parts[0])
if err != nil {
return fmt.Errorf("invalid mark value: %v", err)
}
*f.mark = mark
*f.mask = 0xffffffff

if len(parts) > 1 {
mask, err := parseUint32HexOrDecimal(parts[1])
if err != nil {
return fmt.Errorf("invalid mask value: %v", err)
}
*f.mask = mask
}

return nil
}

func (f *markFlagValue) Type() string {
return "mark[/mask]"
}

func parseUint32HexOrDecimal(s string) (uint32, error) {
base := 10
if strings.HasPrefix(strings.ToLower(s), "0x") {
s = s[2:]
base = 16
}

val, err := strconv.ParseUint(s, base, 32)
if err != nil {
return 0, err
}
return uint32(val), nil
}
Loading