Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSE support #2126

Merged
merged 11 commits into from
Feb 26, 2024
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package caliban.interop.tapir

import caliban.ResponseValue.StreamValue
import caliban._
import caliban.ResponseValue.{ ObjectValue, StreamValue }
import caliban.wrappers.Caching
import sttp.capabilities.zio.ZioStreams
import sttp.capabilities.zio.ZioStreams.Pipe
import sttp.capabilities.{ Streams, WebSockets }
import sttp.model.{ headers => _, _ }
import sttp.model.sse.ServerSentEvent
import sttp.monad.MonadError
import sttp.tapir.Codec.JsonCodec
import sttp.tapir.model.ServerRequest
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.ztapir.ZioServerSentEvents
import sttp.tapir.{ headers, _ }
import zio._
import zio.stream.ZStream
Expand Down Expand Up @@ -89,6 +91,10 @@ object TapirAdapter {
oneOfVariantValueMatcher[CalibanBody.Stream[stream.BinaryStream]](
streamTextBody(stream)(CodecFormat.Json(), Some(StandardCharsets.UTF_8)).toEndpointIO
.map(Right(_)) { case Right(value) => value }
) { case Right(_) => true },
oneOfVariantValueMatcher[CalibanBody.Stream[stream.BinaryStream]](
streamBinaryBody(stream)(CodecFormat.TextEventStream()).toEndpointIO
.map(Right(_)) { case Right(value) => value }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite surprised this works since there is a Right(_) => true above this, does the selector take into account the Accepts headers when computing the match?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly don't know if it does take the Accept header into account but it does work. If we can do it in another way that is more readable or so I'm all for it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If memory serves well, I also came across this some time ago and I was also surprised about it. I think I also concluded that it takes the Accept header into account

) { case Right(_) => true }
)

Expand All @@ -114,6 +120,8 @@ object TapirAdapter {
case _ => false
}
)
val isSSE =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make this a def or lazy as well since we don't need to compute it if we match another case first

Copy link
Collaborator

@kyri-petrou kyri-petrou Feb 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding to this, we could probably check the Accept header just a single time (we currently check it once for acceptsGqlJson and once for isSSE)

request.header(HeaderNames.Accept).map(v => v.contains(MediaType.TextEventStream.toString())).getOrElse(false)
response match {
case resp @ GraphQLResponse(StreamValue(stream), _, _, _) =>
(
Expand Down Expand Up @@ -142,17 +150,27 @@ object TapirAdapter {
val code =
response.errors.collectFirst { case HttpRequestMethod.MutationOverGetError => StatusCode.BadRequest }
.getOrElse(StatusCode.Ok)
val cacheDirective = HttpUtils.computeCacheDirective(response.extensions)
(
MediaType.ApplicationJson,
code,
cacheDirective,
encodeSingleResponse(
resp,
keepDataOnErrors = true,
excludeExtensions = cacheDirective.map(_ => Set(Caching.DirectiveName))
)
)
val cacheDirective = computeCacheDirective(response.extensions)
isSSE match {
case true =>
(
MediaType.TextEventStream,
code,
cacheDirective,
encodeTextEventStreamResponse(resp)
)
case false =>
(
MediaType.ApplicationJson,
code,
cacheDirective,
encodeSingleResponse(
resp,
keepDataOnErrors = true,
excludeExtensions = cacheDirective.map(_ => Set(Caching.DirectiveName))
)
)
}
}
}

Expand Down Expand Up @@ -199,6 +217,37 @@ object TapirAdapter {
)
}

private def encodeTextEventStreamResponse[E, BS](
resp: GraphQLResponse[E]
)(implicit streamConstructor: StreamConstructor[BS], responseCodec: JsonCodec[ResponseValue]): CalibanBody[BS] = {
val response: ZStream[Any, Throwable, ServerSentEvent] = resp.data match {
case ObjectValue(fields) =>
fields.foldLeft(ZStream.empty: ZStream[Any, Throwable, ServerSentEvent]) { case (_, v) =>
v match {
case (fieldName, StreamValue(stream)) =>
stream.map { r =>
ServerSentEvent(
Some(
responseCodec.encode(
GraphQLResponse(
ObjectValue(List(fieldName -> r)),
resp.errors
).toResponseValue
)
),
Some("next")
)
}.orDie
case _ =>
ZStream.succeed(ServerSentEvent(Some(responseCodec.encode(resp.toResponseValue)), Some("next")))
}
}
case _ =>
ZStream.succeed(ServerSentEvent(Some(responseCodec.encode(resp.toResponseValue)), Some("next")))
}
Right(streamConstructor(ZioServerSentEvents.serialiseSSEToBytes(response)))
}

private def encodeSingleResponse[E](
response: GraphQLResponse[E],
keepDataOnErrors: Boolean,
Expand Down