Revert "Update to Dropwizard 5"

This reverts commit 4dbd564442.
This commit is contained in:
Ravi Khadiwala
2025-11-05 15:55:50 -06:00
committed by ravi-signal
parent bb94975d74
commit faa74469ea
37 changed files with 711 additions and 698 deletions

10
pom.xml
View File

@@ -41,7 +41,7 @@
<braintree.version>3.44.0</braintree.version> <braintree.version>3.44.0</braintree.version>
<commons-csv.version>1.14.1</commons-csv.version> <commons-csv.version>1.14.1</commons-csv.version>
<commons-io.version>2.20.0</commons-io.version> <commons-io.version>2.20.0</commons-io.version>
<dropwizard.version>5.0.0</dropwizard.version> <dropwizard.version>4.0.16</dropwizard.version>
<!-- Note: when updating FoundationDB, also include a copy of `libfdb_c.so` from the FoundationDB release at <!-- Note: when updating FoundationDB, also include a copy of `libfdb_c.so` from the FoundationDB release at
src/main/jib/usr/lib/libfdb_c.so. We use x86_64 builds without AVX instructions enabled (i.e. FoundationDB versions src/main/jib/usr/lib/libfdb_c.so. We use x86_64 builds without AVX instructions enabled (i.e. FoundationDB versions
with even-numbered patch versions). Also when updating FoundationDB, make sure to update the version of FoundationDB with even-numbered patch versions). Also when updating FoundationDB, make sure to update the version of FoundationDB
@@ -251,6 +251,12 @@
<artifactId>commons-logging</artifactId> <artifactId>commons-logging</artifactId>
<version>1.3.5</version> <version>1.3.5</version>
</dependency> </dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>9.8</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>com.stripe</groupId> <groupId>com.stripe</groupId>
<artifactId>stripe-java</artifactId> <artifactId>stripe-java</artifactId>
@@ -348,7 +354,7 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.wiremock</groupId> <groupId>org.wiremock</groupId>
<artifactId>wiremock-jetty12</artifactId> <artifactId>wiremock</artifactId>
<version>3.13.1</version> <version>3.13.1</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@@ -237,15 +237,15 @@
<dependency> <dependency>
<groupId>org.eclipse.jetty.websocket</groupId> <groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-websocket-jetty-api</artifactId> <artifactId>websocket-jetty-api</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.eclipse.jetty.ee10</groupId> <groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-ee10-servlets</artifactId> <artifactId>jetty-servlets</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.eclipse.jetty.websocket</groupId> <groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-websocket-jetty-client</artifactId> <artifactId>websocket-jetty-client</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>

View File

@@ -37,12 +37,15 @@ import io.netty.resolver.ResolvedAddressTypes;
import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder; import io.netty.resolver.dns.DnsNameResolverBuilder;
import jakarta.servlet.DispatcherType; import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletRegistration;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.http.HttpClient; import java.net.http.HttpClient;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.List; import java.util.List;
@@ -57,9 +60,9 @@ import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry; import org.eclipse.jetty.websocket.core.WebSocketExtensionRegistry;
import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents; import org.eclipse.jetty.websocket.core.server.WebSocketServerComponents;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.server.ServerProperties;
import org.signal.i18n.HeaderControlledResourceBundleLookup; import org.signal.i18n.HeaderControlledResourceBundleLookup;
import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.GenericServerSecretParams;
@@ -133,7 +136,6 @@ import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.currency.FixerClient; import org.whispersystems.textsecuregcm.currency.FixerClient;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.ExternalRequestFilter; import org.whispersystems.textsecuregcm.filters.ExternalRequestFilter;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
@@ -182,7 +184,7 @@ import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.metrics.CallQualitySurveyManager; import org.whispersystems.textsecuregcm.metrics.CallQualitySurveyManager;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener; import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener;
import org.whispersystems.textsecuregcm.metrics.MetricsHttpEventHandler; import org.whispersystems.textsecuregcm.metrics.MetricsHttpChannelListener;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener; import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
@@ -595,6 +597,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ScheduledExecutorService cloudflareTurnRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "cloudflareTurnRetry").threads(1).build(); ScheduledExecutorService cloudflareTurnRetryExecutor = ScheduledExecutorServiceBuilder.of(environment, "cloudflareTurnRetry").threads(1).build();
ScheduledExecutorService messagePollExecutor = ScheduledExecutorServiceBuilder.of(environment, "messagePollExecutor").threads(1).build(); ScheduledExecutorService messagePollExecutor = ScheduledExecutorServiceBuilder.of(environment, "messagePollExecutor").threads(1).build();
ScheduledExecutorService provisioningWebsocketTimeoutExecutor = ScheduledExecutorServiceBuilder.of(environment, "provisioningWebsocketTimeout").threads(1).build(); ScheduledExecutorService provisioningWebsocketTimeoutExecutor = ScheduledExecutorServiceBuilder.of(environment, "provisioningWebsocketTimeout").threads(1).build();
ScheduledExecutorService jmxDumper = ScheduledExecutorServiceBuilder.of(environment, "jmxDumper").threads(1).build();
final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup(); final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup();
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next()) final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next())
@@ -910,6 +913,17 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(dnsResolutionEventLoopGroup); environment.lifecycle().manage(dnsResolutionEventLoopGroup);
environment.lifecycle().manage(exposedGrpcServer); environment.lifecycle().manage(exposedGrpcServer);
final List<Filter> filters = new ArrayList<>();
filters.add(remoteDeprecationFilter);
filters.add(new RemoteAddressFilter());
filters.add(new TimestampResponseFilter());
for (Filter filter : filters) {
environment.servlets()
.addFilter(filter.getClass().getSimpleName(), filter)
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
}
if (!config.getExternalRequestFilterConfiguration().paths().isEmpty()) { if (!config.getExternalRequestFilterConfiguration().paths().isEmpty()) {
environment.servlets().addFilter(ExternalRequestFilter.class.getSimpleName(), environment.servlets().addFilter(ExternalRequestFilter.class.getSimpleName(),
new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(), new ExternalRequestFilter(config.getExternalRequestFilterConfiguration().permittedInternalRanges(),
@@ -926,7 +940,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final String websocketServletPath = "/v1/websocket/"; final String websocketServletPath = "/v1/websocket/";
final String provisioningWebsocketServletPath = "/v1/websocket/provisioning/"; final String provisioningWebsocketServletPath = "/v1/websocket/provisioning/";
MetricsHttpEventHandler.configure(environment, Metrics.globalRegistry, clientReleaseManager, Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check")); final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager,
Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check"));
metricsHttpChannelListener.configure(environment);
final MessageMetrics messageMetrics = new MessageMetrics(); final MessageMetrics messageMetrics = new MessageMetrics();
final BackupMetrics backupMetrics = new BackupMetrics(); final BackupMetrics backupMetrics = new BackupMetrics();
@@ -1116,35 +1132,36 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE); webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
provisioningEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE); provisioningEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (context, container) -> {
final WebSocketExtensionRegistry extensionRegistry = WebSocketServerComponents
.getWebSocketComponents(environment.getApplicationContext().getServletContext())
.getExtensionRegistry();
if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) {
extensionRegistry.unregister("permessage-deflate");
} else if (config.getWebSocketConfiguration().isDisableCrossMessageOutgoingCompression()) {
extensionRegistry.unregister("permessage-deflate");
extensionRegistry.register("permessage-deflate", NoContextTakeoverPerMessageDeflateExtension.class);
}
});
WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); webSocketEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
WebSocketResourceProviderFactory<AuthenticatedDevice> provisioningServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<AuthenticatedDevice> provisioningServlet = new WebSocketResourceProviderFactory<>(
provisioningEnvironment, AuthenticatedDevice.class, RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); provisioningEnvironment, AuthenticatedDevice.class, config.getWebSocketConfiguration(),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
(servletContext, container) -> { ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
container.addMapping(websocketServletPath, webSocketServlet);
container.addMapping(provisioningWebsocketServletPath, provisioningServlet);
PriorityFilter.ensureFilter(servletContext, new TimestampResponseFilter()); websocket.addMapping(websocketServletPath);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); websocket.setAsyncSupported(true);
PriorityFilter.ensureFilter(servletContext, remoteDeprecationFilter);
container.setMaxBinaryMessageSize(config.getWebSocketConfiguration().getMaxBinaryMessageSize()); provisioning.addMapping(provisioningWebsocketServletPath);
container.setMaxTextMessageSize(config.getWebSocketConfiguration().getMaxTextMessageSize()); provisioning.setAsyncSupported(true);
final WebSocketExtensionRegistry extensionRegistry = WebSocketServerComponents
.getWebSocketComponents(environment.getApplicationContext())
.getExtensionRegistry();
if (config.getWebSocketConfiguration().isDisablePerMessageDeflate()) {
extensionRegistry.unregister("permessage-deflate");
} else if (config.getWebSocketConfiguration().isDisableCrossMessageOutgoingCompression()) {
extensionRegistry.unregister("permessage-deflate");
extensionRegistry.register("permessage-deflate", NoContextTakeoverPerMessageDeflateExtension.class);
}
});
environment.admin().addTask(new SetRequestLoggingEnabledTask()); environment.admin().addTask(new SetRequestLoggingEnabledTask());
} }
private void registerExceptionMappers(Environment environment, private void registerExceptionMappers(Environment environment,

View File

@@ -10,9 +10,10 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant;
import java.util.Optional; import java.util.Optional;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;

View File

@@ -1,82 +0,0 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.filters;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletContext;
import java.util.EnumSet;
import java.util.Objects;
import org.eclipse.jetty.ee10.servlet.FilterHolder;
import org.eclipse.jetty.ee10.servlet.FilterMapping;
import org.eclipse.jetty.ee10.servlet.ServletContextHandler;
import org.eclipse.jetty.ee10.servlet.ServletHandler;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.component.LifeCycle;
public class PriorityFilter {
private PriorityFilter() {}
private static FilterHolder getFilter(ServletContext servletContext, final Class<? extends Filter> filterClass) {
final ContextHandler contextHandler = Objects.requireNonNull(ServletContextHandler.getServletContextHandler(servletContext));
final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class);
return servletHandler.getFilter(filterClass.getName());
}
/**
* Ensure a filter is available on the provided ServletContext, a new filter will added if one does not already
* exist.
* <p>
* If a new filter is added, it will be added before all other filters.
* <p>
* Modeled after {@link org.eclipse.jetty.ee10.websocket.servlet.WebSocketUpgradeFilter#ensureFilter(ServletContext)},
* since its use of {@link org.eclipse.jetty.ee10.servlet.ServletHandler#prependFilter(FilterHolder)} is what makes
* this necessary.
*/
public static void ensureFilter(final ServletContext servletContext, final Filter filter) {
FilterHolder existingFilter = getFilter(servletContext, filter.getClass());
if (existingFilter != null) {
return;
}
final ContextHandler contextHandler = ServletContextHandler.getServletContextHandler(servletContext);
final ServletHandler servletHandler = contextHandler.getDescendant(ServletHandler.class);
final String pathSpec = "/*";
final FilterHolder holder = new FilterHolder(filter);
holder.setName(filter.getClass().getName());
holder.setAsyncSupported(true);
final FilterMapping mapping = new FilterMapping();
mapping.setFilterName(holder.getName());
mapping.setPathSpec(pathSpec);
mapping.setDispatcherTypes(EnumSet.of(DispatcherType.REQUEST));
// Add as the first filter in the list.
servletHandler.prependFilter(holder);
servletHandler.prependFilterMapping(mapping);
// If we create the filter we must also make sure it is removed if the context is stopped.
contextHandler.addEventListener(new LifeCycle.Listener()
{
@Override
public void lifeCycleStopping(LifeCycle event)
{
servletHandler.removeFilterHolder(holder);
servletHandler.removeFilterMapping(mapping);
contextHandler.removeEventListener(this);
}
@Override
public String toString()
{
return String.format("%sCleanupListener", filter.getClass().getSimpleName());
}
});
}
}

View File

@@ -26,6 +26,10 @@ public class RemoteAddressFilter implements Filter {
public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress"; public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress";
private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class); private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class);
public RemoteAddressFilter() {
}
@Override @Override
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
throws ServletException, IOException { throws ServletException, IOException {

View File

@@ -14,7 +14,7 @@ import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
/** /**
* Delegates request events to a listener that captures and reports request-level metrics. * Delegates request events to a listener that captures and reports request-level metrics.
* *
* @see MetricsHttpEventHandler * @see MetricsHttpChannelListener
* @see MetricsRequestEventListener * @see MetricsRequestEventListener
*/ */
public class MetricsApplicationEventListener implements ApplicationEventListener { public class MetricsApplicationEventListener implements ApplicationEventListener {
@@ -23,7 +23,7 @@ public class MetricsApplicationEventListener implements ApplicationEventListener
public MetricsApplicationEventListener(final TrafficSource trafficSource, final ClientReleaseManager clientReleaseManager) { public MetricsApplicationEventListener(final TrafficSource trafficSource, final ClientReleaseManager clientReleaseManager) {
if (trafficSource == TrafficSource.HTTP) { if (trafficSource == TrafficSource.HTTP) {
throw new IllegalArgumentException("Use " + MetricsHttpEventHandler.class.getName() + " for HTTP traffic"); throw new IllegalArgumentException("Use " + MetricsHttpChannelListener.class.getName() + " for HTTP traffic");
} }
this.metricsRequestEventListener = new MetricsRequestEventListener(trafficSource, clientReleaseManager); this.metricsRequestEventListener = new MetricsRequestEventListener(trafficSource, clientReleaseManager);
} }

View File

@@ -0,0 +1,197 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerResponseContext;
import jakarta.ws.rs.container.ContainerResponseFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.util.component.LifeCycle;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil;
/**
* Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get
* templated Jersey request paths, it implements {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give
* itself access to the template. It is limited to {@link TrafficSource#HTTP} requests.
* <p>
* It implements {@link LifeCycle.Listener} without overriding methods, so that it can be an event listener that
* Dropwizard will attach to the container&mdash;the {@link Container.Listener} implementation is where it attaches
* itself to any {@link Connector}s.
*
* @see MetricsRequestEventListener
*/
public class MetricsHttpChannelListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener,
ContainerResponseFilter {
private static final Logger logger = LoggerFactory.getLogger(MetricsHttpChannelListener.class);
private record RequestInfo(String path, String method, int statusCode, @Nullable String userAgent) {
}
private final ClientReleaseManager clientReleaseManager;
private final Set<String> servletPaths;
// Use the same counter namespace as MetricsRequestEventListener for continuity
public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME;
public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME;
@VisibleForTesting
static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String URI_INFO_PROPERTY_NAME = MetricsHttpChannelListener.class.getName() + ".uriInfo";
@VisibleForTesting
static final String PATH_TAG = "path";
@VisibleForTesting
static final String METHOD_TAG = "method";
@VisibleForTesting
static final String STATUS_CODE_TAG = "status";
@VisibleForTesting
static final String TRAFFIC_SOURCE_TAG = "trafficSource";
private final MeterRegistry meterRegistry;
public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set<String> servletPaths) {
this(Metrics.globalRegistry, clientReleaseManager, servletPaths);
}
@VisibleForTesting
MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager,
final Set<String> servletPaths) {
this.meterRegistry = meterRegistry;
this.clientReleaseManager = clientReleaseManager;
this.servletPaths = servletPaths;
}
public void configure(final Environment environment) {
// register as ContainerResponseFilter
environment.jersey().register(this);
// hook into lifecycle events, to react to the Connector being added
environment.lifecycle().addEventListener(this);
}
@Override
public void onRequestFailure(final Request request, final Throwable failure) {
if (logger.isDebugEnabled()) {
final RequestInfo requestInfo = getRequestInfo(request);
logger.debug("Request failure: {} {} ({}) [{}] ",
requestInfo.method(),
requestInfo.path(),
requestInfo.userAgent(),
requestInfo.statusCode(), failure);
}
}
@Override
public void onResponseFailure(Request request, Throwable failure) {
if (failure instanceof org.eclipse.jetty.io.EofException) {
// the client disconnected early
return;
}
final RequestInfo requestInfo = getRequestInfo(request);
logger.warn("Response failure: {} {} ({}) [{}] ",
requestInfo.method(),
requestInfo.path(),
requestInfo.userAgent(),
requestInfo.statusCode(), failure);
}
@Override
public void onComplete(final Request request) {
final RequestInfo requestInfo = getRequestInfo(request);
final List<Tag> tags = new ArrayList<>(5);
tags.add(Tag.of(PATH_TAG, requestInfo.path()));
tags.add(Tag.of(METHOD_TAG, requestInfo.method()));
tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(requestInfo.statusCode())));
tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()));
tags.addAll(UserAgentTagUtil.getLibsignalAndPlatformTags(requestInfo.userAgent()));
final Optional<Tag> maybeClientVersionTag =
UserAgentTagUtil.getClientVersionTag(requestInfo.userAgent, clientReleaseManager);
maybeClientVersionTag.ifPresent(tags::add);
meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment();
meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(request.getResponse().getContentCount());
meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(request.getContentRead());
maybeClientVersionTag.ifPresent(clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME,
Tags.of(clientVersionTag, UserAgentTagUtil.getPlatformTag(requestInfo.userAgent)))
.increment());
}
@Override
public void beanAdded(final Container parent, final Object child) {
if (child instanceof Connector connector) {
connector.addBean(this);
}
}
@Override
public void beanRemoved(final Container parent, final Object child) {
}
@Override
public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext)
throws IOException {
requestContext.setProperty(URI_INFO_PROPERTY_NAME, requestContext.getUriInfo());
}
private RequestInfo getRequestInfo(Request request) {
final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME))
.map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr))
.orElseGet(() ->
Optional.ofNullable(request.getPathInfo())
.filter(servletPaths::contains)
.orElse("unknown")
);
final String method = Optional.ofNullable(request.getMethod()).orElse("unknown");
// Response cannot be null, but its status might not always reflect an actual response status, since it gets
// initialized to 200
final int status = request.getResponse().getStatus();
@Nullable final String userAgent = request.getHeader(HttpHeaders.USER_AGENT);
return new RequestInfo(path, method, status, userAgent);
}
}

View File

@@ -1,250 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import jakarta.validation.constraints.NotNull;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.container.ContainerResponseContext;
import jakarta.ws.rs.container.ContainerResponseFilter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.handler.EventsHandler;
import org.eclipse.jetty.util.component.LifeCycle;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil;
/**
* Gathers and reports HTTP request metrics at the Jetty container level, which sits above Jersey. In order to get
* templated Jersey request paths, it adds a {@link jakarta.ws.rs.container.ContainerResponseFilter}, in order to give
* itself access to the template. It is limited to {@link TrafficSource#HTTP} requests.
*
* @see MetricsRequestEventListener
*/
public class MetricsHttpEventHandler extends EventsHandler {
private static final Logger logger = LoggerFactory.getLogger(MetricsHttpEventHandler.class);
private final ClientReleaseManager clientReleaseManager;
private final Set<String> servletPaths;
// Use the same counter namespace as MetricsRequestEventListener for continuity
public static final String REQUEST_COUNTER_NAME = MetricsRequestEventListener.REQUEST_COUNTER_NAME;
public static final String REQUESTS_BY_VERSION_COUNTER_NAME = MetricsRequestEventListener.REQUESTS_BY_VERSION_COUNTER_NAME;
@VisibleForTesting
static final String RESPONSE_BYTES_COUNTER_NAME = MetricsRequestEventListener.RESPONSE_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String REQUEST_BYTES_COUNTER_NAME = MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME;
@VisibleForTesting
static final String REQUEST_INFO_PROPERTY_NAME = MetricsHttpEventHandler.class.getName() + ".requestInfo";
@VisibleForTesting
static final String PATH_TAG = "path";
@VisibleForTesting
static final String METHOD_TAG = "method";
@VisibleForTesting
static final String STATUS_CODE_TAG = "status";
@VisibleForTesting
static final String TRAFFIC_SOURCE_TAG = "trafficSource";
private final MeterRegistry meterRegistry;
@VisibleForTesting
MetricsHttpEventHandler(
final Handler handler,
final MeterRegistry meterRegistry,
final ClientReleaseManager clientReleaseManager,
final Set<String> servletPaths) {
super(handler);
this.meterRegistry = meterRegistry;
this.clientReleaseManager = clientReleaseManager;
this.servletPaths = servletPaths;
}
/**
* Configure a {@link MetricsHttpEventHandler}
*
* @param environment A dropwizard {@link org.eclipse.jetty.util.component.Environment}
* @param meterRegistry The meter registry to register metrics with
* @param clientReleaseManager A {@link ClientReleaseManager} that determines what tags to include with metrics
* @param servletPaths An allow-list of paths to include in metric tags for requests that are handled by above
* Jersey
*/
public static void configure(final Environment environment, final MeterRegistry meterRegistry,
final ClientReleaseManager clientReleaseManager, final Set<String> servletPaths) {
// register a filter that will set the initial request info
environment.jersey().register(new SetInfoRequestFilter());
// hook into lifecycle events, to react to the Connector being added
environment.lifecycle().addEventListener(new LifeCycle.Listener() {
@Override
public void lifeCycleStarting(LifeCycle event) {
if (event instanceof Server server) {
server.setHandler(
new MetricsHttpEventHandler(server.getHandler(), meterRegistry, clientReleaseManager, servletPaths));
}
}
});
}
private void onResponseFailure(Request request, int status, Throwable failure) {
if (failure instanceof org.eclipse.jetty.io.EofException) {
// the client disconnected early
return;
}
final RequestInfo requestInfo = getRequestInfo(request);
logger.warn("Response failure: {} {} ({}) [{}] ",
requestInfo.method,
requestInfo.path,
requestInfo.userAgent,
status,
failure);
}
@Override
public void onComplete(Request request, int status, HttpFields headers, Throwable failure) {
super.onComplete(request, status, headers, failure);
if (failure != null) {
onResponseFailure(request, status, failure);
}
final RequestInfo requestInfo = getRequestInfo(request);
final List<Tag> tags = new ArrayList<>(5);
tags.add(Tag.of(PATH_TAG, requestInfo.path));
tags.add(Tag.of(METHOD_TAG, requestInfo.method));
tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(status)));
tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()));
tags.addAll(UserAgentTagUtil.getLibsignalAndPlatformTags(requestInfo.userAgent));
final Optional<Tag> maybeClientVersionTag =
UserAgentTagUtil.getClientVersionTag(requestInfo.userAgent, clientReleaseManager);
maybeClientVersionTag.ifPresent(tags::add);
meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment();
meterRegistry.counter(RESPONSE_BYTES_COUNTER_NAME, tags).increment(requestInfo.responseBytes);
meterRegistry.counter(REQUEST_BYTES_COUNTER_NAME, tags).increment(requestInfo.requestBytes);
maybeClientVersionTag.ifPresent(clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME,
Tags.of(clientVersionTag, UserAgentTagUtil.getPlatformTag(requestInfo.userAgent)))
.increment());
}
@Override
protected void onRequestRead(final Request request, final Content.Chunk chunk) {
super.onRequestRead(request, chunk);
if (chunk != null) {
getRequestInfo(request).requestBytes += chunk.remaining();
}
}
@Override
protected void onResponseWrite(final Request request, final boolean last, final ByteBuffer content) {
super.onResponseWrite(request, last, content);
if (content != null) {
getRequestInfo(request).responseBytes += content.remaining();
}
}
private RequestInfo getRequestInfo(Request request) {
Object obj = request.getAttribute(REQUEST_INFO_PROPERTY_NAME);
if (obj != null && obj instanceof RequestInfo requestInfo) {
return requestInfo;
}
// Our ContainerResponseFilter has not run yet. It should eventually run, and will override the path we set here.
// It may not run if this is a websocket upgrade request, a request handled by jetty directly, or a higher priority
// filter aborted the request by throwing an exception, in which case we'll use this path. To avoid giving every
// incorrect path a unique tag we check against a configured list of paths that we know would skip the filter.
final RequestInfo newInfo = new RequestInfo(
Optional.ofNullable(request.getHttpURI().getPath()).filter(servletPaths::contains).orElse("unknown"),
Optional.ofNullable(request.getMethod()).orElse("unknown"),
request.getHeaders().get(HttpHeaders.USER_AGENT));
request.setAttribute(REQUEST_INFO_PROPERTY_NAME, newInfo);
return newInfo;
}
@VisibleForTesting
static class RequestInfo {
private String path;
private final String method;
private final @Nullable String userAgent;
private long requestBytes;
private long responseBytes;
RequestInfo(@NotNull String path, @NotNull String method, @Nullable String userAgent) {
this.path = path;
this.method = method;
this.userAgent = userAgent;
this.requestBytes = 0;
this.responseBytes = 0;
}
@Override
public boolean equals(final Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
RequestInfo that = (RequestInfo) o;
return requestBytes == that.requestBytes && responseBytes == that.responseBytes && Objects.equals(path, that.path)
&& Objects.equals(method, that.method) && Objects.equals(userAgent, that.userAgent);
}
@Override
public int hashCode() {
return Objects.hash(path, method, userAgent, requestBytes, responseBytes);
}
}
@VisibleForTesting
static class SetInfoRequestFilter implements ContainerResponseFilter {
@Override
public void filter(final ContainerRequestContext requestContext, final ContainerResponseContext responseContext) {
// Construct the templated URI path. If no matching path is found, this will be ""
final String path = UriInfoUtil.getPathTemplate((ExtendedUriInfo) requestContext.getUriInfo());
final Object obj = requestContext.getProperty(REQUEST_INFO_PROPERTY_NAME);
if (obj != null && obj instanceof RequestInfo requestInfo) {
requestInfo.path = path;
} else {
requestContext.setProperty(REQUEST_INFO_PROPERTY_NAME,
new RequestInfo(path, requestContext.getMethod(), requestContext.getHeaderString(HttpHeaders.USER_AGENT)));
}
}
}
}

View File

@@ -17,6 +17,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.eclipse.jetty.io.ArrayByteBufferPool;
import org.glassfish.jersey.server.ContainerResponse; import org.glassfish.jersey.server.ContainerResponse;
import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener; import org.glassfish.jersey.server.monitoring.RequestEventListener;
@@ -28,7 +29,7 @@ import org.whispersystems.websocket.WebSocketResourceProvider;
/** /**
* Gathers and reports request-level metrics for WebSocket traffic only. * Gathers and reports request-level metrics for WebSocket traffic only.
* For HTTP traffic, use {@link MetricsHttpEventHandler}. * For HTTP traffic, use {@link MetricsHttpChannelListener}.
*/ */
public class MetricsRequestEventListener implements RequestEventListener { public class MetricsRequestEventListener implements RequestEventListener {
@@ -63,7 +64,7 @@ public class MetricsRequestEventListener implements RequestEventListener {
this(trafficSource, Metrics.globalRegistry, clientReleaseManager); this(trafficSource, Metrics.globalRegistry, clientReleaseManager);
if (trafficSource == TrafficSource.HTTP) { if (trafficSource == TrafficSource.HTTP) {
logger.warn("Use {} for HTTP traffic", MetricsHttpEventHandler.class.getName()); logger.warn("Use {} for HTTP traffic", MetricsHttpChannelListener.class.getName());
} }
} }

View File

@@ -21,7 +21,7 @@ import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import org.eclipse.jetty.util.resource.PathResourceFactory; import org.eclipse.jetty.util.resource.Resource;
import org.eclipse.jetty.util.security.CertificateUtils; import org.eclipse.jetty.util.security.CertificateUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -37,7 +37,7 @@ public class TlsCertificateExpirationUtil {
final KeyStore keyStore; final KeyStore keyStore;
try { try {
keyStore = CertificateUtils.getKeyStore(new PathResourceFactory().newResource(keyStorePath), keyStoreType, keyStoreProvider, keyStore = CertificateUtils.getKeyStore(Resource.newResource(keyStorePath), keyStoreType, keyStoreProvider,
keyStorePassword); keyStorePassword);
} catch (Exception e) { } catch (Exception e) {

View File

@@ -0,0 +1,33 @@
package org.whispersystems.textsecuregcm.util;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.management.MBeanServer;
import java.lang.management.ManagementFactory;
import java.util.concurrent.ScheduledExecutorService;
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
public class JmxDumper implements Managed {
private static final Logger log = LoggerFactory.getLogger(JmxDumper.class);
private final ScheduledExecutorService executor;
public JmxDumper(final ScheduledExecutorService executor) {
this.executor = executor;
}
@Override
public void start() throws Exception {
// executor.schedule()
}
private void dump() {
MBeanServer mbs = ManagementFactory.getPlatformMBeanServer();
}
}

View File

@@ -8,14 +8,14 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader; import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.Optional;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedDevice> { public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedDevice> {
@@ -27,7 +27,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
} }
@Override @Override
public Optional<AuthenticatedDevice> authenticate(final JettyServerUpgradeRequest request) public Optional<AuthenticatedDevice> authenticate(final UpgradeRequest request)
throws InvalidCredentialsException { throws InvalidCredentialsException {
@Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); @Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);

View File

@@ -15,6 +15,7 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@@ -34,6 +35,7 @@ public class BufferingInterceptorIntegrationTest {
environment.jersey().register(testController); environment.jersey().register(testController);
environment.jersey().register(new BufferingInterceptor()); environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10)); environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-", 10));
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
} }
} }

View File

@@ -6,6 +6,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@@ -18,40 +19,39 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment; import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.nio.ByteBuffer;
import java.time.Duration; import java.time.Duration;
import java.util.EnumSet;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener; import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class ProvisioningTimeoutIntegrationTest { public class ProvisioningTimeoutIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
@@ -77,9 +77,9 @@ public class ProvisioningTimeoutIntegrationTest {
CompletableFuture<String> provisioningAddressFuture = new CompletableFuture<>(); CompletableFuture<String> provisioningAddressFuture = new CompletableFuture<>();
@Override @Override
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try { try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.REQUEST_MESSAGE
&& webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) { && webSocketMessage.getRequestMessage().getPath().equals("/v1/address")) {
MessageProtos.ProvisioningAddress provisioningAddress = MessageProtos.ProvisioningAddress provisioningAddress =
@@ -92,7 +92,7 @@ public class ProvisioningTimeoutIntegrationTest {
} catch (InvalidProtocolBufferException e) { } catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
super.onWebSocketBinary(payload, callback); super.onWebSocketBinary(payload, offset, length);
} }
} }
@@ -106,17 +106,21 @@ public class ProvisioningTimeoutIntegrationTest {
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration); new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5))); new ProvisioningConnectListener(mock(ProvisioningManager.class), scheduler, Duration.ofSeconds(5)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet = final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
REMOTE_ADDRESS_ATTRIBUTE_NAME); webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
container.addMapping("/websocket", webSocketServlet); final ServletRegistration.Dynamic websocketServlet = environment.servlets()
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter()); .addServlet("WebSocket", webSocketServlet);
}); websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
} }
} }

View File

@@ -9,6 +9,8 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment; import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET; import jakarta.ws.rs.GET;
import jakarta.ws.rs.PUT; import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
@@ -18,20 +20,19 @@ import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.MediaType;
import java.io.IOException; import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.util.EnumSet;
import java.util.Optional; import java.util.Optional;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync; import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@@ -40,7 +41,6 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class WebsocketResourceProviderIntegrationTest { public class WebsocketResourceProviderIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class); new DropwizardAppExtension<>(TestApplication.class);
@@ -72,6 +72,9 @@ public class WebsocketResourceProviderIntegrationTest {
new WebSocketEnvironment<>(environment, webSocketConfiguration); new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController); environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController); webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter()); webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class))); webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class)));
@@ -82,13 +85,15 @@ public class WebsocketResourceProviderIntegrationTest {
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet = final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
REMOTE_ADDRESS_ATTRIBUTE_NAME); webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
container.addMapping("/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
} }
} }

View File

@@ -15,8 +15,8 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;

View File

@@ -13,17 +13,20 @@ import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment; import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension; import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.ws.rs.GET; import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path; import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Client; import jakarta.ws.rs.client.Client;
import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.Context;
import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.URI; import java.net.URI;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.Principal; import java.security.Principal;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.EnumSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
@@ -32,15 +35,14 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.eclipse.jetty.util.HostPort; import org.eclipse.jetty.util.HostPort;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
@@ -53,7 +55,6 @@ import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFa
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class RemoteAddressFilterIntegrationTest { class RemoteAddressFilterIntegrationTest {
private static final String WEBSOCKET_PREFIX = "/websocket"; private static final String WEBSOCKET_PREFIX = "/websocket";
@@ -130,7 +131,7 @@ class RemoteAddressFilterIntegrationTest {
} }
} }
public static class ClientEndpoint implements Session.Listener.AutoDemanding { private static class ClientEndpoint implements WebSocketListener {
private final String requestPath; private final String requestPath;
private final CompletableFuture<byte[]> responseFuture; private final CompletableFuture<byte[]> responseFuture;
@@ -144,19 +145,22 @@ class RemoteAddressFilterIntegrationTest {
} }
@Override @Override
public void onWebSocketOpen(final Session session) { public void onWebSocketConnect(final Session session) {
final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath, final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath,
List.of("Accept: application/json"), List.of("Accept: application/json"),
Optional.empty()).toByteArray(); Optional.empty()).toByteArray();
try {
session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP); session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
} }
@Override @Override
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try { try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
assert 200 == webSocketMessage.getResponseMessage().getStatus(); assert 200 == webSocketMessage.getResponseMessage().getStatus();
@@ -202,6 +206,10 @@ class RemoteAddressFilterIntegrationTest {
public void run(final Configuration configuration, public void run(final Configuration configuration,
final Environment environment) throws Exception { final Environment environment) throws Exception {
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
environment.jersey().register(new TestRemoteAddressController()); environment.jersey().register(new TestRemoteAddressController());
// WebSocket set up // WebSocket set up
@@ -212,14 +220,15 @@ class RemoteAddressFilterIntegrationTest {
webSocketEnvironment.jersey().register(new TestWebSocketController()); webSocketEnvironment.jersey().register(new TestWebSocketController());
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), (servletContext, container) -> { environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
container.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH, webSocketServlet); .addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
} }
} }

View File

@@ -27,6 +27,7 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import jakarta.annotation.Priority; import jakarta.annotation.Priority;
import jakarta.servlet.DispatcherType;
import jakarta.ws.rs.GET; import jakarta.ws.rs.GET;
import jakarta.ws.rs.InternalServerErrorException; import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.NotAuthorizedException; import jakarta.ws.rs.NotAuthorizedException;
@@ -43,6 +44,7 @@ import java.io.IOException;
import java.net.URI; import java.net.URI;
import java.security.Principal; import java.security.Principal;
import java.time.Duration; import java.time.Duration;
import java.util.EnumSet;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@@ -52,28 +54,25 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import org.eclipse.jetty.ee10.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.server.Connector;
import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Server; import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.server.handler.EventsHandler;
import org.eclipse.jetty.util.component.LifeCycle; import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.PriorityFilter;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.WebSocketResourceProviderFactory;
@@ -81,8 +80,7 @@ import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) class MetricsHttpChannelListenerIntegrationTest {
class MetricsHttpEventHandlerIntegrationTest {
private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP; private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP;
private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class); private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class);
@@ -92,7 +90,7 @@ class MetricsHttpEventHandlerIntegrationTest {
private static final AtomicReference<CountDownLatch> COUNT_DOWN_LATCH_FUTURE_REFERENCE = new AtomicReference<>(); private static final AtomicReference<CountDownLatch> COUNT_DOWN_LATCH_FUTURE_REFERENCE = new AtomicReference<>();
private static final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>( private static final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(
MetricsHttpEventHandlerIntegrationTest.TestApplication.class); MetricsHttpChannelListenerIntegrationTest.TestApplication.class);
@AfterEach @AfterEach
void teardown() { void teardown() {
@@ -113,9 +111,9 @@ class MetricsHttpEventHandlerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of( final Map<String, Counter> counterMap = Map.of(
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
); );
when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
@@ -149,7 +147,7 @@ class MetricsHttpEventHandlerIntegrationTest {
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment(); verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue(); final Iterable<Tag> tagIterable = tagCaptor.getValue();
@@ -160,11 +158,11 @@ class MetricsHttpEventHandlerIntegrationTest {
} }
assertEquals(6, tags.size()); assertEquals(6, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, expectedTagPath))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, expectedTagPath)));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(expectedStatus)))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(expectedStatus))));
assertTrue( assertTrue(
tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false")));
} }
@@ -196,19 +194,24 @@ class MetricsHttpEventHandlerIntegrationTest {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
final Map<String, Counter> counterMap = Map.of( final Map<String, Counter> counterMap = Map.of(
MetricsHttpEventHandler.REQUEST_COUNTER_NAME, REQUEST_COUNTER, MetricsHttpChannelListener.REQUEST_COUNTER_NAME, REQUEST_COUNTER,
MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER, MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME, RESPONSE_BYTES_COUNTER,
MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME, REQUEST_BYTES_COUNTER
); );
when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class))); .thenAnswer(a -> counterMap.getOrDefault(a.getArgument(0, String.class), mock(Counter.class)));
client.connect(new AutoClosingWebSocketSessionListener(), client.connect(new WebSocketListener() {
@Override
public void onWebSocketConnect(final Session session) {
session.close(1000, "OK");
}
},
URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest); URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest);
assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS)); assertTrue(countDownLatch.await(1000, TimeUnit.MILLISECONDS));
verify(METER_REGISTRY).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(REQUEST_COUNTER).increment(); verify(REQUEST_COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue(); final Iterable<Tag> tagIterable = tagCaptor.getValue();
@@ -219,11 +222,11 @@ class MetricsHttpEventHandlerIntegrationTest {
} }
assertEquals(6, tags.size()); assertEquals(6, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, "/v1/websocket"))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket")));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, "GET"))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(101)))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101))));
assertTrue( assertTrue(
tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false")));
} }
@@ -245,16 +248,17 @@ class MetricsHttpEventHandlerIntegrationTest {
public void run(final Configuration configuration, public void run(final Configuration configuration,
final Environment environment) throws Exception { final Environment environment) throws Exception {
MetricsHttpEventHandler.configure(environment, METER_REGISTRY, mock(ClientReleaseManager.class), Set.of("/v1/websocket")); final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(
METER_REGISTRY,
mock(ClientReleaseManager.class),
Set.of("/v1/websocket")
);
environment.lifecycle().addEventListener(new LifeCycle.Listener() { metricsHttpChannelListener.configure(environment);
@Override environment.lifecycle().addEventListener(new TestListener(COUNT_DOWN_LATCH_FUTURE_REFERENCE));
public void lifeCycleStarting(final LifeCycle event) {
if (event instanceof Server server) { environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter())
server.setHandler(new TestListener(server.getHandler(), COUNT_DOWN_LATCH_FUTURE_REFERENCE)); .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
}
}
});
environment.jersey().register(new TestResource()); environment.jersey().register(new TestResource());
environment.jersey().register(new TestAuthFilter()); environment.jersey().register(new TestAuthFilter());
@@ -267,15 +271,14 @@ class MetricsHttpEventHandlerIntegrationTest {
webSocketEnvironment.jersey().register(new TestResource()); webSocketEnvironment.jersey().register(new TestResource());
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), environment.servlets().addServlet("WebSocket", webSocketServlet)
(servletContext, container) -> { .addMapping("/v1/websocket");
container.addMapping("/v1/websocket", webSocketServlet);
PriorityFilter.ensureFilter(servletContext, new RemoteAddressFilter());
});
} }
} }
@@ -291,23 +294,36 @@ class MetricsHttpEventHandlerIntegrationTest {
} }
/** /**
* A simple listener to signal that {@link EventsHandler} has completed its work, since its onComplete() is on * A simple listener to signal that {@link HttpChannel.Listener} has completed its work, since its onComplete() is on
* a different thread from the one that sends the response, creating a race condition between the listener and the * a different thread from the one that sends the response, creating a race condition between the listener and the
* test assertions * test assertions
*/ */
static class TestListener extends EventsHandler { static class TestListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener {
private final AtomicReference<CountDownLatch> completableFutureAtomicReference; private final AtomicReference<CountDownLatch> completableFutureAtomicReference;
TestListener(final Handler handler, AtomicReference<CountDownLatch> countDownLatchReference) { TestListener(AtomicReference<CountDownLatch> countDownLatchReference) {
super(handler);
this.completableFutureAtomicReference = countDownLatchReference; this.completableFutureAtomicReference = countDownLatchReference;
} }
@Override @Override
public void onComplete(Request request, int status, HttpFields headers, Throwable failure) { public void onComplete(final Request request) {
completableFutureAtomicReference.get().countDown(); completableFutureAtomicReference.get().countDown();
} }
@Override
public void beanAdded(final Container parent, final Object child) {
if (child instanceof Connector connector) {
connector.addBean(this);
}
}
@Override
public void beanRemoved(final Container parent, final Object child) {
}
} }
@Path("/v1/test") @Path("/v1/test")
@@ -349,11 +365,4 @@ class MetricsHttpEventHandlerIntegrationTest {
} }
} }
public static class AutoClosingWebSocketSessionListener implements Session.Listener.AutoDemanding {
@Override
public void onWebSocketOpen(final Session session) {
session.close(1000, "OK", Callback.NOOP);
}
}
} }

View File

@@ -18,30 +18,23 @@ import com.google.common.net.HttpHeaders;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import java.nio.ByteBuffer; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpURI; import org.eclipse.jetty.http.HttpURI;
import org.eclipse.jetty.io.Content;
import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response; import org.eclipse.jetty.server.Response;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse;
import org.glassfish.jersey.server.ExtendedUriInfo; import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.uri.UriTemplate; import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
class MetricsHttpEventHandlerTest { class MetricsHttpChannelListenerTest {
private final static String USER_AGENT = "Signal-Android/6.53.7 (Android 8.1)";
private MeterRegistry meterRegistry; private MeterRegistry meterRegistry;
private Counter requestCounter; private Counter requestCounter;
@@ -49,7 +42,7 @@ class MetricsHttpEventHandlerTest {
private Counter responseBytesCounter; private Counter responseBytesCounter;
private Counter requestBytesCounter; private Counter requestBytesCounter;
private ClientReleaseManager clientReleaseManager; private ClientReleaseManager clientReleaseManager;
private MetricsHttpEventHandler listener; private MetricsHttpChannelListener listener;
@BeforeEach @BeforeEach
void setup() { void setup() {
@@ -59,28 +52,27 @@ class MetricsHttpEventHandlerTest {
responseBytesCounter = mock(Counter.class); responseBytesCounter = mock(Counter.class);
requestBytesCounter = mock(Counter.class); requestBytesCounter = mock(Counter.class);
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestCounter); .thenReturn(requestCounter);
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestsByVersionCounter); .thenReturn(requestsByVersionCounter);
when(meterRegistry.counter(eq(MetricsHttpEventHandler.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsHttpChannelListener.RESPONSE_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(responseBytesCounter); .thenReturn(responseBytesCounter);
when(meterRegistry.counter(eq(MetricsHttpEventHandler.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter); .thenReturn(requestBytesCounter);
clientReleaseManager = mock(ClientReleaseManager.class); clientReleaseManager = mock(ClientReleaseManager.class);
listener = new MetricsHttpEventHandler(null, meterRegistry, clientReleaseManager, Set.of("/test")); listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet());
} }
@CartesianTest @ParameterizedTest
@ValueSource(booleans = {true, false})
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
void testRequests(@CartesianTest.Values(booleans = {true, false}) final boolean pathFromFilter, void testRequests(final boolean versionActive) {
@CartesianTest.Values(booleans = {true, false}) final boolean versionActive) {
final String path = "/test"; final String path = "/test";
final String method = "GET"; final String method = "GET";
final int statusCode = 200; final int statusCode = 200;
@@ -90,40 +82,30 @@ class MetricsHttpEventHandlerTest {
final Request request = mock(Request.class); final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method); when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
final HttpFields.Mutable requestHeaders = HttpFields.build();
requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT);
when(request.getHeaders()).thenReturn(requestHeaders);
when(request.getHttpURI()).thenReturn(httpUri); when(request.getHttpURI()).thenReturn(httpUri);
if (pathFromFilter) {
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT));
} else {
when(request.setAttribute(eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME), any())).thenAnswer(invocation -> {
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(invocation.getArgument(1));
return null;
});
}
final Response response = mock(Response.class); final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode); when(response.getStatus()).thenReturn(statusCode);
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive); when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive);
when(response.getContentCount()).thenReturn(1024L);
when(request.getResponse()).thenReturn(response);
when(request.getContentRead()).thenReturn(512L);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true)); listener.onComplete(request);
listener.onResponseWrite(request, true, ByteBuffer.allocate(1024));
listener.onComplete(request, statusCode, requestHeaders, null);
verify(requestCounter).increment(); verify(requestCounter).increment();
verify(responseBytesCounter).increment(1024L); verify(responseBytesCounter).increment(1024L);
verify(requestBytesCounter).increment(512L); verify(requestBytesCounter).increment(512L);
verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
final Set<Tag> tags = new HashSet<>(); final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagCaptor.getValue()) { for (final Tag tag : tagCaptor.getValue()) {
@@ -131,11 +113,11 @@ class MetricsHttpEventHandlerTest {
} }
assertEquals(versionActive ? 7 : 6, tags.size()); assertEquals(versionActive ? 7 : 6, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.PATH_TAG, path))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, path)));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.METHOD_TAG, method))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, method)));
assertTrue(tags.contains(Tag.of(MetricsHttpEventHandler.STATUS_CODE_TAG, String.valueOf(statusCode)))); assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(statusCode))));
assertTrue( assertTrue(
tags.contains(Tag.of(MetricsHttpEventHandler.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()))); tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.LIBSIGNAL_TAG, "false")));
assertEquals(versionActive, tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "6.53.7"))); assertEquals(versionActive, tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "6.53.7")));
@@ -155,22 +137,23 @@ class MetricsHttpEventHandlerTest {
final Request request = mock(Request.class); final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method); when(request.getMethod()).thenReturn(method);
final HttpFields.Mutable requestHeaders = HttpFields.build(); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
requestHeaders.put(HttpHeader.USER_AGENT, USER_AGENT);
when(request.getHeaders()).thenReturn(requestHeaders);
when(request.getHttpURI()).thenReturn(httpUri); when(request.getHttpURI()).thenReturn(httpUri);
when(request.getAttribute(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME))
.thenReturn(new MetricsHttpEventHandler.RequestInfo(path, method, USER_AGENT));
final Response response = mock(Response.class); final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode); when(response.getStatus()).thenReturn(statusCode);
listener.onRequestRead(request, Content.Chunk.from(ByteBuffer.allocate(512), true)); when(response.getContentCount()).thenReturn(1024L);
listener.onResponseWrite(request, true, ByteBuffer.allocate(1024)); when(request.getResponse()).thenReturn(response);
listener.onComplete(request, statusCode, requestHeaders, null); when(request.getContentRead()).thenReturn(512L);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
listener.onComplete(request);
if (versionActive) { if (versionActive) {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).counter(eq(MetricsHttpEventHandler.REQUESTS_BY_VERSION_COUNTER_NAME), verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME),
tagCaptor.capture()); tagCaptor.capture());
final Set<Tag> tags = new HashSet<>(); final Set<Tag> tags = new HashSet<>();
tags.clear(); tags.clear();
@@ -184,38 +167,7 @@ class MetricsHttpEventHandlerTest {
} else { } else {
verifyNoInteractions(requestsByVersionCounter); verifyNoInteractions(requestsByVersionCounter);
} }
}
@Test
void testResponseFilterSetsRequestInfo() {
final ContainerRequest request = mock(ContainerRequest.class);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test")));
when(request.getMethod()).thenReturn("GET");
when(request.getHeaders()).thenReturn(null);
when(request.getUriInfo()).thenReturn(extendedUriInfo);
when(request.getHeaderString(HttpHeaders.USER_AGENT)).thenReturn(USER_AGENT);
new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class));
verify(request).setProperty(
eq(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME),
eq(new MetricsHttpEventHandler.RequestInfo("/test", "GET", USER_AGENT)));
}
@Test
void testResponseFilterModifiesRequestInfo() {
final MetricsHttpEventHandler.RequestInfo requestInfo =
new MetricsHttpEventHandler.RequestInfo("unknown", "POST", USER_AGENT);
final ContainerRequest request = mock(ContainerRequest.class);
when(request.getProperty(MetricsHttpEventHandler.REQUEST_INFO_PROPERTY_NAME)).thenReturn(requestInfo);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate("/test")));
when(request.getUriInfo()).thenReturn(extendedUriInfo);
new MetricsHttpEventHandler.SetInfoRequestFilter().filter(request, mock(ContainerResponse.class));
assertEquals(new MetricsHttpEventHandler.RequestInfo("/test", "POST", USER_AGENT), requestInfo);
} }
} }

View File

@@ -36,9 +36,10 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse; import org.glassfish.jersey.server.ContainerResponse;
@@ -169,9 +170,11 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class); final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class); final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)"); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)");
when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of("Signal-Android/4.53.7 (Android 8.1)"))); when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of("Signal-Android/4.53.7 (Android 8.1)")));
@@ -183,15 +186,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter); .thenReturn(requestBytesCounter);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -235,9 +238,11 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class); final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class); final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(
@@ -247,15 +252,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter); .thenReturn(requestBytesCounter);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -302,9 +307,11 @@ class MetricsRequestEventListenerTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class); final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
final UpgradeRequest request = mock(UpgradeRequest.class); final UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn( when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(
@@ -314,15 +321,15 @@ class MetricsRequestEventListenerTest {
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class))) when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_BYTES_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestBytesCounter); .thenReturn(requestBytesCounter);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", final byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); final ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);

View File

@@ -149,13 +149,15 @@ class TlsCertificateExpirationUtilTest {
@Test @Test
void test() throws Exception { void test() throws Exception {
final Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64); try (Resource keystore = TestResource.fromBase64Mime("keystore", KEYSTORE_BASE64)) {
final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD);
final Map<String, Instant> expected = Map.of( final KeyStore keyStore = CertificateUtils.getKeyStore(keystore, "PKCS12", null, KEYSTORE_PASSWORD);
"localhost:EdDSA", EDDSA_EXPIRATION,
"localhost:RSA", RSA_EXPIRATION); final Map<String, Instant> expected = Map.of(
assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD)); "localhost:EdDSA", EDDSA_EXPIRATION,
"localhost:RSA", RSA_EXPIRATION);
assertEquals(expected, TlsCertificateExpirationUtil.getIdentifiersAndExpirations(keyStore, KEYSTORE_PASSWORD));
}
} }
} }

View File

@@ -4,6 +4,14 @@
*/ */
package org.whispersystems.textsecuregcm.tests.util; package org.whispersystems.textsecuregcm.tests.util;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@@ -11,14 +19,8 @@ import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
public class TestWebsocketListener implements Session.Listener.AutoDemanding { public class TestWebsocketListener implements WebSocketListener {
private final AtomicLong requestId = new AtomicLong(); private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>(); private final CompletableFuture<Session> started = new CompletableFuture<>();
@@ -32,7 +34,7 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding {
@Override @Override
public void onWebSocketOpen(final Session session) { public void onWebSocketConnect(final Session session) {
started.complete(session); started.complete(session);
} }
@@ -61,15 +63,19 @@ public class TestWebsocketListener implements Session.Listener.AutoDemanding {
responseFutures.put(id, future); responseFutures.put(id, future);
final byte[] requestBytes = messageFactory.createRequest( final byte[] requestBytes = messageFactory.createRequest(
Optional.of(id), verb, requestPath, headers, body).toByteArray(); Optional.of(id), verb, requestPath, headers, body).toByteArray();
session.sendBinary(ByteBuffer.wrap(requestBytes), Callback.NOOP); try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
return future; return future;
}); });
} }
@Override @Override
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try { try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId()) responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage()); .complete(webSocketMessage.getResponseMessage());

View File

@@ -6,9 +6,12 @@
package org.whispersystems.textsecuregcm.util.jetty; package org.whispersystems.textsecuregcm.util.jetty;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URI; import java.net.URI;
import java.nio.file.Path; import java.nio.channels.ReadableByteChannel;
import java.util.Base64; import java.util.Base64;
import org.eclipse.jetty.util.resource.Resource; import org.eclipse.jetty.util.resource.Resource;
@@ -27,15 +30,14 @@ public class TestResource extends Resource {
} }
@Override @Override
public Path getPath() { public boolean isContainedIn(final Resource r) throws MalformedURLException {
return null; return false;
} }
@Override @Override
public InputStream newInputStream() { public void close() {
return new ByteArrayInputStream(data);
}
}
@Override @Override
public boolean exists() { public boolean exists() {
@@ -48,13 +50,13 @@ public class TestResource extends Resource {
} }
@Override @Override
public boolean isReadable() { public long lastModified() {
return true; return 0;
} }
@Override @Override
public long length() { public long length() {
return data.length; return 0;
} }
@Override @Override
@@ -62,19 +64,43 @@ public class TestResource extends Resource {
return null; return null;
} }
@Override
public File getFile() throws IOException {
return null;
}
@Override @Override
public String getName() { public String getName() {
return name; return name;
} }
@Override @Override
public String getFileName() { public InputStream getInputStream() throws IOException {
return ""; return new ByteArrayInputStream(data);
} }
@Override @Override
public Resource resolve(final String subUriPath) { public ReadableByteChannel getReadableByteChannel() throws IOException {
return null; return null;
} }
@Override
public boolean delete() throws SecurityException {
return false;
}
@Override
public boolean renameTo(final Resource dest) throws SecurityException {
return false;
}
@Override
public String[] list() {
return new String[]{name};
}
@Override
public Resource addPath(final String path) throws IOException, MalformedURLException {
return this;
}
} }

View File

@@ -39,9 +39,10 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.server.ServerProperties;
@@ -142,12 +143,12 @@ class LoggingUnhandledExceptionMapperTest {
WebSocketResourceProvider<TestPrincipal> provider = createWebsocketProvider(userAgentHeader, session, WebSocketResourceProvider<TestPrincipal> provider = createWebsocketProvider(userAgentHeader, session,
responseFuture::complete); responseFuture::complete);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory() byte[] message = new ProtobufWebSocketMessageFactory()
.createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray(); .createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
responseFuture.get(1, TimeUnit.SECONDS); responseFuture.get(1, TimeUnit.SECONDS);
@@ -178,13 +179,15 @@ class LoggingUnhandledExceptionMapperTest {
TestPrincipal.authenticatedTestPrincipal("foo"), TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
doAnswer(answer -> { doAnswer(answer -> {
responseHandler.accept(answer.getArgument(0, ByteBuffer.class)); responseHandler.accept(answer.getArgument(0, ByteBuffer.class));
return null; return null;
}).when(session).sendBinary(any(ByteBuffer.class), any(Callback.class)); }).when(remoteEndpoint).sendBytes(any(), any(WriteCallback.class));
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn(userAgentHeader); when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn(userAgentHeader);
when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of(userAgentHeader))); when(request.getHeaders()).thenReturn(Map.of(HttpHeaders.USER_AGENT, List.of(userAgentHeader)));

View File

@@ -19,7 +19,7 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
@@ -44,7 +44,7 @@ class WebSocketAccountAuthenticatorTest {
private AccountAuthenticator accountAuthenticator; private AccountAuthenticator accountAuthenticator;
private JettyServerUpgradeRequest upgradeRequest; private UpgradeRequest upgradeRequest;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
@@ -56,7 +56,7 @@ class WebSocketAccountAuthenticatorTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty()); .thenReturn(Optional.empty());
upgradeRequest = mock(JettyServerUpgradeRequest.class); upgradeRequest = mock(UpgradeRequest.class);
} }
@ParameterizedTest @ParameterizedTest

View File

@@ -13,19 +13,15 @@
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.eclipse.jetty.websocket</groupId> <groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-websocket-jetty-api</artifactId> <artifactId>websocket-jetty-api</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.eclipse.jetty.websocket</groupId> <groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-websocket-jetty-server</artifactId> <artifactId>websocket-jetty-server</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.eclipse.jetty.ee10.websocket</groupId> <groupId>org.eclipse.jetty.websocket</groupId>
<artifactId>jetty-ee10-websocket-jetty-server</artifactId> <artifactId>websocket-servlet</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty.ee10.websocket</groupId>
<artifactId>jetty-ee10-websocket-servlet</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>io.dropwizard</groupId> <groupId>io.dropwizard</groupId>

View File

@@ -12,8 +12,9 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.exceptions.WebSocketException; import org.eclipse.jetty.websocket.api.exceptions.WebSocketException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -28,13 +29,15 @@ public class WebSocketClient {
private static final SecureRandom SECURE_RANDOM = new SecureRandom(); private static final SecureRandom SECURE_RANDOM = new SecureRandom();
private final Session session; private final Session session;
private final RemoteEndpoint remoteEndpoint;
private final WebSocketMessageFactory messageFactory; private final WebSocketMessageFactory messageFactory;
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper; private final Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper;
private final Instant created; private final Instant created;
public WebSocketClient(Session session, WebSocketMessageFactory messageFactory, public WebSocketClient(Session session, RemoteEndpoint remoteEndpoint, WebSocketMessageFactory messageFactory,
Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) { Map<Long, CompletableFuture<WebSocketResponseMessage>> pendingRequestMapper) {
this.session = session; this.session = session;
this.remoteEndpoint = remoteEndpoint;
this.messageFactory = messageFactory; this.messageFactory = messageFactory;
this.pendingRequestMapper = pendingRequestMapper; this.pendingRequestMapper = pendingRequestMapper;
this.created = Instant.now(); this.created = Instant.now();
@@ -52,9 +55,9 @@ public class WebSocketClient {
WebSocketMessage requestMessage = messageFactory.createRequest(Optional.of(requestId), verb, path, headers, body); WebSocketMessage requestMessage = messageFactory.createRequest(Optional.of(requestId), verb, path, headers, body);
try { try {
session.sendBinary(ByteBuffer.wrap(requestMessage.toByteArray()), new Callback() { remoteEndpoint.sendBytes(ByteBuffer.wrap(requestMessage.toByteArray()), new WriteCallback() {
@Override @Override
public void fail(Throwable x) { public void writeFailed(Throwable x) {
logger.debug("Write failed", x); logger.debug("Write failed", x);
pendingRequestMapper.remove(requestId); pendingRequestMapper.remove(requestId);
future.completeExceptionally(x); future.completeExceptionally(x);
@@ -82,9 +85,9 @@ public class WebSocketClient {
} }
public void close(final int code, final String message) { public void close(final int code, final String message) {
session.close(code, message, new Callback() { session.close(code, message, new WriteCallback() {
@Override @Override
public void fail(final Throwable throwable) { public void writeFailed(final Throwable throwable) {
try { try {
session.disconnect(); session.disconnect();
} catch (final Exception e) { } catch (final Exception e) {

View File

@@ -24,8 +24,10 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.exceptions.MessageTooLargeException; import org.eclipse.jetty.websocket.api.exceptions.MessageTooLargeException;
import org.glassfish.jersey.internal.MapPropertiesDelegate; import org.glassfish.jersey.internal.MapPropertiesDelegate;
import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ApplicationHandler;
@@ -45,7 +47,7 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketResourceProvider<T extends Principal> implements Session.Listener.AutoDemanding { public class WebSocketResourceProvider<T extends Principal> implements WebSocketListener {
/** /**
* A static exception instance passed to outstanding requests (via {@code completeExceptionally} in * A static exception instance passed to outstanding requests (via {@code completeExceptionally} in
@@ -66,6 +68,7 @@ public class WebSocketResourceProvider<T extends Principal> implements Session.L
private final String remoteAddressPropertyName; private final String remoteAddressPropertyName;
private Session session; private Session session;
private RemoteEndpoint remoteEndpoint;
private WebSocketSessionContext context; private WebSocketSessionContext context;
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade"); private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
@@ -89,10 +92,11 @@ public class WebSocketResourceProvider<T extends Principal> implements Session.L
} }
@Override @Override
public void onWebSocketOpen(Session session) { public void onWebSocketConnect(Session session) {
this.session = session; this.session = session;
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext( this.context = new WebSocketSessionContext(
new WebSocketClient(session, messageFactory, requestMap)); new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(reusableAuth.orElse(null)); this.context.setAuthenticated(reusableAuth.orElse(null));
this.session.setIdleTimeout(idleTimeout); this.session.setIdleTimeout(idleTimeout);
@@ -117,9 +121,9 @@ public class WebSocketResourceProvider<T extends Principal> implements Session.L
} }
@Override @Override
public void onWebSocketBinary(final ByteBuffer payload, final Callback callback) { public void onWebSocketBinary(byte[] payload, int offset, int length) {
try { try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload); WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
switch (webSocketMessage.getType()) { switch (webSocketMessage.getType()) {
case REQUEST_MESSAGE: case REQUEST_MESSAGE:
@@ -254,7 +258,7 @@ public class WebSocketResourceProvider<T extends Principal> implements Session.L
} }
private void close(Session session, int status, String message) { private void close(Session session, int status, String message) {
session.close(status, message, Callback.NOOP); session.close(status, message);
} }
private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response, private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response,
@@ -273,7 +277,7 @@ public class WebSocketResourceProvider<T extends Principal> implements Session.L
Optional.ofNullable(body)) Optional.ofNullable(body))
.toByteArray(); .toByteArray();
session.sendBinary(ByteBuffer.wrap(responseBytes), Callback.NOOP); remoteEndpoint.sendBytes(ByteBuffer.wrap(responseBytes), WriteCallback.NOOP);
} }
} }
@@ -285,7 +289,7 @@ public class WebSocketResourceProvider<T extends Principal> implements Session.L
getHeaderList(error.getStringHeaders()), getHeaderList(error.getStringHeaders()),
Optional.empty()); Optional.empty());
session.sendBinary(ByteBuffer.wrap(response.toByteArray()), Callback.NOOP); remoteEndpoint.sendBytes(ByteBuffer.wrap(response.toByteArray()), WriteCallback.NOOP);
} }
} }

View File

@@ -13,9 +13,11 @@ import java.security.Principal;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.ee10.websocket.server.JettyWebSocketCreator; import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
import org.eclipse.jetty.websocket.server.JettyWebSocketServlet;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.glassfish.jersey.CommonProperties; import org.glassfish.jersey.CommonProperties;
import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ApplicationHandler;
import org.slf4j.Logger; import org.slf4j.Logger;
@@ -23,20 +25,23 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactory<T extends Principal> implements JettyWebSocketCreator { public class WebSocketResourceProviderFactory<T extends Principal> extends JettyWebSocketServlet implements
JettyWebSocketCreator {
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class); private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class);
private final WebSocketEnvironment<T> environment; private final WebSocketEnvironment<T> environment;
private final ApplicationHandler jerseyApplicationHandler; private final ApplicationHandler jerseyApplicationHandler;
private final WebSocketConfiguration configuration;
private final String remoteAddressPropertyName; private final String remoteAddressPropertyName;
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass, public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
String remoteAddressPropertyName) { WebSocketConfiguration configuration, String remoteAddressPropertyName) {
this.environment = environment; this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
@@ -50,6 +55,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> implements Je
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;
this.remoteAddressPropertyName = remoteAddressPropertyName; this.remoteAddressPropertyName = remoteAddressPropertyName;
} }
@@ -83,7 +89,6 @@ public class WebSocketResourceProviderFactory<T extends Principal> implements Je
// Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database). // Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database).
// If that happens, we don't want to incorrectly tell clients that they provided bad credentials. // If that happens, we don't want to incorrectly tell clients that they provided bad credentials.
logger.warn("Authentication failure", e); logger.warn("Authentication failure", e);
try { try {
response.sendError(500, "Failure"); response.sendError(500, "Failure");
} catch (final IOException ignored) { } catch (final IOException ignored) {
@@ -92,6 +97,13 @@ public class WebSocketResourceProviderFactory<T extends Principal> implements Je
} }
} }
@Override
public void configure(JettyWebSocketServletFactory factory) {
factory.setCreator(this);
factory.setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize());
factory.setMaxTextMessageSize(configuration.getMaxTextMessageSize());
}
private String getRemoteAddress(JettyServerUpgradeRequest request) { private String getRemoteAddress(JettyServerUpgradeRequest request) {
final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName); final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName);
if (StringUtils.isBlank(remoteAddress)) { if (StringUtils.isBlank(remoteAddress)) {

View File

@@ -7,8 +7,8 @@ package org.whispersystems.websocket.auth;
import java.security.Principal; import java.security.Principal;
import java.util.Optional; import java.util.Optional;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
public interface AuthenticatedWebSocketUpgradeFilter<T extends Principal> { public interface AuthenticatedWebSocketUpgradeFilter<T extends Principal> {

View File

@@ -6,7 +6,7 @@ package org.whispersystems.websocket.auth;
import java.security.Principal; import java.security.Principal;
import java.util.Optional; import java.util.Optional;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
public interface WebSocketAuthenticator<T extends Principal> { public interface WebSocketAuthenticator<T extends Principal> {
@@ -20,5 +20,5 @@ public interface WebSocketAuthenticator<T extends Principal> {
* *
* @throws InvalidCredentialsException if credentials were provided, but could not be authenticated * @throws InvalidCredentialsException if credentials were provided, but could not be authenticated
*/ */
Optional<T> authenticate(JettyServerUpgradeRequest request) throws InvalidCredentialsException; Optional<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
} }

View File

@@ -5,23 +5,22 @@
package org.whispersystems.websocket.messages; package org.whispersystems.websocket.messages;
import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public interface WebSocketMessageFactory { public interface WebSocketMessageFactory {
WebSocketMessage parseMessage(ByteBuffer serialized) public WebSocketMessage parseMessage(byte[] serialized, int offset, int len)
throws InvalidMessageException; throws InvalidMessageException;
WebSocketMessage createRequest(Optional<Long> requestId, public WebSocketMessage createRequest(Optional<Long> requestId,
String verb, String path, String verb, String path,
List<String> headers, List<String> headers,
Optional<byte[]> body); Optional<byte[]> body);
WebSocketMessage createResponse(long requestId, int status, String message, public WebSocketMessage createResponse(long requestId, int status, String message,
List<String> headers, List<String> headers,
Optional<byte[]> body); Optional<byte[]> body);
} }

View File

@@ -10,15 +10,14 @@ import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketRequestMessage; import org.whispersystems.websocket.messages.WebSocketRequestMessage;
import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import java.nio.ByteBuffer;
public class ProtobufWebSocketMessage implements WebSocketMessage { public class ProtobufWebSocketMessage implements WebSocketMessage {
private final SubProtocol.WebSocketMessage message; private final SubProtocol.WebSocketMessage message;
ProtobufWebSocketMessage(ByteBuffer buffer) throws InvalidMessageException { ProtobufWebSocketMessage(byte[] buffer, int offset, int length) throws InvalidMessageException {
try { try {
this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer)); this.message = SubProtocol.WebSocketMessage.parseFrom(ByteString.copyFrom(buffer, offset, length));
if (getType() == Type.REQUEST_MESSAGE) { if (getType() == Type.REQUEST_MESSAGE) {
if (!message.getRequest().hasVerb() || !message.getRequest().hasPath()) { if (!message.getRequest().hasVerb() || !message.getRequest().hasPath()) {

View File

@@ -9,17 +9,16 @@ import org.whispersystems.websocket.messages.InvalidMessageException;
import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory; import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
public class ProtobufWebSocketMessageFactory implements WebSocketMessageFactory { public class ProtobufWebSocketMessageFactory implements WebSocketMessageFactory {
@Override @Override
public WebSocketMessage parseMessage(ByteBuffer serialized) public WebSocketMessage parseMessage(byte[] serialized, int offset, int len)
throws InvalidMessageException throws InvalidMessageException
{ {
return new ProtobufWebSocketMessage(serialized); return new ProtobufWebSocketMessage(serialized, offset, len);
} }
@Override @Override

View File

@@ -19,15 +19,17 @@ import java.io.IOException;
import java.security.Principal; import java.security.Principal;
import java.util.Optional; import java.util.Optional;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.InvalidCredentialsException; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactoryTest { public class WebSocketResourceProviderFactoryTest {
@@ -58,7 +60,8 @@ public class WebSocketResourceProviderFactoryTest {
when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException()); when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException());
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class, REMOTE_ADDRESS_PROPERTY_NAME); WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response); Object connection = factory.createWebSocket(request, response);
assertNull(connection); assertNull(connection);
@@ -77,15 +80,16 @@ public class WebSocketResourceProviderFactoryTest {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
when(request.getHttpServletRequest()).thenReturn(httpServletRequest); when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class, WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
REMOTE_ADDRESS_PROPERTY_NAME); mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response); Object connection = factory.createWebSocket(request, response);
assertNotNull(connection); assertNotNull(connection);
verifyNoMoreInteractions(response); verifyNoMoreInteractions(response);
verify(authenticator).authenticate(eq(request)); verify(authenticator).authenticate(eq(request));
((WebSocketResourceProvider<?>) connection).onWebSocketOpen(mock(Session.class)); ((WebSocketResourceProvider<?>) connection).onWebSocketConnect(mock(Session.class));
assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated()); assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account); assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
@@ -99,6 +103,7 @@ public class WebSocketResourceProviderFactoryTest {
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment, WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class, Account.class,
mock(WebSocketConfiguration.class),
REMOTE_ADDRESS_PROPERTY_NAME); REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response); Object connection = factory.createWebSocket(request, response);
@@ -107,6 +112,20 @@ public class WebSocketResourceProviderFactoryTest {
verify(authenticator).authenticate(eq(request)); verify(authenticator).authenticate(eq(request));
} }
@Test
void testConfigure() {
JettyWebSocketServletFactory servletFactory = mock(JettyWebSocketServletFactory.class);
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class),
REMOTE_ADDRESS_PROPERTY_NAME);
factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory));
}
@Test @Test
void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException { void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException {
final Account account = new Account(); final Account account = new Account();
@@ -118,11 +137,12 @@ public class WebSocketResourceProviderFactoryTest {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
when(request.getHttpServletRequest()).thenReturn(httpServletRequest); when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
final AuthenticatedWebSocketUpgradeFilter<Account> filter = mock(AuthenticatedWebSocketUpgradeFilter.class); final AuthenticatedWebSocketUpgradeFilter<Account> filter = mock(AuthenticatedWebSocketUpgradeFilter.class);
when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter); when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter);
final WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class, final WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
REMOTE_ADDRESS_PROPERTY_NAME); mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
assertNotNull(factory.createWebSocket(request, response)); assertNotNull(factory.createWebSocket(request, response));
verify(filter).handleAuthentication(reusableAuth, request, response); verify(filter).handleAuthentication(reusableAuth, request, response);

View File

@@ -47,9 +47,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.eclipse.jetty.websocket.api.Callback; import org.eclipse.jetty.websocket.api.CloseStatus;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.ContainerResponse; import org.glassfish.jersey.server.ContainerResponse;
@@ -88,10 +90,11 @@ class WebSocketResourceProviderTest {
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close(); verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
ArgumentCaptor<WebSocketSessionContext> contextArgumentCaptor = ArgumentCaptor.forClass( ArgumentCaptor<WebSocketSessionContext> contextArgumentCaptor = ArgumentCaptor.forClass(
WebSocketSessionContext.class); WebSocketSessionContext.class);
@@ -109,9 +112,11 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
ContainerResponse response = mock(ContainerResponse.class); ContainerResponse response = mock(ContainerResponse.class);
when(response.getStatus()).thenReturn(200); when(response.getStatus()).thenReturn(200);
@@ -141,15 +146,16 @@ class WebSocketResourceProviderTest {
return CompletableFuture.completedFuture(response); return CompletableFuture.completedFuture(response);
}); });
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close(); verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar",
new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class);
ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
@@ -163,7 +169,7 @@ class WebSocketResourceProviderTest {
assertThat(bundledRequest.getPath(false)).isEqualTo("bar"); assertThat(bundledRequest.getPath(false)).isEqualTo("bar");
verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response)); verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response));
verify(session).sendBinary(responseCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom( SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(
responseCaptor.getValue().array()); responseCaptor.getValue().array());
@@ -183,22 +189,25 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn( when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn(
CompletableFuture.failedFuture(new IllegalStateException("foo"))); CompletableFuture.failedFuture(new IllegalStateException("foo")));
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
verify(session, never()).close(anyInt(), anyString(), any(Callback.class)); verify(session, never()).close(anyInt(), anyString());
verify(session, never()).close(); verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar",
new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); ArgumentCaptor<ContainerRequest> requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class);
@@ -212,7 +221,7 @@ class WebSocketResourceProviderTest {
ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom( SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(
responseCaptor.getValue().array()); responseCaptor.getValue().array());
@@ -236,20 +245,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -274,20 +285,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET",
"/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray(); "/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -312,20 +325,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -350,20 +365,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -387,20 +404,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -425,20 +444,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -463,21 +484,23 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/some/testparam", List.of("Content-Type: application/json"), "/v1/test/some/testparam", List.of("Content-Type: application/json"),
Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray(); Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -502,21 +525,23 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/some/testparam", List.of("Content-Type: application/json"), "/v1/test/some/testparam", List.of("Content-Type: application/json"),
Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray(); Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -542,20 +567,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET",
"/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray(); "/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);
@@ -580,20 +607,22 @@ class WebSocketResourceProviderTest {
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
UpgradeRequest request = mock(UpgradeRequest.class); UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getUpgradeRequest()).thenReturn(request); when(session.getUpgradeRequest()).thenReturn(request);
when(session.getRemote()).thenReturn(remoteEndpoint);
provider.onWebSocketOpen(session); provider.onWebSocketConnect(session);
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive", byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive",
new LinkedList<>(), Optional.empty()).toByteArray(); new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(message), Callback.NOOP); provider.onWebSocketBinary(message, 0, message.length);
ArgumentCaptor<ByteBuffer> requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session).sendBinary(requestCaptor.capture(), any(Callback.class)); verify(remoteEndpoint).sendBytes(requestCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor); SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor);
assertThat(requestMessage.getVerb()).isEqualTo("GET"); assertThat(requestMessage.getVerb()).isEqualTo("GET");
@@ -603,11 +632,11 @@ class WebSocketResourceProviderTest {
byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK", byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK",
new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray(); new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray();
provider.onWebSocketBinary(ByteBuffer.wrap(clientResponse), Callback.NOOP); provider.onWebSocketBinary(clientResponse, 0, clientResponse.length);
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(session, times(2)).sendBinary(responseBytesCaptor.capture(), any(Callback.class)); verify(remoteEndpoint, times(2)).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor);