9
9
"io"
10
10
"net"
11
11
"net/http"
12
- "net/netip"
13
12
"net/url"
14
13
"os"
15
14
"strconv"
@@ -27,14 +26,11 @@ var (
27
26
// PortForward obtains a VPN server side port forwarded from PIA.
28
27
func (p * Provider ) PortForward (ctx context.Context ,
29
28
objects utils.PortForwardObjects ) (port uint16 , err error ) {
30
- switch {
31
- case objects .ServerName == "" :
32
- panic ("server name cannot be empty" )
33
- case ! objects .Gateway .IsValid ():
34
- panic ("gateway is not set" )
29
+ if objects .ServerHostname == "" {
30
+ panic ("server hostname cannot be empty" )
35
31
}
36
32
37
- serverName := objects .ServerName
33
+ serverName := objects .ServerHostname
38
34
39
35
logger := objects .Logger
40
36
@@ -43,11 +39,6 @@ func (p *Provider) PortForward(ctx context.Context,
43
39
return 0 , nil
44
40
}
45
41
46
- privateIPClient , err := newHTTPClient (serverName )
47
- if err != nil {
48
- return 0 , fmt .Errorf ("creating custom HTTP client: %w" , err )
49
- }
50
-
51
42
data , err := readPIAPortForwardData (p .portForwardPath )
52
43
if err != nil {
53
44
return 0 , fmt .Errorf ("reading saved port forwarded data: %w" , err )
@@ -66,8 +57,7 @@ func (p *Provider) PortForward(ctx context.Context,
66
57
}
67
58
68
59
if ! dataFound || expired {
69
- client := objects .Client
70
- data , err = refreshPIAPortForwardData (ctx , client , privateIPClient , objects .Gateway ,
60
+ data , err = refreshPIAPortForwardData (ctx , objects .Client , objects .ServerHostname ,
71
61
p .portForwardPath , p .authFilePath )
72
62
if err != nil {
73
63
return 0 , fmt .Errorf ("refreshing port forward data: %w" , err )
@@ -77,7 +67,7 @@ func (p *Provider) PortForward(ctx context.Context,
77
67
logger .Info ("Port forwarded data expires in " + format .FriendlyDuration (durationToExpiration ))
78
68
79
69
// First time binding
80
- if err := bindPort (ctx , privateIPClient , objects .Gateway , data ); err != nil {
70
+ if err := bindPort (ctx , objects . Client , objects .ServerHostname , data ); err != nil {
81
71
return 0 , fmt .Errorf ("binding port: %w" , err )
82
72
}
83
73
90
80
91
81
func (p * Provider ) KeepPortForward (ctx context.Context ,
92
82
objects utils.PortForwardObjects ) (err error ) {
93
- switch {
94
- case objects .ServerName == "" :
95
- panic ("server name cannot be empty" )
96
- case ! objects .Gateway .IsValid ():
97
- panic ("gateway is not set" )
98
- }
99
-
100
- privateIPClient , err := newHTTPClient (objects .ServerName )
101
- if err != nil {
102
- return fmt .Errorf ("creating custom HTTP client: %w" , err )
83
+ if objects .ServerHostname == "" {
84
+ panic ("server hostname cannot be empty" )
103
85
}
104
86
105
87
data , err := readPIAPortForwardData (p .portForwardPath )
@@ -124,7 +106,7 @@ func (p *Provider) KeepPortForward(ctx context.Context,
124
106
}
125
107
return ctx .Err ()
126
108
case <- keepAliveTimer .C :
127
- err = bindPort (ctx , privateIPClient , objects .Gateway , data )
109
+ err = bindPort (ctx , objects . Client , objects .ServerHostname , data )
128
110
if err != nil {
129
111
return fmt .Errorf ("binding port: %w" , err )
130
112
}
@@ -136,14 +118,14 @@ func (p *Provider) KeepPortForward(ctx context.Context,
136
118
}
137
119
}
138
120
139
- func refreshPIAPortForwardData (ctx context.Context , client , privateIPClient * http.Client ,
140
- gateway netip. Addr , portForwardPath , authFilePath string ) (data piaPortForwardData , err error ) {
121
+ func refreshPIAPortForwardData (ctx context.Context , client * http.Client ,
122
+ serverHostname , portForwardPath , authFilePath string ) (data piaPortForwardData , err error ) {
141
123
data .Token , err = fetchToken (ctx , client , authFilePath )
142
124
if err != nil {
143
125
return data , fmt .Errorf ("fetching token: %w" , err )
144
126
}
145
127
146
- data .Port , data .Signature , data .Expiration , err = fetchPortForwardData (ctx , privateIPClient , gateway , data .Token )
128
+ data .Port , data .Signature , data .Expiration , err = fetchPortForwardData (ctx , client , serverHostname , data .Token )
147
129
if err != nil {
148
130
return data , fmt .Errorf ("fetching port forwarding data: %w" , err )
149
131
}
@@ -319,15 +301,15 @@ func getOpenvpnCredentials(authFilePath string) (
319
301
return username , password , nil
320
302
}
321
303
322
- func fetchPortForwardData (ctx context.Context , client * http.Client , gateway netip. Addr , token string ) (
304
+ func fetchPortForwardData (ctx context.Context , client * http.Client , serverHostname , token string ) (
323
305
port uint16 , signature string , expiration time.Time , err error ) {
324
306
errSubstitutions := map [string ]string {url .QueryEscape (token ): "<token>" }
325
307
326
308
queryParams := make (url.Values )
327
309
queryParams .Add ("token" , token )
328
310
url := url.URL {
329
311
Scheme : "https" ,
330
- Host : net .JoinHostPort (gateway . String () , "19999" ),
312
+ Host : net .JoinHostPort (serverHostname , "19999" ),
331
313
Path : "/getSignature" ,
332
314
RawQuery : queryParams .Encode (),
333
315
}
@@ -373,7 +355,7 @@ var (
373
355
ErrBadResponse = errors .New ("bad response received" )
374
356
)
375
357
376
- func bindPort (ctx context.Context , client * http.Client , gateway netip. Addr , data piaPortForwardData ) (err error ) {
358
+ func bindPort (ctx context.Context , client * http.Client , serverHostname string , data piaPortForwardData ) (err error ) {
377
359
payload , err := packPayload (data .Port , data .Token , data .Expiration )
378
360
if err != nil {
379
361
return fmt .Errorf ("serializing payload: %w" , err )
@@ -384,7 +366,7 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
384
366
queryParams .Add ("signature" , data .Signature )
385
367
bindPortURL := url.URL {
386
368
Scheme : "https" ,
387
- Host : net .JoinHostPort (gateway . String () , "19999" ),
369
+ Host : net .JoinHostPort (serverHostname , "19999" ),
388
370
Path : "/bindPort" ,
389
371
RawQuery : queryParams .Encode (),
390
372
}
0 commit comments