diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 052a21cca66b..5659beb509a9 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added -* Added function `FetcherForNextLink` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL. +* Added function `FetcherForNextLink` and `FetcherForNextLinkOptions` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL. ### Breaking Changes diff --git a/sdk/azcore/runtime/pager.go b/sdk/azcore/runtime/pager.go index f1daac50d32a..1dc4d3e49464 100644 --- a/sdk/azcore/runtime/pager.go +++ b/sdk/azcore/runtime/pager.go @@ -91,14 +91,30 @@ func (p *Pager[T]) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &p.current) } +// FetcherForNextLinkOptions contains the optional values for [FetcherForNextLink]. +type FetcherForNextLinkOptions struct { + // NextReq is the func to be called when requesting subsequent pages. + // Used for paged operations that have a custom next link operation. + NextReq func(context.Context, string) (*policy.Request, error) +} + // FetcherForNextLink is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher from a next link URL. -func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) { +// - ctx is the [context.Context] controlling the lifetime of the HTTP operation +// - pl is the [Pipeline] used to dispatch the HTTP request +// - nextLink is the URL used to fetch the next page. the empty string indicates the first page is to be requested +// - firstReq is the func to be called when creating the request for the first page +// - options contains any optional parameters, pass nil to accept the default values +func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, firstReq func(context.Context) (*policy.Request, error), options *FetcherForNextLinkOptions) (*http.Response, error) { var req *policy.Request var err error if nextLink == "" { - req, err = createReq(ctx) + req, err = firstReq(ctx) } else if nextLink, err = EncodeQueryParams(nextLink); err == nil { - req, err = NewRequest(ctx, http.MethodGet, nextLink) + if options != nil && options.NextReq != nil { + req, err = options.NextReq(ctx, nextLink) + } else { + req, err = NewRequest(ctx, http.MethodGet, nextLink) + } } if err != nil { return nil, err diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go index 9d205b297d60..6eb92414e64a 100644 --- a/sdk/azcore/runtime/pager_test.go +++ b/sdk/azcore/runtime/pager_test.go @@ -264,52 +264,69 @@ func TestFetcherForNextLink(t *testing.T) { pl := exported.NewPipeline(srv) srv.AppendResponse() - createReqCalled := false + firstReqCalled := false resp, err := FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { - createReqCalled = true + firstReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) - }) + }, nil) require.NoError(t, err) - require.True(t, createReqCalled) + require.True(t, firstReqCalled) require.NotNil(t, resp) require.EqualValues(t, http.StatusOK, resp.StatusCode) srv.AppendResponse() - createReqCalled = false + firstReqCalled = false + nextReqCalled := false resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { - createReqCalled = true + firstReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) + }, &FetcherForNextLinkOptions{ + NextReq: func(ctx context.Context, s string) (*policy.Request, error) { + nextReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }, }) require.NoError(t, err) - require.False(t, createReqCalled) + require.False(t, firstReqCalled) + require.True(t, nextReqCalled) require.NotNil(t, resp) require.EqualValues(t, http.StatusOK, resp.StatusCode) resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { return nil, errors.New("failed") + }, &FetcherForNextLinkOptions{}) + require.Error(t, err) + require.Nil(t, resp) + + resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + return nil, nil + }, &FetcherForNextLinkOptions{ + NextReq: func(ctx context.Context, s string) (*policy.Request, error) { + return nil, errors.New("failed") + }, }) require.Error(t, err) require.Nil(t, resp) srv.AppendError(errors.New("failed")) resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { - createReqCalled = true + firstReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) - }) + }, &FetcherForNextLinkOptions{}) require.Error(t, err) - require.True(t, createReqCalled) + require.True(t, firstReqCalled) require.Nil(t, resp) srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{ "error": { "code": "InvalidResource", "message": "doesn't exist" } }`))) - createReqCalled = false + firstReqCalled = false resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { - createReqCalled = true + firstReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) - }) + }, nil) require.Error(t, err) var respErr *exported.ResponseError require.ErrorAs(t, err, &respErr) require.EqualValues(t, "InvalidResource", respErr.ErrorCode) - require.False(t, createReqCalled) + require.False(t, firstReqCalled) require.Nil(t, resp) }