diff --git a/spanner/client_test.go b/spanner/client_test.go index c986b4debe32..c69523d15960 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -4482,6 +4482,8 @@ func TestClient_WithCustomBatchTimeout(t *testing.T) { } } +var makeMockServer = NewMockedSpannerInMemTestServer + func TestClient_WithoutCustomBatchTimeout(t *testing.T) { t.Parallel() diff --git a/spanner/regression_test.go b/spanner/regression_test.go new file mode 100644 index 000000000000..acbfae4408b6 --- /dev/null +++ b/spanner/regression_test.go @@ -0,0 +1,210 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "errors" + "fmt" + "maps" + "slices" + "sort" + "testing" + + sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/structpb" + + "cloud.google.com/go/spanner/internal/testutil" +) + +type methodAndMetadata struct { + method string + md metadata.MD +} + +type ourInterceptor struct { + unaryHeaders []*methodAndMetadata + streamHeaders []*methodAndMetadata +} + +func (oi *ourInterceptor) interceptStream(srv any, ss grpc.ServerStream, ssi *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + md, ok := metadata.FromIncomingContext(ss.Context()) + if !ok { + return errors.New("missing metadata in stream") + } + oi.streamHeaders = append(oi.streamHeaders, &methodAndMetadata{ssi.FullMethod, md}) + return handler(srv, ss) +} + +func (oi *ourInterceptor) interceptUnary(ctx context.Context, req any, usi *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("missing metadata in unary") + } + oi.unaryHeaders = append(oi.unaryHeaders, &methodAndMetadata{usi.FullMethod, md}) + return handler(ctx, req) +} + +// This is a regression test to assert that all the expected headers are propagated +// along to the final gRPC server avoiding scenarios where headers got dropped from a +// destructive context augmentation call. +// Please see https://github.com/googleapis/google-cloud-go/issues/11656 +func TestAllHeadersForwardedAppropriately(t *testing.T) { + // 0. Turn off session multiplexing per #11308. + t.Setenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "0") + + // 1. Set up the server interceptor that'll record and collect + // all the headers that are received by the server. + oint := new(ourInterceptor) + sopts := []grpc.ServerOption{ + grpc.UnaryInterceptor(oint.interceptUnary), grpc.StreamInterceptor(oint.interceptStream), + } + mockedServer, clientOpts, teardown := makeMockServer(t, sopts...) + defer teardown() + + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + EnableEndToEndTracing: true, + DisableRouteToLeader: false, + } + formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + sc, err := NewClientWithConfig(context.Background(), formattedDatabase, clientConfig, clientOpts...) + if err != nil { + t.Fatal(err) + } + defer sc.Close() + + // 2. Perform a simple "SELECT 1" to trigger both unary and streaming gRPC calls. + sqlSELECT1 := "SELECT 1" + resultSet := &sppb.ResultSet{ + Rows: []*structpb.ListValue{ + {Values: []*structpb.Value{ + {Kind: &structpb.Value_StringValue{StringValue: "1"}}, + }}, + }, + Metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + }, + }, + }, + } + result := &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + } + mockedServer.TestSpanner.PutStatementResult(sqlSELECT1, result) + + txn := sc.ReadOnlyTransaction() + defer txn.Close() + + ctx := context.Background() + stmt := NewStatement(sqlSELECT1) + rowIter := txn.Query(ctx, stmt) + defer rowIter.Stop() + var got []int64 + if err := SelectAll(rowIter, &got); err != nil { + t.Fatal(err) + } + want := []int64{1} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("Results expectation mismatches: got - want +\n%s", diff) + } + + // 3. Now perform the assertions of expected headers. + type headerExpectation struct { + MethodName string + WantHeaders []string + } + + wantUnaryExpectations := []*headerExpectation{ + { + "/google.spanner.v1.Spanner/BatchCreateSessions", + []string{ + ":authority", "content-type", "google-cloud-resource-prefix", + "grpc-accept-encoding", "user-agent", "x-goog-api-client", + "x-goog-request-params", "x-goog-spanner-end-to-end-tracing", + "x-goog-spanner-request-id", "x-goog-spanner-route-to-leader", + }, + }, + { + "/google.spanner.v1.Spanner/BeginTransaction", + []string{ + ":authority", "content-type", "google-cloud-resource-prefix", + "grpc-accept-encoding", "user-agent", "x-goog-api-client", + "x-goog-request-params", "x-goog-spanner-end-to-end-tracing", + "x-goog-spanner-request-id", + }, + }, + } + + wantStreamingExpectations := []*headerExpectation{ + { + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + []string{ + ":authority", "content-type", "google-cloud-resource-prefix", + "grpc-accept-encoding", "user-agent", "x-goog-api-client", + "x-goog-request-params", "x-goog-spanner-end-to-end-tracing", + "x-goog-spanner-request-id", + }, + }, + } + + var gotUnaryExpectations []*headerExpectation + for _, mdp := range oint.unaryHeaders { + gotHeaderKeys := slices.Collect(maps.Keys(mdp.md)) + gotUnaryExpectations = append(gotUnaryExpectations, &headerExpectation{mdp.method, gotHeaderKeys}) + } + + var gotStreamingExpectations []*headerExpectation + for _, mdp := range oint.streamHeaders { + gotHeaderKeys := slices.Collect(maps.Keys(mdp.md)) + gotStreamingExpectations = append(gotStreamingExpectations, &headerExpectation{mdp.method, gotHeaderKeys}) + } + + sortHeaderExpectations := func(expectations []*headerExpectation) { + // Firstly sort by method name. + sort.Slice(expectations, func(i, j int) bool { + return expectations[i].MethodName < expectations[j].MethodName + }) + + // 2. Within each expectation, also then sort the header keys. + for i := range expectations { + exp := expectations[i] + sort.Strings(exp.WantHeaders) + } + } + + sortHeaderExpectations(gotUnaryExpectations) + sortHeaderExpectations(wantUnaryExpectations) + if diff := cmp.Diff(gotUnaryExpectations, wantUnaryExpectations); diff != "" { + t.Fatalf("Unary headers mismatch: got - want +\n%s", diff) + } + + sortHeaderExpectations(gotStreamingExpectations) + sortHeaderExpectations(wantStreamingExpectations) + if diff := cmp.Diff(gotStreamingExpectations, wantStreamingExpectations); diff != "" { + t.Fatalf("Streaming headers mismatch: got - want +\n%s", diff) + } +} diff --git a/spanner/request_id_header.go b/spanner/request_id_header.go index ec367260a8e2..a10289bfa20f 100644 --- a/spanner/request_id_header.go +++ b/spanner/request_id_header.go @@ -138,9 +138,9 @@ func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (md metadata.MD, reqID r func (wr *requestIDHeaderInjector) interceptUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { // It is imperative to search for the requestID before the call // because gRPC's internals will consume the headers. - metadataWithRequestID, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + _, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) if foundRequestID { - ctx = metadata.NewOutgoingContext(ctx, metadataWithRequestID) + ctx = metadata.AppendToOutgoingContext(ctx, xSpannerRequestIDHeader, string(reqID)) } err := invoker(ctx, method, req, reply, cc, opts...) @@ -179,9 +179,9 @@ type requestIDHeaderInjector int func (wr *requestIDHeaderInjector) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { // It is imperative to search for the requestID before the call // because gRPC's internals will consume the headers. - metadataWithRequestID, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + _, reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) if foundRequestID { - ctx = metadata.NewOutgoingContext(ctx, metadataWithRequestID) + ctx = metadata.AppendToOutgoingContext(ctx, xSpannerRequestIDHeader, string(reqID)) } cs, err := streamer(ctx, desc, cc, method, opts...)