Skip to content

Commit da39fda

Browse files
committed
fix(pia): port forward using server hostname instead of gateway ip
1 parent 19a9ac9 commit da39fda

File tree

13 files changed

+60
-183
lines changed

13 files changed

+60
-183
lines changed

internal/models/connection.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,19 @@ type Connection struct {
1515
Protocol string `json:"protocol"`
1616
// Hostname is used for IPVanish, IVPN, Privado
1717
// and Windscribe for TLS verification.
18+
// It is used for PIA for port forwarding.
1819
Hostname string `json:"hostname"`
1920
// PubKey is the public key of the VPN server,
2021
// used only for Wireguard.
2122
PubKey string `json:"pubkey"`
22-
// ServerName is used for PIA for port forwarding
23-
ServerName string `json:"server_name,omitempty"`
2423
// PortForward is used for PIA for port forwarding
2524
PortForward bool `json:"port_forward"`
2625
}
2726

2827
func (c *Connection) Equal(other Connection) bool {
2928
return c.IP.Compare(other.IP) == 0 && c.Port == other.Port &&
3029
c.Protocol == other.Protocol && c.Hostname == other.Hostname &&
31-
c.PubKey == other.PubKey && c.ServerName == other.ServerName &&
32-
c.PortForward == other.PortForward
30+
c.PubKey == other.PubKey && c.PortForward == other.PortForward
3331
}
3432

3533
// UpdateEmptyWith updates each field of the connection where the

internal/portforward/service/settings.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type Settings struct {
1313
PortForwarder PortForwarder
1414
Filepath string
1515
Interface string // needed for PIA and ProtonVPN, tun0 for example
16-
ServerName string // needed for PIA
16+
ServerHostname string // needed for PIA
1717
CanPortForward bool // needed for PIA
1818
ListeningPort uint16
1919
}
@@ -23,7 +23,7 @@ func (s Settings) Copy() (copied Settings) {
2323
copied.PortForwarder = s.PortForwarder
2424
copied.Filepath = s.Filepath
2525
copied.Interface = s.Interface
26-
copied.ServerName = s.ServerName
26+
copied.ServerHostname = s.ServerHostname
2727
copied.CanPortForward = s.CanPortForward
2828
copied.ListeningPort = s.ListeningPort
2929
return copied
@@ -34,16 +34,16 @@ func (s *Settings) OverrideWith(update Settings) {
3434
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
3535
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
3636
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
37-
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
37+
s.ServerHostname = gosettings.OverrideWithComparable(s.ServerHostname, update.ServerHostname)
3838
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
3939
s.ListeningPort = gosettings.OverrideWithComparable(s.ListeningPort, update.ListeningPort)
4040
}
4141

4242
var (
43-
ErrPortForwarderNotSet = errors.New("port forwarder not set")
44-
ErrServerNameNotSet = errors.New("server name not set")
45-
ErrFilepathNotSet = errors.New("file path not set")
46-
ErrInterfaceNotSet = errors.New("interface not set")
43+
ErrPortForwarderNotSet = errors.New("port forwarder not set")
44+
ErrServerHostnameNotSet = errors.New("server hostname not set")
45+
ErrFilepathNotSet = errors.New("file path not set")
46+
ErrInterfaceNotSet = errors.New("interface not set")
4747
)
4848

4949
func (s *Settings) Validate(forStartup bool) (err error) {
@@ -64,8 +64,8 @@ func (s *Settings) Validate(forStartup bool) (err error) {
6464
return fmt.Errorf("%w", ErrPortForwarderNotSet)
6565
case s.Interface == "":
6666
return fmt.Errorf("%w", ErrInterfaceNotSet)
67-
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerName == "":
68-
return fmt.Errorf("%w", ErrServerNameNotSet)
67+
case s.PortForwarder.Name() == providers.PrivateInternetAccess && s.ServerHostname == "":
68+
return fmt.Errorf("%w", ErrServerHostnameNotSet)
6969
}
7070
return nil
7171
}

internal/portforward/service/start.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
2626
Logger: s.logger,
2727
Gateway: gateway,
2828
Client: s.client,
29-
ServerName: s.settings.ServerName,
29+
ServerHostname: s.settings.ServerHostname,
3030
CanPortForward: s.settings.CanPortForward,
3131
}
3232
port, err := s.settings.PortForwarder.PortForward(ctx, obj)

internal/provider/privateinternetaccess/httpclient.go

-50
This file was deleted.

internal/provider/privateinternetaccess/httpclient_test.go

-51
This file was deleted.

internal/provider/privateinternetaccess/portforward.go

+15-33
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"io"
1010
"net"
1111
"net/http"
12-
"net/netip"
1312
"net/url"
1413
"os"
1514
"strconv"
@@ -27,14 +26,11 @@ var (
2726
// PortForward obtains a VPN server side port forwarded from PIA.
2827
func (p *Provider) PortForward(ctx context.Context,
2928
objects utils.PortForwardObjects) (port uint16, err error) {
30-
switch {
31-
case objects.ServerName == "":
32-
panic("server name cannot be empty")
33-
case !objects.Gateway.IsValid():
34-
panic("gateway is not set")
29+
if objects.ServerHostname == "" {
30+
panic("server hostname cannot be empty")
3531
}
3632

37-
serverName := objects.ServerName
33+
serverName := objects.ServerHostname
3834

3935
logger := objects.Logger
4036

@@ -43,11 +39,6 @@ func (p *Provider) PortForward(ctx context.Context,
4339
return 0, nil
4440
}
4541

46-
privateIPClient, err := newHTTPClient(serverName)
47-
if err != nil {
48-
return 0, fmt.Errorf("creating custom HTTP client: %w", err)
49-
}
50-
5142
data, err := readPIAPortForwardData(p.portForwardPath)
5243
if err != nil {
5344
return 0, fmt.Errorf("reading saved port forwarded data: %w", err)
@@ -66,8 +57,7 @@ func (p *Provider) PortForward(ctx context.Context,
6657
}
6758

6859
if !dataFound || expired {
69-
client := objects.Client
70-
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, objects.Gateway,
60+
data, err = refreshPIAPortForwardData(ctx, objects.Client, objects.ServerHostname,
7161
p.portForwardPath, p.authFilePath)
7262
if err != nil {
7363
return 0, fmt.Errorf("refreshing port forward data: %w", err)
@@ -77,7 +67,7 @@ func (p *Provider) PortForward(ctx context.Context,
7767
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
7868

7969
// First time binding
80-
if err := bindPort(ctx, privateIPClient, objects.Gateway, data); err != nil {
70+
if err := bindPort(ctx, objects.Client, objects.ServerHostname, data); err != nil {
8171
return 0, fmt.Errorf("binding port: %w", err)
8272
}
8373

@@ -90,16 +80,8 @@ var (
9080

9181
func (p *Provider) KeepPortForward(ctx context.Context,
9282
objects utils.PortForwardObjects) (err error) {
93-
switch {
94-
case objects.ServerName == "":
95-
panic("server name cannot be empty")
96-
case !objects.Gateway.IsValid():
97-
panic("gateway is not set")
98-
}
99-
100-
privateIPClient, err := newHTTPClient(objects.ServerName)
101-
if err != nil {
102-
return fmt.Errorf("creating custom HTTP client: %w", err)
83+
if objects.ServerHostname == "" {
84+
panic("server hostname cannot be empty")
10385
}
10486

10587
data, err := readPIAPortForwardData(p.portForwardPath)
@@ -124,7 +106,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
124106
}
125107
return ctx.Err()
126108
case <-keepAliveTimer.C:
127-
err = bindPort(ctx, privateIPClient, objects.Gateway, data)
109+
err = bindPort(ctx, objects.Client, objects.ServerHostname, data)
128110
if err != nil {
129111
return fmt.Errorf("binding port: %w", err)
130112
}
@@ -136,14 +118,14 @@ func (p *Provider) KeepPortForward(ctx context.Context,
136118
}
137119
}
138120

139-
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
140-
gateway netip.Addr, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
121+
func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
122+
serverHostname, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
141123
data.Token, err = fetchToken(ctx, client, authFilePath)
142124
if err != nil {
143125
return data, fmt.Errorf("fetching token: %w", err)
144126
}
145127

146-
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, privateIPClient, gateway, data.Token)
128+
data.Port, data.Signature, data.Expiration, err = fetchPortForwardData(ctx, client, serverHostname, data.Token)
147129
if err != nil {
148130
return data, fmt.Errorf("fetching port forwarding data: %w", err)
149131
}
@@ -319,15 +301,15 @@ func getOpenvpnCredentials(authFilePath string) (
319301
return username, password, nil
320302
}
321303

322-
func fetchPortForwardData(ctx context.Context, client *http.Client, gateway netip.Addr, token string) (
304+
func fetchPortForwardData(ctx context.Context, client *http.Client, serverHostname, token string) (
323305
port uint16, signature string, expiration time.Time, err error) {
324306
errSubstitutions := map[string]string{url.QueryEscape(token): "<token>"}
325307

326308
queryParams := make(url.Values)
327309
queryParams.Add("token", token)
328310
url := url.URL{
329311
Scheme: "https",
330-
Host: net.JoinHostPort(gateway.String(), "19999"),
312+
Host: net.JoinHostPort(serverHostname, "19999"),
331313
Path: "/getSignature",
332314
RawQuery: queryParams.Encode(),
333315
}
@@ -373,7 +355,7 @@ var (
373355
ErrBadResponse = errors.New("bad response received")
374356
)
375357

376-
func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data piaPortForwardData) (err error) {
358+
func bindPort(ctx context.Context, client *http.Client, serverHostname string, data piaPortForwardData) (err error) {
377359
payload, err := packPayload(data.Port, data.Token, data.Expiration)
378360
if err != nil {
379361
return fmt.Errorf("serializing payload: %w", err)
@@ -384,7 +366,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
384366
queryParams.Add("signature", data.Signature)
385367
bindPortURL := url.URL{
386368
Scheme: "https",
387-
Host: net.JoinHostPort(gateway.String(), "19999"),
369+
Host: net.JoinHostPort(serverHostname, "19999"),
388370
Path: "/bindPort",
389371
RawQuery: queryParams.Encode(),
390372
}

internal/provider/utils/connection.go

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ func GetConnection(provider string,
6565
Port: port,
6666
Protocol: protocol,
6767
Hostname: hostname,
68-
ServerName: server.ServerName,
6968
PortForward: server.PortForward,
7069
PubKey: server.WgPubKey, // Wireguard
7170
}

internal/provider/utils/portforward.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@ import (
1010
type PortForwardObjects struct {
1111
// Logger is a logger, used by both Private Internet Access and ProtonVPN.
1212
Logger Logger
13-
// Gateway is the VPN gateway IP address, used by Private Internet Access
14-
// and ProtonVPN.
13+
// Gateway is the VPN gateway IP address, used by ProtonVPN.
1514
Gateway netip.Addr
1615
// Client is used to query the VPN gateway for Private Internet Access.
1716
Client *http.Client
18-
// ServerName is used by Private Internet Access for port forwarding.
19-
ServerName string
17+
// ServerHostname is used by Private Internet Access for port forwarding.
18+
ServerHostname string
2019
// CanPortForward is used by Private Internet Access for port forwarding.
2120
CanPortForward bool
2221
}

0 commit comments

Comments
 (0)