Skip to content

Commit

Permalink
Only clone FlightServer::Builder class
Browse files Browse the repository at this point in the history
Signed-off-by: Andriy Redko <[email protected]>
  • Loading branch information
reta committed Feb 18, 2025
1 parent 6a507df commit 578c923
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 208 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import javax.net.ssl.SSLException;

import org.apache.arrow.flight.auth.ServerAuthHandler;
import org.apache.arrow.flight.auth.ServerAuthInterceptor;
import org.apache.arrow.flight.auth2.Auth2Constants;
Expand All @@ -44,136 +45,32 @@
import org.apache.arrow.flight.grpc.ServerInterceptorAdapter.KeyFactory;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.util.VisibleForTesting;

/**
* Clone of {@link org.apache.arrow.flight.FlightServer} to support setting SslContext. It can be discarded once FlightServer.Builder supports setting SslContext directly.
* <p>
* It changes {@link org.apache.arrow.flight.FlightServer.Builder} to non-final for overriding purposes and adds an overridable method {@link OSFlightServer.Builder#configureBuilder(NettyServerBuilder)}
* to allow hook to configure the NettyServerBuilder.
* <p>
* Note: This file needs to be cloned with version upgrade of arrow flight-core with above changes.
* It changes {@link org.apache.arrow.flight.FlightServer.Builder} to allow hook to configure the NettyServerBuilder.
*/
public class OSFlightServer implements AutoCloseable {

private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(OSFlightServer.class);

private final Location location;
private final Server server;
// The executor used by the gRPC server. We don't use it here, but we do need to clean it up with
// the server.
// May be null, if a user-supplied executor was provided (as we do not want to clean that up)
@VisibleForTesting final ExecutorService grpcExecutor;

public class OSFlightServer {
/** The maximum size of an individual gRPC message. This effectively disables the limit. */
static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE;

/** The default number of bytes that can be queued on an output stream before blocking. */
public static final int DEFAULT_BACKPRESSURE_THRESHOLD = 10 * 1024 * 1024; // 10MB

/** Create a new instance from a gRPC server. For internal use only. */
private OSFlightServer(Location location, Server server, ExecutorService grpcExecutor) {
this.location = location;
this.server = server;
this.grpcExecutor = grpcExecutor;
}

/** Start the server. */
public OSFlightServer start() throws IOException {
server.start();
return this;
}

/** Get the port the server is running on (if applicable). */
public int getPort() {
return server.getPort();
}

/** Get the location for this server. */
public Location getLocation() {
if (location.getUri().getPort() == 0) {
// If the server was bound to port 0, replace the port in the location with the real port.
final URI uri = location.getUri();
try {
return new Location(
new URI(
uri.getScheme(),
uri.getUserInfo(),
uri.getHost(),
getPort(),
uri.getPath(),
uri.getQuery(),
uri.getFragment()));
} catch (URISyntaxException e) {
// We don't expect this to happen
throw new RuntimeException(e);
}
}
return location;
}

/** Block until the server shuts down. */
public void awaitTermination() throws InterruptedException {
server.awaitTermination();
}

/** Request that the server shut down. */
public void shutdown() {
server.shutdown();
if (grpcExecutor != null) {
grpcExecutor.shutdown();
}
}

/**
* Wait for the server to shut down with a timeout.
*
* @return true if the server shut down successfully.
*/
public boolean awaitTermination(final long timeout, final TimeUnit unit)
throws InterruptedException {
return server.awaitTermination(timeout, unit);
}

/** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */
@Override
public void close() throws InterruptedException {
shutdown();
final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS);
if (terminated) {
logger.debug("Server was terminated within 3s");
return;
}
static final int DEFAULT_BACKPRESSURE_THRESHOLD = 10 * 1024 * 1024; // 10MB

// get more aggressive in termination.
server.shutdownNow();
private static final MethodHandle FLIGHT_SERVER_CTOR_MH;

int count = 0;
while (!server.isTerminated() && count < 30) {
count++;
logger.debug("Waiting for termination");
Thread.sleep(100);
}

if (!server.isTerminated()) {
logger.warn("Couldn't shutdown server, resources likely will be leaked.");
static {
try {
FLIGHT_SERVER_CTOR_MH = MethodHandles
.privateLookupIn(FlightServer.class, MethodHandles.lookup())
.findConstructor(FlightServer.class, MethodType.methodType(void.class, Location.class, Server.class, ExecutorService.class));
} catch (final NoSuchMethodException | IllegalAccessException ex) {
throw new IllegalStateException("Unable to find the FlightServer constructor to invoke", ex);
}
}

/** Create a builder for a Flight server. */
public static Builder builder() {
return new Builder();
}

/** Create a builder for a Flight server. */
public static Builder builder(
BufferAllocator allocator, Location location, FlightProducer producer) {
return new Builder(allocator, location, producer);
}

/** A builder for Flight servers. */
public static class Builder {
public final static class Builder {
private BufferAllocator allocator;
private Location location;
private FlightProducer producer;
Expand All @@ -191,7 +88,7 @@ public static class Builder {
private final List<KeyFactory<?>> interceptors;
// Keep track of inserted interceptors
private final Set<String> interceptorKeys;

Builder() {
builderOptions = new HashMap<>();
interceptors = new ArrayList<>();
Expand All @@ -207,7 +104,7 @@ public static class Builder {

/** Create the server for this builder. */
@SuppressWarnings("unchecked")
public OSFlightServer build() {
public FlightServer build() {
// Add the auth middleware if applicable.
if (headerAuthenticator != CallHeaderAuthenticator.NO_OP) {
this.middleware(
Expand Down Expand Up @@ -277,7 +174,7 @@ public OSFlightServer build() {
"Scheme is not supported: " + location.getUri().getScheme());
}

if (certChain != null) {
if (certChain != null && sslContext == null) {
SslContextBuilder sslContextBuilder = GrpcSslContexts.forServer(certChain, key);

if (mTlsCACert != null) {
Expand All @@ -293,6 +190,8 @@ public OSFlightServer build() {
closeKey();
}

builder.sslContext(sslContext);
} else if (sslContext != null) {
builder.sslContext(sslContext);
}

Expand Down Expand Up @@ -354,10 +253,29 @@ public OSFlightServer build() {
});

builder.intercept(new ServerInterceptorAdapter(interceptors));
configureBuilder(builder);
return new OSFlightServer(location, builder.build(), grpcExecutor);

try {
return (FlightServer)FLIGHT_SERVER_CTOR_MH.invoke(location, builder.build(), grpcExecutor);
} catch (final Throwable ex) {
throw new IllegalStateException("Unable to instantiate FlightServer", ex);
}
}

public Builder channelType(Class<? extends io.netty.channel.Channel> channelType) {
builderOptions.put("netty.channelType", channelType);
return this;
}

public Builder workerEventLoopGroup(EventLoopGroup workerELG) {
builderOptions.put("netty.workerEventLoopGroup", workerELG);
return this;
}

public Builder bossEventLoopGroup(EventLoopGroup bossELG) {
builderOptions.put("netty.bossEventLoopGroup", bossELG);
return this;
}

public Builder setMaxHeaderListSize(int maxHeaderListSize) {
this.maxHeaderListSize = maxHeaderListSize;
return this;
Expand Down Expand Up @@ -543,10 +461,13 @@ public Builder producer(FlightProducer producer) {
return this;
}

/**
* Hook to allow custom configuration of the NettyServerBuilder.
*/
public void configureBuilder(NettyServerBuilder builder) {
public Builder sslContext(SslContext sslContext) {
this.sslContext = sslContext;
return this;
}
}

public static Builder builder() {
return new Builder();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
package org.opensearch.arrow.flight.bootstrap;

import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.OSFlightServer;
import org.apache.arrow.flight.OSFlightServerBuilder;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -100,7 +100,7 @@ final class ServerComponents implements AutoCloseable {
private final String[] publishHosts;
private volatile BoundTransportAddress boundAddress;

private OSFlightServer server;
private FlightServer server;
private BufferAllocator allocator;
ClusterService clusterService;
private NetworkService networkService;
Expand Down Expand Up @@ -147,17 +147,17 @@ void setFlightProducer(FlightProducer flightProducer) {
this.flightProducer = Objects.requireNonNull(flightProducer);
}

private OSFlightServer buildAndStartServer(Location location, FlightProducer producer) throws IOException {
OSFlightServer server = OSFlightServerBuilder.builder(
allocator,
location,
producer,
sslContextProvider != null ? sslContextProvider.getServerSslContext() : null,
ServerConfig.serverChannelType(),
bossEventLoopGroup,
workerEventLoopGroup,
serverExecutor
).build();
private FlightServer buildAndStartServer(Location location, FlightProducer producer) throws IOException {
FlightServer server = OSFlightServer.builder()
.allocator(allocator)
.location(location)
.producer(producer)
.sslContext(sslContextProvider != null ? sslContextProvider.getServerSslContext() : null)
.channelType(ServerConfig.serverChannelType())
.bossEventLoopGroup(bossEventLoopGroup)
.workerEventLoopGroup(workerEventLoopGroup)
.executor(serverExecutor)
.build();
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
try {
server.start();
Expand Down

0 comments on commit 578c923

Please sign in to comment.