Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support forwarded IPs using a configurable header #178

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ testData:
maxretry: 4
enabled: true
statuscode: "400,401,403-499"
sourceCriterion:
requestHeaderName: "CF-Connecting-IP"
ipStrategy:
depth: 1
```

Where:
Expand All @@ -116,6 +120,38 @@ enable the plugin).
- `urlregexp`: a regexp list to block / allow requests with regexps on the url
- `statuscode`: a comma separated list of status code (or range of status
codes) to consider as a failed request.
- `sourceCriterion`: (optional) configures how to determine the client IP:
- `requestHeaderName`: specifies which header to use for client IP identification
- `ipStrategy`: configures how to extract the IP from the header:
- `depth`: (optional) which IP to pick from right when header contains multiple IPs

Common header values include:
- `X-Forwarded-For` - Standard proxy header
- `X-Real-IP` - Often used by Nginx
- `CF-Connecting-IP` - Cloudflare
- `True-Client-IP` - Akamai and Cloudflare

When no sourceCriterion is specified or when the specified header is missing/invalid,
the plugin falls back to using the remote address.

Examples:
```yaml
# Use first IP from X-Forwarded-For
sourceCriterion:
requestHeaderName: "X-Forwarded-For"

# Use last IP from X-Forwarded-For
sourceCriterion:
requestHeaderName: "X-Forwarded-For"
ipStrategy:
depth: 1

# Use second to last IP from X-Forwarded-For
sourceCriterion:
requestHeaderName: "X-Forwarded-For"
ipStrategy:
depth: 2
```

#### URL Regexp
Urlregexp are used to defined witch part of your website will be either
Expand Down
8 changes: 5 additions & 3 deletions fail2ban.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ type Config struct {
func CreateConfig() *Config {
return &Config{
Rules: rules.Rules{
Bantime: "300s",
Findtime: "120s",
Enabled: true,
Bantime: "300s",
Findtime: "120s",
Enabled: true,
SourceCriterion: rules.SourceCriterion{}, // Empty SourceCriterion for default behavior
},
}
}
Expand Down Expand Up @@ -136,6 +137,7 @@ func New(_ context.Context, next http.Handler, config *Config, _ string) (http.H

c := chain.New(
next,
rules.SourceCriterion,
denyHandler,
allowHandler,
uDeny.New(rules.URLRegexpBan, f2b),
Expand Down
18 changes: 18 additions & 0 deletions fail2ban_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ func TestFail2Ban(t *testing.T) {
t.Parallel()

remoteAddr := "10.0.0.0"
testHeaderName := "X-Forwarded-For"
testDepth := 1

tests := []struct {
name string
url string
Expand All @@ -133,6 +136,12 @@ func TestFail2Ban(t *testing.T) {
Enabled: true,
Findtime: "300s",
Maxretry: 20,
SourceCriterion: rules.SourceCriterion{
RequestHeaderName: &testHeaderName,
IPStrategy: &rules.IPStrategy{
Depth: &testDepth,
},
},
},
},
newError: true,
Expand All @@ -145,6 +154,9 @@ func TestFail2Ban(t *testing.T) {
Enabled: true,
Bantime: "300s",
Maxretry: 20,
SourceCriterion: rules.SourceCriterion{
RequestHeaderName: &testHeaderName,
},
},
},
newError: true,
Expand All @@ -158,6 +170,12 @@ func TestFail2Ban(t *testing.T) {
Bantime: "300s",
Findtime: "300s",
Maxretry: 20,
SourceCriterion: rules.SourceCriterion{
RequestHeaderName: &testHeaderName,
IPStrategy: &rules.IPStrategy{
Depth: &testDepth,
},
},
},
},
newError: false,
Expand Down
23 changes: 13 additions & 10 deletions pkg/chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"

"github.com/tomMoulard/fail2ban/pkg/data"
"github.com/tomMoulard/fail2ban/pkg/rules"
)

// Status is a status that can be returned by a handler.
Expand All @@ -30,16 +31,18 @@ type Chain interface {
}

type chain struct {
handlers []ChainHandler
final http.Handler
status *http.Handler
handlers []ChainHandler
final http.Handler
status *http.Handler
sourceCriterion rules.SourceCriterion
}

// New creates a new chain.
func New(final http.Handler, handlers ...ChainHandler) Chain {
func New(final http.Handler, sourceCriterion rules.SourceCriterion, handlers ...ChainHandler) Chain {
return &chain{
handlers: handlers,
final: final,
handlers: handlers,
final: final,
sourceCriterion: sourceCriterion,
}
}

Expand All @@ -50,15 +53,15 @@ func (c *chain) WithStatus(status http.Handler) {

// ServeHTTP chains the handlers together, and calls the final handler at the end.
func (c *chain) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r, err := data.ServeHTTP(w, r)
req, err := data.ServeHTTP(w, r, c.sourceCriterion)
if err != nil {
log.Printf("data.ServeHTTP error: %v", err)

return
}

for _, handler := range c.handlers {
s, err := handler.ServeHTTP(w, r)
s, err := handler.ServeHTTP(w, req)
if err != nil {
log.Printf("handler.ServeHTTP error: %v", err)

Expand All @@ -81,10 +84,10 @@ func (c *chain) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

if c.status != nil {
(*c.status).ServeHTTP(w, r)
(*c.status).ServeHTTP(w, req)

return
}

c.final.ServeHTTP(w, r)
c.final.ServeHTTP(w, req)
}
17 changes: 12 additions & 5 deletions pkg/chain/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tomMoulard/fail2ban/pkg/data"
"github.com/tomMoulard/fail2ban/pkg/rules"
)

var testHeaderName = "X-Forwarded-For"

var testSourceCriterion = rules.SourceCriterion{
RequestHeaderName: &testHeaderName,
}

type mockHandler struct {
called int
err error
Expand Down Expand Up @@ -100,10 +107,10 @@ func TestChain(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

c := New(test.finalHandler, test.handlers...)
c := New(test.finalHandler, testSourceCriterion, test.handlers...)
recorder := &httptest.ResponseRecorder{}
req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
req, err := data.ServeHTTP(recorder, req)
req, err := data.ServeHTTP(recorder, req, testSourceCriterion)
require.NoError(t, err)

c.ServeHTTP(recorder, req)
Expand Down Expand Up @@ -142,7 +149,7 @@ func TestChainOrder(t *testing.T) {
expectedCalled: 1,
}

ch := New(final, a, b, c)
ch := New(final, testSourceCriterion, a, b, c)
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
ch.ServeHTTP(nil, r)

Expand Down Expand Up @@ -176,7 +183,7 @@ func TestChainRequestContext(t *testing.T) {
expectedCalled: 1,
}

ch := New(final, handler)
ch := New(final, testSourceCriterion, handler)
r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
ch.ServeHTTP(nil, r)

Expand All @@ -192,7 +199,7 @@ func TestChainWithStatus(t *testing.T) {
final := &mockHandler{expectedCalled: 0}
status := &mockHandler{expectedCalled: 1}

ch := New(final, handler)
ch := New(final, testSourceCriterion, handler)
ch.WithStatus(status)

r := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
Expand Down
10 changes: 9 additions & 1 deletion pkg/chain/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/tomMoulard/fail2ban/pkg/chain"
"github.com/tomMoulard/fail2ban/pkg/data"
"github.com/tomMoulard/fail2ban/pkg/rules"
)

type PongHandler struct{}
Expand All @@ -29,9 +30,16 @@ func Example() {
// This example shows how to chain handlers together.
// The final handler is called only if all the previous handlers did not
// return an error.
// Setup source criterion configuration
headerName := "X-Forwarded-For"

sourceCriterion := rules.SourceCriterion{
RequestHeaderName: &headerName,
}

// Create a new chain with a final h.
h := &Handler{}
c := chain.New(&PongHandler{}, h)
c := chain.New(&PongHandler{}, sourceCriterion, h)

// Create a new request.
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
Expand Down
94 changes: 79 additions & 15 deletions pkg/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,104 @@ package data

import (
"context"
"errors"
"fmt"
"net"
"net/http"
)

type key string
"strings"

const contextDataKey key = "data"
"github.com/tomMoulard/fail2ban/pkg/rules"
)

// Data holds request context data.
type Data struct {
RemoteIP string
}

// ServeHTTP sets data in the request context, to be extracted with GetData.
func ServeHTTP(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
type key int

const contextDataKey key = iota

// GetRemoteIP extracts the remote IP from the request using the specified source criterion.
func GetRemoteIP(r *http.Request, sourceCriterion rules.SourceCriterion) (string, error) {
// If a specific header is configured, use it as the source
if sourceCriterion.RequestHeaderName != nil && *sourceCriterion.RequestHeaderName != "" {
if headerIP := r.Header.Get(*sourceCriterion.RequestHeaderName); headerIP != "" {
var depth int

if sourceCriterion.IPStrategy != nil && sourceCriterion.IPStrategy.Depth != nil {
depth = *sourceCriterion.IPStrategy.Depth
}

return extractIPFromHeader(headerIP, depth)
}
}

// Fall back to RemoteAddr if no header is specified or no valid IP found
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return "", fmt.Errorf("failed to parse remote address: %w", err)
}

return ip, nil
}

// extractIPFromHeader extracts an IP address from a header value using the specified depth.
// If depth <= 0, returns the first IP in the list.
func extractIPFromHeader(headerValue string, depth int) (string, error) {
// Split and clean the IPs
ips := strings.Split(headerValue, ",")

if len(ips) == 0 {
return "", errors.New("no IP addresses found in header value")
}

for i, ip := range ips {
ips[i] = strings.TrimSpace(ip)
}

// Select the appropriate IP based on depth
var ip string
if depth <= 0 || depth > len(ips) {
// Use first IP if depth is invalid or too large
ip = ips[0]
} else {
// Get the IP at the specified depth (counting from right)
ip = ips[len(ips)-depth]
}

// Validate the IP
if parsedIP := net.ParseIP(ip); parsedIP == nil {
return "", fmt.Errorf("invalid IP address in header: %q", ip)
}

return ip, nil
}

// ServeHTTP adds request data to the context.
func ServeHTTP(w http.ResponseWriter, r *http.Request, sourceCriterion rules.SourceCriterion) (*http.Request, error) {
remoteIP, err := GetRemoteIP(r, sourceCriterion)
if err != nil {
return nil, fmt.Errorf("failed to split remote address %q: %w", r.RemoteAddr, err)
return nil, fmt.Errorf("failed to get remote IP: %w", err)
}

data := &Data{
RemoteIP: remoteIP,
}

fmt.Printf("data: %+v", data)

return r.WithContext(context.WithValue(r.Context(), contextDataKey, data)), nil
}

// GetData returns the data stored in the request context.
func GetData(req *http.Request) *Data {
if data, ok := req.Context().Value(contextDataKey).(*Data); ok {
return data
// GetData retrieves request data from the context.
func GetData(r *http.Request) *Data {
if r == nil {
return nil
}

data, ok := r.Context().Value(contextDataKey).(*Data)
if !ok {
return nil
}

return nil
return data
}
Loading