Skip to content

Commit

Permalink
httpu: add context.Context and related interface
Browse files Browse the repository at this point in the history
This adds a new interface for httpu that supports a Context, and uses
that context to set a deadline/timeout and also cancel the request if
the context is canceled. Additionally, add a new method to the SSDP
package that takes a ClientInterfaceCtx.

Updates huin#55

Signed-off-by: Andrew Dunham <[email protected]>
  • Loading branch information
andrew-d committed Aug 24, 2023
1 parent 8ca2329 commit b0a0064
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 8 deletions.
5 changes: 4 additions & 1 deletion goupnp.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ func DiscoverDevicesCtx(ctx context.Context, searchTarget string) ([]MaybeRootDe
return nil, err
}
defer hcCleanup()
responses, err := ssdp.SSDPRawSearchCtx(ctx, hc, string(searchTarget), 2, 3)

searchCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
responses, err := ssdp.RawSearch(searchCtx, hc, string(searchTarget), 3)
if err != nil {
return nil, err
}
Expand Down
56 changes: 54 additions & 2 deletions httpu/httpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package httpu
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"log"
Expand All @@ -26,6 +27,22 @@ type ClientInterface interface {
) ([]*http.Response, error)
}

// ClientInterfaceCtx is the equivalent of ClientInterface, except with methods
// taking a context.Context parameter.
type ClientInterfaceCtx interface {
// DoWithContext performs a request. If the input request has a
// deadline, then that value will be used as the timeout for how long
// to wait before returning the responses that were received. If the
// request's context is canceled, this method will return immediately.
//
// An error is only returned for failing to send the request. Failures
// in receipt simply do not add to the resulting responses.
DoWithContext(
req *http.Request,
numSends int,
) ([]*http.Response, error)
}

// HTTPUClient is a client for dealing with HTTPU (HTTP over UDP). Its typical
// function is for HTTPMU, and particularly SSDP.
type HTTPUClient struct {
Expand All @@ -34,6 +51,7 @@ type HTTPUClient struct {
}

var _ ClientInterface = &HTTPUClient{}
var _ ClientInterfaceCtx = &HTTPUClient{}

// NewHTTPUClient creates a new HTTPUClient, opening up a new UDP socket for the
// purpose.
Expand Down Expand Up @@ -75,6 +93,22 @@ func (httpu *HTTPUClient) Do(
req *http.Request,
timeout time.Duration,
numSends int,
) ([]*http.Response, error) {
ctx := req.Context()
if timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
req = req.WithContext(ctx)
}

return httpu.DoWithContext(req, numSends)
}

// DoWithContext implements ClientInterfaceCtx.DoWithContext.
func (httpu *HTTPUClient) DoWithContext(
req *http.Request,
numSends int,
) ([]*http.Response, error) {
httpu.connLock.Lock()
defer httpu.connLock.Unlock()
Expand All @@ -101,10 +135,28 @@ func (httpu *HTTPUClient) Do(
if err != nil {
return nil, err
}
if err = httpu.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, err

// Handle context deadline/timeout
ctx := req.Context()
deadline, ok := ctx.Deadline()
if ok {
if err = httpu.conn.SetDeadline(deadline); err != nil {
return nil, err
}
}

// Handle context cancelation
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
// if context is cancelled, stop any connections by setting time in the past.
httpu.conn.SetDeadline(time.Now().Add(-time.Second))
case <-done:
}
}()

// Send request.
for i := 0; i < numSends; i++ {
if n, err := httpu.conn.WriteTo(requestBuf.Bytes(), destAddr); err != nil {
Expand Down
66 changes: 64 additions & 2 deletions httpu/multiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ func (mc *MultiClient) Do(
}

func (mc *MultiClient) sendRequests(
results chan<-[]*http.Response,
results chan<- []*http.Response,
req *http.Request,
timeout time.Duration,
numSends int,
) error {
tasks := &errgroup.Group{}
for _, d := range mc.delegates {
d := d // copy for closure
d := d // copy for closure
tasks.Go(func() error {
responses, err := d.Do(req, timeout, numSends)
if err != nil {
Expand All @@ -68,3 +68,65 @@ func (mc *MultiClient) sendRequests(
}
return tasks.Wait()
}

// MultiClientCtx dispatches requests out to all the delegated clients.
type MultiClientCtx struct {
// The HTTPU clients to delegate to.
delegates []ClientInterfaceCtx
}

var _ ClientInterfaceCtx = &MultiClientCtx{}

// NewMultiClient creates a new MultiClient that delegates to all the given
// clients.
func NewMultiClientCtx(delegates []ClientInterfaceCtx) *MultiClientCtx {
return &MultiClientCtx{
delegates: delegates,
}
}

// DoWithContext implements ClientInterfaceCtx.DoWithContext.
func (mc *MultiClientCtx) DoWithContext(
req *http.Request,
numSends int,
) ([]*http.Response, error) {
tasks, ctx := errgroup.WithContext(req.Context())
req = req.WithContext(ctx) // so we cancel if the errgroup errors
results := make(chan []*http.Response)

// For each client, send the request to it and collect results.
tasks.Go(func() error {
defer close(results)
return mc.sendRequestsCtx(results, req, numSends)
})

var responses []*http.Response
tasks.Go(func() error {
for rs := range results {
responses = append(responses, rs...)
}
return nil
})

return responses, tasks.Wait()
}

func (mc *MultiClientCtx) sendRequestsCtx(
results chan<- []*http.Response,
req *http.Request,
numSends int,
) error {
tasks := &errgroup.Group{}
for _, d := range mc.delegates {
d := d // copy for closure
tasks.Go(func() error {
responses, err := d.DoWithContext(req, numSends)
if err != nil {
return err
}
results <- responses
return nil
})
}
return tasks.Wait()
}
6 changes: 3 additions & 3 deletions network.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import (
// httpuClient creates a HTTPU client that multiplexes to all multicast-capable
// IPv4 addresses on the host. Returns a function to clean up once the client is
// no longer required.
func httpuClient() (httpu.ClientInterface, func(), error) {
func httpuClient() (httpu.ClientInterfaceCtx, func(), error) {
addrs, err := localIPv4MCastAddrs()
if err != nil {
return nil, nil, ctxError(err, "requesting host IPv4 addresses")
}

closers := make([]io.Closer, 0, len(addrs))
delegates := make([]httpu.ClientInterface, 0, len(addrs))
delegates := make([]httpu.ClientInterfaceCtx, 0, len(addrs))
for _, addr := range addrs {
c, err := httpu.NewHTTPUClientAddr(addr)
if err != nil {
Expand All @@ -34,7 +34,7 @@ func httpuClient() (httpu.ClientInterface, func(), error) {
}
}

return httpu.NewMultiClient(delegates), closer, nil
return httpu.NewMultiClientCtx(delegates), closer, nil
}

// localIPv2MCastAddrs returns the set of IPv4 addresses on multicast-able
Expand Down
68 changes: 68 additions & 0 deletions ssdp/ssdp.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ type HTTPUClient interface {
) ([]*http.Response, error)
}

// HTTPUClientCtx is an optional interface that will be used to perform
// HTTP-over-UDP requests if the client implements it.
type HTTPUClientCtx interface {
DoWithContext(
req *http.Request,
numSends int,
) ([]*http.Response, error)
}

// SSDPRawSearchCtx performs a fairly raw SSDP search request, and returns the
// unique response(s) that it receives. Each response has the requested
// searchTarget, a USN, and a valid location. maxWaitSeconds states how long to
Expand Down Expand Up @@ -71,7 +80,66 @@ func SSDPRawSearchCtx(
if err != nil {
return nil, err
}
return processSSDPResponses(searchTarget, allResponses)
}

// RawSearch performs a fairly raw SSDP search request, and returns the
// unique response(s) that it receives. Each response has the requested
// searchTarget, a USN, and a valid location. If the provided context times out
// or is canceled, the search will be aborted. numSends is the number of
// requests to send - 3 is a reasonable value for this.
func RawSearch(
ctx context.Context,
httpu HTTPUClientCtx,
searchTarget string,
numSends int,
) ([]*http.Response, error) {
// We need a timeout value to include in the SSDP request; get it by
// checking the deadline on the context.
var maxWaitSeconds int
if deadline, ok := ctx.Deadline(); ok {
maxWaitSeconds = int(deadline.Sub(time.Now()) / time.Second)
} else {
// Pick a default timeout of 3 seconds if none was provided.
maxWaitSeconds = 3

var cancel func()
ctx, cancel = context.WithTimeout(ctx, time.Duration(maxWaitSeconds)*time.Second)
defer cancel()
}

// Check the timeout on the context (if any); if the context would time
// out in less than 1 second, then abort.
if maxWaitSeconds < 1 {
return nil, errors.New("ssdp: context expiry must be at least 1s in the future")
}

req := (&http.Request{
Method: methodSearch,
// TODO: Support both IPv4 and IPv6.
Host: ssdpUDP4Addr,
URL: &url.URL{Opaque: "*"},
Header: http.Header{
// Putting headers in here avoids them being title-cased.
// (The UPnP discovery protocol uses case-sensitive headers)
"HOST": []string{ssdpUDP4Addr},
"MX": []string{strconv.FormatInt(int64(maxWaitSeconds), 10)},
"MAN": []string{ssdpDiscover},
"ST": []string{searchTarget},
},
}).WithContext(ctx)

allResponses, err := httpu.DoWithContext(req, numSends)
if err != nil {
return nil, err
}
return processSSDPResponses(searchTarget, allResponses)
}

func processSSDPResponses(
searchTarget string,
allResponses []*http.Response,
) ([]*http.Response, error) {
isExactSearch := searchTarget != SSDPAll && searchTarget != UPNPRootDevice

seenIDs := make(map[string]bool)
Expand Down

0 comments on commit b0a0064

Please sign in to comment.