Compare commits

...

37 Commits

Author SHA1 Message Date
gram-signal
b17f41c3e8 Check if dashes work in dynamic configuration keys. 2022-08-29 15:51:37 -06:00
gram-signal
08db4ba54b Update authentication to use HKDF_SHA256. 2022-08-29 14:20:47 -06:00
gram-signal
cb6cc39679 Ignore null identity key. 2022-08-29 13:26:49 -06:00
Jon Chambers
b6bf6c994c Remove a spurious @Nullable annotation 2022-08-26 15:22:23 -04:00
Jon Chambers
3bb4709563 Add CLDR region as a dimension 2022-08-26 12:41:51 -04:00
Jon Chambers
b280c768a4 Allow signup captchas to target CLDR two-letter region codes 2022-08-26 12:41:51 -04:00
Chris Eager
d23e89fb9c Update micrometer to 1.9.3 2022-08-25 13:46:36 -07:00
Chris Eager
3a27bd0318 Update test dependencies 2022-08-25 13:40:46 -07:00
Chris Eager
616513edaf Remove unused jdbi dependency 2022-08-25 13:40:46 -07:00
Chris Eager
09a51020e9 Update stripe-java to 21.2.0 2022-08-25 13:40:46 -07:00
Chris Eager
cb8cb94d1a Update aws java v1 SDK to 1.12.287 2022-08-25 13:40:46 -07:00
Chris Eager
2440dc0089 Update netty to 4.1.79.Final 2022-08-25 13:40:46 -07:00
Chris Eager
2336eef333 Update aws java v2 SDK to 2.17.258 2022-08-25 13:40:46 -07:00
Chris Eager
a0e948627c Update jackson to 2.13.3 2022-08-25 13:40:46 -07:00
Chris Eager
88159af588 Update dropwizard to 2.0.32 2022-08-25 13:40:46 -07:00
Chris Eager
38b77bb550 Update libphonenumber to 8.12.54 2022-08-25 13:40:32 -07:00
Jon Chambers
e72d1d0b6f Stop reading attribute-based messages from the messages table 2022-08-22 13:37:39 -07:00
Ravi Khadiwala
1891622e69 Zero-pad discriminators less than initial width 2022-08-22 13:36:38 -07:00
Chris Eager
628a112b38 Include country code for verify failure 2022-08-19 12:21:05 -07:00
Jon Chambers
50f5d760c9 Use existing tagging tools for keepalive counters 2022-08-16 13:16:19 -07:00
Jon Chambers
7292a88ea3 Record table performance metrics around reported messages 2022-08-16 13:15:30 -07:00
Jon Chambers
07cb3ab576 Add a "sealed sender" dimension to the sent message counter 2022-08-16 13:11:12 -07:00
Chris Eager
27b749abbd Filter expired items from Dynamo 2022-08-16 13:09:47 -07:00
Chris Eager
27f67a077c Add metrics for report-verification-succeeded response 2022-08-16 13:08:16 -07:00
Ravi Khadiwala
393e15815b Rename secondary account key namespace for usernames 2022-08-15 10:51:52 -05:00
Ravi Khadiwala
a7f1cd25b9 Remove UAK normalization code
All accounts now have UAKs in top-level attributes
2022-08-15 10:47:52 -05:00
Ravi Khadiwala
953cd2ae0c Revert "Delete any leftover usernames in the accounts db"
This reverts commit a44c18e9b7.

Old username cleanup is finished.
2022-08-15 10:45:38 -05:00
ravi-signal
a84a7dbc3d Add support for generating discriminators
- adds `PUT accounts/username` endpoint
- adds `GET accounts/username/{username}` to lookup aci by username
- deletes `PUT accounts/username/{username}`, `GET profile/username/{username}`
- adds randomized discriminator generation
2022-08-15 10:44:36 -05:00
Chris Eager
24d01f1ab2 Revert "device capabilities: prevent stories downgrade"
This reverts commit 1c67233eb0.
2022-08-12 14:21:27 -05:00
Chris Eager
06eb890761 Improve e164 normalization check by re-parsing without country code 2022-08-12 10:52:55 -07:00
Chris Eager
6d0345d327 Clean up Util 2022-08-12 10:52:55 -07:00
Chris Eager
1c67233eb0 device capabilities: prevent stories downgrade 2022-08-12 10:51:16 -07:00
Jon Chambers
b4281c5a70 Send non-urgent push notifications with lower priority 2022-08-12 11:06:31 -04:00
Jon Chambers
5f6b66dad6 Add support for scheduling background push notifications 2022-08-12 10:57:59 -04:00
Jon Chambers
c2be0af9d9 Refactor ApnPushNotificationSchedulerTest to use a Clock 2022-08-12 10:57:59 -04:00
Jon Chambers
c111e9a35a Update to the latest version of the abusive message filter 2022-08-12 10:50:53 -04:00
Jon Chambers
a53a85d788 Refactor scheduled APNs notifications in preparation for future development 2022-08-12 10:47:49 -04:00
88 changed files with 2626 additions and 1774 deletions

22
pom.xml
View File

@@ -41,31 +41,31 @@
</modules>
<properties>
<aws.sdk.version>1.12.154</aws.sdk.version>
<aws.sdk2.version>2.17.125</aws.sdk2.version>
<aws.sdk.version>1.12.287</aws.sdk.version>
<aws.sdk2.version>2.17.258</aws.sdk2.version>
<commons-codec.version>1.15</commons-codec.version>
<commons-csv.version>1.8</commons-csv.version>
<commons-io.version>2.9.0</commons-io.version>
<dropwizard.version>2.0.28</dropwizard.version>
<dropwizard.version>2.0.32</dropwizard.version>
<dropwizard-metrics-datadog.version>1.1.13</dropwizard-metrics-datadog.version>
<gson.version>2.9.0</gson.version>
<guava.version>30.1.1-jre</guava.version>
<jackson.version>2.13.2.20220328</jackson.version>
<jackson.version>2.13.3</jackson.version>
<jaxb.version>2.3.1</jaxb.version>
<jedis.version>2.9.0</jedis.version>
<lettuce.version>6.1.9.RELEASE</lettuce.version>
<libphonenumber.version>8.12.50</libphonenumber.version>
<libphonenumber.version>8.12.54</libphonenumber.version>
<logstash.logback.version>7.0.1</logstash.logback.version>
<micrometer.version>1.5.3</micrometer.version>
<mockito.version>4.3.1</mockito.version>
<netty.version>4.1.65.Final</netty.version>
<micrometer.version>1.9.3</micrometer.version>
<mockito.version>4.7.0</mockito.version>
<netty.version>4.1.79.Final</netty.version>
<opentest4j.version>1.2.0</opentest4j.version>
<protobuf.version>3.19.4</protobuf.version>
<pushy.version>0.15.1</pushy.version>
<resilience4j.version>1.5.0</resilience4j.version>
<semver4j.version>3.1.0</semver4j.version>
<slf4j.version>1.7.30</slf4j.version>
<stripe.version>20.79.0</stripe.version>
<stripe.version>21.2.0</stripe.version>
<vavr.version>0.10.4</vavr.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
@@ -80,7 +80,7 @@
<dependency>
<groupId>com.fasterxml.jackson</groupId>
<artifactId>jackson-bom</artifactId>
<version>2.13.2.20220328</version>
<version>2.13.3</version>
<type>pom</type>
<scope>import</scope>
</dependency>
@@ -296,7 +296,7 @@
<dependency>
<groupId>com.github.tomakehurst</groupId>
<artifactId>wiremock-jre8</artifactId>
<version>2.32.0</version>
<version>2.33.2</version>
<scope>test</scope>
<exclusions>
<exclusion>

View File

@@ -43,10 +43,6 @@
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-core</artifactId>
</dependency>
<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-jdbi3</artifactId>
</dependency>
<dependency>
<groupId>io.dropwizard</groupId>
<artifactId>dropwizard-auth</artifactId>
@@ -118,19 +114,10 @@
<artifactId>logstash-logback-encoder</artifactId>
</dependency>
<dependency>
<groupId>org.jdbi</groupId>
<artifactId>jdbi3-core</artifactId>
</dependency>
<dependency>
<groupId>io.dropwizard.metrics</groupId>
<artifactId>metrics-core</artifactId>
</dependency>
<dependency>
<groupId>io.dropwizard.metrics</groupId>
<artifactId>metrics-jdbi3</artifactId>
</dependency>
<dependency>
<groupId>io.dropwizard.metrics</groupId>
<artifactId>metrics-healthchecks</artifactId>
@@ -428,14 +415,8 @@
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>DynamoDBLocal</artifactId>
<version>1.16.0</version>
<version>1.17.2</version>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>

View File

@@ -45,6 +45,7 @@ import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration;
import org.whispersystems.textsecuregcm.configuration.TestDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.TwilioConfiguration;
import org.whispersystems.textsecuregcm.configuration.UnidentifiedDeliveryConfiguration;
import org.whispersystems.textsecuregcm.configuration.UsernameConfiguration;
import org.whispersystems.textsecuregcm.configuration.VoiceVerificationConfiguration;
import org.whispersystems.textsecuregcm.configuration.ZkConfig;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@@ -247,6 +248,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private ReportMessageConfiguration reportMessage = new ReportMessageConfiguration();
@Valid
@NotNull
@JsonProperty
private UsernameConfiguration username = new UsernameConfiguration();
@Valid
@JsonProperty
private AbusiveMessageFilterConfiguration abusiveMessageFilter;
@@ -424,4 +430,8 @@ public class WhisperServerConfiguration extends Configuration {
public AbusiveMessageFilterConfiguration getAbusiveMessageFilterConfiguration() {
return abusiveMessageFilter;
}
public UsernameConfiguration getUsername() {
return username;
}
}

View File

@@ -136,14 +136,14 @@ import org.whispersystems.textsecuregcm.metrics.MicrometerRegistryManager;
import org.whispersystems.textsecuregcm.metrics.NetworkReceivedGauge;
import org.whispersystems.textsecuregcm.metrics.NetworkSentGauge;
import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener;
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisClusterHealthCheck;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.MessageSender;
@@ -185,7 +185,6 @@ import org.whispersystems.textsecuregcm.storage.MessagesCache;
import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.NonNormalizedAccountCrawlerListener;
import org.whispersystems.textsecuregcm.storage.UsernameCleaner;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
@@ -205,6 +204,7 @@ import org.whispersystems.textsecuregcm.stripe.StripeManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.HostnameUtil;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
@@ -443,19 +443,21 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MessagesCache messagesCache = new MessagesCache(messagesCluster, messagesCluster, keyspaceNotificationDispatchExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster, dynamicConfigurationManager);
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster, config.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, pushLatencyManager, reportMessageManager);
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager);
UsernameGenerator usernameGenerator = new UsernameGenerator(config.getUsername());
DeletedAccountsManager deletedAccountsManager = new DeletedAccountsManager(deletedAccounts,
deletedAccountsLockDynamoDbClient, config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
deletedAccountsManager, directoryQueue, keys, messagesManager, reservedUsernames, profilesManager,
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, clock);
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, usernameGenerator,
experimentEnrollmentManager, clock);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
DispatchManager dispatchManager = new DispatchManager(pubSubClientFactory, Optional.empty());
PubSubManager pubSubManager = new PubSubManager(pubsubClient, dispatchManager);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushSchedulerCluster, apnSender, accountsManager);
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnFallbackManager);
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), rateLimitersCluster);
DynamicRateLimiters dynamicRateLimiters = new DynamicRateLimiters(rateLimitersCluster, dynamicConfigurationManager);
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
@@ -475,8 +477,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(accountsManager);
reportMessageManager.addListener(reportedMessageMetricsListener);
AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(accountsManager);
AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager, experimentEnrollmentManager);
DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(accountsManager, experimentEnrollmentManager);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration(), dynamicConfigurationManager);
SmsSender smsSender = new SmsSender(twilioSmsSender);
@@ -532,17 +534,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getAccountDatabaseCrawlerConfiguration().getChunkIntervalMs()
);
AccountDatabaseCrawlerCache usernameCleanerAccountDatabaseCrawlerCache =
new AccountDatabaseCrawlerCache(cacheCluster, AccountDatabaseCrawlerCache.USERNAME_CLEANER_PREFIX);
AccountDatabaseCrawler usernameCleanerAccountDatabaseCrawler = new AccountDatabaseCrawler("username cleaner crawler",
accountsManager,
usernameCleanerAccountDatabaseCrawlerCache,
List.of(new UsernameCleaner(accountsManager)),
config.getAccountDatabaseCrawlerConfiguration().getChunkSize(),
config.getAccountDatabaseCrawlerConfiguration().getChunkIntervalMs()
);
// TODO listeners must be ordered so that ones that directly update accounts come last, so that read-only ones are not working with stale data
final List<AccountDatabaseCrawlerListener> accountDatabaseCrawlerListeners = List.of(
new NonNormalizedAccountCrawlerListener(accountsManager, metricsCluster),
@@ -567,12 +558,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
CurrencyConversionManager currencyManager = new CurrencyConversionManager(fixerClient, ftxClient, config.getPaymentsServiceConfiguration().getPaymentCurrencies());
environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(apnFallbackManager);
environment.lifecycle().manage(apnPushNotificationScheduler);
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(accountDatabaseCrawler);
environment.lifecycle().manage(directoryReconciliationAccountDatabaseCrawler);
environment.lifecycle().manage(accountCleanerAccountDatabaseCrawler);
environment.lifecycle().manage(usernameCleanerAccountDatabaseCrawler);
environment.lifecycle().manage(deletedAccountsTableCrawler);
environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(messagePersister);
@@ -625,7 +615,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getWebSocketConfiguration(), 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager, apnFallbackManager,
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager,
clientPresenceManager, websocketScheduledExecutor));
webSocketEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));
webSocketEnvironment.jersey().register(new ContentLengthFilter(TrafficSource.WEBSOCKET));
@@ -652,8 +642,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
ReceiptCredentialPresentation::new, stripeExecutor, config.getDonationConfiguration(), config.getStripe()),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager,
messagesManager, apnFallbackManager, reportMessageManager, multiRecipientMessageExecutor),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager, profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager),

View File

@@ -7,13 +7,14 @@ package org.whispersystems.textsecuregcm.auth;
import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class AccountAuthenticator extends BaseAccountAuthenticator implements
Authenticator<BasicCredentials, AuthenticatedAccount> {
public AccountAuthenticator(AccountsManager accountsManager) {
super(accountsManager);
public AccountAuthenticator(AccountsManager accountsManager, ExperimentEnrollmentManager enrollmentManager) {
super(accountsManager, enrollmentManager);
}
@Override

View File

@@ -4,7 +4,9 @@
*/
package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import org.apache.commons.codec.binary.Hex;
import org.signal.libsignal.protocol.kdf.HKDF;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
@@ -12,10 +14,18 @@ import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
public class AuthenticationCredentials {
private static final String V2_PREFIX = "2.";
private final String hashedAuthenticationToken;
private final String salt;
public enum Version {
V1,
V2,
}
public static final Version CURRENT_VERSION = Version.V2;
public AuthenticationCredentials(String hashedAuthenticationToken, String salt) {
this.hashedAuthenticationToken = hashedAuthenticationToken;
this.salt = salt;
@@ -23,7 +33,20 @@ public class AuthenticationCredentials {
public AuthenticationCredentials(String authenticationToken) {
this.salt = String.valueOf(Math.abs(new SecureRandom().nextInt()));
this.hashedAuthenticationToken = getHashedValue(salt, authenticationToken);
this.hashedAuthenticationToken = getV2HashedValue(salt, authenticationToken);
}
@VisibleForTesting
public AuthenticationCredentials v1ForTesting(String authenticationToken) {
String salt = String.valueOf(Math.abs(new SecureRandom().nextInt()));
return new AuthenticationCredentials(getV1HashedValue(salt, authenticationToken), salt);
}
public Version getVersion() {
if (this.hashedAuthenticationToken.startsWith(V2_PREFIX)) {
return Version.V2;
}
return Version.V1;
}
public String getHashedAuthenticationToken() {
@@ -35,11 +58,14 @@ public class AuthenticationCredentials {
}
public boolean verify(String authenticationToken) {
String theirValue = getHashedValue(salt, authenticationToken);
final String theirValue = switch (getVersion()) {
case V1 -> getV1HashedValue(salt, authenticationToken);
case V2 -> getV2HashedValue(salt, authenticationToken);
};
return MessageDigest.isEqual(theirValue.getBytes(StandardCharsets.UTF_8), this.hashedAuthenticationToken.getBytes(StandardCharsets.UTF_8));
}
private static String getHashedValue(String salt, String token) {
private static String getV1HashedValue(String salt, String token) {
try {
return new String(Hex.encodeHex(MessageDigest.getInstance("SHA1").digest((salt + token).getBytes(StandardCharsets.UTF_8))));
} catch (NoSuchAlgorithmException e) {
@@ -47,4 +73,13 @@ public class AuthenticationCredentials {
}
}
private static final byte[] AUTH_TOKEN_HKDF_INFO = "authtoken".getBytes(StandardCharsets.UTF_8);
private static String getV2HashedValue(String salt, String token) {
byte[] secret = HKDF.deriveSecrets(
token.getBytes(StandardCharsets.UTF_8), // key
salt.getBytes(StandardCharsets.UTF_8), // salt
AUTH_TOKEN_HKDF_INFO,
32);
return V2_PREFIX + Hex.encodeHexString(secret);
}
}

View File

@@ -17,6 +17,7 @@ import java.time.temporal.ChronoUnit;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -33,18 +34,21 @@ public class BaseAccountAuthenticator {
private static final String DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME = name(BaseAccountAuthenticator.class, "daysSinceLastSeen");
private static final String IS_PRIMARY_DEVICE_TAG = "isPrimary";
private static final String AUTH_V2_REWRITE_EXPERIMENT_NAME = "authv2-rewrite";
private final AccountsManager accountsManager;
private final Clock clock;
private final ExperimentEnrollmentManager enrollmentManager;
public BaseAccountAuthenticator(AccountsManager accountsManager) {
this(accountsManager, Clock.systemUTC());
public BaseAccountAuthenticator(AccountsManager accountsManager, ExperimentEnrollmentManager enrollmentManager) {
this(accountsManager, Clock.systemUTC(), enrollmentManager);
}
@VisibleForTesting
public BaseAccountAuthenticator(AccountsManager accountsManager, Clock clock) {
this.accountsManager = accountsManager;
this.clock = clock;
public BaseAccountAuthenticator(AccountsManager accountsManager, Clock clock, ExperimentEnrollmentManager enrollmentManager) {
this.accountsManager = accountsManager;
this.clock = clock;
this.enrollmentManager = enrollmentManager;
}
static Pair<String, Long> getIdentifierAndDeviceId(final String basicUsername) {
@@ -104,9 +108,17 @@ public class BaseAccountAuthenticator {
}
}
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
AuthenticationCredentials deviceAuthenticationCredentials = device.get().getAuthenticationCredentials();
if (deviceAuthenticationCredentials.verify(basicCredentials.getPassword())) {
succeeded = true;
final Account authenticatedAccount = updateLastSeen(account.get(), device.get());
Account authenticatedAccount = updateLastSeen(account.get(), device.get());
if (deviceAuthenticationCredentials.getVersion() != AuthenticationCredentials.CURRENT_VERSION
&& enrollmentManager.isEnrolled(accountUuid, AUTH_V2_REWRITE_EXPERIMENT_NAME)) {
authenticatedAccount = accountsManager.updateDeviceAuthentication(
authenticatedAccount,
device.get(),
new AuthenticationCredentials(basicCredentials.getPassword())); // new credentials have current version
}
return Optional.of(new AuthenticatedAccount(
new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
}
@@ -142,5 +154,4 @@ public class BaseAccountAuthenticator {
return account;
}
}

View File

@@ -8,13 +8,14 @@ package org.whispersystems.textsecuregcm.auth;
import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements
Authenticator<BasicCredentials, DisabledPermittedAuthenticatedAccount> {
public DisabledPermittedAccountAuthenticator(AccountsManager accountsManager) {
super(accountsManager);
public DisabledPermittedAccountAuthenticator(AccountsManager accountsManager, ExperimentEnrollmentManager enrollmentManager) {
super(accountsManager, enrollmentManager);
}
@Override

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.validation.constraints.Min;
public class UsernameConfiguration {
@JsonProperty
@Min(1)
private int discriminatorInitialWidth = 4;
@JsonProperty
@Min(1)
private int discriminatorMaxWidth = 9;
@JsonProperty
@Min(1)
private int attemptsPerWidth = 10;
public int getDiscriminatorInitialWidth() {
return discriminatorInitialWidth;
}
public int getDiscriminatorMaxWidth() {
return discriminatorMaxWidth;
}
public int getAttemptsPerWidth() {
return attemptsPerWidth;
}
}

View File

@@ -21,6 +21,14 @@ public class DynamicCaptchaConfiguration {
@NotNull
private Set<String> signupCountryCodes = Collections.emptySet();
@JsonProperty
@NotNull
private Set<String> signupRegions = Collections.emptySet();
public BigDecimal getScoreFloor() {
return scoreFloor;
}
public Set<String> getSignupCountryCodes() {
return signupCountryCodes;
}
@@ -30,7 +38,12 @@ public class DynamicCaptchaConfiguration {
this.signupCountryCodes = numbers;
}
public BigDecimal getScoreFloor() {
return scoreFloor;
@VisibleForTesting
public void setSignupRegions(final Set<String> signupRegions) {
this.signupRegions = signupRegions;
}
public Set<String> getSignupRegions() {
return signupRegions;
}
}

View File

@@ -48,10 +48,6 @@ public class DynamicConfiguration {
@Valid
private DynamicPushLatencyConfiguration pushLatency = new DynamicPushLatencyConfiguration(Collections.emptyMap());
@JsonProperty
@Valid
private DynamicUakMigrationConfiguration uakMigrationConfiguration = new DynamicUakMigrationConfiguration();
@JsonProperty
@Valid
private DynamicTurnConfiguration turn = new DynamicTurnConfiguration();
@@ -64,6 +60,10 @@ public class DynamicConfiguration {
@Valid
DynamicMessagePersisterConfiguration messagePersister = new DynamicMessagePersisterConfiguration();
@JsonProperty
@Valid
DynamicPushNotificationConfiguration pushNotifications = new DynamicPushNotificationConfiguration();
public Optional<DynamicExperimentEnrollmentConfiguration> getExperimentEnrollmentConfiguration(
final String experimentName) {
return Optional.ofNullable(experiments.get(experimentName));
@@ -111,10 +111,6 @@ public class DynamicConfiguration {
return pushLatency;
}
public DynamicUakMigrationConfiguration getUakMigrationConfiguration() {
return uakMigrationConfiguration;
}
public DynamicTurnConfiguration getTurnConfiguration() {
return turn;
}
@@ -126,4 +122,8 @@ public class DynamicConfiguration {
public DynamicMessagePersisterConfiguration getMessagePersisterConfiguration() {
return messagePersister;
}
public DynamicPushNotificationConfiguration getPushNotificationConfiguration() {
return pushNotifications;
}
}

View File

@@ -0,0 +1,18 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
public class DynamicPushNotificationConfiguration {
@JsonProperty
private boolean lowUrgencyEnabled = false;
public boolean isLowUrgencyEnabled() {
return lowUrgencyEnabled;
}
}

View File

@@ -1,19 +0,0 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
public class DynamicUakMigrationConfiguration {
@JsonProperty
private boolean enabled = true;
@JsonProperty
private int maxOutstandingNormalizes = 25;
public boolean isEnabled() {
return enabled;
}
public int getMaxOutstandingNormalizes() {
return maxOutstandingNormalizes;
}
}

View File

@@ -64,6 +64,7 @@ import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicCaptchaConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
@@ -73,6 +74,9 @@ import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.entities.UsernameRequest;
import org.whispersystems.textsecuregcm.entities.UsernameResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.PushNotification;
@@ -93,7 +97,7 @@ import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.ImpossiblePhoneNumberException;
import org.whispersystems.textsecuregcm.util.NonNormalizedPhoneNumberException;
import org.whispersystems.textsecuregcm.util.Username;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@@ -119,14 +123,17 @@ public class AccountController {
private static final String TWILIO_VERIFY_ERROR_COUNTER_NAME = name(AccountController.class, "twilioVerifyError");
private static final String INVALID_ACCEPT_LANGUAGE_COUNTER_NAME = name(AccountController.class, "invalidAcceptLanguage");
private static final String NONSTANDARD_USERNAME_COUNTER_NAME = name(AccountController.class, "nonStandardUsername");
private static final String CHALLENGE_PRESENT_TAG_NAME = "present";
private static final String CHALLENGE_MATCH_TAG_NAME = "matches";
private static final String COUNTRY_CODE_TAG_NAME = "countryCode";
private static final String REGION_TAG_NAME = "region";
private static final String VERIFICATION_TRANSPORT_TAG_NAME = "transport";
private static final String VERIFY_EXPERIMENT_TAG_NAME = "twilioVerify";
private final StoredVerificationCodeManager pendingAccounts;
private final AccountsManager accounts;
private final AbusiveHostRules abusiveHostRules;
@@ -226,11 +233,11 @@ public class AccountController {
if (requirement.isCaptchaRequired()) {
captchaRequiredMeter.mark();
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)));
Metrics.counter(CHALLENGE_ISSUED_COUNTER_NAME, tags).increment();
Metrics.counter(CHALLENGE_ISSUED_COUNTER_NAME, Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Tag.of(REGION_TAG_NAME, Util.getRegion(number))))
.increment();
if (requirement.isAutoBlock() && shouldAutoBlock(sourceHost)) {
logger.info("Auto-block: {}", sourceHost);
@@ -317,15 +324,13 @@ public class AccountController {
// TODO Remove this meter when external dependencies have been resolved
metricRegistry.meter(name(AccountController.class, "create", Util.getCountryCode(number))).mark();
{
final List<Tag> tags = new ArrayList<>();
tags.add(Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)));
tags.add(Tag.of(VERIFICATION_TRANSPORT_TAG_NAME, transport));
tags.add(UserAgentTagUtil.getPlatformTag(userAgent));
tags.add(Tag.of(VERIFY_EXPERIMENT_TAG_NAME, String.valueOf(enrolledInVerifyExperiment)));
Metrics.counter(ACCOUNT_CREATE_COUNTER_NAME, tags).increment();
}
Metrics.counter(ACCOUNT_CREATE_COUNTER_NAME, Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Tag.of(REGION_TAG_NAME, Util.getRegion(number)),
Tag.of(VERIFICATION_TRANSPORT_TAG_NAME, transport),
Tag.of(VERIFY_EXPERIMENT_TAG_NAME, String.valueOf(enrolledInVerifyExperiment))))
.increment();
return Response.ok().build();
}
@@ -358,7 +363,8 @@ public class AccountController {
}
storedVerificationCode.flatMap(StoredVerificationCode::getTwilioVerificationSid)
.ifPresent(smsSender::reportVerificationSucceeded);
.ifPresent(
verificationSid -> smsSender.reportVerificationSucceeded(verificationSid, userAgent, "registration"));
Optional<Account> existingAccount = accounts.getByE164(number);
@@ -375,16 +381,13 @@ public class AccountController {
Account account = accounts.create(number, password, signalAgent, accountAttributes,
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));
{
metricRegistry.meter(name(AccountController.class, "verify", Util.getCountryCode(number))).mark();
metricRegistry.meter(name(AccountController.class, "verify", Util.getCountryCode(number))).mark();
final List<Tag> tags = new ArrayList<>();
tags.add(Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)));
tags.add(UserAgentTagUtil.getPlatformTag(userAgent));
tags.add(Tag.of(VERIFY_EXPERIMENT_TAG_NAME, String.valueOf(storedVerificationCode.get().getTwilioVerificationSid().isPresent())));
Metrics.counter(ACCOUNT_VERIFY_COUNTER_NAME, tags).increment();
}
Metrics.counter(ACCOUNT_VERIFY_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Tag.of(REGION_TAG_NAME, Util.getRegion(number)),
Tag.of(VERIFY_EXPERIMENT_TAG_NAME, String.valueOf(storedVerificationCode.get().getTwilioVerificationSid().isPresent()))))
.increment();
return new AccountIdentityResponse(account.getUuid(),
account.getNumber(),
@@ -397,7 +400,9 @@ public class AccountController {
@PUT
@Path("/number")
@Produces(MediaType.APPLICATION_JSON)
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount, @NotNull @Valid final ChangePhoneNumberRequest request)
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount,
@NotNull @Valid final ChangePhoneNumberRequest request,
@HeaderParam("User-Agent") String userAgent)
throws RateLimitExceededException, InterruptedException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException {
if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) {
@@ -420,7 +425,8 @@ public class AccountController {
}
storedVerificationCode.flatMap(StoredVerificationCode::getTwilioVerificationSid)
.ifPresent(smsSender::reportVerificationSucceeded);
.ifPresent(
verificationSid -> smsSender.reportVerificationSucceeded(verificationSid, userAgent, "changeNumber"));
final Optional<Account> existingAccount = accounts.getByE164(number);
@@ -628,6 +634,7 @@ public class AccountController {
auth.getAccount().isStorageSupported());
}
@Timed
@DELETE
@Path("/username")
@Produces(MediaType.APPLICATION_JSON)
@@ -635,20 +642,66 @@ public class AccountController {
accounts.clearUsername(auth.getAccount());
}
@Timed
@PUT
@Path("/username/{username}")
@Path("/username")
@Produces(MediaType.APPLICATION_JSON)
public Response setUsername(@Auth AuthenticatedAccount auth, @PathParam("username") @Username String username)
throws RateLimitExceededException {
@Consumes(MediaType.APPLICATION_JSON)
public UsernameResponse setUsername(
@Auth AuthenticatedAccount auth,
@HeaderParam("X-Signal-Agent") String userAgent,
@NotNull @Valid UsernameRequest usernameRequest) throws RateLimitExceededException {
rateLimiters.getUsernameSetLimiter().validate(auth.getAccount().getUuid());
try {
accounts.setUsername(auth.getAccount(), username);
} catch (final UsernameNotAvailableException e) {
return Response.status(Response.Status.CONFLICT).build();
if (StringUtils.isNotBlank(usernameRequest.existingUsername()) &&
!UsernameGenerator.isStandardFormat(usernameRequest.existingUsername())) {
// Technically, a username may not be in the nickname#discriminator format
// if created through some out-of-band mechanism, but it is atypical.
Metrics.counter(NONSTANDARD_USERNAME_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.increment();
}
return Response.ok().build();
try {
final Account account = accounts.setUsername(auth.getAccount(), usernameRequest.nickname(),
usernameRequest.existingUsername());
return account
.getUsername()
.map(UsernameResponse::new)
.orElseThrow(() -> new IllegalStateException("Could not get username after setting"));
} catch (final UsernameNotAvailableException e) {
throw new WebApplicationException(Status.CONFLICT);
}
}
@Timed
@GET
@Path("/username/{username}")
@Produces(MediaType.APPLICATION_JSON)
public AccountIdentifierResponse lookupUsername(
@HeaderParam("X-Signal-Agent") final String userAgent,
@HeaderParam("X-Forwarded-For") final String forwardedFor,
@PathParam("username") final String username,
@Context final HttpServletRequest request) throws RateLimitExceededException {
// Disallow clients from making authenticated requests to this endpoint
if (StringUtils.isNotBlank(request.getHeader("Authorization"))) {
throw new BadRequestException();
}
if (!UsernameGenerator.isStandardFormat(username)) {
// Technically, a username may not be in the nickname#discriminator format
// if created through some out-of-band mechanism, but it is atypical.
Metrics.counter(NONSTANDARD_USERNAME_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.increment();
}
rateLimitByClientIp(rateLimiters.getUsernameLookupLimiter(), forwardedFor);
return accounts
.getByUsername(username)
.map(Account::getUuid)
.map(AccountIdentifierResponse::new)
.orElseThrow(() -> new WebApplicationException(Status.NOT_FOUND));
}
@HEAD
@@ -662,17 +715,7 @@ public class AccountController {
if (StringUtils.isNotBlank(request.getHeader("Authorization"))) {
throw new BadRequestException();
}
final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor)
.orElseThrow(() -> {
// Missing/malformed Forwarded-For, so we cannot check for a rate-limit.
// This shouldn't happen, so conservatively assume we're over the rate-limit
// and indicate that the client should retry
logger.error("Missing/bad Forwarded-For, cannot check account {}", uuid.toString());
return new RateLimitExceededException(Duration.ofHours(1));
});
rateLimiters.getCheckAccountExistenceLimiter().validate(mostRecentProxy);
rateLimitByClientIp(rateLimiters.getCheckAccountExistenceLimiter(), forwardedFor);
final Status status = accounts.getByAccountIdentifier(uuid)
.or(() -> accounts.getByPhoneNumberIdentifier(uuid))
@@ -681,6 +724,19 @@ public class AccountController {
return Response.status(status).build();
}
private void rateLimitByClientIp(final RateLimiter rateLimiter, final String forwardedFor) throws RateLimitExceededException {
final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor)
.orElseThrow(() -> {
// Missing/malformed Forwarded-For, so we cannot check for a rate-limit.
// This shouldn't happen, so conservatively assume we're over the rate-limit
// and indicate that the client should retry
logger.error("Missing/bad Forwarded-For: {}", forwardedFor);
return new RateLimitExceededException(Duration.ofHours(1));
});
rateLimiter.validate(mostRecentProxy);
}
private void verifyRegistrationLock(final Account existingAccount, @Nullable final String clientRegistrationLock)
throws RateLimitExceededException, WebApplicationException {
@@ -716,16 +772,17 @@ public class AccountController {
}
final String countryCode = Util.getCountryCode(number);
final String region = Util.getRegion(number);
if (captchaToken.isPresent()) {
boolean validToken = recaptchaClient.verify(captchaToken.get(), sourceHost);
{
final List<Tag> tags = new ArrayList<>();
tags.add(Tag.of("success", String.valueOf(validToken)));
tags.add(UserAgentTagUtil.getPlatformTag(userAgent));
tags.add(Tag.of(COUNTRY_CODE_TAG_NAME, countryCode));
Metrics.counter(CAPTCHA_ATTEMPT_COUNTER_NAME, tags).increment();
}
Metrics.counter(CAPTCHA_ATTEMPT_COUNTER_NAME, Tags.of(
Tag.of("success", String.valueOf(validToken)),
UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, countryCode),
Tag.of(REGION_TAG_NAME, region)))
.increment();
if (validToken) {
return new CaptchaRequirement(false, false);
@@ -737,6 +794,7 @@ public class AccountController {
{
final List<Tag> tags = new ArrayList<>();
tags.add(Tag.of(COUNTRY_CODE_TAG_NAME, countryCode));
tags.add(Tag.of(REGION_TAG_NAME, region));
try {
if (pushChallenge.isPresent()) {
@@ -762,7 +820,9 @@ public class AccountController {
DynamicCaptchaConfiguration captchaConfig = dynamicConfigurationManager.getConfiguration()
.getCaptchaConfiguration();
boolean countryFiltered = captchaConfig.getSignupCountryCodes().contains(countryCode);
boolean countryFiltered = captchaConfig.getSignupCountryCodes().contains(countryCode) ||
captchaConfig.getSignupRegions().contains(region);
if (abusiveHostRules.isBlocked(sourceHost)) {
blockedHostMeter.mark();

View File

@@ -13,9 +13,11 @@ import io.micrometer.core.instrument.Metrics;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
import io.micrometer.core.instrument.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@@ -30,8 +32,7 @@ public class KeepAliveController {
private final ClientPresenceManager clientPresenceManager;
private static final String NO_LOCAL_SUBSCRIPTION_COUNTER_NAME = name(KeepAliveController.class, "noLocalSubscription");
private static final String NO_LOCAL_SUBSCRIPTION_PLATFORM_TAG_NAME = "platform";
private static final String NO_LOCAL_SUBSCRIPTION_COUNTER_NAME = name(KeepAliveController.class, "noLocalSubscription");
public KeepAliveController(final ClientPresenceManager clientPresenceManager) {
this.clientPresenceManager = clientPresenceManager;
@@ -50,15 +51,9 @@ public class KeepAliveController {
context.getClient().close(1000, "OK");
String platform;
try {
platform = UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform().name().toLowerCase();
} catch (UnrecognizedUserAgentException e) {
platform = "unknown";
}
Metrics.counter(NO_LOCAL_SUBSCRIPTION_COUNTER_NAME, NO_LOCAL_SUBSCRIPTION_PLATFORM_TAG_NAME, platform).increment();
Metrics.counter(NO_LOCAL_SUBSCRIPTION_COUNTER_NAME,
Tags.of(UserAgentTagUtil.getPlatformTag(context.getClient().getUserAgent())))
.increment();
}
}

View File

@@ -77,11 +77,10 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
@@ -107,7 +106,7 @@ public class MessageController {
private final AccountsManager accountsManager;
private final DeletedAccountsManager deletedAccountsManager;
private final MessagesManager messagesManager;
private final ApnFallbackManager apnFallbackManager;
private final PushNotificationManager pushNotificationManager;
private final ReportMessageManager reportMessageManager;
private final ExecutorService multiRecipientMessageExecutor;
@@ -138,7 +137,7 @@ public class MessageController {
AccountsManager accountsManager,
DeletedAccountsManager deletedAccountsManager,
MessagesManager messagesManager,
ApnFallbackManager apnFallbackManager,
PushNotificationManager pushNotificationManager,
ReportMessageManager reportMessageManager,
@Nonnull ExecutorService multiRecipientMessageExecutor) {
this.rateLimiters = rateLimiters;
@@ -147,7 +146,7 @@ public class MessageController {
this.accountsManager = accountsManager;
this.deletedAccountsManager = deletedAccountsManager;
this.messagesManager = messagesManager;
this.apnFallbackManager = apnFallbackManager;
this.pushNotificationManager = pushNotificationManager;
this.reportMessageManager = reportMessageManager;
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
}
@@ -408,18 +407,14 @@ public class MessageController {
@Produces(MediaType.APPLICATION_JSON)
public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth,
@HeaderParam("User-Agent") String userAgent) {
assert auth.getAuthenticatedDevice() != null;
if (!Util.isEmpty(auth.getAuthenticatedDevice().getApnId())) {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice()));
}
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent);
final OutgoingMessageEntityList outgoingMessages;
{
final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
userAgent,
false);
outgoingMessages = new OutgoingMessageEntityList(messagesAndHasMore.first().stream()

View File

@@ -388,6 +388,7 @@ public class ProfileController {
private void checkFingerprintAndAdd(BatchIdentityCheckRequest.Element element,
Collection<BatchIdentityCheckResponse.Element> responseElements, MessageDigest md) {
accountsManager.getByAccountIdentifier(element.aci()).ifPresent(account -> {
if (account.getIdentityKey() == null) return;
byte[] identityKeyBytes;
try {
identityKeyBytes = Base64.getDecoder().decode(account.getIdentityKey());
@@ -503,24 +504,6 @@ public class ProfileController {
account.getPhoneNumberIdentifier());
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/username/{username}")
public BaseProfileResponse getProfileByUsername(
@Auth AuthenticatedAccount auth,
@Context ContainerRequestContext containerRequestContext,
@PathParam("username") String username)
throws RateLimitExceededException {
rateLimiters.getUsernameLookupLimiter().validate(auth.getAccount().getUuid());
final Account targetAccount = accountsManager.getByUsername(username).orElseThrow(NotFoundException::new);
final boolean isSelf = auth.getAccount().getUuid().equals(targetAccount.getUuid());
return buildBaseProfileResponseForAccountIdentity(targetAccount, isSelf, containerRequestContext);
}
private ProfileKeyCredentialResponse getProfileCredential(final String encodedProfileCredentialRequest,
final VersionedProfile profile,
final UUID uuid) {

View File

@@ -0,0 +1,10 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import javax.validation.constraints.NotNull;
import java.util.UUID;
public record AccountIdentifierResponse(@NotNull UUID uuid) {}

View File

@@ -0,0 +1,12 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import org.whispersystems.textsecuregcm.util.Nickname;
import javax.annotation.Nullable;
import javax.validation.Valid;
public record UsernameRequest(@Valid @Nickname String nickname, @Nullable String existingUsername) {}

View File

@@ -0,0 +1,8 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
public record UsernameResponse(String username) {}

View File

@@ -45,6 +45,11 @@ public class APNSender implements Managed, PushNotificationSender {
.setLocalizedAlertMessage("APN_Message")
.build();
@VisibleForTesting
static final String APN_BACKGROUND_PAYLOAD = new SimpleApnsPayloadBuilder()
.setContentAvailable(true)
.build();
@VisibleForTesting
static final Instant MAX_EXPIRATION = Instant.ofEpochMilli(Integer.MAX_VALUE * 1000L);
@@ -83,7 +88,13 @@ public class APNSender implements Managed, PushNotificationSender {
final boolean isVoip = notification.tokenType() == PushNotification.TokenType.APN_VOIP;
final String payload = switch (notification.notificationType()) {
case NOTIFICATION -> isVoip ? APN_VOIP_NOTIFICATION_PAYLOAD : APN_NSE_NOTIFICATION_PAYLOAD;
case NOTIFICATION -> {
if (isVoip) {
yield APN_VOIP_NOTIFICATION_PAYLOAD;
} else {
yield notification.urgent() ? APN_NSE_NOTIFICATION_PAYLOAD : APN_BACKGROUND_PAYLOAD;
}
}
case CHALLENGE -> new SimpleApnsPayloadBuilder()
.setSound("default")
@@ -98,8 +109,19 @@ public class APNSender implements Managed, PushNotificationSender {
.build();
};
final PushType pushType;
if (isVoip) {
pushType = PushType.VOIP;
} else {
pushType = notification.urgent() ? PushType.ALERT : PushType.BACKGROUND;
}
final DeliveryPriority deliveryPriority =
(notification.urgent() || isVoip) ? DeliveryPriority.IMMEDIATE : DeliveryPriority.CONSERVE_POWER;
final String collapseId =
(notification.notificationType() == PushNotification.NotificationType.NOTIFICATION && !isVoip)
(notification.notificationType() == PushNotification.NotificationType.NOTIFICATION && notification.urgent() && !isVoip)
? "incoming-message" : null;
final Instant start = Instant.now();
@@ -108,8 +130,8 @@ public class APNSender implements Managed, PushNotificationSender {
topic,
payload,
MAX_EXPIRATION,
DeliveryPriority.IMMEDIATE,
isVoip ? PushType.VOIP : PushType.ALERT,
deliveryPriority,
pushType,
collapseId))
.whenComplete((response, throwable) -> {
// Note that we deliberately run this small bit of non-blocking measurement on the "send notification" thread

View File

@@ -1,268 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.RatioGauge;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
public class ApnFallbackManager implements Managed {
private static final Logger logger = LoggerFactory.getLogger(ApnFallbackManager.class);
private static final String PENDING_NOTIFICATIONS_KEY = "PENDING_APN";
static final String NEXT_SLOT_TO_PERSIST_KEY = "pending_notification_next_slot";
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Meter delivered = metricRegistry.meter(name(ApnFallbackManager.class, "voip_delivered"));
private static final Meter sent = metricRegistry.meter(name(ApnFallbackManager.class, "voip_sent" ));
private static final Meter retry = metricRegistry.meter(name(ApnFallbackManager.class, "voip_retry"));
private static final Meter evicted = metricRegistry.meter(name(ApnFallbackManager.class, "voip_evicted"));
static {
metricRegistry.register(name(ApnFallbackManager.class, "voip_ratio"), new VoipRatioGauge(delivered, sent));
}
private final APNSender apnSender;
private final AccountsManager accountsManager;
private final FaultTolerantRedisCluster cluster;
private final ClusterLuaScript getScript;
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeScript;
private final Thread[] workerThreads = new Thread[WORKER_THREAD_COUNT];
private static final int WORKER_THREAD_COUNT = 4;
private final AtomicBoolean running = new AtomicBoolean(false);
class NotificationWorker implements Runnable {
@Override
public void run() {
while (running.get()) {
try {
final long entriesProcessed = processNextSlot();
if (entriesProcessed == 0) {
Util.sleep(1000);
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
}
}
long processNextSlot() {
final int slot = getNextSlot();
List<String> pendingDestinations;
long entriesProcessed = 0;
do {
pendingDestinations = getPendingDestinations(slot, 100);
entriesProcessed += pendingDestinations.size();
for (final String uuidAndDevice : pendingDestinations) {
final Optional<Pair<String, Long>> separated = getSeparated(uuidAndDevice);
final Optional<Account> maybeAccount = separated.map(Pair::first)
.map(UUID::fromString)
.flatMap(accountsManager::getByAccountIdentifier);
final Optional<Device> maybeDevice = separated.map(Pair::second)
.flatMap(deviceId -> maybeAccount.flatMap(account -> account.getDevice(deviceId)));
if (maybeAccount.isPresent() && maybeDevice.isPresent()) {
sendNotification(maybeAccount.get(), maybeDevice.get());
} else {
remove(uuidAndDevice);
}
}
} while (!pendingDestinations.isEmpty());
return entriesProcessed;
}
}
public ApnFallbackManager(FaultTolerantRedisCluster cluster,
APNSender apnSender,
AccountsManager accountsManager)
throws IOException
{
this.apnSender = apnSender;
this.accountsManager = accountsManager;
this.cluster = cluster;
this.getScript = ClusterLuaScript.fromResource(cluster, "lua/apn/get.lua", ScriptOutputType.MULTI);
this.insertScript = ClusterLuaScript.fromResource(cluster, "lua/apn/insert.lua", ScriptOutputType.VALUE);
this.removeScript = ClusterLuaScript.fromResource(cluster, "lua/apn/remove.lua", ScriptOutputType.INTEGER);
for (int i = 0; i < this.workerThreads.length; i++) {
this.workerThreads[i] = new Thread(new NotificationWorker(), "ApnFallbackManagerWorker-" + i);
}
}
public void schedule(Account account, Device device) {
schedule(account, device, System.currentTimeMillis());
}
@VisibleForTesting
void schedule(Account account, Device device, long timestamp) {
sent.mark();
insert(account, device, timestamp + (15 * 1000), (15 * 1000));
}
public void cancel(Account account, Device device) {
if (remove(account, device)) {
delivered.mark();
}
}
@Override
public synchronized void start() {
running.set(true);
for (final Thread workerThread : workerThreads) {
workerThread.start();
}
}
@Override
public synchronized void stop() throws InterruptedException {
running.set(false);
for (final Thread workerThread : workerThreads) {
workerThread.join();
}
}
private void sendNotification(final Account account, final Device device) {
String apnId = device.getVoipApnId();
if (apnId == null) {
remove(account, device);
return;
}
long deviceLastSeen = device.getLastSeen();
if (deviceLastSeen < System.currentTimeMillis() - TimeUnit.DAYS.toMillis(7)) {
evicted.mark();
remove(account, device);
return;
}
apnSender.sendNotification(new PushNotification(apnId, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device));
retry.mark();
}
@VisibleForTesting
static Optional<Pair<String, Long>> getSeparated(String encoded) {
try {
if (encoded == null) return Optional.empty();
String[] parts = encoded.split(":");
if (parts.length != 2) {
logger.warn("Got strange encoded number: " + encoded);
return Optional.empty();
}
return Optional.of(new Pair<>(parts[0], Long.parseLong(parts[1])));
} catch (NumberFormatException e) {
logger.warn("Badly formatted: " + encoded, e);
return Optional.empty();
}
}
private boolean remove(Account account, Device device) {
return remove(getEndpointKey(account, device));
}
private boolean remove(final String endpoint) {
return (long)removeScript.execute(List.of(getPendingNotificationQueueKey(endpoint), endpoint),
Collections.emptyList()) > 0;
}
@SuppressWarnings("unchecked")
@VisibleForTesting
List<String> getPendingDestinations(final int slot, final int limit) {
return (List<String>)getScript.execute(List.of(getPendingNotificationQueueKey(slot)),
List.of(String.valueOf(System.currentTimeMillis()), String.valueOf(limit)));
}
private void insert(final Account account, final Device device, final long timestamp, final long interval) {
final String endpoint = getEndpointKey(account, device);
insertScript.execute(List.of(getPendingNotificationQueueKey(endpoint), endpoint),
List.of(String.valueOf(timestamp),
String.valueOf(interval),
account.getUuid().toString(),
String.valueOf(device.getId())));
}
@VisibleForTesting
String getEndpointKey(final Account account, final Device device) {
return "apn_device::{" + account.getUuid() + "::" + device.getId() + "}";
}
private String getPendingNotificationQueueKey(final String endpoint) {
return getPendingNotificationQueueKey(SlotHash.getSlot(endpoint));
}
private String getPendingNotificationQueueKey(final int slot) {
return PENDING_NOTIFICATIONS_KEY + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
private int getNextSlot() {
return (int)(cluster.withCluster(connection -> connection.sync().incr(NEXT_SLOT_TO_PERSIST_KEY)) % SlotHash.SLOT_COUNT);
}
private static class VoipRatioGauge extends RatioGauge {
private final Meter success;
private final Meter attempts;
private VoipRatioGauge(Meter success, Meter attempts) {
this.success = success;
this.attempts = attempts;
}
@Override
protected Ratio getRatio() {
return RatioGauge.Ratio.of(success.getFiveMinuteRate(), attempts.getFiveMinuteRate());
}
}
}

View File

@@ -0,0 +1,390 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.Limit;
import io.lettuce.core.Range;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.SlotHash;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.Util;
public class ApnPushNotificationScheduler implements Managed {
private static final Logger logger = LoggerFactory.getLogger(ApnPushNotificationScheduler.class);
private static final String PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX = "PENDING_APN";
private static final String PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX = "PENDING_BACKGROUND_APN";
private static final String LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX = "LAST_BACKGROUND_NOTIFICATION";
@VisibleForTesting
static final String NEXT_SLOT_TO_PROCESS_KEY = "pending_notification_next_slot";
private static final Counter delivered = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_delivered"));
private static final Counter sent = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_sent"));
private static final Counter retry = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_retry"));
private static final Counter evicted = Metrics.counter(name(ApnPushNotificationScheduler.class, "voip_evicted"));
private static final Counter backgroundNotificationScheduledCounter = Metrics.counter(name(ApnPushNotificationScheduler.class, "backgroundNotification", "scheduled"));
private static final Counter backgroundNotificationSentCounter = Metrics.counter(name(ApnPushNotificationScheduler.class, "backgroundNotification", "sent"));
private final APNSender apnSender;
private final AccountsManager accountsManager;
private final FaultTolerantRedisCluster pushSchedulingCluster;
private final Clock clock;
private final ClusterLuaScript getPendingVoipDestinationsScript;
private final ClusterLuaScript insertPendingVoipDestinationScript;
private final ClusterLuaScript removePendingVoipDestinationScript;
private final ClusterLuaScript scheduleBackgroundNotificationScript;
private final Thread[] workerThreads = new Thread[WORKER_THREAD_COUNT];
private static final int WORKER_THREAD_COUNT = 4;
@VisibleForTesting
static final Duration BACKGROUND_NOTIFICATION_PERIOD = Duration.ofMinutes(20);
private final AtomicBoolean running = new AtomicBoolean(false);
class NotificationWorker implements Runnable {
private static final int PAGE_SIZE = 128;
@Override
public void run() {
while (running.get()) {
try {
final long entriesProcessed = processNextSlot();
if (entriesProcessed == 0) {
Util.sleep(1000);
}
} catch (Exception e) {
logger.warn("Exception while operating", e);
}
}
}
private long processNextSlot() {
final int slot = (int) (pushSchedulingCluster.withCluster(connection ->
connection.sync().incr(NEXT_SLOT_TO_PROCESS_KEY)) % SlotHash.SLOT_COUNT);
return processRecurringVoipNotifications(slot) + processScheduledBackgroundNotifications(slot);
}
@VisibleForTesting
long processRecurringVoipNotifications(final int slot) {
List<String> pendingDestinations;
long entriesProcessed = 0;
do {
pendingDestinations = getPendingDestinationsForRecurringVoipNotifications(slot, PAGE_SIZE);
entriesProcessed += pendingDestinations.size();
for (final String destination : pendingDestinations) {
try {
getAccountAndDeviceFromPairString(destination).ifPresentOrElse(
accountAndDevice -> sendRecurringVoipNotification(accountAndDevice.first(), accountAndDevice.second()),
() -> removeRecurringVoipNotificationEntry(destination));
} catch (final IllegalArgumentException e) {
logger.warn("Failed to parse account/device pair: {}", destination, e);
}
}
} while (!pendingDestinations.isEmpty());
return entriesProcessed;
}
@VisibleForTesting
long processScheduledBackgroundNotifications(final int slot) {
final long currentTimeMillis = clock.millis();
final String queueKey = getPendingBackgroundNotificationQueueKey(slot);
final long processedBackgroundNotifications = pushSchedulingCluster.withCluster(connection -> {
List<String> destinations;
long offset = 0;
do {
destinations = connection.sync().zrangebyscore(queueKey, Range.create(0, currentTimeMillis), Limit.create(offset, PAGE_SIZE));
for (final String destination : destinations) {
try {
getAccountAndDeviceFromPairString(destination).ifPresent(accountAndDevice ->
sendBackgroundNotification(accountAndDevice.first(), accountAndDevice.second()));
} catch (final IllegalArgumentException e) {
logger.warn("Failed to parse account/device pair: {}", destination, e);
}
}
offset += destinations.size();
} while (destinations.size() == PAGE_SIZE);
return offset;
});
pushSchedulingCluster.useCluster(connection ->
connection.sync().zremrangebyscore(queueKey, Range.create(0, currentTimeMillis)));
return processedBackgroundNotifications;
}
}
public ApnPushNotificationScheduler(FaultTolerantRedisCluster pushSchedulingCluster,
APNSender apnSender,
AccountsManager accountsManager) throws IOException {
this(pushSchedulingCluster, apnSender, accountsManager, Clock.systemUTC());
}
@VisibleForTesting
ApnPushNotificationScheduler(FaultTolerantRedisCluster pushSchedulingCluster,
APNSender apnSender,
AccountsManager accountsManager,
Clock clock) throws IOException {
this.apnSender = apnSender;
this.accountsManager = accountsManager;
this.pushSchedulingCluster = pushSchedulingCluster;
this.clock = clock;
this.getPendingVoipDestinationsScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/get.lua", ScriptOutputType.MULTI);
this.insertPendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/insert.lua", ScriptOutputType.VALUE);
this.removePendingVoipDestinationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/remove.lua", ScriptOutputType.INTEGER);
this.scheduleBackgroundNotificationScript = ClusterLuaScript.fromResource(pushSchedulingCluster, "lua/apn/schedule_background_notification.lua", ScriptOutputType.VALUE);
for (int i = 0; i < this.workerThreads.length; i++) {
this.workerThreads[i] = new Thread(new NotificationWorker(), "ApnFallbackManagerWorker-" + i);
}
}
void scheduleRecurringVoipNotification(Account account, Device device) {
sent.increment();
insertRecurringVoipNotificationEntry(account, device, clock.millis() + (15 * 1000), (15 * 1000));
}
void scheduleBackgroundNotification(final Account account, final Device device) {
backgroundNotificationScheduledCounter.increment();
scheduleBackgroundNotificationScript.execute(
List.of(
getLastBackgroundNotificationTimestampKey(account, device),
getPendingBackgroundNotificationQueueKey(account, device)),
List.of(
getPairString(account, device),
String.valueOf(clock.millis()),
String.valueOf(BACKGROUND_NOTIFICATION_PERIOD.toMillis())));
}
public void cancelScheduledNotifications(Account account, Device device) {
if (removeRecurringVoipNotificationEntry(account, device)) {
delivered.increment();
}
pushSchedulingCluster.useCluster(connection ->
connection.sync().zrem(getPendingBackgroundNotificationQueueKey(account, device), getPairString(account, device)));
}
@Override
public synchronized void start() {
running.set(true);
for (final Thread workerThread : workerThreads) {
workerThread.start();
}
}
@Override
public synchronized void stop() throws InterruptedException {
running.set(false);
for (final Thread workerThread : workerThreads) {
workerThread.join();
}
}
private void sendRecurringVoipNotification(final Account account, final Device device) {
String apnId = device.getVoipApnId();
if (apnId == null) {
removeRecurringVoipNotificationEntry(account, device);
return;
}
long deviceLastSeen = device.getLastSeen();
if (deviceLastSeen < clock.millis() - TimeUnit.DAYS.toMillis(7)) {
evicted.increment();
removeRecurringVoipNotificationEntry(account, device);
return;
}
apnSender.sendNotification(new PushNotification(apnId, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true));
retry.increment();
}
@VisibleForTesting
void sendBackgroundNotification(final Account account, final Device device) {
if (StringUtils.isNotBlank(device.getApnId())) {
// It's okay for the "last notification" timestamp to expire after the "cooldown" period has elapsed; a missing
// timestamp and a timestamp older than the period are functionally equivalent.
pushSchedulingCluster.useCluster(connection -> connection.sync().set(
getLastBackgroundNotificationTimestampKey(account, device),
String.valueOf(clock.millis()), new SetArgs().ex(BACKGROUND_NOTIFICATION_PERIOD)));
apnSender.sendNotification(new PushNotification(device.getApnId(), PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, false));
backgroundNotificationSentCounter.increment();
}
}
@VisibleForTesting
static Optional<Pair<String, Long>> getSeparated(String encoded) {
try {
if (encoded == null) return Optional.empty();
String[] parts = encoded.split(":");
if (parts.length != 2) {
logger.warn("Got strange encoded number: " + encoded);
return Optional.empty();
}
return Optional.of(new Pair<>(parts[0], Long.parseLong(parts[1])));
} catch (NumberFormatException e) {
logger.warn("Badly formatted: " + encoded, e);
return Optional.empty();
}
}
@VisibleForTesting
static String getPairString(final Account account, final Device device) {
return account.getUuid() + ":" + device.getId();
}
@VisibleForTesting
Optional<Pair<Account, Device>> getAccountAndDeviceFromPairString(final String endpoint) {
try {
if (StringUtils.isBlank(endpoint)) {
throw new IllegalArgumentException("Endpoint must not be blank");
}
final String[] parts = endpoint.split(":");
if (parts.length != 2) {
throw new IllegalArgumentException("Could not parse endpoint string: " + endpoint);
}
final Optional<Account> maybeAccount = accountsManager.getByAccountIdentifier(UUID.fromString(parts[0]));
return maybeAccount.flatMap(account -> account.getDevice(Long.parseLong(parts[1])))
.map(device -> new Pair<>(maybeAccount.get(), device));
} catch (final NumberFormatException e) {
throw new IllegalArgumentException(e);
}
}
private boolean removeRecurringVoipNotificationEntry(Account account, Device device) {
return removeRecurringVoipNotificationEntry(getEndpointKey(account, device));
}
private boolean removeRecurringVoipNotificationEntry(final String endpoint) {
return (long) removePendingVoipDestinationScript.execute(
List.of(getPendingRecurringVoipNotificationQueueKey(endpoint), endpoint),
Collections.emptyList()) > 0;
}
@SuppressWarnings("unchecked")
@VisibleForTesting
List<String> getPendingDestinationsForRecurringVoipNotifications(final int slot, final int limit) {
return (List<String>) getPendingVoipDestinationsScript.execute(
List.of(getPendingRecurringVoipNotificationQueueKey(slot)),
List.of(String.valueOf(clock.millis()), String.valueOf(limit)));
}
private void insertRecurringVoipNotificationEntry(final Account account, final Device device, final long timestamp, final long interval) {
final String endpoint = getEndpointKey(account, device);
insertPendingVoipDestinationScript.execute(
List.of(getPendingRecurringVoipNotificationQueueKey(endpoint), endpoint),
List.of(String.valueOf(timestamp),
String.valueOf(interval),
account.getUuid().toString(),
String.valueOf(device.getId())));
}
@VisibleForTesting
static String getEndpointKey(final Account account, final Device device) {
return "apn_device::{" + account.getUuid() + "::" + device.getId() + "}";
}
private static String getPendingRecurringVoipNotificationQueueKey(final String endpoint) {
return getPendingRecurringVoipNotificationQueueKey(SlotHash.getSlot(endpoint));
}
private static String getPendingRecurringVoipNotificationQueueKey(final int slot) {
return PENDING_RECURRING_VOIP_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
@VisibleForTesting
static String getPendingBackgroundNotificationQueueKey(final Account account, final Device device) {
return getPendingBackgroundNotificationQueueKey(SlotHash.getSlot(getPairString(account, device)));
}
private static String getPendingBackgroundNotificationQueueKey(final int slot) {
return PENDING_BACKGROUND_NOTIFICATIONS_KEY_PREFIX + "::{" + RedisClusterUtil.getMinimalHashTag(slot) + "}";
}
private static String getLastBackgroundNotificationTimestampKey(final Account account, final Device device) {
return LAST_BACKGROUND_NOTIFICATION_TIMESTAMP_KEY_PREFIX + "::{" + getPairString(account, device) + "}";
}
@VisibleForTesting
Optional<Instant> getLastBackgroundNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().get(getLastBackgroundNotificationTimestampKey(account, device))))
.map(timestampString -> Instant.ofEpochMilli(Long.parseLong(timestampString)));
}
@VisibleForTesting
Optional<Instant> getNextScheduledBackgroundNotificationTimestamp(final Account account, final Device device) {
return Optional.ofNullable(
pushSchedulingCluster.withCluster(connection ->
connection.sync().zscore(getPendingBackgroundNotificationQueueKey(account, device),
getPairString(account, device))))
.map(timestamp -> Instant.ofEpochMilli(timestamp.longValue()));
}
}

View File

@@ -84,7 +84,7 @@ public class FcmSender implements PushNotificationSender {
Message.Builder builder = Message.builder()
.setToken(pushNotification.deviceToken())
.setAndroidConfig(AndroidConfig.builder()
.setPriority(AndroidConfig.Priority.HIGH)
.setPriority(pushNotification.urgent() ? AndroidConfig.Priority.HIGH : AndroidConfig.Priority.NORMAL)
.build());
final String key = switch (pushNotification.notificationType()) {

View File

@@ -9,7 +9,6 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import io.micrometer.core.instrument.Metrics;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -40,6 +39,7 @@ public class MessageSender {
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String CLIENT_ONLINE_TAG_NAME = "clientOnline";
private static final String URGENT_TAG_NAME = "urgent";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender";
public MessageSender(ClientPresenceManager clientPresenceManager,
MessagesManager messagesManager,
@@ -51,7 +51,7 @@ public class MessageSender {
this.pushLatencyManager = pushLatencyManager;
}
public void sendMessage(final Account account, final Device device, final Envelope message, boolean online)
public void sendMessage(final Account account, final Device device, final Envelope message, final boolean online)
throws NotPushRegisteredException {
final String channel;
@@ -84,7 +84,7 @@ public class MessageSender {
if (!clientPresent) {
try {
pushNotificationManager.sendNewMessageNotification(account, device.getId());
pushNotificationManager.sendNewMessageNotification(account, device.getId(), message.getUrgent());
final boolean useVoip = StringUtils.isNotBlank(device.getVoipApnId());
RedisOperation.unchecked(() -> pushLatencyManager.recordPushSent(account.getUuid(), device.getId(), useVoip));
@@ -100,7 +100,8 @@ public class MessageSender {
CHANNEL_TAG_NAME, channel,
EPHEMERAL_TAG_NAME, String.valueOf(online),
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()))
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceUuid()))
.increment();
}
}

View File

@@ -1,9 +1,9 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.MetricRegistry;
import com.fasterxml.jackson.annotation.JsonCreator;
@@ -28,6 +28,7 @@ import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -103,7 +104,7 @@ public class PushLatencyManager {
this.clock = clock;
}
public void recordPushSent(final UUID accountUuid, final long deviceId, final boolean isVoip) {
void recordPushSent(final UUID accountUuid, final long deviceId, final boolean isVoip) {
try {
final String recordJson = SystemMapper.getMapper().writeValueAsString(
new PushRecord(Instant.now(clock), isVoip ? PushType.VOIP : PushType.STANDARD));
@@ -118,7 +119,7 @@ public class PushLatencyManager {
}
}
public void recordQueueRead(final UUID accountUuid, final long deviceId, final String userAgentString) {
void recordQueueRead(final UUID accountUuid, final long deviceId, final String userAgentString) {
takePushRecord(accountUuid, deviceId).thenAccept(pushRecord -> {
if (pushRecord != null) {
final Duration latency = Duration.between(pushRecord.getTimestamp(), Instant.now());

View File

@@ -14,7 +14,8 @@ public record PushNotification(String deviceToken,
NotificationType notificationType,
@Nullable String data,
@Nullable Account destination,
@Nullable Device destinationDevice) {
@Nullable Device destinationDevice,
boolean urgent) {
public enum NotificationType {
NOTIFICATION, CHALLENGE, RATE_LIMIT_CHALLENGE

View File

@@ -5,27 +5,31 @@
package org.whispersystems.textsecuregcm.push;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class PushNotificationManager {
private final AccountsManager accountsManager;
private final APNSender apnSender;
private final FcmSender fcmSender;
private final ApnFallbackManager fallbackManager;
private final ApnPushNotificationScheduler apnPushNotificationScheduler;
private final PushLatencyManager pushLatencyManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private static final String SENT_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "sentPushNotification");
private static final String FAILED_NOTIFICATION_COUNTER_NAME = name(PushNotificationManager.class, "failedPushNotification");
@@ -35,24 +39,32 @@ public class PushNotificationManager {
public PushNotificationManager(final AccountsManager accountsManager,
final APNSender apnSender,
final FcmSender fcmSender,
final ApnFallbackManager fallbackManager) {
final ApnPushNotificationScheduler apnPushNotificationScheduler,
final PushLatencyManager pushLatencyManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.accountsManager = accountsManager;
this.apnSender = apnSender;
this.fcmSender = fcmSender;
this.fallbackManager = fallbackManager;
this.apnPushNotificationScheduler = apnPushNotificationScheduler;
this.pushLatencyManager = pushLatencyManager;
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
public void sendNewMessageNotification(final Account destination, final long destinationDeviceId) throws NotPushRegisteredException {
public void sendNewMessageNotification(final Account destination, final long destinationDeviceId, final boolean urgent) throws NotPushRegisteredException {
final Device device = destination.getDevice(destinationDeviceId).orElseThrow(NotPushRegisteredException::new);
final Pair<String, PushNotification.TokenType> tokenAndType = getToken(device);
final boolean effectiveUrgent =
dynamicConfigurationManager.getConfiguration().getPushNotificationConfiguration().isLowUrgencyEnabled() ?
urgent : true;
sendNotification(new PushNotification(tokenAndType.first(), tokenAndType.second(),
PushNotification.NotificationType.NOTIFICATION, null, destination, device));
PushNotification.NotificationType.NOTIFICATION, null, destination, device, effectiveUrgent));
}
public void sendRegistrationChallengeNotification(final String deviceToken, final PushNotification.TokenType tokenType, final String challengeToken) {
sendNotification(new PushNotification(deviceToken, tokenType, PushNotification.NotificationType.CHALLENGE, challengeToken, null, null));
sendNotification(new PushNotification(deviceToken, tokenType, PushNotification.NotificationType.CHALLENGE, challengeToken, null, null, true));
}
public void sendRateLimitChallengeNotification(final Account destination, final String challengeToken)
@@ -62,7 +74,12 @@ public class PushNotificationManager {
final Pair<String, PushNotification.TokenType> tokenAndType = getToken(device);
sendNotification(new PushNotification(tokenAndType.first(), tokenAndType.second(),
PushNotification.NotificationType.RATE_LIMIT_CHALLENGE, challengeToken, destination, device));
PushNotification.NotificationType.RATE_LIMIT_CHALLENGE, challengeToken, destination, device, true));
}
public void handleMessagesRetrieved(final Account account, final Device device, final String userAgent) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(account.getUuid(), device.getId(), userAgent));
RedisOperation.unchecked(() -> apnPushNotificationScheduler.cancelScheduledNotifications(account, device));
}
@VisibleForTesting
@@ -84,44 +101,55 @@ public class PushNotificationManager {
@VisibleForTesting
void sendNotification(final PushNotification pushNotification) {
final PushNotificationSender sender = switch (pushNotification.tokenType()) {
case FCM -> fcmSender;
case APN, APN_VOIP -> apnSender;
};
if (pushNotification.tokenType() == PushNotification.TokenType.APN && !pushNotification.urgent()) {
// APNs imposes a per-device limit on background push notifications; schedule a notification for some time in the
// future (possibly even now!) rather than sending a notification directly
apnPushNotificationScheduler.scheduleBackgroundNotification(pushNotification.destination(),
pushNotification.destinationDevice());
} else {
final PushNotificationSender sender = switch (pushNotification.tokenType()) {
case FCM -> fcmSender;
case APN, APN_VOIP -> apnSender;
};
sender.sendNotification(pushNotification).whenComplete((result, throwable) -> {
if (throwable == null) {
Tags tags = Tags.of("tokenType", pushNotification.tokenType().name(),
"notificationType", pushNotification.notificationType().name(),
"accepted", String.valueOf(result.accepted()),
"unregistered", String.valueOf(result.unregistered()));
sender.sendNotification(pushNotification).whenComplete((result, throwable) -> {
if (throwable == null) {
Tags tags = Tags.of("tokenType", pushNotification.tokenType().name(),
"notificationType", pushNotification.notificationType().name(),
"urgent", String.valueOf(pushNotification.urgent()),
"accepted", String.valueOf(result.accepted()),
"unregistered", String.valueOf(result.unregistered()));
if (StringUtils.isNotBlank(result.errorCode())) {
tags = tags.and("errorCode", result.errorCode());
if (StringUtils.isNotBlank(result.errorCode())) {
tags = tags.and("errorCode", result.errorCode());
}
Metrics.counter(SENT_NOTIFICATION_COUNTER_NAME, tags).increment();
if (result.unregistered() && pushNotification.destination() != null
&& pushNotification.destinationDevice() != null) {
handleDeviceUnregistered(pushNotification.destination(), pushNotification.destinationDevice());
}
if (result.accepted() &&
pushNotification.tokenType() == PushNotification.TokenType.APN_VOIP &&
pushNotification.notificationType() == PushNotification.NotificationType.NOTIFICATION &&
pushNotification.destination() != null &&
pushNotification.destinationDevice() != null) {
RedisOperation.unchecked(
() -> apnPushNotificationScheduler.scheduleRecurringVoipNotification(pushNotification.destination(),
pushNotification.destinationDevice()));
}
} else {
logger.debug("Failed to deliver {} push notification to {} ({})",
pushNotification.notificationType(), pushNotification.deviceToken(), pushNotification.tokenType(),
throwable);
Metrics.counter(FAILED_NOTIFICATION_COUNTER_NAME, "cause", throwable.getClass().getSimpleName()).increment();
}
Metrics.counter(SENT_NOTIFICATION_COUNTER_NAME, tags).increment();
if (result.unregistered() && pushNotification.destination() != null && pushNotification.destinationDevice() != null) {
handleDeviceUnregistered(pushNotification.destination(), pushNotification.destinationDevice());
}
if (result.accepted() &&
pushNotification.tokenType() == PushNotification.TokenType.APN_VOIP &&
pushNotification.notificationType() == PushNotification.NotificationType.NOTIFICATION &&
pushNotification.destination() != null &&
pushNotification.destinationDevice() != null) {
RedisOperation.unchecked(() -> fallbackManager.schedule(pushNotification.destination(),
pushNotification.destinationDevice()));
}
} else {
logger.debug("Failed to deliver {} push notification to {} ({})",
pushNotification.notificationType(), pushNotification.deviceToken(), pushNotification.tokenType(), throwable);
Metrics.counter(FAILED_NOTIFICATION_COUNTER_NAME, "cause", throwable.getClass().getSimpleName()).increment();
}
});
});
}
}
private void handleDeviceUnregistered(final Account account, final Device device) {
@@ -131,7 +159,7 @@ public class PushNotificationManager {
d.setUninstalledFeedbackTimestamp(Util.todayInMillis()));
}
} else {
RedisOperation.unchecked(() -> fallbackManager.cancel(account, device));
RedisOperation.unchecked(() -> apnPushNotificationScheduler.cancelScheduledNotifications(account, device));
}
}
}

View File

@@ -52,7 +52,8 @@ public class ReceiptSender {
.setSourceDevice((int) sourceDeviceId)
.setDestinationUuid(destinationUuid.toString())
.setTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT);
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT)
.setUrgent(false);
return CompletableFuture.runAsync(() -> {
for (final Device destinationDevice : destinationAccount.getDevices()) {

View File

@@ -9,6 +9,7 @@ import java.util.List;
import java.util.Locale.LanguageRange;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class SmsSender {
@@ -32,7 +33,8 @@ public class SmsSender {
twilioSender.deliverVoxVerification(destination, verificationCode, languageRanges);
}
public CompletableFuture<Optional<String>> deliverSmsVerificationWithTwilioVerify(String destination, Optional<String> clientType,
public CompletableFuture<Optional<String>> deliverSmsVerificationWithTwilioVerify(String destination,
Optional<String> clientType,
String verificationCode, List<LanguageRange> languageRanges) {
// Fix up mexico numbers to 'mobile' format just for SMS delivery.
if (destination.startsWith("+52") && !destination.startsWith("+521")) {
@@ -42,13 +44,14 @@ public class SmsSender {
return twilioSender.deliverSmsVerificationWithVerify(destination, clientType, verificationCode, languageRanges);
}
public CompletableFuture<Optional<String>> deliverVoxVerificationWithTwilioVerify(String destination, String verificationCode,
public CompletableFuture<Optional<String>> deliverVoxVerificationWithTwilioVerify(String destination,
String verificationCode,
List<LanguageRange> languageRanges) {
return twilioSender.deliverVoxVerificationWithVerify(destination, verificationCode, languageRanges);
return twilioSender.deliverVoxVerificationWithVerify(destination, verificationCode, languageRanges);
}
public void reportVerificationSucceeded(String verificationSid) {
twilioSender.reportVerificationSucceeded(verificationSid);
public void reportVerificationSucceeded(String verificationSid, @Nullable String userAgent, String context) {
twilioSender.reportVerificationSucceeded(verificationSid, userAgent, context);
}
}

View File

@@ -133,7 +133,8 @@ public class TwilioSmsSender {
smsMeter.mark();
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(this::parseResponse).handle(this::processResponse);
.thenApply(this::parseResponse)
.handle((response, throwable) -> processResponse(response, throwable, destination));
}
private String getBodyFormatString(@Nonnull String destination, @Nullable String clientType) {
@@ -198,26 +199,29 @@ public class TwilioSmsSender {
voxMeter.mark();
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(this::parseResponse)
.handle(this::processResponse);
.thenApply(this::parseResponse)
.handle((response, throwable) -> processResponse(response, throwable, destination));
}
private String getRandom(Random random, List<String> elements) {
return elements.get(random.nextInt(elements.size()));
}
private boolean processResponse(TwilioResponse response, Throwable throwable) {
private boolean processResponse(TwilioResponse response, Throwable throwable, String destination) {
if (response != null && response.isSuccess()) {
priceMeter.mark((long) (response.successResponse.price * 1000));
return true;
} else if (response != null && response.isFailure()) {
logger.debug("Twilio request failed: " + response.failureResponse.status + "(code " + response.failureResponse.code + "), " + response.failureResponse.message);
Metrics.counter(FAILED_REQUEST_COUNTER_NAME,
SERVICE_NAME_TAG, "classic",
STATUS_CODE_TAG_NAME, String.valueOf(response.failureResponse.status),
ERROR_CODE_TAG_NAME, String.valueOf(response.failureResponse.code)).increment();
logger.info("Failed with code={}, country={}",
response.failureResponse.code,
Util.getCountryCode(destination));
return false;
} else if (throwable != null) {
logger.info("Twilio request failed", throwable);
@@ -246,23 +250,27 @@ public class TwilioSmsSender {
}
}
public CompletableFuture<Optional<String>> deliverSmsVerificationWithVerify(String destination, Optional<String> clientType, String verificationCode, List<LanguageRange> languageRanges) {
public CompletableFuture<Optional<String>> deliverSmsVerificationWithVerify(String destination,
Optional<String> clientType, String verificationCode, List<LanguageRange> languageRanges) {
smsMeter.mark();
return twilioVerifySender.deliverSmsVerificationWithVerify(destination, clientType, verificationCode, languageRanges);
return twilioVerifySender.deliverSmsVerificationWithVerify(destination, clientType, verificationCode,
languageRanges);
}
public CompletableFuture<Optional<String>> deliverVoxVerificationWithVerify(String destination, String verificationCode, List<LanguageRange> languageRanges) {
public CompletableFuture<Optional<String>> deliverVoxVerificationWithVerify(String destination,
String verificationCode, List<LanguageRange> languageRanges) {
voxMeter.mark();
return twilioVerifySender.deliverVoxVerificationWithVerify(destination, verificationCode, languageRanges);
}
public CompletableFuture<Boolean> reportVerificationSucceeded(String verificationSid) {
public CompletableFuture<Boolean> reportVerificationSucceeded(String verificationSid, @Nullable String userAgent,
String context) {
return twilioVerifySender.reportVerificationSucceeded(verificationSid);
return twilioVerifySender.reportVerificationSucceeded(verificationSid, userAgent, context);
}
public static class TwilioResponse {

View File

@@ -1,7 +1,12 @@
package org.whispersystems.textsecuregcm.sms;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpRequest;
@@ -14,13 +19,14 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import javax.validation.constraints.NotEmpty;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.TwilioConfiguration;
import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient;
import org.whispersystems.textsecuregcm.http.FormDataBodyPublisher;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
@@ -29,6 +35,13 @@ class TwilioVerifySender {
private static final Logger logger = LoggerFactory.getLogger(TwilioVerifySender.class);
private static final String VERIFICATION_SUCCEEDED_RESPONSE_COUNTER_NAME = name(TwilioVerifySender.class,
"verificationSucceeded");
private static final String CONTEXT_TAG_NAME = "context";
private static final String STATUS_CODE_TAG_NAME = "statusCode";
private static final String ERROR_CODE_TAG_NAME = "errorCode";
static final Set<String> TWILIO_VERIFY_LANGUAGES = Set.of(
"af",
"ar",
@@ -99,7 +112,7 @@ class TwilioVerifySender {
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(this::parseResponse)
.handle(this::extractVerifySid);
.handle((response, throwable) -> extractVerifySid(response, throwable, destination));
}
private Optional<String> findBestLocale(List<LanguageRange> priorityList) {
@@ -124,18 +137,19 @@ class TwilioVerifySender {
}
}
CompletableFuture<Optional<String>> deliverVoxVerificationWithVerify(String destination, String verificationCode,
List<LanguageRange> languageRanges) {
CompletableFuture<Optional<String>> deliverVoxVerificationWithVerify(String destination,
String verificationCode, List<LanguageRange> languageRanges) {
HttpRequest request = buildVerifyRequest("call", destination, verificationCode, findBestLocale(languageRanges),
Optional.empty());
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(this::parseResponse)
.handle(this::extractVerifySid);
.handle((response, throwable) -> extractVerifySid(response, throwable, destination));
}
private Optional<String> extractVerifySid(TwilioVerifyResponse twilioVerifyResponse, Throwable throwable) {
private Optional<String> extractVerifySid(TwilioVerifyResponse twilioVerifyResponse, Throwable throwable,
String destination) {
if (throwable != null) {
logger.warn("Failed to send Twilio request", throwable);
@@ -148,6 +162,10 @@ class TwilioVerifySender {
TwilioSmsSender.STATUS_CODE_TAG_NAME, String.valueOf(twilioVerifyResponse.failureResponse.status),
TwilioSmsSender.ERROR_CODE_TAG_NAME, String.valueOf(twilioVerifyResponse.failureResponse.code)).increment();
logger.info("Failed with code={}, country={}",
twilioVerifyResponse.failureResponse.code,
Util.getCountryCode(destination));
return Optional.empty();
}
@@ -170,11 +188,13 @@ class TwilioVerifySender {
.uri(verifyServiceUri)
.POST(FormDataBodyPublisher.of(requestParameters))
.header("Content-Type", "application/x-www-form-urlencoded")
.header("Authorization", "Basic " + Base64.getEncoder().encodeToString((accountId + ":" + accountToken).getBytes()))
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString((accountId + ":" + accountToken).getBytes()))
.build();
}
public CompletableFuture<Boolean> reportVerificationSucceeded(String verificationSid) {
public CompletableFuture<Boolean> reportVerificationSucceeded(String verificationSid, @Nullable String userAgent,
String context) {
final Map<String, String> requestParameters = new HashMap<>();
requestParameters.put("Status", "approved");
@@ -183,14 +203,47 @@ class TwilioVerifySender {
.uri(verifyApprovalBaseUri.resolve(verificationSid))
.POST(FormDataBodyPublisher.of(requestParameters))
.header("Content-Type", "application/x-www-form-urlencoded")
.header("Authorization", "Basic " + Base64.getEncoder().encodeToString((accountId + ":" + accountToken).getBytes()))
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString((accountId + ":" + accountToken).getBytes()))
.build();
return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(this::parseResponse)
.handle((response, throwable) -> throwable == null
&& response.isSuccess()
&& "approved".equals(response.successResponse.getStatus()));
.handle((response, throwable) -> processVerificationSucceededResponse(response, throwable, userAgent, context));
}
private boolean processVerificationSucceededResponse(@Nullable final TwilioVerifyResponse response,
@Nullable final Throwable throwable,
final String userAgent,
final String context) {
if (throwable == null) {
assert response != null;
final Tags tags = Tags.of(Tag.of(CONTEXT_TAG_NAME, context), UserAgentTagUtil.getPlatformTag(userAgent));
if (response.isSuccess() && "approved".equals(response.successResponse.getStatus())) {
// the other possible values of `status` are `pending` or `canceled`, but these can never happen in a response
// to this POST, so we dont consider them
Metrics.counter(VERIFICATION_SUCCEEDED_RESPONSE_COUNTER_NAME, tags)
.increment();
return true;
}
// at this point, response.isFailure() == true
Metrics.counter(
VERIFICATION_SUCCEEDED_RESPONSE_COUNTER_NAME,
Tags.of(ERROR_CODE_TAG_NAME, String.valueOf(response.failureResponse.code),
STATUS_CODE_TAG_NAME, String.valueOf(response.failureResponse.status))
.and(tags))
.increment();
} else {
logger.warn("Failed to send verification succeeded", throwable);
}
return false;
}
public static class TwilioVerifyResponse {

View File

@@ -19,7 +19,6 @@ public class AccountDatabaseCrawlerCache {
public static final String GENERAL_PURPOSE_PREFIX = "";
public static final String DIRECTORY_RECONCILER_PREFIX = "directory-reconciler";
public static final String ACCOUNT_CLEANER_PREFIX = "account-cleaner";
public static final String USERNAME_CLEANER_PREFIX = "username-cleaner";
private static final String ACTIVE_WORKER_KEY = "account_database_crawler_cache_active_worker";
private static final String LAST_UUID_KEY = "account_database_crawler_cache_last_uuid";

View File

@@ -9,24 +9,19 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
@@ -35,7 +30,6 @@ import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicUakMigrationConfiguration;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
@@ -47,7 +41,6 @@ import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedExce
import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughputExceededException;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
@@ -78,7 +71,6 @@ public class Accounts extends AbstractDynamoDbStore {
// unidentified access key; byte[] or null
static final String ATTR_UAK = "UAK";
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final DynamoDbClient client;
private final DynamoDbAsyncClient asyncClient;
@@ -101,11 +93,6 @@ public class Accounts extends AbstractDynamoDbStore {
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(Accounts.class, "getAllFrom"));
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(Accounts.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(Accounts.class, "delete"));
private static final Timer NORMALIZE_ITEM_TIMER = Metrics.timer(name(Accounts.class, "normalizeItem"));
private static final Counter UAK_NORMALIZE_SUCCESS_COUNT = Metrics.counter(name(Accounts.class, "normalizeUakSuccess"));
private static final String UAK_NORMALIZE_ERROR_NAME = name(Accounts.class, "normalizeUakError");
private static final String UAK_NORMALIZE_FAILURE_REASON_TAG_NAME = "reason";
private static final Logger log = LoggerFactory.getLogger(Accounts.class);
@@ -116,7 +103,6 @@ public class Accounts extends AbstractDynamoDbStore {
final int scanPageSize) {
super(client);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.client = client;
this.asyncClient = asyncClient;
this.phoneNumberConstraintTableName = phoneNumberConstraintTableName;
@@ -345,8 +331,15 @@ public class Accounts extends AbstractDynamoDbStore {
});
}
/**
* Set the account username
*
* @param account to update
* @param username believed to be available
* @throws ContestedOptimisticLockException if the account has been updated or the username taken by someone else
*/
public void setUsername(final Account account, final String username)
throws ContestedOptimisticLockException, UsernameNotAvailableException {
throws ContestedOptimisticLockException {
final long startNanos = System.nanoTime();
final Optional<String> maybeOriginalUsername = account.getUsername();
@@ -405,18 +398,14 @@ public class Accounts extends AbstractDynamoDbStore {
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
} catch (final TransactionCanceledException e) {
if ("ConditionalCheckFailed".equals(e.cancellationReasons().get(0).code())) {
throw new UsernameNotAvailableException();
} else if ("ConditionalCheckFailed".equals(e.cancellationReasons().get(1).code())) {
if (e.cancellationReasons().stream().map(CancellationReason::code).anyMatch("ConditionalCheckFailed"::equals)) {
throw new ContestedOptimisticLockException();
}
throw e;
} finally {
if (!succeeded) {
account.setUsername(maybeOriginalUsername.orElse(null));
}
SET_USERNAME_TIMER.record(System.nanoTime() - startNanos, TimeUnit.NANOSECONDS);
}
}
@@ -563,6 +552,14 @@ public class Accounts extends AbstractDynamoDbStore {
}
}
public boolean usernameAvailable(final String username) {
final GetItemResponse response = client.getItem(GetItemRequest.builder()
.tableName(usernamesConstraintTableName)
.key(Map.of(ATTR_USERNAME, AttributeValues.fromString(username)))
.build());
return !response.hasItem();
}
public Optional<Account> getByE164(String number) {
return GET_BY_NUMBER_TIMER.record(() -> {
@@ -686,82 +683,10 @@ public class Accounts extends AbstractDynamoDbStore {
return toRecord.get().whenComplete((ignoreT, ignoreE) -> timer.record(Duration.between(start, Instant.now())));
}
private List<Account> normalizeIfRequired(final List<Map<String, AttributeValue>> items) {
// The UAK top-level attribute may not exist on older records,
// if it is absent and there is a UAK in the account blob we'll
// add the UAK as a top-level attribute
// TODO: Can eliminate this once all uaks exist as top-level attributes
final List<Account> allAccounts = new ArrayList<>();
final List<Account> accountsToNormalize = new ArrayList<>();
for (Map<String, AttributeValue> item : items) {
final Account account = fromItem(item);
allAccounts.add(account);
boolean hasAttrUak = item.containsKey(ATTR_UAK);
if (!hasAttrUak && account.getUnidentifiedAccessKey().isPresent()) {
// the top level uak attribute doesn't exist, but there's a uak in the account
accountsToNormalize.add(account);
} else if (hasAttrUak && account.getUnidentifiedAccessKey().isPresent()) {
final AttributeValue attr = item.get(ATTR_UAK);
final byte[] nestedUak = account.getUnidentifiedAccessKey().get();
if (!Arrays.equals(attr.b().asByteArray(), nestedUak)) {
log.warn("Discovered mismatch between attribute UAK data UAK, normalizing");
accountsToNormalize.add(account);
}
}
}
final DynamicUakMigrationConfiguration currentConfig = this.dynamicConfigurationManager.getConfiguration().getUakMigrationConfiguration();
if (!currentConfig.isEnabled()) {
log.debug("Account normalization is disabled, skipping normalization for {} accounts", accountsToNormalize.size());
return allAccounts;
}
for (List<Account> accounts : Lists.partition(accountsToNormalize, currentConfig.getMaxOutstandingNormalizes())) {
try {
final CompletableFuture<?>[] accountFutures = accounts.stream()
.map(account -> record(NORMALIZE_ITEM_TIMER,
() -> this.updateAsync(account).whenComplete((result, throwable) -> {
if (throwable == null) {
UAK_NORMALIZE_SUCCESS_COUNT.increment();
return;
}
throwable = unwrap(throwable);
if (throwable instanceof ContestedOptimisticLockException) {
// Could succeed on retry, but just backoff since this is a housekeeping operation
Metrics.counter(UAK_NORMALIZE_ERROR_NAME,
Tags.of(UAK_NORMALIZE_FAILURE_REASON_TAG_NAME, "ContestedOptimisticLock")).increment();
} else if (throwable instanceof ProvisionedThroughputExceededException) {
Metrics.counter(UAK_NORMALIZE_ERROR_NAME,
Tags.of(UAK_NORMALIZE_FAILURE_REASON_TAG_NAME, "ProvisionedThroughPutExceeded"))
.increment();
} else {
log.warn("Failed to normalize account, skipping", throwable);
Metrics.counter(UAK_NORMALIZE_ERROR_NAME,
Tags.of(UAK_NORMALIZE_FAILURE_REASON_TAG_NAME, "unknown"))
.increment();
}
})).toCompletableFuture()).toArray(CompletableFuture[]::new);
// wait for a futures in batch to complete
CompletableFuture
.allOf(accountFutures)
// exceptions handled in individual futures
.exceptionally(e -> null)
.join();
} catch (Exception e) {
log.warn("Failed to update batch of {} accounts, skipping", accounts.size(), e);
}
}
return allAccounts;
}
private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) {
scanRequestBuilder.tableName(accountsTableName);
final List<Map<String, AttributeValue>> items = timer.record(() -> scan(scanRequestBuilder.build(), maxCount));
final List<Account> accounts = normalizeIfRequired(items);
final List<Account> accounts = items.stream().map(Accounts::fromItem).toList();
return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null);
}

View File

@@ -12,6 +12,8 @@ import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.micrometer.core.instrument.Metrics;
@@ -37,6 +39,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
@@ -46,7 +49,7 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameValidator;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import org.whispersystems.textsecuregcm.util.Util;
public class AccountsManager {
@@ -71,6 +74,9 @@ public class AccountsManager {
private static final String COUNTRY_CODE_TAG_NAME = "country";
private static final String DELETION_REASON_TAG_NAME = "reason";
@VisibleForTesting
public static final String USERNAME_EXPERIMENT_NAME = "usernames";
private final Logger logger = LoggerFactory.getLogger(AccountsManager.class);
private final Accounts accounts;
@@ -86,7 +92,9 @@ public class AccountsManager {
private final SecureStorageClient secureStorageClient;
private final SecureBackupClient secureBackupClient;
private final ClientPresenceManager clientPresenceManager;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final Clock clock;
private final UsernameGenerator usernameGenerator;
private static final ObjectMapper mapper = SystemMapper.getMapper();
@@ -126,6 +134,8 @@ public class AccountsManager {
final SecureStorageClient secureStorageClient,
final SecureBackupClient secureBackupClient,
final ClientPresenceManager clientPresenceManager,
final UsernameGenerator usernameGenerator,
final ExperimentEnrollmentManager experimentEnrollmentManager,
final Clock clock) {
this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers;
@@ -140,6 +150,8 @@ public class AccountsManager {
this.secureBackupClient = secureBackupClient;
this.clientPresenceManager = clientPresenceManager;
this.reservedUsernames = reservedUsernames;
this.usernameGenerator = usernameGenerator;
this.experimentEnrollmentManager = experimentEnrollmentManager;
this.clock = Objects.requireNonNull(clock);
}
@@ -276,32 +288,27 @@ public class AccountsManager {
final Account numberChangedAccount;
try {
numberChangedAccount = updateWithRetries(
account,
a -> {
//noinspection ConstantConditions
if (pniSignedPreKeys != null && pniRegistrationIds != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
numberChangedAccount = updateWithRetries(
account,
a -> {
//noinspection ConstantConditions
if (pniSignedPreKeys != null && pniRegistrationIds != null) {
pniSignedPreKeys.forEach((deviceId, signedPreKey) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentitySignedPreKey(signedPreKey)));
pniRegistrationIds.forEach((deviceId, registrationId) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
}
pniRegistrationIds.forEach((deviceId, registrationId) ->
a.getDevice(deviceId).ifPresent(device -> device.setPhoneNumberIdentityRegistrationId(registrationId)));
}
if (pniIdentityKey != null) {
a.setPhoneNumberIdentityKey(pniIdentityKey);
}
if (pniIdentityKey != null) {
a.setPhoneNumberIdentityKey(pniIdentityKey);
}
return true;
},
a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
} catch (UsernameNotAvailableException e) {
// This should never happen when changing numbers
throw new RuntimeException(e);
}
return true;
},
a -> accounts.changeNumber(a, number, phoneNumberIdentifier),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
updatedAccount.set(numberChangedAccount);
directoryQueue.changePhoneNumber(numberChangedAccount, originalNumber, number);
@@ -315,23 +322,31 @@ public class AccountsManager {
return updatedAccount.get();
}
public Account setUsername(final Account account, final String username) throws UsernameNotAvailableException {
final String canonicalUsername = UsernameValidator.getCanonicalUsername(username);
if (account.getUsername().map(canonicalUsername::equals).orElse(false)) {
return account;
public Account setUsername(final Account account, final String requestedNickname, final @Nullable String expectedOldUsername) throws UsernameNotAvailableException {
if (!experimentEnrollmentManager.isEnrolled(account.getUuid(), USERNAME_EXPERIMENT_NAME)) {
throw new UsernameNotAvailableException();
}
if (reservedUsernames.isReserved(canonicalUsername, account.getUuid())) {
if (reservedUsernames.isReserved(requestedNickname, account.getUuid())) {
throw new UsernameNotAvailableException();
}
final Optional<String> currentUsername = account.getUsername();
final Optional<String> currentNickname = currentUsername.map(UsernameGenerator::extractNickname);
if (currentNickname.map(requestedNickname::equals).orElse(false) && !Objects.equals(expectedOldUsername, currentUsername.orElse(null))) {
// The requested nickname matches what the server already has, and the
// client provided the wrong existing username. Treat this as a replayed
// request, assuming that the client has previously succeeded
return account;
}
redisDelete(account);
return updateWithRetries(
return failableUpdateWithRetries(
account,
a -> true,
a -> accounts.setUsername(a, canonicalUsername),
// In the future, this may also check for any forbidden discriminators
a -> accounts.setUsername(a, usernameGenerator.generateAvailableUsername(requestedNickname, accounts::usernameAvailable)),
() -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(),
AccountChangeValidator.USERNAME_CHANGE_VALIDATOR);
}
@@ -339,31 +354,20 @@ public class AccountsManager {
public Account clearUsername(final Account account) {
redisDelete(account);
try {
return updateWithRetries(
account,
a -> true,
accounts::clearUsername,
() -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(),
AccountChangeValidator.USERNAME_CHANGE_VALIDATOR);
} catch (UsernameNotAvailableException e) {
// This should never happen
throw new RuntimeException(e);
}
return updateWithRetries(
account,
a -> true,
accounts::clearUsername,
() -> accounts.getByAccountIdentifier(account.getUuid()).orElseThrow(),
AccountChangeValidator.USERNAME_CHANGE_VALIDATOR);
}
public Account update(Account account, Consumer<Account> updater) {
try {
return update(account, a -> {
updater.accept(a);
// assume that all updaters passed to the public method actually modify the account
return true;
});
} catch (UsernameNotAvailableException e) {
// This should never happen for general-purpose, public account updates
throw new RuntimeException(e);
}
return update(account, a -> {
updater.accept(a);
// assume that all updaters passed to the public method actually modify the account
return true;
});
}
/**
@@ -371,34 +375,38 @@ public class AccountsManager {
* redundant updates of {@code device.lastSeen}
*/
public Account updateDeviceLastSeen(Account account, Device device, final long lastSeen) {
return update(account, a -> {
try {
return update(account, a -> {
final Optional<Device> maybeDevice = a.getDevice(device.getId());
final Optional<Device> maybeDevice = a.getDevice(device.getId());
return maybeDevice.map(d -> {
if (d.getLastSeen() >= lastSeen) {
return false;
}
return maybeDevice.map(d -> {
if (d.getLastSeen() >= lastSeen) {
return false;
}
d.setLastSeen(lastSeen);
d.setLastSeen(lastSeen);
return true;
return true;
}).orElse(false);
});
}
}).orElse(false);
});
} catch (UsernameNotAvailableException e) {
// This should never happen when updating last-seen timestamps
throw new RuntimeException(e);
}
public Account updateDeviceAuthentication(final Account account, final Device device, final AuthenticationCredentials credentials) {
Preconditions.checkArgument(credentials.getVersion() == AuthenticationCredentials.CURRENT_VERSION);
return updateDevice(account, device.getId(), new Consumer<Device>() {
@Override
public void accept(final Device device) {
device.setAuthenticationCredentials(credentials);
}
});
}
/**
* @param account account to update
* @param updater must return {@code true} if the account was actually updated
*/
private Account update(Account account, Function<Account, Boolean> updater) throws UsernameNotAvailableException {
private Account update(Account account, Function<Account, Boolean> updater) {
final boolean wasVisibleBeforeUpdate = account.shouldBeVisibleInDirectory();
@@ -429,6 +437,19 @@ public class AccountsManager {
}
private Account updateWithRetries(Account account,
final Function<Account, Boolean> updater,
final Consumer<Account> persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) {
try {
return failableUpdateWithRetries(account, updater, persister::accept, retriever, changeValidator);
} catch (UsernameNotAvailableException e) {
// not possible
throw new IllegalStateException(e);
}
}
private Account failableUpdateWithRetries(Account account,
final Function<Account, Boolean> updater,
final AccountPersister persister,
final Supplier<Account> retriever,
@@ -482,16 +503,11 @@ public class AccountsManager {
}
public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
try {
return update(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
// assume that all updaters passed to the public method actually modify the device
return true;
});
} catch (UsernameNotAvailableException e) {
// This should never happen when updating devices
throw new RuntimeException(e);
}
return update(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
// assume that all updaters passed to the public method actually modify the device
return true;
});
}
public Optional<Account> getByE164(String number) {
@@ -522,12 +538,10 @@ public class AccountsManager {
public Optional<Account> getByUsername(final String username) {
try (final Timer.Context ignored = getByUsernameTimer.time()) {
final String canonicalUsername = UsernameValidator.getCanonicalUsername(username);
Optional<Account> account = redisGetByUsername(canonicalUsername);
Optional<Account> account = redisGetByUsername(username);
if (account.isEmpty()) {
account = accounts.getByUsername(canonicalUsername);
account = accounts.getByUsername(username);
account.ifPresent(this::redisSet);
}
@@ -604,6 +618,10 @@ public class AccountsManager {
clientPresenceManager.disconnectPresence(account.getUuid(), device.getId())));
}
private String getUsernameAccountMapKey(String key) {
return "UAccountMap::" + key;
}
private String getAccountMapKey(String key) {
return "AccountMap::" + key;
}
@@ -624,7 +642,7 @@ public class AccountsManager {
commands.setex(getAccountEntityKey(account.getUuid()), CACHE_TTL_SECONDS, accountJson);
account.getUsername().ifPresent(username ->
commands.setex(getAccountMapKey(username), CACHE_TTL_SECONDS, account.getUuid().toString()));
commands.setex(getUsernameAccountMapKey(username), CACHE_TTL_SECONDS, account.getUuid().toString()));
});
} catch (JsonProcessingException e) {
throw new IllegalStateException(e);
@@ -632,20 +650,20 @@ public class AccountsManager {
}
private Optional<Account> redisGetByPhoneNumberIdentifier(UUID uuid) {
return redisGetBySecondaryKey(uuid.toString(), redisPniGetTimer);
return redisGetBySecondaryKey(getAccountMapKey(uuid.toString()), redisPniGetTimer);
}
private Optional<Account> redisGetByE164(String e164) {
return redisGetBySecondaryKey(e164, redisNumberGetTimer);
return redisGetBySecondaryKey(getAccountMapKey(e164), redisNumberGetTimer);
}
private Optional<Account> redisGetByUsername(String username) {
return redisGetBySecondaryKey(username, redisUsernameGetTimer);
return redisGetBySecondaryKey(getUsernameAccountMapKey(username), redisUsernameGetTimer);
}
private Optional<Account> redisGetBySecondaryKey(String secondaryKey, Timer timer) {
try (Timer.Context ignored = timer.time()) {
final String uuid = cacheCluster.withCluster(connection -> connection.sync().get(getAccountMapKey(secondaryKey)));
final String uuid = cacheCluster.withCluster(connection -> connection.sync().get(secondaryKey));
if (uuid != null) return redisGetByAccountIdentifier(UUID.fromString(uuid));
else return Optional.empty();
@@ -691,7 +709,7 @@ public class AccountsManager {
getAccountMapKey(account.getPhoneNumberIdentifier().toString()),
getAccountEntityKey(account.getUuid()));
account.getUsername().ifPresent(username -> connection.sync().del(getAccountMapKey(username)));
account.getUsername().ifPresent(username -> connection.sync().del(getUsernameAccountMapKey(username)));
});
}
}

View File

@@ -110,7 +110,9 @@ public class ChangeNumberManager {
.setSourceUuid(sourceAndDestinationAccount.getUuid().toString())
.setSourceDevice((int) Device.MASTER_ID)
.setUpdatedPni(sourceAndDestinationAccount.getPhoneNumberIdentifier().toString())
.setUrgent(true)
.build();
messageSender.sendMessage(sourceAndDestinationAccount, destinationDevice.get(), envelope, false);
} catch (NotPushRegisteredException e) {
logger.debug("Not registered", e);

View File

@@ -1,44 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.SharedMetricRegistries;
import org.jdbi.v3.core.Jdbi;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.util.CircuitBreakerUtil;
import org.whispersystems.textsecuregcm.util.Constants;
import java.util.function.Consumer;
import java.util.function.Function;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
public class FaultTolerantDatabase {
private final Jdbi database;
private final CircuitBreaker circuitBreaker;
public FaultTolerantDatabase(String name, Jdbi database, CircuitBreakerConfiguration circuitBreakerConfiguration) {
this.database = database;
this.circuitBreaker = CircuitBreaker.of(name, circuitBreakerConfiguration.toCircuitBreakerConfig());
CircuitBreakerUtil.registerMetrics(SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME),
circuitBreaker,
FaultTolerantDatabase.class);
}
public void use(Consumer<Jdbi> consumer) {
this.circuitBreaker.executeRunnable(() -> consumer.accept(database));
}
public <T> T with(Function<Jdbi, T> consumer) {
return this.circuitBreaker.executeSupplier(() -> consumer.apply(database));
}
public Jdbi getDatabase() {
return database;
}
}

View File

@@ -11,8 +11,6 @@ import static io.micrometer.core.instrument.Metrics.timer;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.InvalidProtocolBufferException;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.nio.ByteBuffer;
import java.time.Duration;
@@ -26,9 +24,7 @@ import javax.annotation.Nonnull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@@ -51,26 +47,6 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
private static final String KEY_TTL = "E";
private static final String KEY_ENVELOPE_BYTES = "EB";
// TODO Stop reading messages by attribute value after DATE
@Deprecated
private static final String KEY_TYPE = "T";
@Deprecated
private static final String KEY_TIMESTAMP = "TS";
private static final String KEY_SOURCE_UUID = "SU";
@Deprecated
private static final String KEY_SOURCE_DEVICE = "SD";
@Deprecated
private static final String KEY_DESTINATION_UUID = "DU";
@Deprecated
private static final String KEY_UPDATED_PNI = "UP";
@Deprecated
private static final String KEY_CONTENT = "C";
private final Timer storeTimer = timer(name(getClass(), "store"));
private final Timer loadTimer = timer(name(getClass(), "load"));
private final Timer deleteByGuid = timer(name(getClass(), "delete", "guid"));
@@ -81,9 +57,6 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
private final String tableName;
private final Duration timeToLive;
private static final Counter GET_MESSAGE_WITH_ATTRIBUTES_COUNTER = Metrics.counter(name(MessagesDynamoDb.class, "loadMessage"), "format", "attributes");
private static final Counter GET_MESSAGE_WITH_ENVELOPE_COUNTER = Metrics.counter(name(MessagesDynamoDb.class, "loadMessage"), "format", "envelope");
private static final Logger logger = LoggerFactory.getLogger(MessagesDynamoDb.class);
public MessagesDynamoDb(DynamoDbClient dynamoDb, String tableName, Duration timeToLive) {
@@ -260,30 +233,8 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
@VisibleForTesting
static MessageProtos.Envelope convertItemToEnvelope(final Map<String, AttributeValue> item)
throws InvalidProtocolBufferException {
final MessageProtos.Envelope envelope;
if (item.containsKey(KEY_ENVELOPE_BYTES)) {
envelope = MessageProtos.Envelope.parseFrom(item.get(KEY_ENVELOPE_BYTES).b().asByteArray());
GET_MESSAGE_WITH_ENVELOPE_COUNTER.increment();
} else {
final SortKey sortKey = convertSortKey(item.get(KEY_SORT).b().asByteArray());
final UUID messageUuid = convertLocalIndexMessageUuidSortKey(item.get(LOCAL_INDEX_MESSAGE_UUID_KEY_SORT).b().asByteArray());
final int type = AttributeValues.getInt(item, KEY_TYPE, 0);
final long timestamp = AttributeValues.getLong(item, KEY_TIMESTAMP, 0L);
final UUID sourceUuid = AttributeValues.getUUID(item, KEY_SOURCE_UUID, null);
final int sourceDevice = AttributeValues.getInt(item, KEY_SOURCE_DEVICE, 0);
final UUID destinationUuid = AttributeValues.getUUID(item, KEY_DESTINATION_UUID, null);
final byte[] content = AttributeValues.getByteArray(item, KEY_CONTENT, null);
final UUID updatedPni = AttributeValues.getUUID(item, KEY_UPDATED_PNI, null);
envelope = new OutgoingMessageEntity(messageUuid, type, timestamp, sourceUuid, sourceDevice, destinationUuid,
updatedPni, content, sortKey.getServerTimestamp(), true).toEnvelope();
GET_MESSAGE_WITH_ATTRIBUTES_COUNTER.increment();
}
return envelope;
return MessageProtos.Envelope.parseFrom(item.get(KEY_ENVELOPE_BYTES).b().asByteArray());
}
private void deleteRowsMatchingQuery(AttributeValue partitionKey, QueryRequest querySpec) {
@@ -324,56 +275,7 @@ public class MessagesDynamoDb extends AbstractDynamoDbStore {
return AttributeValues.fromByteBuffer(byteBuffer.flip());
}
private static SortKey convertSortKey(final byte[] bytes) {
if (bytes.length != 32) {
throw new IllegalArgumentException("unexpected sort key byte length");
}
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
final long destinationDeviceId = byteBuffer.getLong();
final long serverTimestamp = byteBuffer.getLong();
final long mostSigBits = byteBuffer.getLong();
final long leastSigBits = byteBuffer.getLong();
return new SortKey(destinationDeviceId, serverTimestamp, new UUID(mostSigBits, leastSigBits));
}
private static AttributeValue convertLocalIndexMessageUuidSortKey(final UUID messageUuid) {
return AttributeValues.fromUUID(messageUuid);
}
private static UUID convertLocalIndexMessageUuidSortKey(final byte[] bytes) {
return convertUuidFromBytes(bytes, "local index message uuid sort key");
}
private static UUID convertUuidFromBytes(final byte[] bytes, final String name) {
try {
return UUIDUtil.fromBytes(bytes);
} catch (final IllegalArgumentException e) {
throw new IllegalArgumentException("unexpected " + name + " byte length; was " + bytes.length + " but expected 16");
}
}
private static final class SortKey {
private final long destinationDeviceId;
private final long serverTimestamp;
private final UUID messageUuid;
public SortKey(long destinationDeviceId, long serverTimestamp, UUID messageUuid) {
this.destinationDeviceId = destinationDeviceId;
this.serverTimestamp = serverTimestamp;
this.messageUuid = messageUuid;
}
public long getDestinationDeviceId() {
return destinationDeviceId;
}
public long getServerTimestamp() {
return serverTimestamp;
}
public UUID getMessageUuid() {
return messageUuid;
}
}
}

View File

@@ -15,8 +15,6 @@ import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
@@ -32,17 +30,14 @@ public class MessagesManager {
private final MessagesDynamoDb messagesDynamoDb;
private final MessagesCache messagesCache;
private final PushLatencyManager pushLatencyManager;
private final ReportMessageManager reportMessageManager;
public MessagesManager(
final MessagesDynamoDb messagesDynamoDb,
final MessagesCache messagesCache,
final PushLatencyManager pushLatencyManager,
final ReportMessageManager reportMessageManager) {
this.messagesDynamoDb = messagesDynamoDb;
this.messagesCache = messagesCache;
this.pushLatencyManager = pushLatencyManager;
this.reportMessageManager = reportMessageManager;
}
@@ -60,9 +55,7 @@ public class MessagesManager {
return messagesCache.hasMessages(destinationUuid, destinationDevice);
}
public Pair<List<Envelope>, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) {
RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent));
public Pair<List<Envelope>, Boolean> getMessagesForDevice(UUID destinationUuid, long destinationDevice, final boolean cachedMessagesOnly) {
List<Envelope> messageList = new ArrayList<>();
if (!cachedMessagesOnly) {

View File

@@ -5,11 +5,11 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import java.time.Clock;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
@@ -29,7 +29,8 @@ public class PushChallengeDynamoDb extends AbstractDynamoDbStore {
static final String ATTR_TTL = "T";
private static final Map<String, String> UUID_NAME_MAP = Map.of("#uuid", KEY_ACCOUNT_UUID);
private static final Map<String, String> CHALLENGE_TOKEN_NAME_MAP = Map.of("#challenge", ATTR_CHALLENGE_TOKEN);
private static final Map<String, String> CHALLENGE_TOKEN_NAME_MAP = Map.of("#challenge", ATTR_CHALLENGE_TOKEN, "#ttl",
ATTR_TTL);
public PushChallengeDynamoDb(final DynamoDbClient dynamoDB, final String tableName) {
this(dynamoDB, tableName, Clock.systemUTC());
@@ -87,9 +88,10 @@ public class PushChallengeDynamoDb extends AbstractDynamoDbStore {
db().deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountUuid)))
.conditionExpression("#challenge = :challenge")
.conditionExpression("#challenge = :challenge AND #ttl >= :currentTime")
.expressionAttributeNames(CHALLENGE_TOKEN_NAME_MAP)
.expressionAttributeValues(Map.of(":challenge", AttributeValues.fromByteArray(challengeToken)))
.expressionAttributeValues(Map.of(":challenge", AttributeValues.fromByteArray(challengeToken),
":currentTime", AttributeValues.fromLong(clock.instant().getEpochSecond())))
.build());
return true;
} catch (final ConditionalCheckFailedException e) {

View File

@@ -1,5 +1,7 @@
package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
@@ -10,6 +12,8 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class ReportMessageDynamoDb {
static final String KEY_HASH = "H";
@@ -19,6 +23,9 @@ public class ReportMessageDynamoDb {
private final String tableName;
private final Duration ttl;
private static final String REMOVED_MESSAGE_COUNTER_NAME = name(ReportMessageDynamoDb.class, "removed");
private static final Timer REMOVED_MESSAGE_AGE_TIMER = Metrics.timer(name(ReportMessageDynamoDb.class, "removedMessageAge"));
public ReportMessageDynamoDb(final DynamoDbClient dynamoDB, final String tableName, final Duration ttl) {
this.db = dynamoDB;
this.tableName = tableName;
@@ -41,6 +48,22 @@ public class ReportMessageDynamoDb {
.key(Map.of(KEY_HASH, AttributeValues.fromByteArray(hash)))
.returnValues(ReturnValue.ALL_OLD)
.build());
return !deleteItemResponse.attributes().isEmpty();
final boolean found = !deleteItemResponse.attributes().isEmpty();
if (found) {
if (deleteItemResponse.attributes().containsKey(ATTR_TTL)) {
final Instant expiration =
Instant.ofEpochSecond(Long.parseLong(deleteItemResponse.attributes().get(ATTR_TTL).n()));
final Duration approximateAge = ttl.minus(Duration.between(Instant.now(), expiration));
REMOVED_MESSAGE_AGE_TIMER.record(approximateAge);
}
}
Metrics.counter(REMOVED_MESSAGE_COUNTER_NAME, "found", String.valueOf(found)).increment();
return found;
}
}

View File

@@ -53,7 +53,7 @@ public class ReservedUsernames {
this.tableName = tableName;
}
public boolean isReserved(final String username, final UUID accountIdentifier) {
public boolean isReserved(final String nickname, final UUID accountIdentifier) {
return IS_RESERVED_TIMER.record(() -> {
final ScanIterable scanIterable = dynamoDbClient.scanPaginator(ScanRequest.builder()
.tableName(tableName)
@@ -66,7 +66,7 @@ public class ReservedUsernames {
final Pattern pattern = patternCache.get(item.get(KEY_PATTERN).s());
final UUID reservedFor = AttributeValues.getUUID(item, ATTR_RESERVED_FOR_UUID, null);
if (pattern.matcher(username).matches() && !accountIdentifier.equals(reservedFor)) {
if (pattern.matcher(nickname).matches() && !accountIdentifier.equals(reservedFor)) {
return true;
}
} catch (final Exception e) {

View File

@@ -1,52 +0,0 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class UsernameCleaner extends AccountDatabaseCrawlerListener {
private static final String DELETED_USERNAME_COUNTER = name(UsernameCleaner.class, "deletedUsernames");
private static final Logger logger = LoggerFactory.getLogger(UsernameCleaner.class);
private final AccountsManager accountsManager;
public UsernameCleaner(AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void onCrawlStart() {
}
@Override
protected void onCrawlChunk(final Optional<UUID> fromUuid, final List<Account> chunkAccounts) {
for (Account account : chunkAccounts) {
if (account.getUsername().isPresent()) {
logger.info("Deleting username present for account {}", account.getUuid());
try {
this.accountsManager.clearUsername(account);
Metrics.counter(DELETED_USERNAME_COUNTER, Tags.of("outcome", "success")).increment();
} catch (Exception e) {
logger.warn("Failed to clear username on account {}", account.getUuid(), e);
Metrics.counter(DELETED_USERNAME_COUNTER, Tags.of("outcome", "error")).increment();
}
}
}
}
@Override
public void onCrawlEnd(final Optional<UUID> fromUuid) {
logger.info("Username cleaner crawl completed");
}
}

View File

@@ -5,10 +5,15 @@
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
@@ -19,11 +24,6 @@ import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name;
public class VerificationCodeStore {
@@ -83,7 +83,8 @@ public class VerificationCodeStore {
try {
return response.hasItem()
? Optional.of(SystemMapper.getMapper().readValue(response.item().get(ATTR_STORED_CODE).s(), StoredVerificationCode.class))
? filterMaybeExpiredCode(
SystemMapper.getMapper().readValue(response.item().get(ATTR_STORED_CODE).s(), StoredVerificationCode.class))
: Optional.empty();
} catch (final JsonProcessingException e) {
log.error("Failed to parse stored verification code", e);
@@ -92,6 +93,16 @@ public class VerificationCodeStore {
});
}
private Optional<StoredVerificationCode> filterMaybeExpiredCode(StoredVerificationCode storedVerificationCode) {
// It's possible for DynamoDB to return items after their expiration time (although it is very unlikely for small
// tables)
if (getExpirationTimestamp(storedVerificationCode) < Instant.now().getEpochSecond()) {
return Optional.empty();
}
return Optional.of(storedVerificationCode);
}
public void remove(final String number) {
removeTimer.record(() -> {
dynamoDbClient.deleteItem(DeleteItemRequest.builder()

View File

@@ -5,21 +5,21 @@
package org.whispersystems.textsecuregcm.util;
import javax.validation.Constraint;
import javax.validation.Payload;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import javax.validation.Constraint;
import javax.validation.Payload;
@Target({ FIELD, PARAMETER })
@Retention(RUNTIME)
@Constraint(validatedBy = UsernameValidator.class)
public @interface Username {
@Constraint(validatedBy = NicknameValidator.class)
public @interface Nickname {
String message() default "{org.whispersystems.textsecuregcm.util.Username.message}";
String message() default "{org.whispersystems.textsecuregcm.util.Nickname.message}";
Class<?>[] groups() default { };

View File

@@ -0,0 +1,17 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
public class NicknameValidator implements ConstraintValidator<Nickname, String> {
@Override
public boolean isValid(final String nickname, final ConstraintValidatorContext context) {
return UsernameGenerator.isValidNickname(nickname);
}
}

View File

@@ -0,0 +1,143 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.math.IntMath;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.configuration.UsernameConfiguration;
import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class UsernameGenerator {
/**
* Nicknames are
* <list>
* <li> lowercase </li>
* <li> do not start with a number </li>
* <li> alphanumeric or underscores only </li>
* <li> minimum length 3 </li>
* <li> maximum length 32 </li>
* </list>
*
* Usernames typically consist of a nickname and an integer discriminator
*/
public static final Pattern NICKNAME_PATTERN = Pattern.compile("^[_a-z][_a-z0-9]{2,31}$");
public static final String SEPARATOR = "#";
private static final Counter USERNAME_NOT_AVAILABLE_COUNTER = Metrics.counter(name(UsernameGenerator.class, "usernameNotAvailable"));
private static final DistributionSummary DISCRIMINATOR_ATTEMPT_COUNTER = Metrics.summary(name(UsernameGenerator.class, "discriminatorAttempts"));
private final int initialWidth;
private final int discriminatorMaxWidth;
private final int attemptsPerWidth;
public UsernameGenerator(UsernameConfiguration configuration) {
this(configuration.getDiscriminatorInitialWidth(), configuration.getDiscriminatorMaxWidth(), configuration.getAttemptsPerWidth());
}
@VisibleForTesting
public UsernameGenerator(int initialWidth, int discriminatorMaxWidth, int attemptsPerWidth) {
this.initialWidth = initialWidth;
this.discriminatorMaxWidth = discriminatorMaxWidth;
this.attemptsPerWidth = attemptsPerWidth;
}
/**
* Generate a username with a random discriminator
*
* @param nickname The string nickname
* @param usernameAvailableFun A {@link Predicate} that returns true if the provided username is available
* @return The nickname appended with a random discriminator
* @throws UsernameNotAvailableException if we failed to find a nickname+discriminator pair that was available
*/
public String generateAvailableUsername(final String nickname, final Predicate<String> usernameAvailableFun) throws UsernameNotAvailableException {
int rangeMin = 1;
int rangeMax = IntMath.pow(10, initialWidth);
int totalMax = IntMath.pow(10, discriminatorMaxWidth);
int attempts = 0;
while (rangeMax <= totalMax) {
// check discriminators of the current width up to attemptsPerWidth times
for (int i = 0; i < attemptsPerWidth; i++) {
int discriminator = ThreadLocalRandom.current().nextInt(rangeMin, rangeMax);
String username = fromParts(nickname, discriminator);
attempts++;
if (usernameAvailableFun.test(username)) {
DISCRIMINATOR_ATTEMPT_COUNTER.record(attempts);
return username;
}
}
// update the search range to look for numbers of one more digit
// than the previous iteration
rangeMin = rangeMax;
rangeMax *= 10;
}
USERNAME_NOT_AVAILABLE_COUNTER.increment();
throw new UsernameNotAvailableException();
}
/**
* Strips the discriminator from a username, if it is present
*
* @param username the string username
* @return the nickname prefix of the username
*/
public static String extractNickname(final String username) {
int sep = username.indexOf(SEPARATOR);
return sep == -1 ? username : username.substring(0, sep);
}
/**
* Generate a username from a nickname and discriminator
*/
public String fromParts(final String nickname, final int discriminator) throws IllegalArgumentException {
if (!isValidNickname(nickname)) {
throw new IllegalArgumentException("Invalid nickname " + nickname);
}
// zero pad discriminators less than the discriminator initial width
return String.format("%s#%0" + initialWidth + "d", nickname, discriminator);
}
public static boolean isValidNickname(final String nickname) {
return StringUtils.isNotBlank(nickname) && NICKNAME_PATTERN.matcher(nickname).matches();
}
/**
* Checks if the username consists of a valid nickname followed by an integer discriminator
*
* @param username string username to check
* @return true if the username is in standard form
*/
public static boolean isStandardFormat(final String username) {
if (username == null) {
return false;
}
int sep = username.indexOf(SEPARATOR);
if (sep == -1) {
return false;
}
final String nickname = username.substring(0, sep);
if (!isValidNickname(nickname)) {
return false;
}
try {
int discriminator = Integer.parseInt(username.substring(sep + 1));
return discriminator > 0;
} catch (NumberFormatException e) {
return false;
}
}
}

View File

@@ -1,26 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.apache.commons.lang3.StringUtils;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import java.util.regex.Pattern;
public class UsernameValidator implements ConstraintValidator<Username, String> {
private static final Pattern USERNAME_PATTERN =
Pattern.compile("^[a-z_][a-z0-9_]{3,25}$", Pattern.CASE_INSENSITIVE);
@Override
public boolean isValid(final String username, final ConstraintValidatorContext context) {
return StringUtils.isNotBlank(username) && USERNAME_PATTERN.matcher(getCanonicalUsername(username)).matches();
}
public static String getCanonicalUsername(final String username) {
return username != null ? username.toLowerCase() : null;
}
}

View File

@@ -8,12 +8,6 @@ import com.google.i18n.phonenumbers.NumberParseException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import com.google.i18n.phonenumbers.PhoneNumberUtil.PhoneNumberFormat;
import com.google.i18n.phonenumbers.Phonenumber.PhoneNumber;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.time.Clock;
import java.time.Duration;
import java.time.temporal.ChronoField;
@@ -22,12 +16,12 @@ import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Locale.LanguageRange;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.StringUtils;
public class Util {
@@ -35,18 +29,6 @@ public class Util {
private static final PhoneNumberUtil PHONE_NUMBER_UTIL = PhoneNumberUtil.getInstance();
public static byte[] getContactToken(String number) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA1");
byte[] result = digest.digest(number.getBytes());
byte[] truncated = Util.truncate(result, 10);
return truncated;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
/**
* Checks that the given number is a valid, E164-normalized phone number.
*
@@ -61,11 +43,30 @@ public class Util {
}
try {
final PhoneNumber phoneNumber = PHONE_NUMBER_UTIL.parse(number, null);
final String normalizedNumber = PHONE_NUMBER_UTIL.format(phoneNumber, PhoneNumberFormat.E164);
final PhoneNumber inputNumber = PHONE_NUMBER_UTIL.parse(number, null);
if (!number.equals(normalizedNumber)) {
throw new NonNormalizedPhoneNumberException(number, normalizedNumber);
// For normalization, we want to format from a version parsed with the country code removed.
// This handles some cases of "possible", but non-normalized input numbers with a doubled country code, that is
// with the format "+{country code} {country code} {national number}"
final int countryCode = inputNumber.getCountryCode();
final String region = PHONE_NUMBER_UTIL.getRegionCodeForCountryCode(countryCode);
final PhoneNumber normalizedNumber = switch (region) {
// the country code has no associated region. Be lenient (and simple) and accept the input number
case "ZZ", "001" -> inputNumber;
default -> {
final String maybeLeadingZero =
inputNumber.hasItalianLeadingZero() && inputNumber.isItalianLeadingZero() ? "0" : "";
yield PHONE_NUMBER_UTIL.parse(
maybeLeadingZero + inputNumber.getNationalNumber(), region);
}
};
final String normalizedE164 = PHONE_NUMBER_UTIL.format(normalizedNumber,
PhoneNumberFormat.E164);
if (!number.equals(normalizedE164)) {
throw new NonNormalizedPhoneNumberException(number, normalizedE164);
}
} catch (final NumberParseException e) {
throw new ImpossiblePhoneNumberException(e);
@@ -79,6 +80,15 @@ public class Util {
else return "0";
}
public static String getRegion(final String number) {
try {
final PhoneNumber phoneNumber = PHONE_NUMBER_UTIL.parse(number, null);
return StringUtils.defaultIfBlank(PHONE_NUMBER_UTIL.getRegionCodeForNumber(phoneNumber), "ZZ");
} catch (final NumberParseException e) {
return "ZZ";
}
}
public static String getNumberPrefix(String number) {
String countryCode = getCountryCode(number);
int remaining = number.length() - (1 + countryCode.length());
@@ -87,24 +97,6 @@ public class Util {
return number.substring(0, 1 + countryCode.length() + prefixLength);
}
public static String encodeFormParams(Map<String, String> params) {
try {
StringBuffer buffer = new StringBuffer();
for (String key : params.keySet()) {
buffer.append(String.format("%s=%s",
URLEncoder.encode(key, "UTF-8"),
URLEncoder.encode(params.get(key), "UTF-8")));
buffer.append("&");
}
buffer.deleteCharAt(buffer.length()-1);
return buffer.toString();
} catch (UnsupportedEncodingException e) {
throw new AssertionError(e);
}
}
public static boolean isEmpty(String param) {
return param == null || param.length() == 0;
}
@@ -146,20 +138,6 @@ public class Util {
return parts;
}
public static byte[] generateSecretBytes(int size) {
byte[] data = new byte[size];
new SecureRandom().nextBytes(data);
return data;
}
public static byte[] longToByteArray(long value) {
final ByteBuffer longBuffer = ByteBuffer.allocate(Long.BYTES);
longBuffer.putLong(value);
return longBuffer.array();
}
public static int toIntExact(long value) {
if ((int) value != value) {
throw new ArithmeticException("integer overflow");
@@ -171,16 +149,6 @@ public class Util {
return toIntExact(clock.millis() / 1000 / 60/ 60 / 24);
}
/**
* Returns the current number of days since the epoch.
*
* @deprecated use {@link #currentDaysSinceEpoch(Clock)} instead
*/
@Deprecated
public static int currentDaysSinceEpoch() {
return currentDaysSinceEpoch(Clock.systemUTC());
}
public static void sleep(long i) {
try {
Thread.sleep(i);
@@ -207,10 +175,6 @@ public class Util {
return Arrays.hashCode(objects);
}
public static boolean isEquals(Object first, Object second) {
return (first == null && second == null) || (first == second) || (first != null && first.equals(second));
}
public static long todayInMillis() {
return todayInMillis(Clock.systemUTC());
}

View File

@@ -18,7 +18,6 @@ import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@@ -44,21 +43,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager;
private final ApnFallbackManager apnFallbackManager;
private final ClientPresenceManager clientPresenceManager;
private final ScheduledExecutorService scheduledExecutorService;
public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager,
PushNotificationManager pushNotificationManager,
ApnFallbackManager apnFallbackManager,
ClientPresenceManager clientPresenceManager,
ScheduledExecutorService scheduledExecutorService)
{
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
this.apnFallbackManager = apnFallbackManager;
this.clientPresenceManager = clientPresenceManager;
this.scheduledExecutorService = scheduledExecutorService;
}
@@ -75,7 +71,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
scheduledExecutorService);
openWebsocketCounter.inc();
RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), device));
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), device, context.getClient().getUserAgent());
final AtomicReference<ScheduledFuture<?>> renewPresenceFutureReference = new AtomicReference<>();
@@ -100,7 +96,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) {
try {
pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId());
pushNotificationManager.sendNewMessageNotification(auth.getAccount(), device.getId(), true);
} catch (NotPushRegisteredException ignored) {
}
}

View File

@@ -320,8 +320,8 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) {
try {
final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager
.getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
final Pair<List<Envelope>, Boolean> messagesAndHasMore = messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(), device.getId(), cachedMessagesOnly);
final List<Envelope> messages = messagesAndHasMore.first();
final boolean hasMore = messagesAndHasMore.second();

View File

@@ -25,12 +25,14 @@ import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
@@ -50,6 +52,7 @@ import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@@ -68,11 +71,11 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
public void configure(Subparser subparser) {
super.configure(subparser);
subparser.addArgument("-n", "--username")
.dest("username")
subparser.addArgument("-n", "--nickname")
.dest("nickname")
.type(String.class)
.required(true)
.help("The username to assign");
.help("The nickname (without discriminator) to assign");
subparser.addArgument("-a", "--aci")
.dest("aci")
@@ -109,6 +112,9 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
configuration.getAppConfig().getConfigurationName(), DynamicConfiguration.class);
dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
DynamoDbAsyncClient dynamoDbAsyncClient = DynamoDbFromConfig.asyncClient(
configuration.getDynamoDbClientConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
@@ -180,22 +186,25 @@ public class AssignUsernameCommand extends EnvironmentCommand<WhisperServerConfi
configuration.getReportMessageConfiguration().getReportTtl());
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
configuration.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, pushLatencyManager,
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
reportMessageManager);
DeletedAccountsManager deletedAccountsManager = new DeletedAccountsManager(deletedAccounts,
deletedAccountsLockDynamoDbClient,
configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName());
UsernameGenerator usernameGenerator = new UsernameGenerator(configuration.getUsername());
StoredVerificationCodeManager pendingAccountsManager = new StoredVerificationCodeManager(pendingAccounts);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
deletedAccountsManager, directoryQueue, keys, messagesManager, reservedUsernames, profilesManager,
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, Clock.systemUTC());
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, usernameGenerator,
experimentEnrollmentManager, Clock.systemUTC());
final String username = namespace.getString("username");
final String nickname = namespace.getString("nickname");
final UUID accountIdentifier = UUID.fromString(namespace.getString("aci"));
accountsManager.getByAccountIdentifier(accountIdentifier).ifPresentOrElse(account -> {
try {
accountsManager.setUsername(account, username);
final Account result = accountsManager.setUsername(account, nickname, null);
System.out.println("New username: " + result.getUsername());
} catch (final UsernameNotAvailableException e) {
throw new IllegalArgumentException("Username already taken");
}

View File

@@ -27,7 +27,8 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
@@ -53,6 +54,7 @@ import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@@ -112,6 +114,9 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
configuration.getAppConfig().getConfigurationName(), DynamicConfiguration.class);
dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
DynamoDbAsyncClient dynamoDbAsyncClient = DynamoDbFromConfig.asyncClient(
configuration.getDynamoDbClientConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
@@ -183,15 +188,17 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
configuration.getReportMessageConfiguration().getReportTtl());
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
configuration.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, pushLatencyManager,
reportMessageManager);
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
reportMessageManager);
DeletedAccountsManager deletedAccountsManager = new DeletedAccountsManager(deletedAccounts,
deletedAccountsLockDynamoDbClient,
configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName());
StoredVerificationCodeManager pendingAccountsManager = new StoredVerificationCodeManager(pendingAccounts);
UsernameGenerator usernameGenerator = new UsernameGenerator(configuration.getUsername());
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
deletedAccountsManager, directoryQueue, keys, messagesManager, reservedUsernames, profilesManager,
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, clock);
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, usernameGenerator,
experimentEnrollmentManager, clock);
for (String user : users) {
Optional<Account> account = accountsManager.getByE164(user);

View File

@@ -26,7 +26,8 @@ import net.sourceforge.argparse4j.inf.Subparser;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
@@ -51,6 +52,7 @@ import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@@ -115,6 +117,9 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
configuration.getAppConfig().getConfigurationName(), DynamicConfiguration.class);
dynamicConfigurationManager.start();
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
DynamoDbAsyncClient dynamoDbAsyncClient = DynamoDbFromConfig.asyncClient(
configuration.getDynamoDbClientConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
@@ -184,15 +189,17 @@ public class SetUserDiscoverabilityCommand extends EnvironmentCommand<WhisperSer
configuration.getReportMessageConfiguration().getReportTtl());
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, rateLimitersCluster,
configuration.getReportMessageConfiguration().getCounterTtl());
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, pushLatencyManager,
reportMessageManager);
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
reportMessageManager);
DeletedAccountsManager deletedAccountsManager = new DeletedAccountsManager(deletedAccounts,
deletedAccountsLockDynamoDbClient,
configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName());
StoredVerificationCodeManager pendingAccountsManager = new StoredVerificationCodeManager(pendingAccounts);
UsernameGenerator usernameGenerator = new UsernameGenerator(configuration.getUsername());
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
deletedAccountsManager, directoryQueue, keys, messagesManager, reservedUsernames, profilesManager,
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, clock);
pendingAccountsManager, secureStorageClient, secureBackupClient, clientPresenceManager, usernameGenerator,
experimentEnrollmentManager, clock);
Optional<Account> maybeAccount;

View File

@@ -0,0 +1,17 @@
local lastBackgroundNotificationTimestampKey = KEYS[1]
local queueKey = KEYS[2]
local accountDevicePair = ARGV[1]
local currentTimeMillis = tonumber(ARGV[2])
local backgroundNotificationPeriod = tonumber(ARGV[3])
local lastBackgroundNotificationTimestamp = redis.call("GET", lastBackgroundNotificationTimestampKey)
local nextNotificationTimestamp
if (lastBackgroundNotificationTimestamp) then
nextNotificationTimestamp = tonumber(lastBackgroundNotificationTimestamp) + backgroundNotificationPeriod
else
nextNotificationTimestamp = currentTimeMillis
end
redis.call("ZADD", queueKey, "NX", nextNotificationTimestamp, accountDevicePair)

View File

@@ -14,6 +14,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -30,6 +31,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -54,7 +56,9 @@ class BaseAccountAuthenticatorTest {
void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
ExperimentEnrollmentManager enrollmentManager = mock(ExperimentEnrollmentManager.class);
when(enrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock, enrollmentManager);
// We use static UUIDs here because the UUID affects the "date last seen" offset
acct1 = AccountsHelper.generateTestAccount("+14088675309", UUID.fromString("c139cb3e-f70c-4460-b221-815e8bdf778f"), UUID.randomUUID(), List.of(generateTestDevice(yesterday)), null);
@@ -164,6 +168,7 @@ class BaseAccountAuthenticatorTest {
when(device.isEnabled()).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(AuthenticationCredentials.CURRENT_VERSION);
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid.toString(), password), true);
@@ -171,6 +176,7 @@ class BaseAccountAuthenticatorTest {
assertThat(maybeAuthenticatedAccount).isPresent();
assertThat(maybeAuthenticatedAccount.get().getAccount().getUuid()).isEqualTo(uuid);
assertThat(maybeAuthenticatedAccount.get().getAuthenticatedDevice()).isEqualTo(device);
verify(accountsManager, never()).updateDeviceAuthentication(any(), any(), any());;
}
@Test
@@ -192,6 +198,7 @@ class BaseAccountAuthenticatorTest {
when(device.isEnabled()).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(AuthenticationCredentials.CURRENT_VERSION);
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid + "." + deviceId, password), true);
@@ -199,6 +206,7 @@ class BaseAccountAuthenticatorTest {
assertThat(maybeAuthenticatedAccount).isPresent();
assertThat(maybeAuthenticatedAccount.get().getAccount().getUuid()).isEqualTo(uuid);
assertThat(maybeAuthenticatedAccount.get().getAuthenticatedDevice()).isEqualTo(device);
verify(accountsManager, never()).updateDeviceAuthentication(any(), any(), any());
}
@ParameterizedTest
@@ -221,6 +229,7 @@ class BaseAccountAuthenticatorTest {
when(device.isEnabled()).thenReturn(false);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(AuthenticationCredentials.CURRENT_VERSION);
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid.toString(), password), enabledRequired);
@@ -234,6 +243,37 @@ class BaseAccountAuthenticatorTest {
}
}
@Test
void testAuthenticateV1() {
final UUID uuid = UUID.randomUUID();
final long deviceId = 1;
final String password = "12345";
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final AuthenticationCredentials credentials = mock(AuthenticationCredentials.class);
when(clock.instant()).thenReturn(Instant.now());
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(AuthenticationCredentials.Version.V1);
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid.toString(), password), true);
assertThat(maybeAuthenticatedAccount).isPresent();
assertThat(maybeAuthenticatedAccount.get().getAccount().getUuid()).isEqualTo(uuid);
assertThat(maybeAuthenticatedAccount.get().getAuthenticatedDevice()).isEqualTo(device);
verify(accountsManager, times(1)).updateDeviceAuthentication(
any(), // this won't be 'account', because it'll already be updated by updateDeviceLastSeen
eq(device), any());
}
@Test
void testAuthenticateAccountNotFound() {
assertThat(baseAccountAuthenticator.authenticate(new BasicCredentials(UUID.randomUUID().toString(), "password"), true))
@@ -259,6 +299,7 @@ class BaseAccountAuthenticatorTest {
when(device.isEnabled()).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(AuthenticationCredentials.CURRENT_VERSION);
final Optional<AuthenticatedAccount> maybeAuthenticatedAccount =
baseAccountAuthenticator.authenticate(new BasicCredentials(uuid + "." + (deviceId + 1), password), true);
@@ -286,6 +327,7 @@ class BaseAccountAuthenticatorTest {
when(device.isEnabled()).thenReturn(true);
when(device.getAuthenticationCredentials()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(AuthenticationCredentials.CURRENT_VERSION);
final String incorrectPassword = password + "incorrect";

View File

@@ -54,6 +54,9 @@ class DynamicConfigurationTest {
uuidsOnly:
enrolledUuids:
- 71618739-114c-4b1f-bb0d-6478a44eb600
uuids-with-dash:
enrolledUuids:
- 71618739-114c-4b1f-bb0d-6478ffffffff
""");
final DynamicConfiguration config =
@@ -77,6 +80,11 @@ class DynamicConfigurationTest {
assertEquals(0, config.getExperimentEnrollmentConfiguration("uuidsOnly").get().getEnrollmentPercentage());
assertEquals(Set.of(UUID.fromString("71618739-114c-4b1f-bb0d-6478a44eb600")),
config.getExperimentEnrollmentConfiguration("uuidsOnly").get().getEnrolledUuids());
assertTrue(config.getExperimentEnrollmentConfiguration("uuids-with-dash").isPresent());
assertEquals(0, config.getExperimentEnrollmentConfiguration("uuids-with-dash").get().getEnrollmentPercentage());
assertEquals(Set.of(UUID.fromString("71618739-114c-4b1f-bb0d-6478ffffffff")),
config.getExperimentEnrollmentConfiguration("uuids-with-dash").get().getEnrolledUuids());
}
}

View File

@@ -67,8 +67,8 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@@ -107,7 +107,7 @@ class MessageControllerTest {
private static final MessagesManager messagesManager = mock(MessagesManager.class);
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
@@ -119,7 +119,7 @@ class MessageControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager,
messagesManager, apnFallbackManager, reportMessageManager, multiRecipientMessageExecutor))
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor))
.build();
@BeforeEach
@@ -170,7 +170,7 @@ class MessageControllerTest {
messagesManager,
rateLimiters,
rateLimiter,
apnFallbackManager,
pushNotificationManager,
reportMessageManager
);
}
@@ -435,13 +435,16 @@ class MessageControllerTest {
.map(OutgoingMessageEntity::fromEnvelope)
.toList(), false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
final String userAgent = "Test-UA";
OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.header("USer-Agent", userAgent)
.accept(MediaType.APPLICATION_JSON_TYPE)
.get(OutgoingMessageEntityList.class);
@@ -458,6 +461,8 @@ class MessageControllerTest {
assertEquals(updatedPniOne, response.messages().get(0).updatedPni());
assertNull(response.messages().get(1).updatedPni());
verify(pushNotificationManager).handleMessagesRetrieved(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE, userAgent);
}
@Test
@@ -472,7 +477,7 @@ class MessageControllerTest {
UUID.randomUUID(), 2, AuthHelper.VALID_UUID, null, null, 0)
);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(messages, false));
Response response =

View File

@@ -80,8 +80,6 @@ import org.whispersystems.textsecuregcm.configuration.BadgeConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPaymentsConfiguration;
import org.whispersystems.textsecuregcm.controllers.ProfileController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.Badge;
import org.whispersystems.textsecuregcm.entities.BadgeSvg;
import org.whispersystems.textsecuregcm.entities.BaseProfileResponse;
@@ -362,23 +360,6 @@ class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(401);
}
@Test
void testProfileGetByUsername() throws RateLimitExceededException {
BaseProfileResponse profile = resources.getJerseyTest()
.target("/v1/profile/username/n00bkiller")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(BaseProfileResponse.class);
assertThat(profile.getIdentityKey()).isEqualTo("bar");
assertThat(profile.getUuid()).isEqualTo(AuthHelper.VALID_UUID_TWO);
assertThat(profile.getBadges()).hasSize(1).element(0).has(new Condition<>(
badge -> "Test Badge".equals(badge.getName()), "has badge with expected name"));
verify(accountsManager).getByUsername("n00bkiller");
verify(usernameRateLimiter, times(1)).validate(eq(AuthHelper.VALID_UUID));
}
@Test
void testProfileGetUnauthorized() {
Response response = resources.getJerseyTest()
@@ -389,31 +370,6 @@ class ProfileControllerTest {
assertThat(response.getStatus()).isEqualTo(401);
}
@Test
void testProfileGetByUsernameUnauthorized() {
Response response = resources.getJerseyTest()
.target("/v1/profile/username/n00bkiller")
.request()
.get();
assertThat(response.getStatus()).isEqualTo(401);
}
@Test
void testProfileGetByUsernameNotFound() throws RateLimitExceededException {
Response response = resources.getJerseyTest()
.target("/v1/profile/username/n00bkillerzzzzz")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertThat(response.getStatus()).isEqualTo(404);
verify(accountsManager).getByUsername("n00bkillerzzzzz");
verify(usernameRateLimiter).validate(eq(AuthHelper.VALID_UUID));
}
@Test
void testProfileGetDisabled() {

View File

@@ -17,14 +17,18 @@ import com.eatthepath.pushy.apns.ApnsClient;
import com.eatthepath.pushy.apns.ApnsPushNotification;
import com.eatthepath.pushy.apns.DeliveryPriority;
import com.eatthepath.pushy.apns.PushNotificationResponse;
import com.eatthepath.pushy.apns.PushType;
import com.eatthepath.pushy.apns.util.SimpleApnsPushNotification;
import com.eatthepath.pushy.apns.util.concurrent.PushNotificationFuture;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -33,45 +37,51 @@ import org.whispersystems.textsecuregcm.tests.util.SynchronousExecutorService;
class APNSenderTest {
private static final String DESTINATION_APN_ID = "foo";
private static final String DESTINATION_DEVICE_TOKEN = RandomStringUtils.randomAlphanumeric(32);
private static final String BUNDLE_ID = "org.signal.test";
private Account destinationAccount;
private Device destinationDevice;
private ApnsClient apnsClient;
private APNSender apnSender;
@BeforeEach
void setup() {
destinationAccount = mock(Account.class);
destinationDevice = mock(Device.class);
destinationDevice = mock(Device.class);
apnsClient = mock(ApnsClient.class);
apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, BUNDLE_ID);
when(destinationAccount.getDevice(1)).thenReturn(Optional.of(destinationDevice));
when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID);
when(destinationDevice.getApnId()).thenReturn(DESTINATION_DEVICE_TOKEN);
}
@Test
void testSendVoip() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testSendVoip(final boolean urgent) {
PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class);
when(response.isAccepted()).thenReturn(true);
when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
.thenAnswer(
(Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
PushNotification pushNotification = new PushNotification(DESTINATION_APN_ID, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, "foo");
PushNotification pushNotification = new PushNotification(DESTINATION_DEVICE_TOKEN, PushNotification.TokenType.APN_VOIP,
PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice, urgent);
final SendPushNotificationResult result = apnSender.sendNotification(pushNotification).join();
ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class);
verify(apnsClient).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_DEVICE_TOKEN);
assertThat(notification.getValue().getExpiration()).isEqualTo(APNSender.MAX_EXPIRATION);
assertThat(notification.getValue().getPayload()).isEqualTo(APNSender.APN_VOIP_NOTIFICATION_PAYLOAD);
// Delivery priority should always be `IMMEDIATE` for VOIP notifications
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(notification.getValue().getTopic()).isEqualTo("foo.voip");
assertThat(notification.getValue().getTopic()).isEqualTo(BUNDLE_ID + ".voip");
assertThat(result.accepted()).isTrue();
assertThat(result.errorCode()).isNull();
@@ -80,27 +90,41 @@ class APNSenderTest {
verifyNoMoreInteractions(apnsClient);
}
@Test
void testSendApns() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testSendApns(final boolean urgent) {
PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class);
when(response.isAccepted()).thenReturn(true);
when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
.thenAnswer(
(Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
PushNotification pushNotification = new PushNotification(DESTINATION_APN_ID, PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, "foo");
PushNotification pushNotification = new PushNotification(DESTINATION_DEVICE_TOKEN, PushNotification.TokenType.APN,
PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice, urgent);
final SendPushNotificationResult result = apnSender.sendNotification(pushNotification).join();
ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class);
verify(apnsClient).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_DEVICE_TOKEN);
assertThat(notification.getValue().getExpiration()).isEqualTo(APNSender.MAX_EXPIRATION);
assertThat(notification.getValue().getPayload()).isEqualTo(APNSender.APN_NSE_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
assertThat(notification.getValue().getTopic()).isEqualTo("foo");
assertThat(notification.getValue().getPayload())
.isEqualTo(urgent ? APNSender.APN_NSE_NOTIFICATION_PAYLOAD : APNSender.APN_BACKGROUND_PAYLOAD);
assertThat(notification.getValue().getPriority())
.isEqualTo(urgent ? DeliveryPriority.IMMEDIATE : DeliveryPriority.CONSERVE_POWER);
assertThat(notification.getValue().getTopic()).isEqualTo(BUNDLE_ID);
assertThat(notification.getValue().getPushType())
.isEqualTo(urgent ? PushType.ALERT : PushType.BACKGROUND);
if (urgent) {
assertThat(notification.getValue().getCollapseId()).isNotNull();
} else {
assertThat(notification.getValue().getCollapseId()).isNull();
}
assertThat(result.accepted()).isTrue();
assertThat(result.errorCode()).isNull();
@@ -110,19 +134,19 @@ class APNSenderTest {
}
@Test
void testUnregisteredUser() throws Exception {
void testUnregisteredUser() {
PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class);
when(response.isAccepted()).thenReturn(false);
when(response.getRejectionReason()).thenReturn(Optional.of("Unregistered"));
when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
.thenAnswer(
(Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
PushNotification pushNotification = new PushNotification(DESTINATION_DEVICE_TOKEN, PushNotification.TokenType.APN_VOIP,
PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice, true);
PushNotification pushNotification = new PushNotification(DESTINATION_APN_ID, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, "foo");
when(destinationDevice.getApnId()).thenReturn(DESTINATION_APN_ID);
when(destinationDevice.getApnId()).thenReturn(DESTINATION_DEVICE_TOKEN);
when(destinationDevice.getPushTimestamp()).thenReturn(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(11));
final SendPushNotificationResult result = apnSender.sendNotification(pushNotification).join();
@@ -130,7 +154,7 @@ class APNSenderTest {
ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class);
verify(apnsClient).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_DEVICE_TOKEN);
assertThat(notification.getValue().getExpiration()).isEqualTo(APNSender.MAX_EXPIRATION);
assertThat(notification.getValue().getPayload()).isEqualTo(APNSender.APN_VOIP_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
@@ -142,24 +166,23 @@ class APNSenderTest {
@Test
void testGenericFailure() {
ApnsClient apnsClient = mock(ApnsClient.class);
PushNotificationResponse<SimpleApnsPushNotification> response = mock(PushNotificationResponse.class);
when(response.isAccepted()).thenReturn(false);
when(response.getRejectionReason()).thenReturn(Optional.of("BadTopic"));
when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
.thenAnswer(
(Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
PushNotification pushNotification = new PushNotification(DESTINATION_APN_ID, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, "foo");
PushNotification pushNotification = new PushNotification(DESTINATION_DEVICE_TOKEN, PushNotification.TokenType.APN_VOIP,
PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice, true);
final SendPushNotificationResult result = apnSender.sendNotification(pushNotification).join();
ArgumentCaptor<SimpleApnsPushNotification> notification = ArgumentCaptor.forClass(SimpleApnsPushNotification.class);
verify(apnsClient).sendNotification(notification.capture());
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_APN_ID);
assertThat(notification.getValue().getToken()).isEqualTo(DESTINATION_DEVICE_TOKEN);
assertThat(notification.getValue().getExpiration()).isEqualTo(APNSender.MAX_EXPIRATION);
assertThat(notification.getValue().getPayload()).isEqualTo(APNSender.APN_VOIP_NOTIFICATION_PAYLOAD);
assertThat(notification.getValue().getPriority()).isEqualTo(DeliveryPriority.IMMEDIATE);
@@ -175,10 +198,11 @@ class APNSenderTest {
when(response.isAccepted()).thenReturn(true);
when(apnsClient.sendNotification(any(SimpleApnsPushNotification.class)))
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), new IOException("lost connection")));
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0),
new IOException("lost connection")));
PushNotification pushNotification = new PushNotification(DESTINATION_APN_ID, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice);
APNSender apnSender = new APNSender(new SynchronousExecutorService(), apnsClient, "foo");
PushNotification pushNotification = new PushNotification(DESTINATION_DEVICE_TOKEN, PushNotification.TokenType.APN_VOIP,
PushNotification.NotificationType.NOTIFICATION, null, destinationAccount, destinationDevice, true);
assertThatThrownBy(() -> apnSender.sendNotification(pushNotification).join())
.isInstanceOf(CompletionException.class)
@@ -189,7 +213,8 @@ class APNSenderTest {
verifyNoMoreInteractions(apnsClient);
}
private static class MockPushNotificationFuture <P extends ApnsPushNotification, V> extends PushNotificationFuture<P, V> {
private static class MockPushNotificationFuture<P extends ApnsPushNotification, V> extends
PushNotificationFuture<P, V> {
MockPushNotificationFuture(final P pushNotification, final V response) {
super(pushNotification);

View File

@@ -1,114 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
class ApnFallbackManagerTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private Account account;
private Device device;
private APNSender apnSender;
private ApnFallbackManager apnFallbackManager;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final long DEVICE_ID = 1L;
private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32);
@BeforeEach
void setUp() throws Exception {
device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
when(device.getVoipApnId()).thenReturn(VOIP_APN_ID);
when(device.getLastSeen()).thenReturn(System.currentTimeMillis());
account = mock(Account.class);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(ACCOUNT_NUMBER);
when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
final AccountsManager accountsManager = mock(AccountsManager.class);
when(accountsManager.getByE164(ACCOUNT_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account));
apnSender = mock(APNSender.class);
apnFallbackManager = new ApnFallbackManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), apnSender, accountsManager);
}
@Test
void testClusterInsert() {
final String endpoint = apnFallbackManager.getEndpointKey(account, device);
assertTrue(apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 1).isEmpty());
apnFallbackManager.schedule(account, device, System.currentTimeMillis() - 30_000);
final List<String> pendingDestinations = apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 2);
assertEquals(1, pendingDestinations.size());
final Optional<Pair<String, Long>> maybeUuidAndDeviceId = ApnFallbackManager.getSeparated(
pendingDestinations.get(0));
assertTrue(maybeUuidAndDeviceId.isPresent());
assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first());
assertEquals(DEVICE_ID, (long) maybeUuidAndDeviceId.get().second());
assertTrue(apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 1).isEmpty());
}
@Test
void testProcessNextSlot() {
final ApnFallbackManager.NotificationWorker worker = apnFallbackManager.new NotificationWorker();
apnFallbackManager.schedule(account, device, System.currentTimeMillis() - 30_000);
final int slot = SlotHash.getSlot(apnFallbackManager.getEndpointKey(account, device));
final int previousSlot = (slot + SlotHash.SLOT_COUNT - 1) % SlotHash.SLOT_COUNT;
REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> connection.sync()
.set(ApnFallbackManager.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(previousSlot)));
assertEquals(1, worker.processNextSlot());
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(VOIP_APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(0, worker.processNextSlot());
}
}

View File

@@ -0,0 +1,217 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.cluster.SlotHash;
import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
class ApnPushNotificationSchedulerTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private Account account;
private Device device;
private APNSender apnSender;
private Clock clock;
private ApnPushNotificationScheduler apnPushNotificationScheduler;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String ACCOUNT_NUMBER = "+18005551234";
private static final long DEVICE_ID = 1L;
private static final String APN_ID = RandomStringUtils.randomAlphanumeric(32);
private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32);
@BeforeEach
void setUp() throws Exception {
device = mock(Device.class);
when(device.getId()).thenReturn(DEVICE_ID);
when(device.getApnId()).thenReturn(APN_ID);
when(device.getVoipApnId()).thenReturn(VOIP_APN_ID);
when(device.getLastSeen()).thenReturn(System.currentTimeMillis());
account = mock(Account.class);
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(ACCOUNT_NUMBER);
when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device));
final AccountsManager accountsManager = mock(AccountsManager.class);
when(accountsManager.getByE164(ACCOUNT_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(ACCOUNT_UUID)).thenReturn(Optional.of(account));
apnSender = mock(APNSender.class);
clock = mock(Clock.class);
apnPushNotificationScheduler = new ApnPushNotificationScheduler(REDIS_CLUSTER_EXTENSION.getRedisCluster(), apnSender, accountsManager, clock);
}
@Test
void testClusterInsert() {
final String endpoint = ApnPushNotificationScheduler.getEndpointKey(account, device);
final long currentTimeMillis = System.currentTimeMillis();
assertTrue(
apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
when(clock.millis()).thenReturn(currentTimeMillis - 30_000);
apnPushNotificationScheduler.scheduleRecurringVoipNotification(account, device);
when(clock.millis()).thenReturn(currentTimeMillis);
final List<String> pendingDestinations = apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 2);
assertEquals(1, pendingDestinations.size());
final Optional<Pair<String, Long>> maybeUuidAndDeviceId = ApnPushNotificationScheduler.getSeparated(
pendingDestinations.get(0));
assertTrue(maybeUuidAndDeviceId.isPresent());
assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first());
assertEquals(DEVICE_ID, (long) maybeUuidAndDeviceId.get().second());
assertTrue(
apnPushNotificationScheduler.getPendingDestinationsForRecurringVoipNotifications(SlotHash.getSlot(endpoint), 1).isEmpty());
}
@Test
void testProcessRecurringVoipNotifications() {
final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker();
final long currentTimeMillis = System.currentTimeMillis();
when(clock.millis()).thenReturn(currentTimeMillis - 30_000);
apnPushNotificationScheduler.scheduleRecurringVoipNotification(account, device);
when(clock.millis()).thenReturn(currentTimeMillis);
final int slot = SlotHash.getSlot(ApnPushNotificationScheduler.getEndpointKey(account, device));
assertEquals(1, worker.processRecurringVoipNotifications(slot));
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(VOIP_APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(0, worker.processRecurringVoipNotifications(slot));
}
@Test
void testScheduleBackgroundNotificationWithNoRecentNotification() {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
when(clock.millis()).thenReturn(now.toEpochMilli());
assertEquals(Optional.empty(),
apnPushNotificationScheduler.getLastBackgroundNotificationTimestamp(account, device));
assertEquals(Optional.empty(),
apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device));
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device);
assertEquals(Optional.of(now),
apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device));
}
@Test
void testScheduleBackgroundNotificationWithRecentNotification() {
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
final Instant recentNotificationTimestamp =
now.minus(ApnPushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD.dividedBy(2));
// Insert a timestamp for a recently-sent background push notification
when(clock.millis()).thenReturn(recentNotificationTimestamp.toEpochMilli());
apnPushNotificationScheduler.sendBackgroundNotification(account, device);
when(clock.millis()).thenReturn(now.toEpochMilli());
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device);
final Instant expectedScheduledTimestamp =
recentNotificationTimestamp.plus(ApnPushNotificationScheduler.BACKGROUND_NOTIFICATION_PERIOD);
assertEquals(Optional.of(expectedScheduledTimestamp),
apnPushNotificationScheduler.getNextScheduledBackgroundNotificationTimestamp(account, device));
}
@Test
void testProcessScheduledBackgroundNotifications() {
final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker();
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
when(clock.millis()).thenReturn(now.toEpochMilli());
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device);
final int slot =
SlotHash.getSlot(ApnPushNotificationScheduler.getPendingBackgroundNotificationQueueKey(account, device));
when(clock.millis()).thenReturn(now.minusMillis(1).toEpochMilli());
assertEquals(0, worker.processScheduledBackgroundNotifications(slot));
when(clock.millis()).thenReturn(now.toEpochMilli());
assertEquals(1, worker.processScheduledBackgroundNotifications(slot));
final ArgumentCaptor<PushNotification> notificationCaptor = ArgumentCaptor.forClass(PushNotification.class);
verify(apnSender).sendNotification(notificationCaptor.capture());
final PushNotification pushNotification = notificationCaptor.getValue();
assertEquals(PushNotification.TokenType.APN, pushNotification.tokenType());
assertEquals(APN_ID, pushNotification.deviceToken());
assertEquals(account, pushNotification.destination());
assertEquals(device, pushNotification.destinationDevice());
assertEquals(PushNotification.NotificationType.NOTIFICATION, pushNotification.notificationType());
assertFalse(pushNotification.urgent());
assertEquals(0, worker.processRecurringVoipNotifications(slot));
}
@Test
void testProcessScheduledBackgroundNotificationsCancelled() {
final ApnPushNotificationScheduler.NotificationWorker worker = apnPushNotificationScheduler.new NotificationWorker();
final Instant now = Instant.now().truncatedTo(ChronoUnit.MILLIS);
when(clock.millis()).thenReturn(now.toEpochMilli());
apnPushNotificationScheduler.scheduleBackgroundNotification(account, device);
apnPushNotificationScheduler.cancelScheduledNotifications(account, device);
final int slot =
SlotHash.getSlot(ApnPushNotificationScheduler.getPendingBackgroundNotificationQueueKey(account, device));
assertEquals(0, worker.processScheduledBackgroundNotifications(slot));
verify(apnSender, never()).sendNotification(any());
}
}

View File

@@ -54,7 +54,7 @@ class FcmSenderTest {
@Test
void testSendMessage() {
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null);
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null, true);
final SettableApiFuture<String> sendFuture = SettableApiFuture.create();
sendFuture.set("message-id");
@@ -71,7 +71,7 @@ class FcmSenderTest {
@Test
void testSendMessageRejected() {
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null);
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null, true);
final FirebaseMessagingException invalidArgumentException = mock(FirebaseMessagingException.class);
when(invalidArgumentException.getMessagingErrorCode()).thenReturn(MessagingErrorCode.INVALID_ARGUMENT);
@@ -91,7 +91,7 @@ class FcmSenderTest {
@Test
void testSendMessageUnregistered() {
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null);
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null, true);
final FirebaseMessagingException unregisteredException = mock(FirebaseMessagingException.class);
when(unregisteredException.getMessagingErrorCode()).thenReturn(MessagingErrorCode.UNREGISTERED);
@@ -111,7 +111,7 @@ class FcmSenderTest {
@Test
void testSendMessageException() {
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null);
final PushNotification pushNotification = new PushNotification("foo", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, null, null, true);
final SettableApiFuture<String> sendFuture = SettableApiFuture.create();
sendFuture.setException(new IOException());

View File

@@ -26,7 +26,6 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@@ -117,7 +116,7 @@ class MessageSenderTest {
messageSender.sendMessage(account, device, message, false);
verify(messagesManager).insert(ACCOUNT_UUID, DEVICE_ID, message);
verify(pushNotificationManager).sendNewMessageNotification(account, device.getId());
verify(pushNotificationManager).sendNewMessageNotification(account, device.getId(), message.getUrgent());
}
@Test
@@ -128,7 +127,7 @@ class MessageSenderTest {
messageSender.sendMessage(account, device, message, false);
verify(messagesManager).insert(ACCOUNT_UUID, DEVICE_ID, message);
verify(pushNotificationManager).sendNewMessageNotification(account, device.getId());
verify(pushNotificationManager).sendNewMessageNotification(account, device.getId(), message.getUrgent());
}
@Test
@@ -137,7 +136,7 @@ class MessageSenderTest {
when(device.getFetchesMessages()).thenReturn(true);
doThrow(NotPushRegisteredException.class)
.when(pushNotificationManager).sendNewMessageNotification(account, DEVICE_ID);
.when(pushNotificationManager).sendNewMessageNotification(account, DEVICE_ID, message.getUrgent());
assertDoesNotThrow(() -> messageSender.sendMessage(account, device, message, false));
verify(messagesManager).insert(ACCOUNT_UUID, DEVICE_ID, message);

View File

@@ -1,9 +1,9 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.metrics;
package org.whispersystems.textsecuregcm.push;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
@@ -23,8 +23,9 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPushLatencyConfiguration;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager.PushRecord;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager.PushType;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager.PushRecord;
import org.whispersystems.textsecuregcm.push.PushLatencyManager.PushType;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;

View File

@@ -14,12 +14,18 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPushNotificationConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.Util;
@@ -28,7 +34,9 @@ class PushNotificationManagerTest {
private AccountsManager accountsManager;
private APNSender apnSender;
private FcmSender fcmSender;
private ApnFallbackManager apnFallbackManager;
private ApnPushNotificationScheduler apnPushNotificationScheduler;
private PushLatencyManager pushLatencyManager;
private DynamicPushNotificationConfiguration pushNotificationConfiguration;
private PushNotificationManager pushNotificationManager;
@@ -37,15 +45,28 @@ class PushNotificationManagerTest {
accountsManager = mock(AccountsManager.class);
apnSender = mock(APNSender.class);
fcmSender = mock(FcmSender.class);
apnFallbackManager = mock(ApnFallbackManager.class);
apnPushNotificationScheduler = mock(ApnPushNotificationScheduler.class);
pushLatencyManager = mock(PushLatencyManager.class);
pushNotificationConfiguration = mock(DynamicPushNotificationConfiguration.class);
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getPushNotificationConfiguration()).thenReturn(pushNotificationConfiguration);
when(pushNotificationConfiguration.isLowUrgencyEnabled()).thenReturn(true);
AccountsHelper.setupMockUpdate(accountsManager);
pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnFallbackManager);
pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender,
apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
}
@Test
void sendNewMessageNotification() throws NotPushRegisteredException {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void sendNewMessageNotification(final boolean urgent) throws NotPushRegisteredException {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
@@ -58,8 +79,30 @@ class PushNotificationManagerTest {
when(fcmSender.sendNotification(any()))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
pushNotificationManager.sendNewMessageNotification(account, Device.MASTER_ID);
verify(fcmSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device));
pushNotificationManager.sendNewMessageNotification(account, Device.MASTER_ID, urgent);
verify(fcmSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device, urgent));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void sendNewMessageNotificationLowUrgencyDisabled(final boolean urgent) throws NotPushRegisteredException {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final String deviceToken = "token";
when(device.getId()).thenReturn(Device.MASTER_ID);
when(device.getApnId()).thenReturn(deviceToken);
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
when(pushNotificationConfiguration.isLowUrgencyEnabled()).thenReturn(false);
when(apnSender.sendNotification(any()))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
pushNotificationManager.sendNewMessageNotification(account, Device.MASTER_ID, urgent);
verify(apnSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, true));
}
@Test
@@ -71,7 +114,7 @@ class PushNotificationManagerTest {
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
pushNotificationManager.sendRegistrationChallengeNotification(deviceToken, PushNotification.TokenType.APN_VOIP, challengeToken);
verify(apnSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.CHALLENGE, challengeToken, null, null));
verify(apnSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.CHALLENGE, challengeToken, null, null, true));
}
@Test
@@ -90,11 +133,12 @@ class PushNotificationManagerTest {
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
pushNotificationManager.sendRateLimitChallengeNotification(account, challengeToken);
verify(apnSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.APN, PushNotification.NotificationType.RATE_LIMIT_CHALLENGE, challengeToken, account, device));
verify(apnSender).sendNotification(new PushNotification(deviceToken, PushNotification.TokenType.APN, PushNotification.NotificationType.RATE_LIMIT_CHALLENGE, challengeToken, account, device, true));
}
@Test
void testSendNotification() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testSendNotificationFcm(final boolean urgent) {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
@@ -102,7 +146,7 @@ class PushNotificationManagerTest {
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
final PushNotification pushNotification = new PushNotification(
"token", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device);
"token", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device, urgent);
when(fcmSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
@@ -113,11 +157,12 @@ class PushNotificationManagerTest {
verifyNoInteractions(apnSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.MASTER_ID), any());
verify(device, never()).setUninstalledFeedbackTimestamp(Util.todayInMillis());
verifyNoInteractions(apnFallbackManager);
verifyNoInteractions(apnPushNotificationScheduler);
}
@Test
void testSendNotificationApnVoip() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testSendNotificationApn(final boolean urgent) {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
@@ -125,7 +170,35 @@ class PushNotificationManagerTest {
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
final PushNotification pushNotification = new PushNotification(
"token", PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device);
"token", PushNotification.TokenType.APN, PushNotification.NotificationType.NOTIFICATION, null, account, device, urgent);
when(apnSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
pushNotificationManager.sendNotification(pushNotification);
verifyNoInteractions(fcmSender);
if (urgent) {
verify(apnSender).sendNotification(pushNotification);
verifyNoInteractions(apnPushNotificationScheduler);
} else {
verifyNoInteractions(apnSender);
verify(apnPushNotificationScheduler).scheduleBackgroundNotification(account, device);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testSendNotificationApnVoip(final boolean urgent) {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.MASTER_ID);
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
final PushNotification pushNotification = new PushNotification(
"token", PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, urgent);
when(apnSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(true, null, false)));
@@ -133,10 +206,12 @@ class PushNotificationManagerTest {
pushNotificationManager.sendNotification(pushNotification);
verify(apnSender).sendNotification(pushNotification);
verifyNoInteractions(fcmSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.MASTER_ID), any());
verify(device, never()).setUninstalledFeedbackTimestamp(Util.todayInMillis());
verify(apnFallbackManager).schedule(account, device);
verify(apnPushNotificationScheduler).scheduleRecurringVoipNotification(account, device);
verify(apnPushNotificationScheduler, never()).scheduleBackgroundNotification(any(), any());
}
@Test
@@ -149,7 +224,7 @@ class PushNotificationManagerTest {
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
final PushNotification pushNotification = new PushNotification(
"token", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device);
"token", PushNotification.TokenType.FCM, PushNotification.NotificationType.NOTIFICATION, null, account, device, true);
when(fcmSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, null, true)));
@@ -159,7 +234,7 @@ class PushNotificationManagerTest {
verify(accountsManager).updateDevice(eq(account), eq(Device.MASTER_ID), any());
verify(device).setUninstalledFeedbackTimestamp(Util.todayInMillis());
verifyNoInteractions(apnSender);
verifyNoInteractions(apnFallbackManager);
verifyNoInteractions(apnPushNotificationScheduler);
}
@Test
@@ -171,7 +246,7 @@ class PushNotificationManagerTest {
when(account.getDevice(Device.MASTER_ID)).thenReturn(Optional.of(device));
final PushNotification pushNotification = new PushNotification(
"token", PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device);
"token", PushNotification.TokenType.APN_VOIP, PushNotification.NotificationType.NOTIFICATION, null, account, device, true);
when(apnSender.sendNotification(pushNotification))
.thenReturn(CompletableFuture.completedFuture(new SendPushNotificationResult(false, null, true)));
@@ -181,6 +256,22 @@ class PushNotificationManagerTest {
verifyNoInteractions(fcmSender);
verify(accountsManager, never()).updateDevice(eq(account), eq(Device.MASTER_ID), any());
verify(device, never()).setUninstalledFeedbackTimestamp(Util.todayInMillis());
verify(apnFallbackManager).cancel(account, device);
verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device);
}
@Test
void testHandleMessagesRetrieved() {
final UUID accountIdentifier = UUID.randomUUID();
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final String userAgent = "User-Agent";
when(account.getUuid()).thenReturn(accountIdentifier);
when(device.getId()).thenReturn(Device.MASTER_ID);
pushNotificationManager.handleMessagesRetrieved(account, device, userAgent);
verify(pushLatencyManager).recordQueueRead(accountIdentifier, Device.MASTER_ID, userAgent);
verify(apnPushNotificationScheduler).cancelScheduledNotifications(account, device);
}
}

View File

@@ -216,7 +216,7 @@ class TwilioVerifySenderTest {
.withHeader("Content-Type", "application/json")
.withBody("{\"status\": \"approved\", \"sid\": \"" + VERIFICATION_SID + "\"}")));
final Boolean success = sender.reportVerificationSucceeded(VERIFICATION_SID).get();
final Boolean success = sender.reportVerificationSucceeded(VERIFICATION_SID, null, "test").get();
assertThat(success).isTrue();
@@ -225,4 +225,24 @@ class TwilioVerifySenderTest {
.withHeader("Content-Type", equalTo("application/x-www-form-urlencoded"))
.withRequestBody(equalTo("Status=approved")));
}
@Test
void reportVerificationFailed() throws Exception {
wireMock.stubFor(post(urlEqualTo("/v2/Services/" + VERIFY_SERVICE_SID + "/Verifications/" + VERIFICATION_SID))
.withBasicAuth(ACCOUNT_ID, ACCOUNT_TOKEN)
.willReturn(aResponse()
.withStatus(404)
.withHeader("Content-Type", "application/json")
.withBody("{\"status\": 404, \"code\": 20404}")));
final Boolean success = sender.reportVerificationSucceeded(VERIFICATION_SID, null, "test").get();
assertThat(success).isFalse();
wireMock.verify(1,
postRequestedFor(urlEqualTo("/v2/Services/" + VERIFY_SERVICE_SID + "/Verifications/" + VERIFICATION_SID))
.withHeader("Content-Type", equalTo("application/x-www-form-urlencoded"))
.withRequestBody(equalTo("Status=approved")));
}
}

View File

@@ -27,11 +27,13 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
@@ -197,6 +199,8 @@ class AccountsManagerChangeNumberIntegrationTest {
secureStorageClient,
secureBackupClient,
clientPresenceManager,
mock(UsernameGenerator.class),
mock(ExperimentEnrollmentManager.class),
mock(Clock.class));
}
}

View File

@@ -42,6 +42,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -50,6 +51,7 @@ import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
@@ -164,6 +166,8 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(SecureStorageClient.class),
mock(SecureBackupClient.class),
mock(ClientPresenceManager.class),
mock(UsernameGenerator.class),
mock(ExperimentEnrollmentManager.class),
mock(Clock.class)
);
}

View File

@@ -10,10 +10,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.AdditionalMatchers.and;
import static org.mockito.AdditionalMatchers.not;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.startsWith;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
@@ -43,11 +46,14 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.UsernameConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
@@ -55,6 +61,7 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
class AccountsManagerTest {
@@ -65,6 +72,7 @@ class AccountsManagerTest {
private MessagesManager messagesManager;
private ProfilesManager profilesManager;
private ReservedUsernames reservedUsernames;
private ExperimentEnrollmentManager enrollmentManager;
private Map<String, UUID> phoneNumberIdentifiersByE164;
@@ -129,6 +137,10 @@ class AccountsManagerTest {
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
enrollmentManager = mock(ExperimentEnrollmentManager.class);
when(enrollmentManager.isEnrolled(any(UUID.class), eq(AccountsManager.USERNAME_EXPERIMENT_NAME))).thenReturn(true);
when(accounts.usernameAvailable(any())).thenReturn(true);
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
@@ -143,6 +155,8 @@ class AccountsManagerTest {
storageClient,
backupClient,
mock(ClientPresenceManager.class),
new UsernameGenerator(new UsernameConfiguration()),
enrollmentManager,
mock(Clock.class));
}
@@ -211,7 +225,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID();
String username = "test";
when(commands.get(eq("AccountMap::" + username))).thenReturn(uuid.toString());
when(commands.get(eq("UAccountMap::" + username))).thenReturn(uuid.toString());
when(commands.get(eq("Account3::" + uuid))).thenReturn("{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\", \"username\": \"test\"}");
Optional<Account> account = accountsManager.getByUsername(username);
@@ -221,7 +235,7 @@ class AccountsManagerTest {
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
assertEquals(Optional.of(username), account.get().getUsername());
verify(commands).get(eq("AccountMap::" + username));
verify(commands).get(eq("UAccountMap::" + username));
verify(commands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands);
@@ -309,7 +323,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account.setUsername(username);
when(commands.get(eq("AccountMap::" + username))).thenReturn(null);
when(commands.get(eq("UAccountMap::" + username))).thenReturn(null);
when(accounts.getByUsername(username)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByUsername(username);
@@ -317,8 +331,8 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + username));
verify(commands).setex(eq("AccountMap::" + username), anyLong(), eq(uuid.toString()));
verify(commands).get(eq("UAccountMap::" + username));
verify(commands).setex(eq("UAccountMap::" + username), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("AccountMap::" + account.getPhoneNumberIdentifier()), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString());
@@ -409,7 +423,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
account.setUsername(username);
when(commands.get(eq("AccountMap::" + username))).thenThrow(new RedisException("OH NO"));
when(commands.get(eq("UAccountMap::" + username))).thenThrow(new RedisException("OH NO"));
when(accounts.getByUsername(username)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByUsername(username);
@@ -417,8 +431,8 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + username));
verify(commands).setex(eq("AccountMap::" + username), anyLong(), eq(uuid.toString()));
verify(commands).get(eq("UAccountMap::" + username));
verify(commands).setex(eq("UAccountMap::" + username), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("AccountMap::" + account.getPhoneNumberIdentifier()), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("AccountMap::+14152222222"), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString());
@@ -716,45 +730,82 @@ class AccountsManagerTest {
}
@Test
void testSetUsername() throws UsernameNotAvailableException {
void testSetUsername() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String username = "test";
assertDoesNotThrow(() -> accountsManager.setUsername(account, username));
verify(accounts).setUsername(account, username);
final String nickname = "test";
assertDoesNotThrow(() -> accountsManager.setUsername(account, nickname, null));
verify(accounts).setUsername(eq(account), startsWith(nickname));
}
@Test
void testSetUsernameSameUsername() throws UsernameNotAvailableException {
void testSetUsernameSameUsername() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String username = "test";
account.setUsername(username);
final String nickname = "test";
account.setUsername(nickname + "#123");
assertDoesNotThrow(() -> accountsManager.setUsername(account, username));
// should be treated as a replayed request
assertDoesNotThrow(() -> accountsManager.setUsername(account, nickname, null));
verify(accounts, never()).setUsername(eq(account), any());
}
@Test
void testSetUsernameNotAvailable() throws UsernameNotAvailableException {
void testSetUsernameReroll() throws UsernameNotAvailableException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String username = "test";
final String nickname = "test";
final String username = nickname + "#ZZZ";
account.setUsername(username);
doThrow(new UsernameNotAvailableException()).when(accounts).setUsername(account, username);
// given the correct old username, should reroll discriminator even if the nick matches
accountsManager.setUsername(account, nickname, username);
verify(accounts).setUsername(eq(account), and(startsWith(nickname), not(eq(username))));
}
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, username));
verify(accounts).setUsername(account, username);
@Test
void testSetUsernameExpandDiscriminator() throws UsernameNotAvailableException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String nickname = "test";
ArgumentMatcher<String> isWide = (String username) -> {
String[] spl = username.split(UsernameGenerator.SEPARATOR);
assertEquals(spl.length, 2);
int discriminator = Integer.parseInt(spl[1]);
// require a 7 digit discriminator
return discriminator > 1_000_000;
};
when(accounts.usernameAvailable(any())).thenReturn(false);
when(accounts.usernameAvailable(argThat(isWide))).thenReturn(true);
accountsManager.setUsername(account, nickname, null);
verify(accounts).setUsername(eq(account), and(startsWith(nickname), argThat(isWide)));
}
@Test
void testChangeUsername() throws UsernameNotAvailableException {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String nickname = "test";
account.setUsername("old#123");
accountsManager.setUsername(account, nickname, "old#123");
verify(accounts).setUsername(eq(account), startsWith(nickname));
}
@Test
void testSetUsernameNotAvailable() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
final String nickname = "unavailable";
when(accounts.usernameAvailable(startsWith(nickname))).thenReturn(false);
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, nickname, null));
verify(accounts, never()).setUsername(any(), any());
assertTrue(account.getUsername().isEmpty());
}
@Test
void testSetUsernameReserved() {
final String username = "reserved";
when(reservedUsernames.isReserved(eq(username), any())).thenReturn(true);
final String nickname = "reserved";
when(reservedUsernames.isReserved(eq(nickname), any())).thenReturn(true);
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, username));
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, nickname, null));
assertTrue(account.getUsername().isEmpty());
}
@@ -765,6 +816,13 @@ class AccountsManagerTest {
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsername("test")));
}
@Test
void testSetUsernameDisabled() {
final Account account = AccountsHelper.generateTestAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID(), new ArrayList<>(), new byte[16]);
when(enrollmentManager.isEnrolled(account.getUuid(), AccountsManager.USERNAME_EXPERIMENT_NAME)).thenReturn(false);
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, "n00bkiller", null));
}
private static Device generateTestDevice(final long lastSeen) {
final Device device = new Device();
device.setId(Device.MASTER_ID);

View File

@@ -0,0 +1,240 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import software.amazon.awssdk.services.dynamodb.model.*;
import java.time.Clock;
import java.util.*;
import java.util.function.Consumer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
class AccountsManagerUsernameIntegrationTest {
private static final String ACCOUNTS_TABLE_NAME = "accounts_test";
private static final String NUMBERS_TABLE_NAME = "numbers_test";
private static final String PNI_ASSIGNMENT_TABLE_NAME = "pni_assignment_test";
private static final String USERNAMES_TABLE_NAME = "usernames_test";
private static final String PNI_TABLE_NAME = "pni_test";
private static final int SCAN_PAGE_SIZE = 1;
@RegisterExtension
static DynamoDbExtension ACCOUNTS_DYNAMO_EXTENSION = DynamoDbExtension.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.hashKey(Accounts.KEY_ACCOUNT_UUID)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(Accounts.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build())
.build();
@RegisterExtension
static DynamoDbExtension PNI_DYNAMO_EXTENSION = DynamoDbExtension.builder()
.tableName(PNI_TABLE_NAME)
.hashKey(PhoneNumberIdentifiers.KEY_E164)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(PhoneNumberIdentifiers.KEY_E164)
.attributeType(ScalarAttributeType.S)
.build())
.build();
@RegisterExtension
static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private AccountsManager accountsManager;
private Accounts accounts;
private UsernameGenerator usernameGenerator;
@BeforeEach
void setup() throws InterruptedException {
CreateTableRequest createNumbersTableRequest = CreateTableRequest.builder()
.tableName(NUMBERS_TABLE_NAME)
.keySchema(KeySchemaElement.builder()
.attributeName(Accounts.ATTR_ACCOUNT_E164)
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName(Accounts.ATTR_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S)
.build())
.provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT)
.build();
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient().createTable(createNumbersTableRequest);
CreateTableRequest createUsernamesTableRequest = CreateTableRequest.builder()
.tableName(USERNAMES_TABLE_NAME)
.keySchema(KeySchemaElement.builder()
.attributeName(Accounts.ATTR_USERNAME)
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName(Accounts.ATTR_USERNAME)
.attributeType(ScalarAttributeType.S)
.build())
.provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT)
.build();
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient().createTable(createUsernamesTableRequest);
CreateTableRequest createPhoneNumberIdentifierTableRequest = CreateTableRequest.builder()
.tableName(PNI_ASSIGNMENT_TABLE_NAME)
.keySchema(KeySchemaElement.builder()
.attributeName(Accounts.ATTR_PNI_UUID)
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName(Accounts.ATTR_PNI_UUID)
.attributeType(ScalarAttributeType.B)
.build())
.provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT)
.build();
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient().createTable(createPhoneNumberIdentifierTableRequest);
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
accounts = Mockito.spy(new Accounts(
dynamicConfigurationManager,
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient(),
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbAsyncClient(),
ACCOUNTS_DYNAMO_EXTENSION.getTableName(),
NUMBERS_TABLE_NAME,
PNI_ASSIGNMENT_TABLE_NAME,
USERNAMES_TABLE_NAME,
SCAN_PAGE_SIZE));
final DeletedAccountsManager deletedAccountsManager = mock(DeletedAccountsManager.class);
doAnswer((final InvocationOnMock invocationOnMock) -> {
@SuppressWarnings("unchecked")
Consumer<Optional<UUID>> consumer = invocationOnMock.getArgument(1, Consumer.class);
consumer.accept(Optional.empty());
return null;
}).when(deletedAccountsManager).lockAndTake(any(), any());
final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(PNI_DYNAMO_EXTENSION.getDynamoDbClient(), PNI_TABLE_NAME);
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(AccountsManager.USERNAME_EXPERIMENT_NAME)))
.thenReturn(true);
usernameGenerator = new UsernameGenerator(1, 2, 10);
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
deletedAccountsManager,
mock(DirectoryQueue.class),
mock(Keys.class),
mock(MessagesManager.class),
mock(ReservedUsernames.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),
mock(SecureStorageClient.class),
mock(SecureBackupClient.class),
mock(ClientPresenceManager.class),
usernameGenerator,
experimentEnrollmentManager,
mock(Clock.class));
}
private static int discriminator(String username) {
return Integer.parseInt(username.split(UsernameGenerator.SEPARATOR)[1]);
}
@Test
void testSetClearUsername() throws UsernameNotAvailableException, InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
account = accountsManager.setUsername(account, "n00bkiller", null);
assertThat(account.getUsername()).isPresent();
assertThat(account.getUsername().get()).startsWith("n00bkiller");
int discriminator = discriminator(account.getUsername().get());
assertThat(discriminator).isGreaterThan(0).isLessThan(10);
assertThat(accountsManager.getByUsername(account.getUsername().get()).orElseThrow().getUuid()).isEqualTo(
account.getUuid());
// reroll
account = accountsManager.setUsername(account, "n00bkiller", account.getUsername().get());
final String newUsername = account.getUsername().orElseThrow();
assertThat(discriminator(account.getUsername().orElseThrow())).isNotEqualTo(discriminator);
// clear
account = accountsManager.clearUsername(account);
assertThat(accountsManager.getByUsername(newUsername)).isEmpty();
assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsername()).isEmpty();
}
@Test
void testNoUsernames() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
for (int i = 1; i <= 99; i++) {
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder()
.tableName(USERNAMES_TABLE_NAME)
.item(Map.of(
Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(UUID.randomUUID()),
Accounts.ATTR_USERNAME, AttributeValues.fromString(usernameGenerator.fromParts("n00bkiller", i))))
.build());
}
assertThrows(UsernameNotAvailableException.class, () -> accountsManager.setUsername(account, "n00bkiller", null));
assertThat(accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getUsername()).isEmpty();
}
@Test
void testUsernameSnatched() throws InterruptedException, UsernameNotAvailableException {
final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
for (int i = 1; i <= 9; i++) {
ACCOUNTS_DYNAMO_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder()
.tableName(USERNAMES_TABLE_NAME)
.item(Map.of(
Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(UUID.randomUUID()),
Accounts.ATTR_USERNAME, AttributeValues.fromString(usernameGenerator.fromParts("n00bkiller", i))))
.build());
}
// first time this is called lie and say the username is available
// this simulates seeing an available username and then it being taken
// by someone before the write
doReturn(true).doCallRealMethod().when(accounts).usernameAvailable(any());
final String username = accountsManager
.setUsername(account, "n00bkiller", null)
.getUsername().orElseThrow();
assertThat(username).startsWith("n00bkiller");
assertThat(discriminator(username)).isGreaterThanOrEqualTo(10).isLessThan(100);
// 1 attempt on first try (returns true),
// 10 (attempts per width) on width=2 discriminators (all taken)
verify(accounts, times(11)).usernameAvailable(argThat(un -> discriminator(un) < 10));
// 1 final attempt on width=3 discriminators
verify(accounts, times(1)).usernameAvailable(argThat(un -> discriminator(un) >= 10));
}
}

View File

@@ -16,11 +16,9 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.uuid.UUIDComparator;
import io.github.resilience4j.circuitbreaker.CallNotPermittedException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -28,16 +26,11 @@ import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.jdbi.v3.core.transaction.TransactionException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
@@ -46,7 +39,6 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
@@ -498,60 +490,6 @@ class AccountsTest {
assertThat(retrieved.isPresent()).isFalse();
}
@Test
@Disabled("Need fault tolerant dynamodb")
void testBreaker() throws InterruptedException {
CircuitBreakerConfiguration configuration = new CircuitBreakerConfiguration();
configuration.setWaitDurationInOpenStateInSeconds(1);
configuration.setRingBufferSizeInHalfOpenState(1);
configuration.setRingBufferSizeInClosedState(2);
configuration.setFailureRateThreshold(50);
final DynamoDbClient client = mock(DynamoDbClient.class);
final DynamoDbAsyncClient asyncClient = mock(DynamoDbAsyncClient.class);
when(client.transactWriteItems(any(TransactWriteItemsRequest.class)))
.thenThrow(RuntimeException.class);
when(asyncClient.updateItem(any(UpdateItemRequest.class)))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
Accounts accounts = new Accounts(mockDynamicConfigManager, client, asyncClient, ACCOUNTS_TABLE_NAME, NUMBER_CONSTRAINT_TABLE_NAME,
PNI_CONSTRAINT_TABLE_NAME, USERNAME_CONSTRAINT_TABLE_NAME, SCAN_PAGE_SIZE);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID());
try {
accounts.update(account);
throw new AssertionError();
} catch (TransactionException e) {
// good
}
try {
accounts.update(account);
throw new AssertionError();
} catch (TransactionException e) {
// good
}
try {
accounts.update(account);
throw new AssertionError();
} catch (CallNotPermittedException e) {
// good
}
Thread.sleep(1100);
try {
accounts.update(account);
throw new AssertionError();
} catch (TransactionException e) {
// good
}
}
@Test
void testCanonicallyDiscoverableSet() {
Device device = generateDevice(1);
@@ -722,7 +660,7 @@ class AccountsTest {
assertThat(maybeAccount).isPresent();
verifyStoredState(firstAccount.getNumber(), firstAccount.getUuid(), firstAccount.getPhoneNumberIdentifier(), maybeAccount.get(), firstAccount);
assertThatExceptionOfType(UsernameNotAvailableException.class)
assertThatExceptionOfType(ContestedOptimisticLockException.class)
.isThrownBy(() -> accounts.setUsername(secondAccount, username));
assertThat(secondAccount.getUsername()).isEmpty();
@@ -781,164 +719,6 @@ class AccountsTest {
assertThat(account.getUsername()).hasValueSatisfying(u -> assertThat(u).isEqualTo(username));
}
@Test
void testAddUakMissingInJson() {
// If there's no uak in the json, we shouldn't add an attribute on crawl
final UUID accountIdentifier = UUID.randomUUID();
final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID());
account.setUnidentifiedAccessKey(null);
accounts.create(account);
// there should be no top level uak
Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).doesNotContainKey(Accounts.ATTR_UAK);
// crawling should return 1 account
final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1);
assertThat(allFromStart.getAccounts()).hasSize(1);
assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier);
// there should still be no top level uak
item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).doesNotContainKey(Accounts.ATTR_UAK);
}
@Test
void testUakMismatch() {
// If there's a UAK mismatch, we should correct it
final UUID accountIdentifier = UUID.randomUUID();
final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID());
accounts.create(account);
// set the uak to garbage in the attributes
dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK))
.expressionAttributeValues(Map.of(":uak", AttributeValues.fromByteArray("bad-uak".getBytes())))
.updateExpression("SET #uak = :uak").build());
// crawling should return 1 account and fix the uak mismatch
final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1);
assertThat(allFromStart.getAccounts()).hasSize(1);
assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier);
assertThat(allFromStart.getAccounts().get(0).getUnidentifiedAccessKey().get()).isEqualTo(account.getUnidentifiedAccessKey().get());
// the top level uak should be the original
final Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
assertThat(item).containsEntry(
Accounts.ATTR_UAK,
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testAddMissingUakAttribute(boolean normalizeDisabled) throws JsonProcessingException {
final UUID accountIdentifier = UUID.randomUUID();
if (normalizeDisabled) {
final DynamicConfiguration config = DynamicConfigurationManager.parseConfiguration("""
captcha:
scoreFloor: 1.0
uakMigrationConfiguration:
enabled: false
""", DynamicConfiguration.class).orElseThrow();
when(mockDynamicConfigManager.getConfiguration()).thenReturn(config);
}
final Account account = generateAccount("+18005551234", accountIdentifier, UUID.randomUUID());
accounts.create(account);
// remove the top level uak (simulates old format)
dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK))
.updateExpression("REMOVE #uak").build());
// crawling should return 1 account, and fix the discrepancy between
// the json blob and the top level attributes if normalization is enabled
final AccountCrawlChunk allFromStart = accounts.getAllFromStart(1);
assertThat(allFromStart.getAccounts()).hasSize(1);
assertThat(allFromStart.getAccounts().get(0).getUuid()).isEqualTo(accountIdentifier);
// check whether normalization happened
final Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(accountIdentifier)))
.consistentRead(true)
.build()).item();
if (normalizeDisabled) {
assertThat(item).doesNotContainKey(Accounts.ATTR_UAK);
} else {
assertThat(item).containsEntry(Accounts.ATTR_UAK,
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
}
}
@ParameterizedTest
@ValueSource(ints = {24, 25, 26, 101})
void testAddMissingUakAttributeBatched(int n) {
// generate N + 5 accounts
List<Account> allAccounts = IntStream.range(0, n + 5)
.mapToObj(i -> generateAccount(String.format("+1800555%04d", i), UUID.randomUUID(), UUID.randomUUID()))
.collect(Collectors.toList());
allAccounts.forEach(accounts::create);
// delete the UAK on n of them
Collections.shuffle(allAccounts);
allAccounts.stream().limit(n).forEach(account ->
dynamoDbExtension.getDynamoDbClient().updateItem(UpdateItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.expressionAttributeNames(Map.of("#uak", Accounts.ATTR_UAK))
.updateExpression("REMOVE #uak")
.build()));
// crawling should fix the discrepancy between
// the json blob and the top level attributes
AccountCrawlChunk chunk = accounts.getAllFromStart(7);
long verifiedCount = 0;
while (true) {
for (Account account : chunk.getAccounts()) {
// check that the attribute now exists at top level
final Map<String, AttributeValue> item = dynamoDbExtension.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.key(Map.of(Accounts.KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.consistentRead(true)
.build()).item();
assertThat(item).containsEntry(Accounts.ATTR_UAK,
AttributeValues.fromByteArray(account.getUnidentifiedAccessKey().get()));
verifiedCount++;
}
if (chunk.getLastUuid().isPresent()) {
chunk = accounts.getAllFrom(chunk.getLastUuid().get(), 7);
} else {
break;
}
}
assertThat(verifiedCount).isEqualTo(n + 5);
}
private Device generateDevice(long id) {
return DevicesHelper.createDevice(id);
}

View File

@@ -32,7 +32,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@@ -73,8 +73,7 @@ class MessagePersisterIntegrationTest {
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class),
mock(ReportMessageManager.class));
messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class));
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY);

View File

@@ -14,7 +14,7 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.PushLatencyManager;
class MessagesManagerTest {
@@ -24,7 +24,7 @@ class MessagesManagerTest {
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
pushLatencyManager, reportMessageManager);
reportMessageManager);
@Test
void insert() {

View File

@@ -62,6 +62,8 @@ class PushChallengeDynamoDbTest {
assertFalse(pushChallengeDynamoDb.remove(uuid, token));
assertTrue(pushChallengeDynamoDb.add(uuid, token, Duration.ofMinutes(1)));
assertTrue(pushChallengeDynamoDb.remove(uuid, token));
assertTrue(pushChallengeDynamoDb.add(uuid, token, Duration.ofMinutes(-1)));
assertFalse(pushChallengeDynamoDb.remove(uuid, token));
}
@Test

View File

@@ -1,35 +0,0 @@
/*
* Copyright 2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.*;
import static org.mockito.Mockito.*;
class UsernameCleanerTest {
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final Account hasUsername = mock(Account.class);
private final Account noUsername = mock(Account.class);
@BeforeEach
void setup() {
when(hasUsername.getUsername()).thenReturn(Optional.of("n00bkiller"));
when(noUsername.getUsername()).thenReturn(Optional.empty());
}
@Test
void testAccounts() throws AccountDatabaseCrawlerRestartException {
UsernameCleaner accountCleaner = new UsernameCleaner(accountsManager);
accountCleaner.onCrawlStart();
accountCleaner.timeAndProcessCrawlChunk(Optional.empty(), Arrays.asList(hasUsername, noUsername));
accountCleaner.onCrawlEnd(Optional.empty());
verify(accountsManager).clearUsername(hasUsername);
verify(accountsManager, never()).clearUsername(noUsername);
}
}

View File

@@ -5,27 +5,33 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.time.Duration;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import java.util.Objects;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
class VerificationCodeStoreTest {
private VerificationCodeStore verificationCodeStore;
private static final String TABLE_NAME = "verification_code_test";
private static final String PHONE_NUMBER = "+14151112222";
private static final long VALID_TIMESTAMP = Instant.now().toEpochMilli();
private static final long EXPIRED_TIMESTAMP = Instant.now().minus(StoredVerificationCode.EXPIRATION).minus(
Duration.ofHours(1)).toEpochMilli();
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = DynamoDbExtension.builder()
.tableName(TABLE_NAME)
@@ -45,8 +51,8 @@ class VerificationCodeStoreTest {
void testStoreAndFind() {
assertEquals(Optional.empty(), verificationCodeStore.findForNumber(PHONE_NUMBER));
final StoredVerificationCode originalCode = new StoredVerificationCode("1234", 1111, "abcd", "0987");
final StoredVerificationCode secondCode = new StoredVerificationCode("5678", 2222, "efgh", "7890");
final StoredVerificationCode originalCode = new StoredVerificationCode("1234", VALID_TIMESTAMP, "abcd", "0987");
final StoredVerificationCode secondCode = new StoredVerificationCode("5678", VALID_TIMESTAMP, "efgh", "7890");
verificationCodeStore.insert(PHONE_NUMBER, originalCode);
{
@@ -69,11 +75,14 @@ class VerificationCodeStoreTest {
void testRemove() {
assertEquals(Optional.empty(), verificationCodeStore.findForNumber(PHONE_NUMBER));
verificationCodeStore.insert(PHONE_NUMBER, new StoredVerificationCode("1234", 1111, "abcd", "0987"));
verificationCodeStore.insert(PHONE_NUMBER, new StoredVerificationCode("1234", VALID_TIMESTAMP, "abcd", "0987"));
assertTrue(verificationCodeStore.findForNumber(PHONE_NUMBER).isPresent());
verificationCodeStore.remove(PHONE_NUMBER);
assertFalse(verificationCodeStore.findForNumber(PHONE_NUMBER).isPresent());
verificationCodeStore.insert(PHONE_NUMBER, new StoredVerificationCode("1234", EXPIRED_TIMESTAMP, "abcd", "0987"));
assertFalse(verificationCodeStore.findForNumber(PHONE_NUMBER).isPresent());
}
private static boolean storedVerificationCodesAreEqual(final StoredVerificationCode first, final StoredVerificationCode second) {

View File

@@ -17,7 +17,7 @@ class AuthenticationCredentialsTest {
AuthenticationCredentials credentials = new AuthenticationCredentials("mypassword");
assertThat(credentials.getSalt()).isNotEmpty();
assertThat(credentials.getHashedAuthenticationToken()).isNotEmpty();
assertThat(credentials.getHashedAuthenticationToken().length()).isEqualTo(40);
assertThat(credentials.getHashedAuthenticationToken().length()).isEqualTo(66);
}
@Test

View File

@@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
@@ -23,6 +24,7 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@@ -50,6 +52,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
@@ -66,6 +69,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.AccountIdentifierResponse;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ChangePhoneNumberRequest;
@@ -74,6 +78,8 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessage;
import org.whispersystems.textsecuregcm.entities.RegistrationLock;
import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.UsernameRequest;
import org.whispersystems.textsecuregcm.entities.UsernameResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
@@ -138,6 +144,7 @@ class AccountControllerTest {
private static RateLimiter smsVoicePrefixLimiter = mock(RateLimiter.class);
private static RateLimiter autoBlockLimiter = mock(RateLimiter.class);
private static RateLimiter usernameSetLimiter = mock(RateLimiter.class);
private static RateLimiter usernameLookupLimiter = mock(RateLimiter.class);
private static SmsSender smsSender = mock(SmsSender.class);
private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class);
private static Account senderPinAccount = mock(Account.class);
@@ -201,6 +208,7 @@ class AccountControllerTest {
when(rateLimiters.getSmsVoicePrefixLimiter()).thenReturn(smsVoicePrefixLimiter);
when(rateLimiters.getAutoBlockLimiter()).thenReturn(autoBlockLimiter);
when(rateLimiters.getUsernameSetLimiter()).thenReturn(usernameSetLimiter);
when(rateLimiters.getUsernameLookupLimiter()).thenReturn(usernameLookupLimiter);
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
when(senderPinAccount.getRegistrationLock()).thenReturn(new StoredRegistrationLock(Optional.empty(), Optional.empty(), System.currentTimeMillis()));
@@ -246,7 +254,7 @@ class AccountControllerTest {
return account;
});
when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername"))
when(accountsManager.setUsername(AuthHelper.VALID_ACCOUNT, "takenusername", null))
.thenThrow(new UsernameNotAvailableException());
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
@@ -311,6 +319,7 @@ class AccountControllerTest {
smsVoicePrefixLimiter,
autoBlockLimiter,
usernameSetLimiter,
usernameLookupLimiter,
smsSender,
turnTokenGenerator,
senderPinAccount,
@@ -976,7 +985,8 @@ class AccountControllerTest {
.thenReturn(enrolledInVerifyExperiment);
final String challenge = "challenge";
when(pendingAccountsManager.getCodeForNumber(RESTRICTED_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("123456", System.currentTimeMillis(), challenge, null)));
when(pendingAccountsManager.getCodeForNumber(RESTRICTED_NUMBER))
.thenReturn(Optional.of(new StoredVerificationCode("123456", System.currentTimeMillis(), challenge, null)));
Response response =
resources.getJerseyTest()
@@ -992,6 +1002,54 @@ class AccountControllerTest {
verifyNoMoreInteractions(smsSender);
}
@ParameterizedTest
@CsvSource({
"+12025550123, true, true",
"+12025550123, false, true",
"+12505550199, true, false",
"+12505550199, false, false",
})
void testRestrictedRegion(final String number, final boolean enrolledInVerifyExperiment, final boolean expectSendCode) {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final DynamicCaptchaConfiguration signupCaptchaConfig = new DynamicCaptchaConfiguration();
signupCaptchaConfig.setSignupRegions(Set.of("CA"));
when(dynamicConfiguration.getCaptchaConfiguration()).thenReturn(signupCaptchaConfig);
when(verifyExperimentEnrollmentManager.isEnrolled(any(), anyString(), anyList(), anyString()))
.thenReturn(enrolledInVerifyExperiment);
final String challenge = "challenge";
when(pendingAccountsManager.getCodeForNumber(number))
.thenReturn(Optional.of(new StoredVerificationCode("123456", System.currentTimeMillis(), challenge, null)));
when(smsSender.deliverSmsVerificationWithTwilioVerify(any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/sms/code/%s", number))
.queryParam("challenge", challenge)
.request()
.header("X-Forwarded-For", NICE_HOST)
.get();
if (expectSendCode) {
assertThat(response.getStatus()).isEqualTo(200);
if (enrolledInVerifyExperiment) {
verify(smsSender).deliverSmsVerificationWithTwilioVerify(eq(number), any(), any(), any());
} else {
verify(smsSender).deliverSmsVerification(eq(number), any(), any());
}
} else {
assertThat(response.getStatus()).isEqualTo(402);
verifyNoMoreInteractions(smsSender);
}
}
@ParameterizedTest
@ValueSource(booleans = {false, true})
void testSendRestrictedIn(final boolean enrolledInVerifyExperiment) throws Exception {
@@ -1060,7 +1118,7 @@ class AccountControllerTest {
verify(accountsManager).create(eq(SENDER), eq("bar"), any(), any(), anyList());
if (enrolledInVerifyExperiment) {
verify(smsSender).reportVerificationSucceeded("VerificationSid");
verify(smsSender).reportVerificationSucceeded(eq("VerificationSid"), any(), eq("registration"));
}
}
@@ -1572,7 +1630,7 @@ class AccountControllerTest {
assertThat(pinCapture.getValue()).isNotEmpty();
assertThat(pinSaltCapture.getValue()).isNotEmpty();
assertThat(pinCapture.getValue().length()).isEqualTo(40);
assertThat(pinCapture.getValue().length()).isEqualTo(66);
}
@Test
@@ -1660,25 +1718,29 @@ class AccountControllerTest {
}
@Test
void testSetUsername() {
void testSetUsername() throws UsernameNotAvailableException {
Account account = mock(Account.class);
when(account.getUsername()).thenReturn(Optional.of("n00bkiller#1234"));
when(accountsManager.setUsername(any(), eq("n00bkiller"), isNull()))
.thenReturn(account);
Response response =
resources.getJerseyTest()
.target("/v1/accounts/username/n00bkiller")
.target("/v1/accounts/username")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.text(""));
.put(Entity.json(new UsernameRequest("n00bkiller", null)));
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.readEntity(UsernameResponse.class).username()).isEqualTo("n00bkiller#1234");
}
@Test
void testSetTakenUsername() {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/username/takenusername")
.target("/v1/accounts/username/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.text(""));
.put(Entity.json(new UsernameRequest("takenusername", null)));
assertThat(response.getStatus()).isEqualTo(409);
}
@@ -1687,35 +1749,34 @@ class AccountControllerTest {
void testSetInvalidUsername() {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/username/pаypal")
.target("/v1/accounts/username")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.text(""));
// contains non-ascii character
.put(Entity.json(new UsernameRequest("pаypal", null)));
assertThat(response.getStatus()).isEqualTo(400);
assertThat(response.getStatus()).isEqualTo(422);
}
@Test
void testSetInvalidPrefixUsername() {
void testSetInvalidPrefixUsername() throws JsonProcessingException {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/username/0n00bkiller")
.target("/v1/accounts/username")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.text(""));
assertThat(response.getStatus()).isEqualTo(400);
.put(Entity.json(new UsernameRequest("0n00bkiller", null)));
assertThat(response.getStatus()).isEqualTo(422);
}
@Test
void testSetUsernameBadAuth() {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/username/n00bkiller")
.target("/v1/accounts/username")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.put(Entity.text(""));
.put(Entity.json(new UsernameRequest("n00bkiller", null)));
assertThat(response.getStatus()).isEqualTo(401);
}
@@ -1935,4 +1996,43 @@ class AccountControllerTest {
.head()
.getStatus()).isEqualTo(400);
}
@Test
void testLookupUsername() {
final Account account = mock(Account.class);
final UUID uuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(uuid);
when(accountsManager.getByUsername(eq("n00bkiller#1234"))).thenReturn(Optional.of(account));
Response response = resources.getJerseyTest()
.target("v1/accounts/username/n00bkiller#1234")
.request()
.header("X-Forwarded-For", "127.0.0.1")
.get();
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.readEntity(AccountIdentifierResponse.class).uuid()).isEqualTo(uuid);
}
@Test
void testLookupUsernameDoesNotExist() {
when(accountsManager.getByUsername(eq("n00bkiller#1234"))).thenReturn(Optional.empty());
assertThat(resources.getJerseyTest()
.target("v1/accounts/username/n00bkiller#1234")
.request()
.header("X-Forwarded-For", "127.0.0.1")
.get().getStatus()).isEqualTo(404);
}
@Test
void testLookupUsernameRateLimited() throws RateLimitExceededException {
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(usernameLookupLimiter).validate("127.0.0.1");
final Response response = resources.getJerseyTest()
.target("/v1/accounts/username/test#123")
.request()
.header("X-Forwarded-For", "127.0.0.1")
.get();
assertThat(response.getStatus()).isEqualTo(413);
assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds()));
}
}

View File

@@ -62,16 +62,17 @@ class CertificateControllerTest {
private static final String caPrivateKey = "EO3Mnf0kfVlVnwSaqPoQnAxhnnGL1JTdXqktCKEe9Eo=";
private static final String signingCertificate = "CiUIDBIhBbTz4h1My+tt+vw+TVscgUe/DeHS0W02tPWAWbTO2xc3EkD+go4bJnU0AcnFfbOLKoiBfCzouZtDYMOVi69rE7r4U9cXREEqOkUmU2WJBjykAxWPCcSTmVTYHDw7hkSp/puG";
private static final String signingKey = "ABOxG29xrfq4E7IrW11Eg7+HBbtba9iiS0500YoBjn4=";
private static final String signingKey = "ABOxG29xrfq4E7IrW11Eg7+HBbtba9iiS0500YoBjn4=";
private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate();
private static final CertificateGenerator certificateGenerator;
private static final ServerSecretParams serverSecretParams = ServerSecretParams.generate();
private static final CertificateGenerator certificateGenerator;
private static final ServerZkAuthOperations serverZkAuthOperations;
private static final Clock clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
static {
try {
certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(signingCertificate), Curve.decodePrivatePoint(Base64.getDecoder().decode(signingKey)), 1);
certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(signingCertificate),
Curve.decodePrivatePoint(Base64.getDecoder().decode(signingKey)), 1);
serverZkAuthOperations = new ServerZkAuthOperations(serverSecretParams);
} catch (IOException e) {
throw new AssertionError(e);
@@ -91,86 +92,98 @@ class CertificateControllerTest {
@Test
void testValidCertificate() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(DeliveryCertificate.class);
.target("/v1/certificate/delivery")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(DeliveryCertificate.class);
SenderCertificate certificateHolder = SenderCertificate.parseFrom(certificateObject.getCertificate());
SenderCertificate.Certificate certificate = SenderCertificate.Certificate.parseFrom(
certificateHolder.getCertificate());
SenderCertificate certificateHolder = SenderCertificate.parseFrom(certificateObject.getCertificate());
SenderCertificate.Certificate certificate = SenderCertificate.Certificate.parseFrom(certificateHolder.getCertificate());
ServerCertificate serverCertificateHolder = certificate.getSigner();
ServerCertificate.Certificate serverCertificate = ServerCertificate.Certificate.parseFrom(
serverCertificateHolder.getCertificate());
ServerCertificate serverCertificateHolder = certificate.getSigner();
ServerCertificate.Certificate serverCertificate = ServerCertificate.Certificate.parseFrom(serverCertificateHolder.getCertificate());
assertTrue(Curve.verifySignature(Curve.decodePoint(serverCertificate.getKey().toByteArray(), 0), certificateHolder.getCertificate().toByteArray(), certificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(Base64.getDecoder().decode(caPublicKey), 0), serverCertificateHolder.getCertificate().toByteArray(), serverCertificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(serverCertificate.getKey().toByteArray(), 0),
certificateHolder.getCertificate().toByteArray(), certificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(Base64.getDecoder().decode(caPublicKey), 0),
serverCertificateHolder.getCertificate().toByteArray(), serverCertificateHolder.getSignature().toByteArray()));
assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER);
assertEquals(certificate.getSenderDevice(), 1L);
assertTrue(certificate.hasSenderUuid());
assertEquals(AuthHelper.VALID_UUID.toString(), certificate.getSenderUuid());
assertArrayEquals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
assertArrayEquals(certificate.getIdentityKey().toByteArray(),
Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
}
@Test
void testValidCertificateWithUuid() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.queryParam("includeUuid", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(DeliveryCertificate.class);
.target("/v1/certificate/delivery")
.queryParam("includeUuid", "true")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(DeliveryCertificate.class);
SenderCertificate certificateHolder = SenderCertificate.parseFrom(certificateObject.getCertificate());
SenderCertificate.Certificate certificate = SenderCertificate.Certificate.parseFrom(
certificateHolder.getCertificate());
SenderCertificate certificateHolder = SenderCertificate.parseFrom(certificateObject.getCertificate());
SenderCertificate.Certificate certificate = SenderCertificate.Certificate.parseFrom(certificateHolder.getCertificate());
ServerCertificate serverCertificateHolder = certificate.getSigner();
ServerCertificate.Certificate serverCertificate = ServerCertificate.Certificate.parseFrom(
serverCertificateHolder.getCertificate());
ServerCertificate serverCertificateHolder = certificate.getSigner();
ServerCertificate.Certificate serverCertificate = ServerCertificate.Certificate.parseFrom(serverCertificateHolder.getCertificate());
assertTrue(Curve.verifySignature(Curve.decodePoint(serverCertificate.getKey().toByteArray(), 0), certificateHolder.getCertificate().toByteArray(), certificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(Base64.getDecoder().decode(caPublicKey), 0), serverCertificateHolder.getCertificate().toByteArray(), serverCertificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(serverCertificate.getKey().toByteArray(), 0),
certificateHolder.getCertificate().toByteArray(), certificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(Base64.getDecoder().decode(caPublicKey), 0),
serverCertificateHolder.getCertificate().toByteArray(), serverCertificateHolder.getSignature().toByteArray()));
assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER);
assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertArrayEquals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
assertArrayEquals(certificate.getIdentityKey().toByteArray(),
Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
}
@Test
void testValidCertificateWithUuidNoE164() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.queryParam("includeUuid", "true")
.queryParam("includeE164", "false")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(DeliveryCertificate.class);
.target("/v1/certificate/delivery")
.queryParam("includeUuid", "true")
.queryParam("includeE164", "false")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(DeliveryCertificate.class);
SenderCertificate certificateHolder = SenderCertificate.parseFrom(certificateObject.getCertificate());
SenderCertificate.Certificate certificate = SenderCertificate.Certificate.parseFrom(
certificateHolder.getCertificate());
SenderCertificate certificateHolder = SenderCertificate.parseFrom(certificateObject.getCertificate());
SenderCertificate.Certificate certificate = SenderCertificate.Certificate.parseFrom(certificateHolder.getCertificate());
ServerCertificate serverCertificateHolder = certificate.getSigner();
ServerCertificate.Certificate serverCertificate = ServerCertificate.Certificate.parseFrom(
serverCertificateHolder.getCertificate());
ServerCertificate serverCertificateHolder = certificate.getSigner();
ServerCertificate.Certificate serverCertificate = ServerCertificate.Certificate.parseFrom(serverCertificateHolder.getCertificate());
assertTrue(Curve.verifySignature(Curve.decodePoint(serverCertificate.getKey().toByteArray(), 0), certificateHolder.getCertificate().toByteArray(), certificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(Base64.getDecoder().decode(caPublicKey), 0), serverCertificateHolder.getCertificate().toByteArray(), serverCertificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(serverCertificate.getKey().toByteArray(), 0),
certificateHolder.getCertificate().toByteArray(), certificateHolder.getSignature().toByteArray()));
assertTrue(Curve.verifySignature(Curve.decodePoint(Base64.getDecoder().decode(caPublicKey), 0),
serverCertificateHolder.getCertificate().toByteArray(), serverCertificateHolder.getSignature().toByteArray()));
assertTrue(StringUtils.isBlank(certificate.getSender()));
assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertArrayEquals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
assertArrayEquals(certificate.getIdentityKey().toByteArray(),
Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
}
@Test
void testBadAuthentication() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.get();
.target("/v1/certificate/delivery")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.get();
assertEquals(response.getStatus(), 401);
}
@@ -179,9 +192,9 @@ class CertificateControllerTest {
@Test
void testNoAuthentication() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.get();
.target("/v1/certificate/delivery")
.request()
.get();
assertEquals(response.getStatus(), 401);
}
@@ -190,10 +203,10 @@ class CertificateControllerTest {
@Test
void testUnidentifiedAuthentication() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.get();
.target("/v1/certificate/delivery")
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1234".getBytes()))
.get();
assertEquals(response.getStatus(), 401);
}
@@ -201,10 +214,10 @@ class CertificateControllerTest {
@Test
void testDisabledAuthentication() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.get();
.target("/v1/certificate/delivery")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.get();
assertEquals(response.getStatus(), 401);
}
@@ -212,59 +225,62 @@ class CertificateControllerTest {
@Test
void testGetSingleAuthCredential() {
GroupCredentials credentials = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + Util.currentDaysSinceEpoch())
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(GroupCredentials.class);
.target("/v1/certificate/group/" + currentDaysSinceEpoch() + "/" + currentDaysSinceEpoch())
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(GroupCredentials.class);
assertThat(credentials.credentials().size()).isEqualTo(1);
assertThat(credentials.credentials().get(0).redemptionTime()).isEqualTo(Util.currentDaysSinceEpoch());
assertThat(credentials.credentials().get(0).redemptionTime()).isEqualTo(currentDaysSinceEpoch());
ClientZkAuthOperations clientZkAuthOperations = new ClientZkAuthOperations(serverSecretParams.getPublicParams());
assertThatCode(() ->
clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, Util.currentDaysSinceEpoch(), new AuthCredentialResponse(credentials.credentials().get(0).credential())))
clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, currentDaysSinceEpoch(),
new AuthCredentialResponse(credentials.credentials().get(0).credential())))
.doesNotThrowAnyException();
}
@Test
void testGetSingleAuthCredentialByPni() {
GroupCredentials credentials = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + Util.currentDaysSinceEpoch())
.target("/v1/certificate/group/" + currentDaysSinceEpoch() + "/" + currentDaysSinceEpoch())
.queryParam("identity", "pni")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(GroupCredentials.class);
assertThat(credentials.credentials().size()).isEqualTo(1);
assertThat(credentials.credentials().get(0).redemptionTime()).isEqualTo(Util.currentDaysSinceEpoch());
assertThat(credentials.credentials().get(0).redemptionTime()).isEqualTo(currentDaysSinceEpoch());
ClientZkAuthOperations clientZkAuthOperations = new ClientZkAuthOperations(serverSecretParams.getPublicParams());
assertThatExceptionOfType(VerificationFailedException.class)
.isThrownBy(() ->
clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, Util.currentDaysSinceEpoch(), new AuthCredentialResponse(credentials.credentials().get(0).credential())));
clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, currentDaysSinceEpoch(),
new AuthCredentialResponse(credentials.credentials().get(0).credential())));
}
@Test
void testGetWeekLongAuthCredentials() {
GroupCredentials credentials = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + (Util.currentDaysSinceEpoch() + 7))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(GroupCredentials.class);
.target("/v1/certificate/group/" + currentDaysSinceEpoch() + "/" + (currentDaysSinceEpoch() + 7))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(GroupCredentials.class);
assertThat(credentials.credentials().size()).isEqualTo(8);
for (int i=0;i<=7;i++) {
assertThat(credentials.credentials().get(i).redemptionTime()).isEqualTo(Util.currentDaysSinceEpoch() + i);
for (int i = 0; i <= 7; i++) {
assertThat(credentials.credentials().get(i).redemptionTime()).isEqualTo(currentDaysSinceEpoch() + i);
ClientZkAuthOperations clientZkAuthOperations = new ClientZkAuthOperations(serverSecretParams.getPublicParams());
final int time = i;
assertThatCode(() ->
clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, Util.currentDaysSinceEpoch() + time , new AuthCredentialResponse(credentials.credentials().get(time).credential())))
clientZkAuthOperations.receiveAuthCredential(AuthHelper.VALID_UUID, currentDaysSinceEpoch() + time,
new AuthCredentialResponse(credentials.credentials().get(time).credential())))
.doesNotThrowAnyException();
}
}
@@ -272,10 +288,10 @@ class CertificateControllerTest {
@Test
void testTooManyDaysOut() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + (Util.currentDaysSinceEpoch() + 8))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
.target("/v1/certificate/group/" + currentDaysSinceEpoch() + "/" + (currentDaysSinceEpoch() + 8))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertThat(response.getStatus()).isEqualTo(400);
}
@@ -283,10 +299,10 @@ class CertificateControllerTest {
@Test
void testBackwardsInTime() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/group/" + (Util.currentDaysSinceEpoch() - 1) + "/" + (Util.currentDaysSinceEpoch() + 7))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
.target("/v1/certificate/group/" + (currentDaysSinceEpoch() - 1) + "/" + (currentDaysSinceEpoch() + 7))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertThat(response.getStatus()).isEqualTo(400);
}
@@ -294,10 +310,10 @@ class CertificateControllerTest {
@Test
void testBadAuth() {
Response response = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + (Util.currentDaysSinceEpoch() + 7))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.get();
.target("/v1/certificate/group/" + currentDaysSinceEpoch() + "/" + (currentDaysSinceEpoch() + 7))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.INVALID_PASSWORD))
.get();
assertThat(response.getStatus()).isEqualTo(401);
}
@@ -387,7 +403,8 @@ class CertificateControllerTest {
Arguments.of(clock.instant().minus(Duration.ofDays(1)), clock.instant()),
// End is too far in the future
Arguments.of(clock.instant(), clock.instant().plus(CertificateController.MAX_REDEMPTION_DURATION).plus(Duration.ofDays(1))),
Arguments.of(clock.instant(),
clock.instant().plus(CertificateController.MAX_REDEMPTION_DURATION).plus(Duration.ofDays(1))),
// Start is not at a day boundary
Arguments.of(clock.instant().plusSeconds(17), clock.instant().plus(Duration.ofDays(1))),
@@ -396,4 +413,8 @@ class CertificateControllerTest {
Arguments.of(clock.instant(), clock.instant().plusSeconds(17))
);
}
private static int currentDaysSinceEpoch() {
return Util.currentDaysSinceEpoch(Clock.systemUTC());
}
}

View File

@@ -20,6 +20,7 @@ import java.util.UUID;
import java.util.function.Consumer;
import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -69,10 +70,13 @@ public class AccountsHelper {
});
when(mockAccountsManager.updateDeviceLastSeen(any(), any(), anyLong())).thenAnswer(answer -> {
answer.getArgument(1, Device.class).setLastSeen(answer.getArgument(2, Long.class));
return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {
});
return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {});
});
when(mockAccountsManager.updateDeviceAuthentication(any(), any(), any())).thenAnswer(answer -> {
answer.getArgument(1, Device.class).setAuthenticationCredentials(answer.getArgument(2, AuthenticationCredentials.class));
return mockAccountsManager.update(answer.getArgument(0, Account.class), account -> {});
});
}

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.util;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
@@ -25,6 +26,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -189,10 +191,12 @@ public class AuthHelper {
testAccount.setup(ACCOUNTS_MANAGER);
}
ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), any())).thenReturn(true);
AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter = new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>().setAuthenticator(
new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
new AccountAuthenticator(ACCOUNTS_MANAGER, experimentEnrollmentManager)).buildAuthFilter();
AuthFilter<BasicCredentials, DisabledPermittedAuthenticatedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAuthenticatedAccount>().setAuthenticator(
new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER, experimentEnrollmentManager)).buildAuthFilter();
return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter,
DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter));

View File

@@ -0,0 +1,141 @@
package org.whispersystems.textsecuregcm.tests.util;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.storage.UsernameNotAvailableException;
import org.whispersystems.textsecuregcm.util.UsernameGenerator;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;
import static org.assertj.core.api.Assertions.assertThat;
public class UsernameGeneratorTest {
@ParameterizedTest(name = "[{index}]:{0} ({2})")
@MethodSource
public void nicknameValidation(String nickname, boolean valid, String testCaseName) {
assertThat(UsernameGenerator.isValidNickname(nickname)).isEqualTo(valid);
}
static Stream<Arguments> nicknameValidation() {
return Stream.of(
Arguments.of("Test", false, "upper case"),
Arguments.of("tesT", false, "upper case"),
Arguments.of("te-st", false, "illegal character"),
Arguments.of("ab\uD83D\uDC1B", false, "illegal character"),
Arguments.of("1test", false, "illegal start"),
Arguments.of("test#123", false, "illegal character"),
Arguments.of("ab", false, "too short"),
Arguments.of("", false, ""),
Arguments.of("_123456789_123456789_123456789123", false, "33 characters"),
Arguments.of("_test", true, ""),
Arguments.of("test", true, ""),
Arguments.of("test123", true, ""),
Arguments.of("abc", true, ""),
Arguments.of("_123456789_123456789_12345678912", true, "32 characters")
);
}
@ParameterizedTest(name="[{index}]: {0}")
@MethodSource
public void nonStandardUsernames(final String username, final boolean isStandard) {
assertThat(UsernameGenerator.isStandardFormat(username)).isEqualTo(isStandard);
}
static Stream<Arguments> nonStandardUsernames() {
return Stream.of(
Arguments.of("Test#123", false),
Arguments.of("test#-123", false),
Arguments.of("test#0", false),
Arguments.of("test#", false),
Arguments.of("test#1_00", false),
Arguments.of("test#1", true),
Arguments.of("abc#1234", true)
);
}
@Test
public void zeroPadDiscriminators() {
final UsernameGenerator generator = new UsernameGenerator(4, 5, 1);
assertThat(generator.fromParts("test", 1)).isEqualTo("test#0001");
assertThat(generator.fromParts("test", 123)).isEqualTo("test#0123");
assertThat(generator.fromParts("test", 9999)).isEqualTo("test#9999");
assertThat(generator.fromParts("test", 99999)).isEqualTo("test#99999");
}
@Test
public void expectedWidth() throws UsernameNotAvailableException {
String username = new UsernameGenerator(1, 6, 1).generateAvailableUsername("test", t -> true);
assertThat(extractDiscriminator(username)).isGreaterThan(0).isLessThan(10);
username = new UsernameGenerator(2, 6, 1).generateAvailableUsername("test", t -> true);
assertThat(extractDiscriminator(username)).isGreaterThan(0).isLessThan(100);
}
@Test
public void expandDiscriminator() throws UsernameNotAvailableException {
UsernameGenerator ug = new UsernameGenerator(1, 6, 10);
final String username = ug.generateAvailableUsername("test", allowDiscriminator(d -> d >= 10000));
int discriminator = extractDiscriminator(username);
assertThat(discriminator).isGreaterThanOrEqualTo(10000).isLessThan(100000);
}
@Test
public void expandDiscriminatorToMax() throws UsernameNotAvailableException {
UsernameGenerator ug = new UsernameGenerator(1, 6, 10);
final String username = ug.generateAvailableUsername("test", allowDiscriminator(d -> d >= 100000));
int discriminator = extractDiscriminator(username);
assertThat(discriminator).isGreaterThanOrEqualTo(100000).isLessThan(1000000);
}
@Test
public void exhaustDiscriminator() {
UsernameGenerator ug = new UsernameGenerator(1, 6, 10);
Assertions.assertThrows(UsernameNotAvailableException.class, () -> {
// allow greater than our max width
ug.generateAvailableUsername("test", allowDiscriminator(d -> d >= 1000000));
});
}
@Test
public void randomCoverageMinWidth() throws UsernameNotAvailableException {
UsernameGenerator ug = new UsernameGenerator(1, 6, 10);
final Set<Integer> seen = new HashSet<>();
for (int i = 0; i < 1000 && seen.size() < 9; i++) {
seen.add(extractDiscriminator(ug.generateAvailableUsername("test", ignored -> true)));
}
// after 1K iterations, probability of a missed value is (9/10)^999
assertThat(seen.size()).isEqualTo(9);
assertThat(seen).allMatch(i -> i > 0 && i < 10);
}
@Test
public void randomCoverageMidWidth() throws UsernameNotAvailableException {
UsernameGenerator ug = new UsernameGenerator(1, 6, 10);
final Set<Integer> seen = new HashSet<>();
for (int i = 0; i < 100000 && seen.size() < 90; i++) {
seen.add(extractDiscriminator(ug.generateAvailableUsername("test", allowDiscriminator(d -> d >= 10))));
}
// after 100K iterations, probability of a missed value is (99/100)^99999
assertThat(seen.size()).isEqualTo(90);
assertThat(seen).allMatch(i -> i >= 10 && i < 100);
}
private static Predicate<String> allowDiscriminator(Predicate<Integer> p) {
return username -> p.test(extractDiscriminator(username));
}
private static int extractDiscriminator(final String username) {
return Integer.parseInt(username.split(UsernameGenerator.SEPARATOR)[1]);
}
}

View File

@@ -24,11 +24,13 @@ class ValidNumberTest {
"+71234567890",
"+447535742222",
"+4915174108888",
"+2250707312345",
"+298123456",
"+299123456",
"+376123456",
"+68512345",
"+689123456"})
"+689123456",
"+80011111111"})
void requireNormalizedNumber(final String number) {
assertDoesNotThrow(() -> Util.requireNormalizedNumber(number));
}
@@ -56,6 +58,7 @@ class ValidNumberTest {
@ParameterizedTest
@ValueSource(strings = {
"+4407700900111",
"+49493023125000", // double country code - this e164 is "possible"
"+1 415 123 1234",
"+1 (415) 123-1234",
"+1 415)123-1234",

View File

@@ -12,14 +12,14 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
class UsernameValidatorTest {
class NicknameValidatorTest {
@ParameterizedTest
@MethodSource
void isValid(final String username, final boolean expectValid) {
final UsernameValidator usernameValidator = new UsernameValidator();
final NicknameValidator nicknameValidator = new NicknameValidator();
assertEquals(expectValid, usernameValidator.isValid(username, null));
assertEquals(expectValid, nicknameValidator.isValid(username, null));
}
private static Stream<Arguments> isValid() {
@@ -28,8 +28,8 @@ class UsernameValidatorTest {
Arguments.of("_test", true),
Arguments.of("test123", true),
Arguments.of("a", false), // Too short
Arguments.of("thisIsAReallyReallyReallyLongUsernameThatWeWouldNotAllow", false),
Arguments.of("Illegal character", false),
Arguments.of("thisisareallyreallyreallylongusernamethatwewouldnotalllow", false),
Arguments.of("illegal character", false),
Arguments.of("0test", false), // Illegal first character
Arguments.of("pаypal", false), // Unicode confusable characters
Arguments.of("test\uD83D\uDC4E", false), // Emoji
@@ -38,19 +38,4 @@ class UsernameValidatorTest {
Arguments.of(null, false)
);
}
@ParameterizedTest
@MethodSource
void getCanonicalUsername(final String username, final String expectedCanonicalUsername) {
assertEquals(expectedCanonicalUsername, UsernameValidator.getCanonicalUsername(username));
}
private static Stream<Arguments> getCanonicalUsername() {
return Stream.of(
Arguments.of("test", "test"),
Arguments.of("TEst", "test"),
Arguments.of("t_e_S_T", "t_e_s_t"),
Arguments.of(null, null)
);
}
}

View File

@@ -43,7 +43,6 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.storage.Account;
@@ -100,7 +99,7 @@ class WebSocketConnectionIntegrationTest {
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,

View File

@@ -52,7 +52,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -83,7 +83,7 @@ class WebSocketConnectionTest {
private AuthenticatedAccount auth;
private UpgradeRequest upgradeRequest;
private ReceiptSender receiptSender;
private ApnFallbackManager apnFallbackManager;
private PushNotificationManager pushNotificationManager;
private ScheduledExecutorService retrySchedulingExecutor;
@BeforeEach
@@ -95,7 +95,7 @@ class WebSocketConnectionTest {
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
upgradeRequest = mock(UpgradeRequest.class);
receiptSender = mock(ReceiptSender.class);
apnFallbackManager = mock(ApnFallbackManager.class);
pushNotificationManager = mock(PushNotificationManager.class);
retrySchedulingExecutor = mock(ScheduledExecutorService.class);
}
@@ -104,7 +104,7 @@ class WebSocketConnectionTest {
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages,
mock(PushNotificationManager.class), apnFallbackManager, mock(ClientPresenceManager.class),
mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
retrySchedulingExecutor);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
@@ -166,7 +166,7 @@ class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
.thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -221,9 +221,8 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(Collections.emptyList(), false))
.thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), false))
.thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), false));
@@ -316,7 +315,7 @@ class WebSocketConnectionTest {
String userAgent = "user-agent";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
.thenReturn(new Pair<>(pendingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -362,12 +361,11 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer(
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)).thenAnswer(
(Answer<OutgoingMessageEntityList>) invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
@@ -415,7 +413,7 @@ class WebSocketConnectionTest {
}
});
verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), anyString(), eq(false));
verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), eq(false));
}
@Test
@@ -429,7 +427,6 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<Envelope> firstPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"),
@@ -438,7 +435,7 @@ class WebSocketConnectionTest {
final List<Envelope> secondPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false))
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false))
.thenReturn(new Pair<>(firstPageMessages, true))
.thenReturn(new Pair<>(secondPageMessages, false));
@@ -473,13 +470,12 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID();
final List<Envelope> messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false))
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false))
.thenReturn(new Pair<>(messages, false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -530,9 +526,8 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -560,7 +555,6 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<Envelope> firstPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"),
@@ -569,7 +563,7 @@ class WebSocketConnectionTest {
final List<Envelope> secondPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(firstPageMessages, false))
.thenReturn(new Pair<>(secondPageMessages, false))
.thenReturn(new Pair<>(Collections.emptyList(), false));
@@ -609,9 +603,8 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -623,11 +616,11 @@ class WebSocketConnectionTest {
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), false);
verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), false);
connection.handleNewMessagesAvailable();
verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), true);
verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), true);
}
@Test
@@ -643,9 +636,8 @@ class WebSocketConnectionTest {
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
.thenReturn(new Pair<>(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
@@ -658,7 +650,7 @@ class WebSocketConnectionTest {
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), false);
verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), false);
}
@Test
@@ -692,7 +684,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Desktop/1.2.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
.thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -762,7 +754,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Android/4.68.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
.thenReturn(new Pair<>(outgoingMessages, false));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@@ -814,7 +806,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Android/4.68.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
.thenThrow(new RedisException("OH NO"));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
@@ -848,7 +840,7 @@ class WebSocketConnectionTest {
String userAgent = "Signal-Android/4.68.3";
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
.thenThrow(new RedisException("OH NO"));
final WebSocketClient client = mock(WebSocketClient.class);