Compare commits

...

89 Commits

Author SHA1 Message Date
Jon Chambers
ef9a7fda9a Publish outstanding SQS operation count as a gauge. 2021-07-27 11:15:41 -04:00
Chris Eager
13447df1e0 Update validation for NotNull items in IncomingMessagesList 2021-07-27 10:39:30 -04:00
Jon Chambers
3608c5bfb0 Wait for outstanding requests to be resolved before shutting down the directory queue. 2021-07-27 10:36:53 -04:00
Jon Chambers
34dbff6786 Switch to an async SQS client. 2021-07-27 10:36:53 -04:00
Jon Chambers
a6066bfc2f Migrate DirectoryQueueTest to JUnit 5. 2021-07-27 10:36:53 -04:00
Jon Chambers
8579190cdf Consolidate account creation/directory updates into AccountsManager 2021-07-27 10:27:47 -04:00
Chris Eager
917f667229 Remove AccountController and KeysController from websocket 2021-07-26 14:27:43 -05:00
Chris Eager
317a551bdb Migrate MetricsRequestEventListenerTest to JUnit 5 2021-07-26 12:06:29 -05:00
Chris Eager
27e9271473 Add request path and user agent to unhandled exception logging 2021-07-26 12:06:29 -05:00
Fedor Indutny
11dff6c546 more controllers 2021-07-26 12:06:17 -05:00
Fedor Indutny
e6712937ca fix indent 2021-07-26 12:06:17 -05:00
Fedor Indutny
cf8887bb5a Provide more WebSocket endpoints 2021-07-26 12:06:17 -05:00
Chris Eager
696340f780 Migrate DeviceControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
86ddcbaa08 Migrate CertificateControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
2144d2a8d8 Migrate AttachmentControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
f7af861b31 Migrate SecureStorageControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
208a09b3ae Migrate RemoteConfigControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
831023e41d Migrate PaymentsControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
ff627793d6 Migrate DirectoryControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
f971c76a99 Migrate StickerControllerTest to JUnit 5 2021-07-26 11:18:17 -05:00
Chris Eager
8f41176c76 Enable "sms" transport for +98 2021-07-26 10:40:05 -05:00
Ehren Kret
31bbbbb5e0 Raise default message TTL to 14 days 2021-07-20 14:08:08 -05:00
Jon Chambers
effcd6038d Also record dimensional metrics for circuit breakers and retries. 2021-07-19 16:56:16 -04:00
Jon Chambers
12be7d49c2 Clear one-time pre-keys on re-registration. 2021-07-19 10:05:01 -04:00
Jon Chambers
14863b575e Clear one-time pre-keys when a device is unlinked. 2021-07-19 10:05:01 -04:00
Jon Chambers
32a95f96ff Add a pessimistic locking system for operations on recently-deleted account records 2021-07-16 16:52:58 -04:00
Jon Chambers
b757d4b334 Measure how many "send message" requests are still using e164-based addressing. 2021-07-16 16:52:58 -04:00
Chris Eager
bd03d910fe Set authenticated device after updating last seen 2021-07-16 16:52:58 -04:00
Chris Eager
01ef855157 Return a non-stale account from base authenticator when last seen is updated 2021-07-16 16:52:58 -04:00
Chris Eager
817866caf3 Use fresh accounts to update in PushFeedbackProcessor 2021-07-16 16:52:58 -04:00
Chris Eager
158d65c6a7 Add optimistic locking to account updates 2021-07-16 16:52:58 -04:00
realturner
62022c7de1 Migrate AppConfig to SDK v2 to detect and use web identify token 2021-07-16 16:48:33 -04:00
Chris Eager
a824b5575d Add dynamic configuration for using DynamoDB in AccountsDatabaseCrawler 2021-07-06 13:01:24 -05:00
Jon Chambers
78819d5382 Remove expiration logic when checking token validity.
The data store will no longer return tokens that have expired, and we no longer need to check for expiration in application space.
2021-07-06 11:03:49 -04:00
Jon Chambers
d128bc782a Retire Postgres-backed pending account/device tables. 2021-07-06 11:03:49 -04:00
Chris Eager
530b2a310f Ensure active future is always completed 2021-07-02 15:05:11 -05:00
Chris Eager
d5b0d99a54 Remove unused method 2021-07-02 15:05:11 -05:00
Chris Eager
43be72d076 Add test for ManagedPeriodicWork; fix shutdown not awaiting active execution 2021-07-02 15:05:11 -05:00
Chris Eager
9558944e22 Add needsReconciliationIndexName to sample.yml 2021-07-02 15:05:11 -05:00
Chris Eager
0f6c866c8d Update imports 2021-07-02 15:05:11 -05:00
Chris Eager
bac78e9291 Switch DeletedAccountsTableCrawler metrics to a basic Metrics#summary 2021-07-02 15:05:11 -05:00
Chris Eager
c22ea78672 Add crawler to process migration retry accounts 2021-07-02 15:05:11 -05:00
Chris Eager
a85afe827d Avoid NPE by using scheduledFuture as the Gauge state object 2021-07-02 15:05:11 -05:00
Chris Eager
abaed821ec Add additional case to unit test 2021-07-02 15:05:11 -05:00
Chris Eager
6fa9dcd954 Refactor to use shared recurringJobExecutor 2021-07-02 15:05:11 -05:00
Chris Eager
819d59cd79 Update reconciliation crawler to use secondary index 2021-07-02 15:05:11 -05:00
Chris Eager
2f88f0eedb Refactor to use single threaded scheduled executor 2021-07-02 15:05:11 -05:00
Chris Eager
74ff491671 Rename ManagedPeriodicWorkCache to ManagedPeriodicWorkLock 2021-07-02 15:05:11 -05:00
Chris Eager
eac48a6617 Don’t delete accounts after reconciling 2021-07-02 15:05:11 -05:00
Chris Eager
19617c14f8 Improved logging in ManagedPeriodcWork 2021-07-02 15:05:11 -05:00
Chris Eager
fc7291c3e8 Migrate DeletedAccountsTableCrawler to ManagedPeriodicWork 2021-07-02 15:05:11 -05:00
Chris Eager
88db808298 Add abstract ManagedPeriodicWork 2021-07-02 15:05:11 -05:00
Chris Eager
5193abdab3 Add DeletedAccountsTableCrawler 2021-07-02 15:05:11 -05:00
Chris Eager
a315c9be92 Add DeletedAccounts DynamoDB table 2021-07-02 15:05:11 -05:00
Chris Eager
fc1541591a Add AbstractDynamoDbStore#scan 2021-07-02 15:05:11 -05:00
Chris Eager
ae97c4db9f Use editorconfig in AbstractDynamoDbStore 2021-07-02 15:05:11 -05:00
Chris Eager
26bc5973b5 Clear message queue before and after removing a device 2021-07-02 10:48:42 -05:00
Chris Eager
e52b8c8423 Implement DatadogConfig in DatadogConfiguration 2021-07-02 10:48:05 -05:00
Jon Chambers
7395489bac Add tests for pending account/device managers. 2021-07-02 11:30:13 -04:00
Jon Chambers
b384ed7f5c Add a counter for requests for delivery certificates with/without e164s. 2021-07-01 10:59:10 -04:00
Jon Chambers
e3afcae7d3 Gather data to verify safety of retiring legacy reglock system. 2021-07-01 10:58:47 -04:00
Jon Chambers
9faeed7b20 Count E164 authentications versus UUID authentications. 2021-07-01 10:51:34 -04:00
Jon Chambers
49adcca80e Use Optional.isEmpty(). 2021-07-01 10:51:34 -04:00
Jon Chambers
49c43a6816 Simplify distribution summary for "days since last seen." 2021-07-01 10:51:34 -04:00
Jon Chambers
84f85ae098 Collapse various account meters into a single, multi-dimensional counter. 2021-07-01 10:51:34 -04:00
Jon Chambers
3d581941ab Add plumbing and configuration to migrate pending accounts/devices to DynamoDB. 2021-07-01 10:50:52 -04:00
Jon Chambers
d2d39baede Add a DynamoDB-backed stored verification code store. 2021-07-01 10:50:52 -04:00
Jon Chambers
111f5ba024 Use java.time classes for stored verification code expiration; add tests. 2021-07-01 10:50:52 -04:00
Jon Chambers
ce3fb7fa99 Extract a common base class for verification code store tests. 2021-07-01 10:50:52 -04:00
Jon Chambers
fc421d3f21 Introduce a common interface for verification code stores. 2021-07-01 10:50:52 -04:00
Jon Chambers
71bea759c6 Consolidate StoredVerificationCode constructors. 2021-07-01 10:50:52 -04:00
Jon Chambers
bf1dd791a5 Drop caching for pending accounts/devices. 2021-07-01 10:50:52 -04:00
Chris Eager
4c99577c08 Add configuration for Datadog batch size 2021-06-30 16:44:25 -05:00
Graeme Connell
5d5c63e6d4 Update profile controller to S3 AWSv2. 2021-06-30 13:09:18 -06:00
Graeme Connell
42ff3f8432 Switch SQS to Amazon SDKv2. 2021-06-30 12:46:12 -06:00
Chris Eager
be6ef76486 Update DynamoDBLocal to 1.16.0 2021-06-23 13:50:58 -05:00
Chris Eager
bc297e6d34 Update wiremock-jre8 to 2.28.1 2021-06-23 13:50:58 -05:00
Chris Eager
3a526dcbd7 Update mockito to 3.11.1 2021-06-23 13:50:58 -05:00
Ehren Kret
7883352b74 Match random capability generation in test 2021-06-21 17:32:31 -05:00
Ehren Kret
982d122d18 Match random capability generation in test 2021-06-21 17:32:31 -05:00
Ehren Kret
d8d94407c6 Create announcement group capability 2021-06-21 17:32:31 -05:00
Chris Eager
28cfc54170 Update FunctionCounter builder to use non-null object and method 2021-06-11 11:27:45 -05:00
Jon Chambers
2ee7279743 Pause nstat counters. 2021-06-11 12:26:56 -04:00
Jon Chambers
eb1b073385 Add a hostname-aware reporter factory. 2021-06-10 14:23:05 -04:00
Jon Chambers
c634185b6f Standardize a utility method for getting local host names. 2021-06-10 14:23:05 -04:00
Ehren Kret
827a3af419 Code cleanup 2021-06-09 20:44:18 -05:00
Jon Chambers
2c33d22a30 Stop recording specific client versions in metrics until we know we need them again. 2021-06-08 12:25:31 -04:00
Chris Eager
b41ed9d810 Update sample.yml config 2021-06-07 17:21:36 -04:00
Jon Chambers
58d3a12eff Set hostname to lowercase to avoid strange case mismatch issues; log hostname failures. 2021-06-07 17:17:46 -04:00
119 changed files with 5430 additions and 2626 deletions

10
pom.xml
View File

@@ -35,6 +35,7 @@
<commons-csv.version>1.8</commons-csv.version>
<commons-io.version>2.9.0</commons-io.version>
<dropwizard.version>2.0.22</dropwizard.version>
<dropwizard-metrics-datadog.version>1.1.13</dropwizard-metrics-datadog.version>
<guava.version>30.1.1-jre</guava.version>
<jaxb.version>2.3.1</jaxb.version>
<jedis.version>2.9.0</jedis.version>
@@ -42,7 +43,7 @@
<libphonenumber.version>8.12.23</libphonenumber.version>
<logstash.logback.version>6.6</logstash.logback.version>
<micrometer.version>1.5.3</micrometer.version>
<mockito.version>2.25.1</mockito.version>
<mockito.version>3.11.1</mockito.version>
<netty.version>4.1.65.Final</netty.version>
<netty.tcnative-boringssl-static.version>2.0.39.Final</netty.tcnative-boringssl-static.version>
<opentest4j.version>1.2.0</opentest4j.version>
@@ -171,6 +172,11 @@
<artifactId>commons-csv</artifactId>
<version>${commons-csv.version}</version>
</dependency>
<dependency>
<groupId>org.coursera</groupId>
<artifactId>dropwizard-metrics-datadog</artifactId>
<version>${dropwizard-metrics-datadog.version}</version>
</dependency>
<dependency>
<groupId>org.glassfish.jaxb</groupId>
<artifactId>jaxb-runtime</artifactId>
@@ -230,7 +236,7 @@
<dependency>
<groupId>com.github.tomakehurst</groupId>
<artifactId>wiremock-jre8</artifactId>
<version>2.27.2</version>
<version>2.28.1</version>
<scope>test</scope>
<exclusions>
<exclusion>

View File

@@ -5,7 +5,7 @@
package org.whispersystems.dispatch;
public interface DispatchChannel {
public void onDispatchMessage(String channel, byte[] message);
public void onDispatchSubscribed(String channel);
public void onDispatchUnsubscribed(String channel);
void onDispatchMessage(String channel, byte[] message);
void onDispatchSubscribed(String channel);
void onDispatchUnsubscribed(String channel);
}

View File

@@ -59,9 +59,7 @@ public class DispatchManager extends Thread {
logger.warn("Subscription error", e);
}
if (previous.isPresent()) {
dispatchUnsubscription(name, previous.get());
}
previous.ifPresent(channel -> dispatchUnsubscription(name, channel));
}
public synchronized void unsubscribe(String name, DispatchChannel channel) {
@@ -132,46 +130,28 @@ public class DispatchManager extends Thread {
}
private void resubscribeAll() {
new Thread() {
@Override
public void run() {
synchronized (DispatchManager.this) {
try {
for (String name : subscriptions.keySet()) {
pubSubConnection.subscribe(name);
}
} catch (IOException e) {
logger.warn("***** RESUBSCRIPTION ERROR *****", e);
new Thread(() -> {
synchronized (DispatchManager.this) {
try {
for (String name : subscriptions.keySet()) {
pubSubConnection.subscribe(name);
}
} catch (IOException e) {
logger.warn("***** RESUBSCRIPTION ERROR *****", e);
}
}
}.start();
}).start();
}
private void dispatchMessage(final String name, final DispatchChannel channel, final byte[] message) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchMessage(name, message);
}
});
executor.execute(() -> channel.onDispatchMessage(name, message));
}
private void dispatchSubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchSubscribed(name);
}
});
executor.execute(() -> channel.onDispatchSubscribed(name));
}
private void dispatchUnsubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchUnsubscribed(name);
}
});
executor.execute(() -> channel.onDispatchUnsubscribed(name));
}
}

View File

@@ -8,6 +8,6 @@ import org.whispersystems.dispatch.redis.PubSubConnection;
public interface RedisPubSubConnectionFactory {
public PubSubConnection connect();
PubSubConnection connect();
}

View File

@@ -21,9 +21,6 @@ twilio: # Twilio gateway configuration
push:
queueSize: # Size of push pending queue
redphone:
authKey: # Deprecated
turn: # TURN server configuration
secret: # TURN server secret
uris:
@@ -36,6 +33,23 @@ cacheCluster: # Redis server configuration for cache cluster
urls:
- redis://redis.example.com:6379/
clientPresenceCluster: # Redis server configuration for client presence cluster
urls:
- redis://redis.example.com:6379/
pubsub: # Redis server configuration for pubsub cluster
url: redis://redis.example.com:6379/
replicaUrls:
- redis://redis.example.com:6379/
pushSchedulerCluster: # Redis server configuration for push scheduler cluster
urls:
- redis://redis.example.com:6379/
rateLimitersCluster: # Redis server configuration for rate limiters cluster
urls:
- redis://redis.example.com:6379/
directory:
client: # Configuration for interfacing with Contact Discovery Service cluster
userAuthenticationTokenSharedSecret: # hex-encoded secret shared with CDS used to generate auth tokens for Signal users
@@ -43,13 +57,13 @@ directory:
sqs:
accessKey: # AWS SQS accessKey
accessSecret: # AWS SQS accessSecret
queueUrl: # AWS SQS queue url
server:
replicationUrl: # CDS replication endpoint base url
replicationPassword: # CDS replication endpoint password
replicationCaCertificate: # CDS replication endpoint TLS certificate trust root
reconciliationChunkSize: # CDS reconciliation chunk size
reconciliationChunkIntervalMs: # CDS reconciliation chunk interval, in milliseconds
queueUrls: # AWS SQS queue urls
- https://sqs.example.com/directory.fifo
server: # One or more CDS servers
- replicationName: # CDS replication name
replicationUrl: # CDS replication endpoint base url
replicationPassword: # CDS replication endpoint password
replicationCaCertificate: # CDS replication endpoint TLS certificate trust root
messageCache: # Redis server configuration for message store cache
persistDelayMinutes:
@@ -58,16 +72,44 @@ messageCache: # Redis server configuration for message store cache
urls:
- redis://redis.example.com:6379/
messageStore: # Postgresql database configuration for message store
driverClass: org.postgresql.Driver
user:
password:
url:
metricsCluster:
urls:
- redis://redis.example.com:6379/
messageDynamoDb: # DynamoDB table configuration
region:
tableName:
keysDynamoDb: # DynamoDB table configuration
region:
tableName:
accountsDynamoDb: # DynamoDB table configuration
region:
tableName:
phoneNumberTableName:
deletedAccountsDynamoDb: # DynamoDb table configuration
region:
tableName:
needsReconciliationIndexName:
migrationDeletedAccountsDynamoDb: # DynamoDB table configuration
region:
tableName:
migrationRetryAccountsDynamoDb: # DynamoDB table configuration
region:
tableName:
pushChallengeDynamoDb: # DynamoDB table configuration
region:
tableName:
reportMessageDynamoDb: # DynamoDB table configuration
region:
tableName:
awsAttachments: # AWS S3 configuration
accessKey:
accessSecret:
@@ -81,18 +123,22 @@ gcpAttachments: # GCP Storage configuration
pathPrefix:
rsaSigningKey:
profiles: # AWS S3 configuration
accessKey:
accessSecret:
bucket:
region:
database: # Postgresql database configuration
abuseDatabase: # Postgresql database configuration
driverClass: org.postgresql.Driver
user:
password:
url:
accountsDatabase: # Postgresql database configuration
driverClass: org.postgresql.Driver
user:
password:
url:
accountDatabaseCrawler:
chunkSize: # accounts per run
chunkIntervalMs: # time per run
apn: # Apple Push Notifications configuration
sandbox: true
bundleId:
@@ -104,11 +150,52 @@ gcm: # GCM Configuration
senderId:
apiKey:
micrometer: # Micrometer metrics config
- name: "example"
- uri: "https://metrics.example.com/"
- apiKey:
- accountId:
cdn:
accessKey: # AWS Access Key ID
accessSecret: # AWS Access Secret
bucket: # S3 Bucket name
region: # AWS region
wavefront: # Wavefront micrometer metrics config
uri: # Wavefront proxy endpoint
batchSize: # Number of measurements to send per request
datadog:
apiKey:
environment:
unidentifiedDelivery:
certificate:
privateKey:
expiresDays:
voiceVerification:
url: https://cdn-ca.signal.org/verification/
locales:
- en
recaptcha:
secret:
storageService:
uri:
userAuthenticationTokenSharedSecret:
storageCaCertificate:
backupService:
uri:
userAuthenticationTokenSharedSecret:
backupCaCertificate:
zkConfig:
serverPublic:
serverSecret:
enabled:
appConfig:
application:
environment:
configuration:
remoteConfig:
authorizedTokens:
@@ -118,9 +205,22 @@ remoteConfig:
- # Nth authorized token
globalConfig: # keys and values that are given to clients on GET /v1/config
paymentService:
paymentsService:
userAuthenticationTokenSharedSecret: # hex-encoded 32-byte secret shared with MobileCoin services used to generate auth tokens for Signal users
torExitNodeList:
s3Region:
s3Bucket:
objectKey:
maxSize:
asnTable:
s3Region:
s3Bucket:
objectKey:
maxSize:
donation:
uri: # value
apiKey: # value

View File

@@ -226,7 +226,10 @@
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-datadog</artifactId>
</dependency>
<dependency>
<groupId>org.coursera</groupId>
<artifactId>dropwizard-metrics-datadog</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
@@ -252,14 +255,26 @@
<artifactId>jackson-jaxrs-json-provider</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sts</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sqs</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>dynamodb</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>appconfig</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-core</artifactId>
@@ -270,11 +285,14 @@
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sqs</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-appconfig</artifactId>
<artifactId>dynamodb-lock-client</artifactId>
<version>1.1.0</version>
<exclusions>
<exclusion>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
@@ -407,7 +425,7 @@
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>DynamoDBLocal</artifactId>
<version>1.13.6</version>
<version>1.16.0</version>
<scope>test</scope>
<exclusions>
<exclusion>

View File

@@ -22,6 +22,7 @@ import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguratio
import org.whispersystems.textsecuregcm.configuration.CdnConfiguration;
import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration;
import org.whispersystems.textsecuregcm.configuration.DatadogConfiguration;
import org.whispersystems.textsecuregcm.configuration.DeletedAccountsDynamoDbConfiguration;
import org.whispersystems.textsecuregcm.configuration.DirectoryConfiguration;
import org.whispersystems.textsecuregcm.configuration.DonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.DynamoDbConfiguration;
@@ -157,6 +158,16 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private DynamoDbConfiguration migrationRetryAccountsDynamoDb;
@Valid
@NotNull
@JsonProperty
private DeletedAccountsDynamoDbConfiguration deletedAccountsDynamoDb;
@Valid
@NotNull
@JsonProperty
private DynamoDbConfiguration deletedAccountsLockDynamoDb;
@Valid
@NotNull
@JsonProperty
@@ -167,6 +178,16 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private DynamoDbConfiguration reportMessageDynamoDb;
@Valid
@NotNull
@JsonProperty
private DynamoDbConfiguration pendingAccountsDynamoDb;
@Valid
@NotNull
@JsonProperty
private DynamoDbConfiguration pendingDevicesDynamoDb;
@Valid
@NotNull
@JsonProperty
@@ -371,6 +392,14 @@ public class WhisperServerConfiguration extends Configuration {
return migrationRetryAccountsDynamoDb;
}
public DeletedAccountsDynamoDbConfiguration getDeletedAccountsDynamoDbConfiguration() {
return deletedAccountsDynamoDb;
}
public DynamoDbConfiguration getDeletedAccountsLockDynamoDbConfiguration() {
return deletedAccountsLockDynamoDb;
}
public DatabaseConfiguration getAbuseDatabaseConfiguration() {
return abuseDatabase;
}
@@ -465,6 +494,14 @@ public class WhisperServerConfiguration extends Configuration {
return reportMessageDynamoDb;
}
public DynamoDbConfiguration getPendingAccountsDynamoDbConfiguration() {
return pendingAccountsDynamoDb;
}
public DynamoDbConfiguration getPendingDevicesDynamoDbConfiguration() {
return pendingDevicesDynamoDb;
}
public MonitoredS3ObjectConfiguration getTorExitNodeListConfiguration() {
return torExitNodeList;
}

View File

@@ -6,12 +6,10 @@ package org.whispersystems.textsecuregcm;
import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.jdbi3.strategies.DefaultNameStrategy;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
@@ -37,12 +35,9 @@ import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.config.MeterFilter;
import io.micrometer.core.instrument.distribution.DistributionStatisticConfig;
import io.micrometer.datadog.DatadogConfig;
import io.micrometer.datadog.DatadogMeterRegistry;
import io.micrometer.wavefront.WavefrontConfig;
import io.micrometer.wavefront.WavefrontMeterRegistry;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.net.http.HttpClient;
import java.time.Duration;
import java.util.ArrayList;
@@ -65,6 +60,8 @@ import org.jdbi.v3.core.Jdbi;
import org.signal.zkgroup.ServerSecretParams;
import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.signal.zkgroup.profiles.ServerZkProfileOperations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
@@ -122,7 +119,6 @@ import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener;
import org.whispersystems.textsecuregcm.metrics.MetricsRequestEventListener;
import org.whispersystems.textsecuregcm.metrics.NetworkReceivedGauge;
import org.whispersystems.textsecuregcm.metrics.NetworkSentGauge;
import org.whispersystems.textsecuregcm.metrics.NstatCounters;
import org.whispersystems.textsecuregcm.metrics.OperatingSystemMemoryGauge;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.metrics.TrafficSource;
@@ -159,6 +155,10 @@ import org.whispersystems.textsecuregcm.storage.AccountsDynamoDb;
import org.whispersystems.textsecuregcm.storage.AccountsDynamoDbMigrator;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ActiveUserCounter;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsDirectoryReconciler;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsTableCrawler;
import org.whispersystems.textsecuregcm.storage.DirectoryReconciler;
import org.whispersystems.textsecuregcm.storage.DirectoryReconciliationClient;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@@ -170,10 +170,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.MigrationDeletedAccounts;
import org.whispersystems.textsecuregcm.storage.MigrationRetryAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevices;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.storage.MigrationRetryAccountsTableCrawler;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
@@ -185,12 +182,16 @@ import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.Usernames;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.AsnManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.HostnameUtil;
import org.whispersystems.textsecuregcm.util.TorExitNodeManager;
import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper;
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
import org.whispersystems.textsecuregcm.websocket.DeadLetterHandler;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
@@ -205,11 +206,17 @@ import org.whispersystems.textsecuregcm.workers.VacuumCommand;
import org.whispersystems.textsecuregcm.workers.ZkParamsCommand;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.s3.S3Client;
public class WhisperServerService extends Application<WhisperServerConfiguration> {
private static final Logger log = LoggerFactory.getLogger(WhisperServerService.class);
@Override
public void initialize(Bootstrap<WhisperServerConfiguration> bootstrap) {
bootstrap.addCommand(new VacuumCommand());
@@ -275,45 +282,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
});
{
final String hostname;
{
String localHostName = "unknown";
try {
localHostName = InetAddress.getLocalHost().getHostName();
} catch (final UnknownHostException ignored) {
}
hostname = localHostName;
}
final DatadogMeterRegistry datadogMeterRegistry = new DatadogMeterRegistry(new DatadogConfig() {
@Override
public String get(final String key) {
return null;
}
@Override
public String apiKey() {
return config.getDatadogConfiguration().getApiKey();
}
@Override
public Duration step() {
return config.getDatadogConfiguration().getStep();
}
@Override
public String hostTag() {
return "host";
}
}, Clock.SYSTEM);
final DatadogMeterRegistry datadogMeterRegistry = new DatadogMeterRegistry(config.getDatadogConfiguration(), Clock.SYSTEM);
datadogMeterRegistry.config().commonTags(
Tags.of(
"service", "chat",
"host", hostname,
"host", HostnameUtil.getLocalHostname(),
"version", WhisperServerVersion.getServerVersion(),
"env", config.getDatadogConfiguration().getEnvironment()))
.meterFilter(MeterFilter.denyNameStartsWith(MetricsRequestEventListener.REQUEST_COUNTER_NAME))
@@ -358,6 +332,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create(),
accountsDynamoDbMigrationThreadPool);
DynamoDbClient deletedAccountsDynamoDbClient = DynamoDbFromConfig.client(config.getDeletedAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
DynamoDbClient recentlyDeletedAccountsDynamoDb = DynamoDbFromConfig.client(config.getMigrationDeletedAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
@@ -370,13 +347,25 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DynamoDbClient migrationRetryAccountsDynamoDb = DynamoDbFromConfig.client(config.getMigrationRetryAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
DynamoDbClient pendingAccountsDynamoDbClient = DynamoDbFromConfig.client(config.getPendingAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
DynamoDbClient pendingDevicesDynamoDbClient = DynamoDbFromConfig.client(config.getPendingDevicesDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
AmazonDynamoDB deletedAccountsLockDynamoDbClient = AmazonDynamoDBClientBuilder.standard()
.withRegion(config.getDeletedAccountsLockDynamoDbConfiguration().getRegion())
.withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getDeletedAccountsLockDynamoDbConfiguration().getClientExecutionTimeout().toMillis()))
.withRequestTimeout((int) config.getDeletedAccountsLockDynamoDbConfiguration().getClientRequestTimeout().toMillis()))
.withCredentials(InstanceProfileCredentialsProvider.getInstance())
.build();
DeletedAccounts deletedAccounts = new DeletedAccounts(deletedAccountsDynamoDbClient, config.getDeletedAccountsDynamoDbConfiguration().getTableName(), config.getDeletedAccountsDynamoDbConfiguration().getNeedsReconciliationIndexName());
MigrationDeletedAccounts migrationDeletedAccounts = new MigrationDeletedAccounts(recentlyDeletedAccountsDynamoDb, config.getMigrationDeletedAccountsDynamoDbConfiguration().getTableName());
MigrationRetryAccounts migrationRetryAccounts = new MigrationRetryAccounts(migrationRetryAccountsDynamoDb, config.getMigrationRetryAccountsDynamoDbConfiguration().getTableName());
Accounts accounts = new Accounts(accountDatabase);
AccountsDynamoDb accountsDynamoDb = new AccountsDynamoDb(accountsDynamoDbClient, accountsDynamoDbAsyncClient, accountsDynamoDbMigrationThreadPool, config.getAccountsDynamoDbConfiguration().getTableName(), config.getAccountsDynamoDbConfiguration().getPhoneNumberTableName(), migrationDeletedAccounts, migrationRetryAccounts);
PendingAccounts pendingAccounts = new PendingAccounts(accountDatabase);
PendingDevices pendingDevices = new PendingDevices (accountDatabase);
Usernames usernames = new Usernames(accountDatabase);
ReservedUsernames reservedUsernames = new ReservedUsernames(accountDatabase);
Profiles profiles = new Profiles(accountDatabase);
@@ -386,6 +375,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
RemoteConfigs remoteConfigs = new RemoteConfigs(accountDatabase);
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(pushChallengeDynamoDbClient, config.getPushChallengeDynamoDbConfiguration().getTableName());
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(reportMessageDynamoDbClient, config.getReportMessageDynamoDbConfiguration().getTableName());
VerificationCodeStore pendingAccounts = new VerificationCodeStore(pendingAccountsDynamoDbClient, config.getPendingAccountsDynamoDbConfiguration().getTableName());
VerificationCodeStore pendingDevices = new VerificationCodeStore(pendingDevicesDynamoDbClient, config.getPendingDevicesDynamoDbConfiguration().getTableName());
RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache", config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool();
@@ -412,7 +403,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
BlockingQueue<Runnable> keyspaceNotificationDispatchQueue = new ArrayBlockingQueue<>(10_000);
Metrics.gaugeCollectionSize(name(getClass(), "keyspaceNotificationDispatchQueueSize"), Collections.emptyList(), keyspaceNotificationDispatchQueue);
ScheduledExecutorService recurringJobExecutor = environment.lifecycle().scheduledExecutorService(name(getClass(), "recurringJob-%d")).threads(2).build();
ScheduledExecutorService recurringJobExecutor = environment.lifecycle().scheduledExecutorService(name(getClass(), "recurringJob-%d")).threads(3).build();
ScheduledExecutorService declinedMessageReceiptExecutor = environment.lifecycle().scheduledExecutorService(name(getClass(), "declined-receipt-%d")).threads(2).build();
ScheduledExecutorService retrySchedulingExecutor = environment.lifecycle().scheduledExecutorService(name(getClass(), "retry-%d")).threads(2).build();
ExecutorService keyspaceNotificationDispatchExecutor = environment.lifecycle().executorService(name(getClass(), "keyspaceNotification-%d")).maxThreads(16).workQueue(keyspaceNotificationDispatchQueue).build();
@@ -442,15 +433,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, storageServiceExecutor, config.getSecureStorageServiceConfiguration());
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, recurringJobExecutor, keyspaceNotificationDispatchExecutor);
DirectoryQueue directoryQueue = new DirectoryQueue(config.getDirectoryConfiguration().getSqsConfiguration());
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, cacheCluster);
PendingDevicesManager pendingDevicesManager = new PendingDevicesManager(pendingDevices, cacheCluster);
StoredVerificationCodeManager pendingAccountsManager = new StoredVerificationCodeManager(pendingAccounts);
StoredVerificationCodeManager pendingDevicesManager = new StoredVerificationCodeManager(pendingDevices);
UsernamesManager usernamesManager = new UsernamesManager(usernames, reservedUsernames, cacheCluster);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, messagesCluster, keyspaceNotificationDispatchExecutor);
PushLatencyManager pushLatencyManager = new PushLatencyManager(metricsCluster);
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, Metrics.globalRegistry);
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, pushLatencyManager, reportMessageManager);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
DeletedAccountsManager deletedAccountsManager = new DeletedAccountsManager(deletedAccounts, deletedAccountsLockDynamoDbClient, config.getDeletedAccountsLockDynamoDbConfiguration().getTableName());
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccountsManager, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, pendingAccountsManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(accountsManager, messagesManager);
DispatchManager dispatchManager = new DispatchManager(pubSubClientFactory, Optional.of(deadLetterHandler));
@@ -482,13 +474,17 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
final List<DeletedAccountsDirectoryReconciler> deletedAccountsDirectoryReconcilers = new ArrayList<>();
final List<AccountDatabaseCrawlerListener> accountDatabaseCrawlerListeners = new ArrayList<>();
accountDatabaseCrawlerListeners.add(new PushFeedbackProcessor(accountsManager, directoryQueue));
accountDatabaseCrawlerListeners.add(new PushFeedbackProcessor(accountsManager));
accountDatabaseCrawlerListeners.add(new ActiveUserCounter(config.getMetricsFactory(), cacheCluster));
for (DirectoryServerConfiguration directoryServerConfiguration : config.getDirectoryConfiguration().getDirectoryServerConfiguration()) {
final DirectoryReconciliationClient directoryReconciliationClient = new DirectoryReconciliationClient(directoryServerConfiguration);
final DirectoryReconciler directoryReconciler = new DirectoryReconciler(directoryServerConfiguration.getReplicationName(), directoryReconciliationClient);
accountDatabaseCrawlerListeners.add(directoryReconciler);
final DeletedAccountsDirectoryReconciler deletedAccountsDirectoryReconciler = new DeletedAccountsDirectoryReconciler(directoryServerConfiguration.getReplicationName(), directoryReconciliationClient);
deletedAccountsDirectoryReconcilers.add(deletedAccountsDirectoryReconciler);
}
accountDatabaseCrawlerListeners.add(new AccountCleaner(accountsManager));
accountDatabaseCrawlerListeners.add(new RegistrationLockVersionCounter(metricsCluster, config.getMetricsFactory()));
@@ -500,13 +496,18 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
CurrencyConversionManager currencyManager = new CurrencyConversionManager(fixerClient, ftxClient, config.getPaymentsServiceConfiguration().getPaymentCurrencies());
AccountDatabaseCrawlerCache accountDatabaseCrawlerCache = new AccountDatabaseCrawlerCache(cacheCluster);
AccountDatabaseCrawler accountDatabaseCrawler = new AccountDatabaseCrawler(accountsManager, accountDatabaseCrawlerCache, accountDatabaseCrawlerListeners, config.getAccountDatabaseCrawlerConfiguration().getChunkSize(), config.getAccountDatabaseCrawlerConfiguration().getChunkIntervalMs());
AccountDatabaseCrawler accountDatabaseCrawler = new AccountDatabaseCrawler(accountsManager, accountDatabaseCrawlerCache, accountDatabaseCrawlerListeners, config.getAccountDatabaseCrawlerConfiguration().getChunkSize(), config.getAccountDatabaseCrawlerConfiguration().getChunkIntervalMs(), dynamicConfigurationManager);
DeletedAccountsTableCrawler deletedAccountsTableCrawler = new DeletedAccountsTableCrawler(deletedAccountsManager, deletedAccountsDirectoryReconcilers, cacheCluster, recurringJobExecutor);
MigrationRetryAccountsTableCrawler migrationRetryAccountsTableCrawler = new MigrationRetryAccountsTableCrawler(migrationRetryAccounts, accountsManager, accountsDynamoDb, cacheCluster, recurringJobExecutor);
apnSender.setApnFallbackManager(apnFallbackManager);
environment.lifecycle().manage(apnFallbackManager);
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(messageSender);
environment.lifecycle().manage(accountDatabaseCrawler);
environment.lifecycle().manage(deletedAccountsTableCrawler);
environment.lifecycle().manage(migrationRetryAccountsTableCrawler);
environment.lifecycle().manage(remoteConfigsManager);
environment.lifecycle().manage(messagesCache);
environment.lifecycle().manage(messagePersister);
@@ -514,10 +515,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(currencyManager);
environment.lifecycle().manage(torExitNodeManager);
environment.lifecycle().manage(asnManager);
environment.lifecycle().manage(directoryQueue);
AWSCredentials credentials = new BasicAWSCredentials(config.getCdnConfiguration().getAccessKey(), config.getCdnConfiguration().getAccessSecret());
AWSCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials);
AmazonS3 cdnS3Client = AmazonS3Client.builder().withCredentials(credentialsProvider).withRegion(config.getCdnConfiguration().getRegion()).build();
StaticCredentialsProvider cdnCredentialsProvider = StaticCredentialsProvider
.create(AwsBasicCredentials.create(
config.getCdnConfiguration().getAccessKey(),
config.getCdnConfiguration().getAccessSecret()));
S3Client cdnS3Client = S3Client.builder()
.credentialsProvider(cdnCredentialsProvider)
.region(Region.of(config.getCdnConfiguration().getRegion()))
.build();
PostPolicyGenerator profileCdnPolicyGenerator = new PostPolicyGenerator(config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket(), config.getCdnConfiguration().getAccessKey());
PolicySigner profileCdnPolicySigner = new PolicySigner(config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion());
@@ -526,17 +533,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ServerZkAuthOperations zkAuthOperations = new ServerZkAuthOperations(zkSecretParams);
boolean isZkEnabled = config.getZkConfig().isEnabled();
AttachmentControllerV1 attachmentControllerV1 = new AttachmentControllerV1(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getBucket());
AttachmentControllerV2 attachmentControllerV2 = new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket());
AttachmentControllerV3 attachmentControllerV3 = new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey());
DonationController donationController = new DonationController(donationExecutor, config.getDonationConfiguration());
KeysController keysController = new KeysController(rateLimiters, keysDynamoDb, accountsManager, directoryQueue, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager);
MessageController messageController = new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, reportMessageManager, metricsCluster, declinedMessageReceiptExecutor);
ProfileController profileController = new ProfileController(rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations, isZkEnabled);
StickerController stickerController = new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket());
RemoteConfigController remoteConfigController = new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getGlobalConfig());
ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(accountAuthenticator).buildAuthFilter ();
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(disabledPermittedAccountAuthenticator).buildAuthFilter();
@@ -549,25 +545,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter)));
environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)));
environment.jersey().register(new TimestampResponseFilter());
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, usernamesManager, abusiveHostRules, rateLimiters, smsSender, directoryQueue, messagesManager, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient, gcmSender, apnSender, backupCredentialsGenerator, verifyExperimentEnrollmentManager));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, messagesManager, directoryQueue, rateLimiters, config.getMaxDevices()));
environment.jersey().register(new DirectoryController(directoryCredentialsGenerator));
environment.jersey().register(new ProvisioningController(rateLimiters, provisioningManager));
environment.jersey().register(new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, isZkEnabled));
environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales()));
environment.jersey().register(new SecureStorageController(storageCredentialsGenerator));
environment.jersey().register(new SecureBackupController(backupCredentialsGenerator));
environment.jersey().register(new PaymentsController(currencyManager, paymentsCredentialsGenerator));
environment.jersey().register(attachmentControllerV1);
environment.jersey().register(attachmentControllerV2);
environment.jersey().register(attachmentControllerV3);
environment.jersey().register(donationController);
environment.jersey().register(keysController);
environment.jersey().register(messageController);
environment.jersey().register(profileController);
environment.jersey().register(stickerController);
environment.jersey().register(remoteConfigController);
environment.jersey().register(challengeController);
///
WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000);
@@ -576,13 +554,34 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
webSocketEnvironment.jersey().register(messageController);
webSocketEnvironment.jersey().register(profileController);
webSocketEnvironment.jersey().register(attachmentControllerV1);
webSocketEnvironment.jersey().register(attachmentControllerV2);
webSocketEnvironment.jersey().register(attachmentControllerV3);
webSocketEnvironment.jersey().register(donationController);
webSocketEnvironment.jersey().register(remoteConfigController);
// these should be common, but use @Auth DisabledPermittedAccount, which isnt supported yet on websocket
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, usernamesManager, abusiveHostRules, rateLimiters, smsSender, dynamicConfigurationManager, turnTokenGenerator, config.getTestDevices(), recaptchaClient, gcmSender, apnSender, backupCredentialsGenerator, verifyExperimentEnrollmentManager));
environment.jersey().register(new KeysController(rateLimiters, keysDynamoDb, accountsManager, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager));
final List<Object> commonControllers = List.of(
new AttachmentControllerV1(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getBucket()),
new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()),
new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, isZkEnabled),
new ChallengeController(rateLimitChallengeManager),
new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keysDynamoDb, rateLimiters, config.getMaxDevices()),
new DirectoryController(directoryCredentialsGenerator),
new DonationController(donationExecutor, config.getDonationConfiguration()),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, reportMessageManager, metricsCluster, declinedMessageReceiptExecutor),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations, isZkEnabled),
new ProvisioningController(rateLimiters, provisioningManager),
new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getGlobalConfig()),
new SecureBackupController(backupCredentialsGenerator),
new SecureStorageController(storageCredentialsGenerator),
new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket())
);
for (Object controller : commonControllers) {
environment.jersey().register(controller);
webSocketEnvironment.jersey().register(controller);
}
WebSocketEnvironment<Account> provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000);
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
@@ -628,23 +627,24 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
BufferPoolGauges.registerMetrics();
GarbageCollectionGauges.registerMetrics();
new NstatCounters().registerMetrics(recurringJobExecutor, wavefrontConfig.step());
}
private void registerExceptionMappers(Environment environment, WebSocketEnvironment<Account> webSocketEnvironment, WebSocketEnvironment<Account> provisioningEnvironment) {
environment.jersey().register(new LoggingUnhandledExceptionMapper());
environment.jersey().register(new IOExceptionMapper());
environment.jersey().register(new RateLimitExceededExceptionMapper());
environment.jersey().register(new InvalidWebsocketAddressExceptionMapper());
environment.jersey().register(new DeviceLimitExceededExceptionMapper());
environment.jersey().register(new RetryLaterExceptionMapper());
webSocketEnvironment.jersey().register(new LoggingUnhandledExceptionMapper());
webSocketEnvironment.jersey().register(new IOExceptionMapper());
webSocketEnvironment.jersey().register(new RateLimitExceededExceptionMapper());
webSocketEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper());
webSocketEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper());
webSocketEnvironment.jersey().register(new RetryLaterExceptionMapper());
provisioningEnvironment.jersey().register(new LoggingUnhandledExceptionMapper());
provisioningEnvironment.jersey().register(new IOExceptionMapper());
provisioningEnvironment.jersey().register(new RateLimitExceededExceptionMapper());
provisioningEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper());

View File

@@ -5,45 +5,33 @@
package org.whispersystems.textsecuregcm.auth;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.basic.BasicCredentials;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import io.micrometer.core.instrument.Tags;
import java.time.Clock;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Util;
public class BaseAccountAuthenticator {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter authenticationFailedMeter = metricRegistry.meter(name(getClass(), "authentication", "failed" ));
private final Meter authenticationSucceededMeter = metricRegistry.meter(name(getClass(), "authentication", "succeeded" ));
private final Meter noSuchAccountMeter = metricRegistry.meter(name(getClass(), "authentication", "noSuchAccount" ));
private final Meter noSuchDeviceMeter = metricRegistry.meter(name(getClass(), "authentication", "noSuchDevice" ));
private final Meter accountDisabledMeter = metricRegistry.meter(name(getClass(), "authentication", "accountDisabled"));
private final Meter deviceDisabledMeter = metricRegistry.meter(name(getClass(), "authentication", "deviceDisabled" ));
private final Meter invalidAuthHeaderMeter = metricRegistry.meter(name(getClass(), "authentication", "invalidHeader" ));
private final String daysSinceLastSeenDistributionName = name(getClass(), "authentication", "daysSinceLastSeen");
private static final String AUTHENTICATION_COUNTER_NAME = name(BaseAccountAuthenticator.class, "authentication");
private static final String AUTHENTICATION_SUCCEEDED_TAG_NAME = "succeeded";
private static final String AUTHENTICATION_FAILURE_REASON_TAG_NAME = "reason";
private static final String AUTHENTICATION_ENABLED_REQUIRED_TAG_NAME = "enabledRequired";
private static final String AUTHENTICATION_CREDENTIAL_TYPE_TAG_NAME = "credentialType";
private static final String DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME = name(BaseAccountAuthenticator.class, "daysSinceLastSeen");
private static final String IS_PRIMARY_DEVICE_TAG = "isPrimary";
private final Logger logger = LoggerFactory.getLogger(BaseAccountAuthenticator.class);
private final AccountsManager accountsManager;
private final Clock clock;
@@ -58,64 +46,81 @@ public class BaseAccountAuthenticator {
}
public Optional<Account> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) {
boolean succeeded = false;
String failureReason = null;
String credentialType = null;
try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword());
Optional<Account> account = accountsManager.get(authorizationHeader.getIdentifier());
if (!account.isPresent()) {
noSuchAccountMeter.mark();
credentialType = authorizationHeader.getIdentifier().hasNumber() ? "e164" : "uuid";
if (account.isEmpty()) {
failureReason = "noSuchAccount";
return Optional.empty();
}
Optional<Device> device = account.get().getDevice(authorizationHeader.getDeviceId());
if (!device.isPresent()) {
noSuchDeviceMeter.mark();
if (device.isEmpty()) {
failureReason = "noSuchDevice";
return Optional.empty();
}
if (enabledRequired) {
if (!device.get().isEnabled()) {
deviceDisabledMeter.mark();
failureReason = "deviceDisabled";
return Optional.empty();
}
if (!account.get().isEnabled()) {
accountDisabledMeter.mark();
failureReason = "accountDisabled";
return Optional.empty();
}
}
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
authenticationSucceededMeter.mark();
account.get().setAuthenticatedDevice(device.get());
updateLastSeen(account.get(), device.get());
return account;
succeeded = true;
final Account authenticatedAccount = updateLastSeen(account.get(), device.get());
authenticatedAccount.setAuthenticatedDevice(device.get());
return Optional.of(authenticatedAccount);
}
authenticationFailedMeter.mark();
return Optional.empty();
} catch (IllegalArgumentException | InvalidAuthorizationHeaderException iae) {
invalidAuthHeaderMeter.mark();
failureReason = "invalidHeader";
return Optional.empty();
} finally {
Tags tags = Tags.of(
AUTHENTICATION_SUCCEEDED_TAG_NAME, String.valueOf(succeeded),
AUTHENTICATION_ENABLED_REQUIRED_TAG_NAME, String.valueOf(enabledRequired));
if (StringUtils.isNotBlank(failureReason)) {
tags = tags.and(AUTHENTICATION_FAILURE_REASON_TAG_NAME, failureReason);
}
if (StringUtils.isNotBlank(credentialType)) {
tags = tags.and(AUTHENTICATION_CREDENTIAL_TYPE_TAG_NAME, credentialType);
}
Metrics.counter(AUTHENTICATION_COUNTER_NAME, tags).increment();
}
}
@VisibleForTesting
public void updateLastSeen(Account account, Device device) {
public Account updateLastSeen(Account account, Device device) {
final long lastSeenOffsetSeconds = Math.abs(account.getUuid().getLeastSignificantBits()) % ChronoUnit.DAYS.getDuration().toSeconds();
final long todayInMillisWithOffset = Util.todayInMillisGivenOffsetFromNow(clock, Duration.ofSeconds(lastSeenOffsetSeconds).negated());
if (device.getLastSeen() < todayInMillisWithOffset) {
DistributionSummary.builder(daysSinceLastSeenDistributionName)
.tags(IS_PRIMARY_DEVICE_TAG, String.valueOf(device.isMaster()))
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
Metrics.summary(DAYS_SINCE_LAST_SEEN_DISTRIBUTION_NAME, IS_PRIMARY_DEVICE_TAG, String.valueOf(device.isMaster()))
.record(Duration.ofMillis(todayInMillisWithOffset - device.getLastSeen()).toDays());
device.setLastSeen(Util.todayInMillis(clock));
accountsManager.update(account);
return accountsManager.updateDevice(account, device.getId(), d -> d.setLastSeen(Util.todayInMillis(clock)));
}
return account;
}
}

View File

@@ -33,7 +33,7 @@ public class CertificateGenerator {
this.serverCertificate = ServerCertificate.parseFrom(serverCertificate);
}
public byte[] createFor(Account account, Device device, boolean includeE164) throws IOException, InvalidKeyException {
public byte[] createFor(Account account, Device device, boolean includeE164) throws InvalidKeyException {
SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder()
.setSenderDevice(Math.toIntExact(device.getId()))
.setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays))

View File

@@ -5,36 +5,38 @@
package org.whispersystems.textsecuregcm.auth;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.util.Util;
import java.security.MessageDigest;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.util.Util;
public class StoredVerificationCode {
@JsonProperty
private String code;
private final String code;
@JsonProperty
private long timestamp;
private final long timestamp;
@JsonProperty
private String pushCode;
private final String pushCode;
@JsonProperty
private String twilioVerificationSid;
@Nullable
private final String twilioVerificationSid;
public StoredVerificationCode() {
}
public static final Duration EXPIRATION = Duration.ofMinutes(10);
public StoredVerificationCode(String code, long timestamp, String pushCode) {
this(code, timestamp, pushCode, null);
}
@JsonCreator
public StoredVerificationCode(
@JsonProperty("code") final String code,
@JsonProperty("timestamp") final long timestamp,
@JsonProperty("pushCode") final String pushCode,
@JsonProperty("twilioVerificationSid") @Nullable final String twilioVerificationSid) {
public StoredVerificationCode(String code, long timestamp, String pushCode, String twilioVerificationSid) {
this.code = code;
this.timestamp = timestamp;
this.pushCode = pushCode;
@@ -58,10 +60,6 @@ public class StoredVerificationCode {
}
public boolean isValid(String theirCodeString) {
if (timestamp + TimeUnit.MINUTES.toMillis(10) < System.currentTimeMillis()) {
return false;
}
if (Util.isEmpty(code) || Util.isEmpty(theirCodeString)) {
return false;
}

View File

@@ -7,15 +7,15 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import java.time.Duration;
import io.github.resilience4j.circuitbreaker.CircuitBreakerConfig;
public class CircuitBreakerConfiguration {
@JsonProperty
@@ -39,6 +39,9 @@ public class CircuitBreakerConfiguration {
@Min(1)
private long waitDurationInOpenStateInSeconds = 10;
@JsonProperty
private List<String> ignoredExceptions = Collections.emptyList();
public int getFailureRateThreshold() {
return failureRateThreshold;
@@ -56,6 +59,18 @@ public class CircuitBreakerConfiguration {
return waitDurationInOpenStateInSeconds;
}
public List<Class> getIgnoredExceptions() {
return ignoredExceptions.stream()
.map(name -> {
try {
return Class.forName(name);
} catch (final ClassNotFoundException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList());
}
@VisibleForTesting
public void setFailureRateThreshold(int failureRateThreshold) {
this.failureRateThreshold = failureRateThreshold;
@@ -76,9 +91,15 @@ public class CircuitBreakerConfiguration {
this.waitDurationInOpenStateInSeconds = seconds;
}
@VisibleForTesting
public void setIgnoredExceptions(final List<String> ignoredExceptions) {
this.ignoredExceptions = ignoredExceptions;
}
public CircuitBreakerConfig toCircuitBreakerConfig() {
return CircuitBreakerConfig.custom()
.failureRateThreshold(getFailureRateThreshold())
.ignoreExceptions(getIgnoredExceptions().toArray(new Class[0]))
.ringBufferSizeInHalfOpenState(getRingBufferSizeInHalfOpenState())
.waitDurationInOpenState(Duration.ofSeconds(getWaitDurationInOpenStateInSeconds()))
.ringBufferSizeInClosedState(getRingBufferSizeInClosedState())

View File

@@ -6,11 +6,13 @@
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.micrometer.datadog.DatadogConfig;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import java.time.Duration;
public class DatadogConfiguration {
public class DatadogConfiguration implements DatadogConfig {
@JsonProperty
@NotBlank
@@ -24,15 +26,36 @@ public class DatadogConfiguration {
@NotBlank
private String environment;
public String getApiKey() {
@JsonProperty
@Min(1)
private int batchSize = 5_000;
@Override
public String apiKey() {
return apiKey;
}
public Duration getStep() {
@Override
public Duration step() {
return step;
}
public String getEnvironment() {
return environment;
}
@Override
public int batchSize() {
return batchSize;
}
@Override
public String hostTag() {
return "host";
}
@Override
public String get(final String key) {
return null;
}
}

View File

@@ -0,0 +1,13 @@
package org.whispersystems.textsecuregcm.configuration;
import javax.validation.constraints.NotNull;
public class DeletedAccountsDynamoDbConfiguration extends DynamoDbConfiguration {
@NotNull
private String needsReconciliationIndexName;
public String getNeedsReconciliationIndexName() {
return needsReconciliationIndexName;
}
}

View File

@@ -5,13 +5,12 @@
package org.whispersystems.textsecuregcm.configuration;
import javax.validation.Valid;
import javax.validation.constraints.NotEmpty;
import java.time.Duration;
import javax.validation.Valid;
public class MessageDynamoDbConfiguration extends DynamoDbConfiguration {
private Duration timeToLive = Duration.ofDays(7);
private Duration timeToLive = Duration.ofDays(14);
@Valid
public Duration getTimeToLive() {

View File

@@ -23,6 +23,12 @@ public class DynamicAccountsDynamoDbMigrationConfiguration {
@JsonProperty
boolean logMismatches;
@JsonProperty
boolean dynamoCrawlerEnabled;
@JsonProperty
int dynamoCrawlerScanPageSize = 10;
public boolean isBackgroundMigrationEnabled() {
return backgroundMigrationEnabled;
}
@@ -59,4 +65,12 @@ public class DynamicAccountsDynamoDbMigrationConfiguration {
public boolean isLogMismatches() {
return logMismatches;
}
public boolean isDynamoCrawlerEnabled() {
return dynamoCrawlerEnabled;
}
public int getDynamoCrawlerScanPageSize() {
return dynamoCrawlerScanPageSize;
}
}

View File

@@ -22,7 +22,6 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import javax.validation.Valid;
@@ -68,15 +67,13 @@ import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioVerifyExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
@@ -90,7 +87,6 @@ public class AccountController {
private final Logger logger = LoggerFactory.getLogger(AccountController.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter newUserMeter = metricRegistry.meter(name(AccountController.class, "brand_new_user" ));
private final Meter blockedHostMeter = metricRegistry.meter(name(AccountController.class, "blocked_host" ));
private final Meter filteredHostMeter = metricRegistry.meter(name(AccountController.class, "filtered_host" ));
private final Meter rateLimitedHostMeter = metricRegistry.meter(name(AccountController.class, "rate_limited_host" ));
@@ -112,14 +108,12 @@ public class AccountController {
private static final String VERIFY_EXPERIMENT_TAG_NAME = "twilioVerify";
private final PendingAccountsManager pendingAccounts;
private final StoredVerificationCodeManager pendingAccounts;
private final AccountsManager accounts;
private final UsernamesManager usernames;
private final AbusiveHostRules abusiveHostRules;
private final RateLimiters rateLimiters;
private final SmsSender smsSender;
private final DirectoryQueue directoryQueue;
private final MessagesManager messagesManager;
private final DynamicConfigurationManager dynamicConfigurationManager;
private final TurnTokenGenerator turnTokenGenerator;
private final Map<String, Integer> testDevices;
@@ -130,14 +124,12 @@ public class AccountController {
private final TwilioVerifyExperimentEnrollmentManager verifyExperimentEnrollmentManager;
public AccountController(PendingAccountsManager pendingAccounts,
public AccountController(StoredVerificationCodeManager pendingAccounts,
AccountsManager accounts,
UsernamesManager usernames,
AbusiveHostRules abusiveHostRules,
RateLimiters rateLimiters,
SmsSender smsSenderFactory,
DirectoryQueue directoryQueue,
MessagesManager messagesManager,
DynamicConfigurationManager dynamicConfigurationManager,
TurnTokenGenerator turnTokenGenerator,
Map<String, Integer> testDevices,
@@ -153,8 +145,6 @@ public class AccountController {
this.abusiveHostRules = abusiveHostRules;
this.rateLimiters = rateLimiters;
this.smsSender = smsSenderFactory;
this.directoryQueue = directoryQueue;
this.messagesManager = messagesManager;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.testDevices = testDevices;
this.turnTokenGenerator = turnTokenGenerator;
@@ -184,7 +174,8 @@ public class AccountController {
String pushChallenge = generatePushChallenge();
StoredVerificationCode storedVerificationCode = new StoredVerificationCode(null,
System.currentTimeMillis(),
pushChallenge);
pushChallenge,
null);
pendingAccounts.store(number, storedVerificationCode);
@@ -217,10 +208,6 @@ public class AccountController {
throw new WebApplicationException(Response.status(400).build());
}
if (number.startsWith("+98")) {
transport = "voice";
}
String requester = ForwardedIpUtil.getMostRecentProxy(forwardedFor).orElseThrow();
Optional<StoredVerificationCode> storedChallenge = pendingAccounts.getCodeForNumber(number);
@@ -317,6 +304,7 @@ public class AccountController {
});
});
// TODO Remove this meter when external dependencies have been resolved
metricRegistry.meter(name(AccountController.class, "create", Util.getCountryCode(number))).mark();
{
@@ -391,7 +379,7 @@ public class AccountController {
throw new WebApplicationException(Response.status(409).build());
}
Account account = createAccount(number, password, signalAgent, accountAttributes);
Account account = accounts.create(number, password, signalAgent, accountAttributes);
{
metricRegistry.meter(name(AccountController.class, "verify", Util.getCountryCode(number))).mark();
@@ -430,7 +418,6 @@ public class AccountController {
public void setGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid GcmRegistrationId registrationId) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
if (device.getGcmId() != null &&
device.getGcmId().equals(registrationId.getGcmRegistrationId()))
@@ -438,16 +425,12 @@ public class AccountController {
return;
}
device.setApnId(null);
device.setVoipApnId(null);
device.setGcmId(registrationId.getGcmRegistrationId());
device.setFetchesMessages(false);
accounts.update(account);
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
}
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setVoipApnId(null);
d.setGcmId(registrationId.getGcmRegistrationId());
d.setFetchesMessages(false);
});
}
@Timed
@@ -456,12 +439,12 @@ public class AccountController {
public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
device.setGcmId(null);
device.setFetchesMessages(false);
device.setUserAgent("OWA");
accounts.update(account);
directoryQueue.refreshRegisteredUser(account);
accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null);
d.setFetchesMessages(false);
d.setUserAgent("OWA");
});
}
@Timed
@@ -471,17 +454,13 @@ public class AccountController {
public void setApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid ApnRegistrationId registrationId) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
device.setApnId(registrationId.getApnRegistrationId());
device.setVoipApnId(registrationId.getVoipRegistrationId());
device.setGcmId(null);
device.setFetchesMessages(false);
accounts.update(account);
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
}
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(registrationId.getApnRegistrationId());
d.setVoipApnId(registrationId.getVoipRegistrationId());
d.setGcmId(null);
d.setFetchesMessages(false);
});
}
@Timed
@@ -490,16 +469,16 @@ public class AccountController {
public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
device.setFetchesMessages(false);
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
accounts.update(account);
directoryQueue.refreshRegisteredUser(account);
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
d.setFetchesMessages(false);
if (d.getId() == 1) {
d.setUserAgent("OWI");
} else {
d.setUserAgent("OWP");
}
});
}
@Timed
@@ -508,37 +487,43 @@ public class AccountController {
@Path("/registration_lock")
public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) {
AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock());
account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
account.setPin(null);
accounts.update(account);
accounts.update(account, a -> {
a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
a.setPin(null);
});
}
@Timed
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Auth Account account) {
account.setRegistrationLock(null, null);
accounts.update(account);
accounts.update(account, a -> a.setRegistrationLock(null, null));
}
@Timed
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/pin/")
public void setPin(@Auth Account account, @Valid DeprecatedPin accountLock) {
account.setPin(accountLock.getPin());
account.setRegistrationLock(null, null);
public void setPin(@Auth Account account, @Valid DeprecatedPin accountLock, @HeaderParam("User-Agent") String userAgent) {
// TODO Remove once PIN-based reglocks have been deprecated
logger.info("PIN set by User-Agent: {}", userAgent);
accounts.update(account);
accounts.update(account, a -> {
a.setPin(accountLock.getPin());
a.setRegistrationLock(null, null);
});
}
@Timed
@DELETE
@Path("/pin/")
public void removePin(@Auth Account account) {
account.setPin(null);
accounts.update(account);
public void removePin(@Auth Account account, @HeaderParam("User-Agent") String userAgent) {
// TODO Remove once PIN-based reglocks have been deprecated
logger.info("PIN removed by User-Agent: {}", userAgent);
accounts.update(account, a -> a.setPin(null));
}
@Timed
@@ -546,8 +531,8 @@ public class AccountController {
@Path("/name/")
public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) {
Account account = disabledPermittedAccount.getAccount();
account.getAuthenticatedDevice().get().setName(deviceName.getDeviceName());
accounts.update(account);
Device device = account.getAuthenticatedDevice().get();
accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName()));
}
@Timed
@@ -565,28 +550,25 @@ public class AccountController {
@Valid AccountAttributes attributes)
{
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
long deviceId = account.getAuthenticatedDevice().get().getId();
device.setFetchesMessages(attributes.getFetchesMessages());
device.setName(attributes.getName());
device.setLastSeen(Util.todayInMillis());
device.setCapabilities(attributes.getCapabilities());
device.setRegistrationId(attributes.getRegistrationId());
device.setUserAgent(userAgent);
accounts.update(account, a-> {
setAccountRegistrationLockFromAttributes(account, attributes);
a.getDevice(deviceId).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
d.setUserAgent(userAgent);
});
final boolean hasDiscoverabilityChange = (account.isDiscoverableByPhoneNumber() != attributes.isDiscoverableByPhoneNumber());
a.setRegistrationLockFromAttributes(attributes);
account.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber());
accounts.update(account);
if (hasDiscoverabilityChange) {
directoryQueue.refreshRegisteredUser(account);
}
a.setUnidentifiedAccessKey(attributes.getUnidentifiedAccessKey());
a.setUnrestrictedUnidentifiedAccess(attributes.isUnrestrictedUnidentifiedAccess());
a.setDiscoverableByPhoneNumber(attributes.isDiscoverableByPhoneNumber());
});
}
@GET
@@ -724,7 +706,7 @@ public class AccountController {
@Timed
@DELETE
@Path("/me")
public void deleteAccount(@Auth Account account) {
public void deleteAccount(@Auth Account account) throws InterruptedException {
accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST);
}
@@ -738,52 +720,6 @@ public class AccountController {
return false;
}
private Account createAccount(String number, String password, String signalAgent, AccountAttributes accountAttributes) {
Optional<Account> maybeExistingAccount = accounts.get(number);
Device device = new Device();
device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
Account account = new Account();
account.setNumber(number);
account.setUuid(UUID.randomUUID());
account.addDevice(device);
setAccountRegistrationLockFromAttributes(account, accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(accountAttributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(accountAttributes.isDiscoverableByPhoneNumber());
if (accounts.create(account)) {
newUserMeter.mark();
}
directoryQueue.refreshRegisteredUser(account);
maybeExistingAccount.ifPresent(definitelyExistingAccount -> messagesManager.clear(definitelyExistingAccount.getUuid()));
pendingAccounts.remove(number);
return account;
}
private void setAccountRegistrationLockFromAttributes(Account account, @Valid AccountAttributes attributes) {
if (!Util.isEmpty(attributes.getPin())) {
account.setPin(attributes.getPin());
} else if (!Util.isEmpty(attributes.getRegistrationLock())) {
AuthenticationCredentials credentials = new AuthenticationCredentials(attributes.getRegistrationLock());
account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
} else {
account.setPin(null);
account.setRegistrationLock(null, null);
}
}
@VisibleForTesting protected
VerificationCode generateVerificationCode(String number) {
if (testDevices.containsKey(number)) {

View File

@@ -5,17 +5,16 @@
package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
@@ -24,22 +23,24 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/certificate")
public class CertificateController {
private final Logger logger = LoggerFactory.getLogger(CertificateController.class);
private final CertificateGenerator certificateGenerator;
private final ServerZkAuthOperations serverZkAuthOperations;
private final boolean isZkEnabled;
private static final String GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME = name(CertificateGenerator.class, "generateCertificate");
private static final String INCLUDE_E164_TAG_NAME = "includeE164";
public CertificateController(CertificateGenerator certificateGenerator, ServerZkAuthOperations serverZkAuthOperations, boolean isZkEnabled) {
this.certificateGenerator = certificateGenerator;
this.serverZkAuthOperations = serverZkAuthOperations;
@@ -51,8 +52,8 @@ public class CertificateController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/delivery")
public DeliveryCertificate getDeliveryCertificate(@Auth Account account,
@QueryParam("includeE164") Optional<Boolean> includeE164)
throws IOException, InvalidKeyException
@QueryParam("includeE164") Optional<Boolean> maybeIncludeE164)
throws InvalidKeyException
{
if (account.getAuthenticatedDevice().isEmpty()) {
throw new AssertionError();
@@ -61,7 +62,11 @@ public class CertificateController {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
return new DeliveryCertificate(certificateGenerator.createFor(account, account.getAuthenticatedDevice().get(), includeE164.orElse(true)));
final boolean includeE164 = maybeIncludeE164.orElse(true);
Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164)).increment();
return new DeliveryCertificate(certificateGenerator.createFor(account, account.getAuthenticatedDevice().get(), includeE164));
}
@Timed

View File

@@ -23,8 +23,9 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
@@ -55,24 +56,24 @@ public class DeviceController {
private static final int MAX_DEVICES = 6;
private final PendingDevicesManager pendingDevices;
private final StoredVerificationCodeManager pendingDevices;
private final AccountsManager accounts;
private final MessagesManager messages;
private final KeysDynamoDb keys;
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;
private final DirectoryQueue directoryQueue;
public DeviceController(PendingDevicesManager pendingDevices,
public DeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts,
MessagesManager messages,
DirectoryQueue directoryQueue,
KeysDynamoDb keys,
RateLimiters rateLimiters,
Map<String, Integer> maxDeviceConfiguration)
{
this.pendingDevices = pendingDevices;
this.accounts = accounts;
this.messages = messages;
this.directoryQueue = directoryQueue;
this.keys = keys;
this.rateLimiters = rateLimiters;
this.maxDeviceConfiguration = maxDeviceConfiguration;
}
@@ -99,9 +100,10 @@ public class DeviceController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
account.removeDevice(deviceId);
accounts.update(account);
directoryQueue.refreshRegisteredUser(account);
messages.clear(account.getUuid(), deviceId);
account = accounts.update(account, a -> a.removeDevice(deviceId));
keys.delete(account, deviceId);
// ensure any messages that came in after the first clear() are also removed
messages.clear(account.getUuid(), deviceId);
}
@@ -131,6 +133,7 @@ public class DeviceController {
VerificationCode verificationCode = generateVerificationCode();
StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(),
System.currentTimeMillis(),
null,
null);
pendingDevices.store(account.getNumber(), storedVerificationCode);
@@ -189,15 +192,16 @@ public class DeviceController {
device.setName(accountAttributes.getName());
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setId(account.get().getNextDeviceId());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
account.get().addDevice(device);
messages.clear(account.get().getUuid(), device.getId());
accounts.update(account.get());
accounts.update(account.get(), a -> {
device.setId(account.get().getNextDeviceId());
messages.clear(account.get().getUuid(), device.getId());
a.addDevice(device);
});;
pendingDevices.remove(number);
@@ -221,8 +225,8 @@ public class DeviceController {
@Path("/capabilities")
public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) {
assert(account.getAuthenticatedDevice().isPresent());
account.getAuthenticatedDevice().get().setCapabilities(capabilities);
accounts.update(account);
final long deviceId = account.getAuthenticatedDevice().get().getId();
accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities));
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {
@@ -234,13 +238,9 @@ public class DeviceController {
private boolean isCapabilityDowngrade(Account account, DeviceCapabilities capabilities, String userAgent) {
boolean isDowngrade = false;
if (account.isSenderKeySupported() && !capabilities.isSenderKey()) {
isDowngrade = true;
}
if (account.isGv1MigrationSupported() && !capabilities.isGv1Migration()) {
isDowngrade = true;
}
isDowngrade |= account.isAnnouncementGroupSupported() && !capabilities.isAnnouncementGroup();
isDowngrade |= account.isSenderKeySupported() && !capabilities.isSenderKey();
isDowngrade |= account.isGv1MigrationSupported() && !capabilities.isGv1Migration();
if (account.isGroupsV2Supported()) {
try {

View File

@@ -55,7 +55,6 @@ public class KeysController {
private final RateLimiters rateLimiters;
private final KeysDynamoDb keysDynamoDb;
private final AccountsManager accounts;
private final DirectoryQueue directoryQueue;
private final PreKeyRateLimiter preKeyRateLimiter;
private final DynamicConfigurationManager dynamicConfigurationManager;
@@ -69,13 +68,12 @@ public class KeysController {
private static final String PREKEY_TARGET_IDENTIFIER_TAG_NAME = "identifierType";
public KeysController(RateLimiters rateLimiters, KeysDynamoDb keysDynamoDb, AccountsManager accounts,
DirectoryQueue directoryQueue, PreKeyRateLimiter preKeyRateLimiter,
PreKeyRateLimiter preKeyRateLimiter,
DynamicConfigurationManager dynamicConfigurationManager,
RateLimitChallengeManager rateLimitChallengeManager) {
this.rateLimiters = rateLimiters;
this.keysDynamoDb = keysDynamoDb;
this.accounts = accounts;
this.directoryQueue = directoryQueue;
this.preKeyRateLimiter = preKeyRateLimiter;
this.dynamicConfigurationManager = dynamicConfigurationManager;
@@ -100,25 +98,21 @@ public class KeysController {
public void setKeys(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid PreKeyState preKeys) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
boolean updateAccount = false;
if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) {
device.setSignedPreKey(preKeys.getSignedPreKey());
updateAccount = true;
}
if (!preKeys.getIdentityKey().equals(account.getIdentityKey())) {
account.setIdentityKey(preKeys.getIdentityKey());
updateAccount = true;
}
if (updateAccount) {
accounts.update(account);
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
}
account = accounts.update(account, a -> {
a.getDevice(device.getId()).ifPresent(d -> d.setSignedPreKey(preKeys.getSignedPreKey()));
a.setIdentityKey(preKeys.getIdentityKey());
});
}
keysDynamoDb.store(account, device.getId(), preKeys.getPreKeys());
@@ -197,15 +191,9 @@ public class KeysController {
@Path("/signed")
@Consumes(MediaType.APPLICATION_JSON)
public void setSignedKey(@Auth Account account, @Valid SignedPreKey signedPreKey) {
Device device = account.getAuthenticatedDevice().get();
boolean wasAccountEnabled = account.isEnabled();
Device device = account.getAuthenticatedDevice().get();
device.setSignedPreKey(signedPreKey);
accounts.update(account);
if (!wasAccountEnabled && account.isEnabled()) {
directoryQueue.refreshRegisteredUser(account);
}
accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey));
}
@Timed

View File

@@ -56,6 +56,7 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import io.micrometer.core.instrument.Tags;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -63,6 +64,7 @@ import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountMismatchedDevices;
import org.whispersystems.textsecuregcm.entities.AccountStaleDevices;
@@ -146,6 +148,7 @@ public class MessageController {
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String SENDER_TYPE_TAG_NAME = "senderType";
private static final String SENDER_COUNTRY_TAG_NAME = "senderCountry";
private static final String DESTINATION_TYPE_TAG_NAME = "destinationType";
private static final long MAX_MESSAGE_SIZE = DataSize.kibibytes(256).toBytes();
@@ -232,7 +235,7 @@ public class MessageController {
contentLength += message.getBody().length();
}
Metrics.summary(CONTENT_SIZE_DISTRIBUTION_NAME, UserAgentTagUtil.getUserAgentTags(userAgent)).record(contentLength);
Metrics.summary(CONTENT_SIZE_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).record(contentLength);
if (contentLength > MAX_MESSAGE_SIZE) {
rejectOver256kibMessageMeter.mark();
@@ -302,7 +305,8 @@ public class MessageController {
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
Tag.of(SENDER_TYPE_TAG_NAME, senderType));
Tag.of(SENDER_TYPE_TAG_NAME, senderType),
Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid"));
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.get().getDevice(incomingMessage.getDestinationDeviceId());
@@ -495,6 +499,11 @@ public class MessageController {
public OutgoingMessageEntityList getPendingMessages(@Auth Account account, @HeaderParam("User-Agent") String userAgent) {
assert account.getAuthenticatedDevice().isPresent();
// TODO Remove once PIN-based reglocks have been deprecated
if (account.getRegistrationLock().requiresClientRegistrationLock() && account.getRegistrationLock().hasDeprecatedPin()) {
logger.info("User-Agent with deprecated PIN-based registration lock: {}", userAgent);
}
if (!Util.isEmpty(account.getAuthenticatedDevice().get().getApnId())) {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get()));
}

View File

@@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.services.s3.AmazonS3;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import java.security.SecureRandom;
@@ -60,6 +59,8 @@ import org.whispersystems.textsecuregcm.storage.VersionedProfile;
import org.whispersystems.textsecuregcm.util.ExactlySize;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/profile")
@@ -78,7 +79,7 @@ public class ProfileController {
private final ServerZkProfileOperations zkProfileOperations;
private final boolean isZkEnabled;
private final AmazonS3 s3client;
private final S3Client s3client;
private final String bucket;
public ProfileController(RateLimiters rateLimiters,
@@ -86,7 +87,7 @@ public class ProfileController {
ProfilesManager profilesManager,
UsernamesManager usernamesManager,
DynamicConfigurationManager dynamicConfigurationManager,
AmazonS3 s3client,
S3Client s3client,
PostPolicyGenerator policyGenerator,
PolicySigner policySigner,
String bucket,
@@ -147,15 +148,19 @@ public class ProfileController {
currentAvatar = Optional.of(account.getAvatar());
}
currentAvatar.ifPresent(s -> s3client.deleteObject(bucket, s));
currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(s)
.build()));
response = Optional.of(generateAvatarUploadForm(avatar));
}
account.setProfileName(request.getName());
account.setAvatar(avatar);
account.setCurrentProfileVersion(request.getVersion());
accountsManager.update(account);
accountsManager.update(account, a -> {
a.setProfileName(request.getName());
a.setAvatar(avatar);
a.setCurrentProfileVersion(request.getVersion());
});
if (response.isPresent()) return Response.ok(response).build();
else return Response.ok().build();
@@ -313,8 +318,7 @@ public class ProfileController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/name/{name}")
public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) {
account.setProfileName(name.orElse(null));
accountsManager.update(account);
accountsManager.update(account, a -> a.setProfileName(name.orElse(null)));
}
@Deprecated
@@ -372,11 +376,13 @@ public class ProfileController {
ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName);
if (previousAvatar != null && previousAvatar.startsWith("profiles/")) {
s3client.deleteObject(bucket, previousAvatar);
s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(previousAvatar)
.build());
}
account.setAvatar(objectName);
accountsManager.update(account);
accountsManager.update(account, a -> a.setAvatar(objectName));
return profileAvatarUploadAttributes;
}

View File

@@ -4,21 +4,17 @@
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import java.util.List;
public class IncomingMessageList {
@JsonProperty
@NotNull
@Valid
@JsonInclude(Include.NON_NULL)
private List<IncomingMessage> messages;
private List<@NotNull IncomingMessage> messages;
@JsonProperty
private long timestamp;

View File

@@ -14,7 +14,8 @@ public class UserCapabilities {
return new UserCapabilities(
account.isGroupsV2Supported(),
account.isGv1MigrationSupported(),
account.isSenderKeySupported());
account.isSenderKeySupported(),
account.isAnnouncementGroupSupported());
}
@JsonProperty
@@ -26,12 +27,16 @@ public class UserCapabilities {
@JsonProperty
private boolean senderKey;
@JsonProperty
private boolean announcementGroup;
public UserCapabilities() {}
public UserCapabilities(boolean gv2, boolean gv1Migration, final boolean senderKey) {
public UserCapabilities(boolean gv2, boolean gv1Migration, final boolean senderKey, final boolean announcementGroup) {
this.gv2 = gv2;
this.gv1Migration = gv1Migration;
this.senderKey = senderKey;
this.announcementGroup = announcementGroup;
}
public boolean isGv2() {
@@ -45,4 +50,8 @@ public class UserCapabilities {
public boolean isSenderKey() {
return senderKey;
}
public boolean isAnnouncementGroup() {
return announcementGroup;
}
}

View File

@@ -18,7 +18,7 @@ public class PreKeyRateLimiter {
private static final String RATE_LIMIT_RESET_COUNTER_NAME = name(PreKeyRateLimiter.class, "reset");
private static final String RATE_LIMITED_PREKEYS_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimited");
private static final String RATE_LIMITED_PREKEYS_TOTAL_ACCOUNTS_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimited");
private static final String RATE_LIMITED_PREKEYS_TOTAL_ACCOUNTS_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimitedTotal");
private static final String RATE_LIMITED_PREKEYS_ACCOUNTS_ENFORCED_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimitedAccountsEnforced");
private static final String RATE_LIMITED_PREKEYS_ACCOUNTS_UNENFORCED_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimitedAccountsUnenforced");

View File

@@ -18,8 +18,13 @@ public class RateLimitResetMetricsManager {
}
void initializeFunctionCounters(String counterKey, String hllKey) {
FunctionCounter.builder(counterKey, null, (ignored) ->
metricsCluster.<Long>withCluster(conn -> conn.sync().pfcount(hllKey))).register(meterRegistry);
FunctionCounter
.builder(counterKey, this, manager -> manager.getCount(hllKey))
.register(meterRegistry);
}
Long getCount(final String hllKey) {
return metricsCluster.<Long>withCluster(conn -> conn.sync().pfcount(hllKey));
}
void recordMetrics(Account account, boolean enforced, String counterKey, String hllEnforcedKey, String hllTotalKey,

View File

@@ -21,6 +21,7 @@ import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.HostnameUtil;
import javax.ws.rs.core.UriBuilder;
import java.io.ByteArrayOutputStream;
@@ -83,7 +84,7 @@ public class JsonMetricsReporter extends ScheduledReporter {
{
super(registry, "json-reporter", filter, rateUnit, durationUnit, Executors.newSingleThreadScheduledExecutor(new NamedThreadFactory("json-reporter")), true, disabledMetricAttributes);
this.httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_2).build();
this.uri = UriBuilder.fromUri(uri).queryParam("h", InetAddress.getLocalHost().getHostName()).build();
this.uri = UriBuilder.fromUri(uri).queryParam("h", HostnameUtil.getLocalHostname()).build();
}
@Override

View File

@@ -27,6 +27,7 @@ import javax.validation.constraints.NotEmpty;
import net.logstash.logback.appender.LogstashTcpSocketAppender;
import net.logstash.logback.encoder.LogstashEncoder;
import org.whispersystems.textsecuregcm.WhisperServerVersion;
import org.whispersystems.textsecuregcm.util.HostnameUtil;
@JsonTypeName("logstashtcpsocket")
public class LogstashTcpSocketAppenderFactory extends AbstractAppenderFactory<ILoggingEvent> {
@@ -77,11 +78,7 @@ public class LogstashTcpSocketAppenderFactory extends AbstractAppenderFactory<IL
final LogstashEncoder encoder = new LogstashEncoder();
final ObjectNode customFieldsNode = new ObjectNode(JsonNodeFactory.instance);
try {
customFieldsNode.set("host", TextNode.valueOf(InetAddress.getLocalHost().getHostName()));
} catch (UnknownHostException e) {
customFieldsNode.set("host", TextNode.valueOf("unknown"));
}
customFieldsNode.set("host", TextNode.valueOf(HostnameUtil.getLocalHostname()));
customFieldsNode.set("service", TextNode.valueOf("chat"));
customFieldsNode.set("ddsource", TextNode.valueOf("logstash"));
customFieldsNode.set("ddtags", TextNode.valueOf("env:" + environment + ",version:" + WhisperServerVersion.getServerVersion()));

View File

@@ -12,9 +12,9 @@ import com.vdurmont.semver4j.SemverException;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.whispersystems.textsecuregcm.util.logging.UriInfoUtil;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
@@ -71,7 +71,7 @@ public class MetricsRequestEventListener implements RequestEventListener {
if (event.getType() == RequestEvent.Type.FINISHED) {
if (!event.getUriInfo().getMatchedTemplates().isEmpty()) {
final List<Tag> tags = new ArrayList<>(5);
tags.add(Tag.of(PATH_TAG, getPathTemplate(event.getUriInfo())));
tags.add(Tag.of(PATH_TAG, UriInfoUtil.getPathTemplate(event.getUriInfo())));
tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(event.getContainerResponse().getStatus())));
tags.add(Tag.of(TRAFFIC_SOURCE_TAG, trafficSource.name().toLowerCase()));
@@ -149,14 +149,4 @@ public class MetricsRequestEventListener implements RequestEventListener {
}
}
@VisibleForTesting
static String getPathTemplate(final ExtendedUriInfo uriInfo) {
final StringBuilder pathBuilder = new StringBuilder();
for (int i = uriInfo.getMatchedTemplates().size() - 1; i >= 0; i--) {
pathBuilder.append(uriInfo.getMatchedTemplates().get(i).getTemplate());
}
return pathBuilder.toString();
}
}

View File

@@ -10,6 +10,7 @@ import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import java.time.Duration;
@@ -49,7 +50,7 @@ public class PushLatencyManager {
public void recordQueueRead(final UUID accountUuid, final long deviceId, final String userAgent) {
getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).thenAccept(latency -> {
if (latency != null) {
Metrics.timer(TIMER_NAME, UserAgentTagUtil.getUserAgentTags(userAgent)).record(latency, TimeUnit.MILLISECONDS);
Metrics.timer(TIMER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))).record(latency, TimeUnit.MILLISECONDS);
}
});
}

View File

@@ -0,0 +1,74 @@
/*
* This is derived from Coursera's dropwizard datadog reporter.
* https://github.com/coursera/metrics-datadog
*/
package org.whispersystems.textsecuregcm.metrics;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.ScheduledReporter;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import io.dropwizard.metrics.BaseReporterFactory;
import java.util.EnumSet;
import java.util.List;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.coursera.metrics.datadog.DatadogReporter;
import org.coursera.metrics.datadog.DatadogReporter.Expansion;
import org.coursera.metrics.datadog.DefaultMetricNameFormatterFactory;
import org.coursera.metrics.datadog.DynamicTagsCallbackFactory;
import org.coursera.metrics.datadog.MetricNameFormatterFactory;
import org.coursera.metrics.datadog.transport.AbstractTransportFactory;
import org.whispersystems.textsecuregcm.util.HostnameUtil;
@JsonTypeName("signal-datadog")
public class SignalDatadogReporterFactory extends BaseReporterFactory {
@JsonProperty
private List<String> tags = null;
@Valid
@JsonProperty
private DynamicTagsCallbackFactory dynamicTagsCallback = null;
@JsonProperty
private String prefix = null;
@Valid
@NotNull
@JsonProperty
private MetricNameFormatterFactory metricNameFormatter = new DefaultMetricNameFormatterFactory();
@Valid
@NotNull
@JsonProperty
private AbstractTransportFactory transport = null;
private static final EnumSet<Expansion> EXPANSIONS = EnumSet.of(
Expansion.COUNT,
Expansion.MIN,
Expansion.MAX,
Expansion.MEAN,
Expansion.MEDIAN,
Expansion.P75,
Expansion.P95,
Expansion.P99,
Expansion.P999
);
public ScheduledReporter build(MetricRegistry registry) {
return DatadogReporter.forRegistry(registry)
.withTransport(transport.build())
.withHost(HostnameUtil.getLocalHostname())
.withTags(tags)
.withPrefix(prefix)
.withExpansions(EXPANSIONS)
.withMetricNameFormatter(metricNameFormatter.build())
.withDynamicTagCallback(dynamicTagsCallback != null ? dynamicTagsCallback.build() : null)
.filter(getFilter())
.convertDurationsTo(getDurationUnit())
.convertRatesTo(getRateUnit())
.build();
}
}

View File

@@ -110,8 +110,8 @@ public class GCMSender {
Device device = account.get().getDevice(message.getDeviceId()).get();
if (device.getUninstalledFeedbackTimestamp() == 0) {
device.setUninstalledFeedbackTimestamp(Util.todayInMillis());
accountsManager.update(account.get());
accountsManager.updateDevice(account.get(), message.getDeviceId(), d ->
d.setUninstalledFeedbackTimestamp(Util.todayInMillis()));
}
}
@@ -122,15 +122,11 @@ public class GCMSender {
logger.warn(String.format("Actually received 'CanonicalRegistrationId' ::: (canonical=%s), (original=%s)",
result.getCanonicalRegistrationId(), message.getGcmId()));
Optional<Account> account = getAccountForEvent(message);
if (account.isPresent()) {
//noinspection OptionalGetWithoutIsPresent
Device device = account.get().getDevice(message.getDeviceId()).get();
device.setGcmId(result.getCanonicalRegistrationId());
accountsManager.update(account.get());
}
getAccountForEvent(message).ifPresent(account ->
accountsManager.updateDevice(
account,
message.getDeviceId(),
d -> d.setGcmId(result.getCanonicalRegistrationId())));
canonical.mark();
}

View File

@@ -6,34 +6,32 @@ package org.whispersystems.textsecuregcm.sqs;
import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Metrics;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import com.google.common.collect.Iterables;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.SqsConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkServiceException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
public class DirectoryQueue {
public class DirectoryQueue implements Managed {
private static final Logger logger = LoggerFactory.getLogger(DirectoryQueue.class);
@@ -42,75 +40,113 @@ public class DirectoryQueue {
private final Meter clientErrorMeter = metricRegistry.meter(name(DirectoryQueue.class, "clientError"));
private final Timer sendMessageBatchTimer = metricRegistry.timer(name(DirectoryQueue.class, "sendMessageBatch"));
private final List<String> queueUrls;
private final AmazonSQS sqs;
private final List<String> queueUrls;
private final SqsAsyncClient sqs;
public DirectoryQueue(SqsConfiguration sqsConfig) {
final AWSCredentials credentials = new BasicAWSCredentials(sqsConfig.getAccessKey(), sqsConfig.getAccessSecret());
final AWSStaticCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider(credentials);
private final AtomicInteger outstandingRequests = new AtomicInteger();
this.queueUrls = sqsConfig.getQueueUrls();
this.sqs = AmazonSQSClientBuilder.standard().withRegion(sqsConfig.getRegion()).withCredentials(credentialsProvider).build();
}
private enum UpdateAction {
ADD("add"),
DELETE("delete");
@VisibleForTesting
DirectoryQueue(final List<String> queueUrls, final AmazonSQS sqs) {
this.queueUrls = queueUrls;
this.sqs = sqs;
}
private final String action;
public void refreshRegisteredUser(final Account account) {
refreshRegisteredUsers(List.of(account));
}
UpdateAction(final String action) {
this.action = action;
}
public void refreshRegisteredUsers(final List<Account> accounts) {
final List<Pair<Account, String>> accountsAndActions = accounts.stream()
.map(account -> new Pair<>(account, account.isEnabled() && account.isDiscoverableByPhoneNumber() ? "add" : "delete"))
.collect(Collectors.toList());
sendUpdateMessages(accountsAndActions);
}
public void deleteAccount(final Account account) {
sendUpdateMessages(List.of(new Pair<>(account, "delete")));
}
private void sendUpdateMessages(final List<Pair<Account, String>> accountsAndActions) {
for (final String queueUrl : queueUrls) {
for (final List<Pair<Account, String>> partition : Iterables.partition(accountsAndActions, 10)) {
final List<SendMessageBatchRequestEntry> entries = partition.stream().map(pair -> {
final Account account = pair.first();
final String action = pair.second();
return new SendMessageBatchRequestEntry()
.withMessageBody("-")
.withId(UUID.randomUUID().toString())
.withMessageDeduplicationId(UUID.randomUUID().toString())
.withMessageGroupId(account.getNumber())
.withMessageAttributes(Map.of(
"id", new MessageAttributeValue().withDataType("String").withStringValue(account.getNumber()),
"uuid", new MessageAttributeValue().withDataType("String").withStringValue(account.getUuid().toString()),
"action", new MessageAttributeValue().withDataType("String").withStringValue(action)
));
}).collect(Collectors.toList());
final SendMessageBatchRequest sendMessageBatchRequest = new SendMessageBatchRequest()
.withQueueUrl(queueUrl)
.withEntries(entries);
try (final Timer.Context ignored = sendMessageBatchTimer.time()) {
sqs.sendMessageBatch(sendMessageBatchRequest);
} catch (AmazonServiceException ex) {
serviceErrorMeter.mark();
logger.warn("sqs service error: ", ex);
} catch (AmazonClientException ex) {
clientErrorMeter.mark();
logger.warn("sqs client error: ", ex);
} catch (Throwable t) {
logger.warn("sqs unexpected error: ", t);
}
}
public MessageAttributeValue toMessageAttributeValue() {
return MessageAttributeValue.builder().dataType("String").stringValue(action).build();
}
}
public DirectoryQueue(SqsConfiguration sqsConfig) {
StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(AwsBasicCredentials.create(
sqsConfig.getAccessKey(), sqsConfig.getAccessSecret()));
this.queueUrls = sqsConfig.getQueueUrls();
this.sqs = SqsAsyncClient.builder()
.region(Region.of(sqsConfig.getRegion()))
.credentialsProvider(credentialsProvider)
.build();
Metrics.gauge(name(getClass(), "outstandingRequests"), outstandingRequests);
}
@VisibleForTesting
DirectoryQueue(final List<String> queueUrls, final SqsAsyncClient sqs) {
this.queueUrls = queueUrls;
this.sqs = sqs;
}
@Override
public void start() throws Exception {
}
@Override
public void stop() throws Exception {
synchronized (outstandingRequests) {
while (outstandingRequests.get() > 0) {
outstandingRequests.wait();
}
}
sqs.close();
}
public boolean isDiscoverable(final Account account) {
return account.isEnabled() && account.isDiscoverableByPhoneNumber();
}
public void refreshAccount(final Account account) {
sendUpdateMessage(account, isDiscoverable(account) ? UpdateAction.ADD : UpdateAction.DELETE);
}
public void deleteAccount(final Account account) {
sendUpdateMessage(account, UpdateAction.DELETE);
}
private void sendUpdateMessage(final Account account, final UpdateAction action) {
for (final String queueUrl : queueUrls) {
final Timer.Context timerContext = sendMessageBatchTimer.time();
final SendMessageRequest request = SendMessageRequest.builder()
.queueUrl(queueUrl)
.messageBody("-")
.messageDeduplicationId(UUID.randomUUID().toString())
.messageGroupId(account.getNumber())
.messageAttributes(Map.of(
"id", MessageAttributeValue.builder().dataType("String").stringValue(account.getNumber()).build(),
"uuid", MessageAttributeValue.builder().dataType("String").stringValue(account.getUuid().toString()).build(),
"action", action.toMessageAttributeValue()
))
.build();
synchronized (outstandingRequests) {
outstandingRequests.incrementAndGet();
}
sqs.sendMessage(request).whenComplete((response, cause) -> {
try {
if (cause instanceof SdkServiceException) {
serviceErrorMeter.mark();
logger.warn("sqs service error", cause);
} else if (cause instanceof SdkClientException) {
clientErrorMeter.mark();
logger.warn("sqs client error", cause);
} else if (cause != null) {
logger.warn("sqs unexpected error", cause);
}
} finally {
synchronized (outstandingRequests) {
outstandingRequests.decrementAndGet();
outstandingRequests.notifyAll();
}
timerContext.close();
}
});
}
}
}

View File

@@ -5,77 +5,91 @@
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import static io.micrometer.core.instrument.Metrics.counter;
import static io.micrometer.core.instrument.Metrics.timer;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import static com.codahale.metrics.MetricRegistry.name;
import static io.micrometer.core.instrument.Metrics.counter;
import static io.micrometer.core.instrument.Metrics.timer;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchWriteItemResponse;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
public class AbstractDynamoDbStore {
private final DynamoDbClient dynamoDbClient;
private final DynamoDbClient dynamoDbClient;
private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true");
private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false");
private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed"));
private final Timer batchWriteItemsFirstPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "true");
private final Timer batchWriteItemsRetryPass = timer(name(getClass(), "batchWriteItems"), "firstAttempt", "false");
private final Counter batchWriteItemsUnprocessed = counter(name(getClass(), "batchWriteItemsUnprocessed"));
private final Logger logger = LoggerFactory.getLogger(getClass());
private final Logger logger = LoggerFactory.getLogger(getClass());
private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25; // This was arbitrarily chosen and may be entirely too high.
public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25; // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this.
public static final int RESULT_SET_CHUNK_SIZE = 100;
private static final int MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE = 25; // This was arbitrarily chosen and may be entirely too high.
public static final int DYNAMO_DB_MAX_BATCH_SIZE = 25; // This limit comes from Amazon Dynamo DB itself. It will reject batch writes larger than this.
public static final int RESULT_SET_CHUNK_SIZE = 100;
public AbstractDynamoDbStore(final DynamoDbClient dynamoDbClient) {
this.dynamoDbClient = dynamoDbClient;
public AbstractDynamoDbStore(final DynamoDbClient dynamoDbClient) {
this.dynamoDbClient = dynamoDbClient;
}
protected DynamoDbClient db() {
return dynamoDbClient;
}
protected void executeTableWriteItemsUntilComplete(final Map<String, List<WriteRequest>> items) {
AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>();
batchWriteItemsFirstPass.record(
() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder().requestItems(items).build())));
int attemptCount = 0;
while (!outcome.get().unprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) {
batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder()
.requestItems(outcome.get().unprocessedItems())
.build())));
++attemptCount;
}
protected DynamoDbClient db() {
return dynamoDbClient;
if (!outcome.get().unprocessedItems().isEmpty()) {
int totalItems = outcome.get().unprocessedItems().values().stream().mapToInt(List::size).sum();
logger.error(
"Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.",
attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, totalItems);
batchWriteItemsUnprocessed.increment(totalItems);
}
}
protected void executeTableWriteItemsUntilComplete(final Map<String,List<WriteRequest>> items) {
AtomicReference<BatchWriteItemResponse> outcome = new AtomicReference<>();
batchWriteItemsFirstPass.record(() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder().requestItems(items).build())));
int attemptCount = 0;
while (!outcome.get().unprocessedItems().isEmpty() && attemptCount < MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE) {
batchWriteItemsRetryPass.record(() -> outcome.set(dynamoDbClient.batchWriteItem(BatchWriteItemRequest.builder()
.requestItems(outcome.get().unprocessedItems())
.build())));
++attemptCount;
}
if (!outcome.get().unprocessedItems().isEmpty()) {
int totalItems = outcome.get().unprocessedItems().values().stream().mapToInt(List::size).sum();
logger.error("Attempt count ({}) reached max ({}}) before applying all batch writes to dynamo. {} unprocessed items remain.", attemptCount, MAX_ATTEMPTS_TO_SAVE_BATCH_WRITE, totalItems);
batchWriteItemsUnprocessed.increment(totalItems);
}
protected List<Map<String, AttributeValue>> scan(ScanRequest scanRequest, int max) {
return db().scanPaginator(scanRequest)
.items()
.stream()
.limit(max)
.collect(Collectors.toList());
}
static <T> void writeInBatches(final Iterable<T> items, final Consumer<List<T>> action) {
final List<T> batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE);
for (T item : items) {
batch.add(item);
if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) {
action.accept(batch);
batch.clear();
}
}
static <T> void writeInBatches(final Iterable<T> items, final Consumer<List<T>> action) {
final List<T> batch = new ArrayList<>(DYNAMO_DB_MAX_BATCH_SIZE);
for (T item : items) {
batch.add(item);
if (batch.size() == DYNAMO_DB_MAX_BATCH_SIZE) {
action.accept(batch);
batch.clear();
}
}
if (!batch.isEmpty()) {
action.accept(batch);
}
if (!batch.isEmpty()) {
action.accept(batch);
}
}
}

View File

@@ -14,11 +14,19 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.security.auth.Subject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.util.Util;
public class Account implements Principal {
@JsonIgnore
private static final Logger logger = LoggerFactory.getLogger(Account.class);
@JsonIgnore
private UUID uuid;
@@ -58,12 +66,15 @@ public class Account implements Principal {
@JsonProperty("inCds")
private boolean discoverableByPhoneNumber = true;
@JsonProperty("_ddbV")
private int dynamoDbMigrationVersion;
@JsonIgnore
private Device authenticatedDevice;
@JsonProperty
private int version;
@JsonIgnore
private boolean stale;
public Account() {}
@VisibleForTesting
@@ -75,47 +86,68 @@ public class Account implements Principal {
}
public Optional<Device> getAuthenticatedDevice() {
requireNotStale();
return Optional.ofNullable(authenticatedDevice);
}
public void setAuthenticatedDevice(Device device) {
requireNotStale();
this.authenticatedDevice = device;
}
public UUID getUuid() {
// this is the one method that may be called on a stale account
return uuid;
}
public void setUuid(UUID uuid) {
requireNotStale();
this.uuid = uuid;
}
public void setNumber(String number) {
requireNotStale();
this.number = number;
}
public String getNumber() {
requireNotStale();
return number;
}
public void addDevice(Device device) {
requireNotStale();
this.devices.remove(device);
this.devices.add(device);
}
public void removeDevice(long deviceId) {
requireNotStale();
this.devices.remove(new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "NA", 0, null));
}
public Set<Device> getDevices() {
requireNotStale();
return devices;
}
public Optional<Device> getMasterDevice() {
requireNotStale();
return getDevice(Device.MASTER_ID);
}
public Optional<Device> getDevice(long deviceId) {
requireNotStale();
for (Device device : devices) {
if (device.getId() == deviceId) {
return Optional.of(device);
@@ -126,36 +158,58 @@ public class Account implements Principal {
}
public boolean isGroupsV2Supported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(Device::isGroupsV2Supported);
}
public boolean isStorageSupported() {
requireNotStale();
return devices.stream().anyMatch(device -> device.getCapabilities() != null && device.getCapabilities().isStorage());
}
public boolean isTransferSupported() {
requireNotStale();
return getMasterDevice().map(Device::getCapabilities).map(Device.DeviceCapabilities::isTransfer).orElse(false);
}
public boolean isGv1MigrationSupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isGv1Migration());
}
public boolean isSenderKeySupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isSenderKey());
}
public boolean isAnnouncementGroupSupported() {
requireNotStale();
return devices.stream()
.filter(Device::isEnabled)
.allMatch(device -> device.getCapabilities() != null && device.getCapabilities().isAnnouncementGroup());
}
public boolean isEnabled() {
requireNotStale();
return getMasterDevice().map(Device::isEnabled).orElse(false);
}
public long getNextDeviceId() {
requireNotStale();
long highestDevice = Device.MASTER_ID;
for (Device device : devices) {
@@ -170,6 +224,8 @@ public class Account implements Principal {
}
public int getEnabledDeviceCount() {
requireNotStale();
int count = 0;
for (Device device : devices) {
@@ -180,22 +236,32 @@ public class Account implements Principal {
}
public boolean isRateLimited() {
requireNotStale();
return true;
}
public Optional<String> getRelay() {
requireNotStale();
return Optional.empty();
}
public void setIdentityKey(String identityKey) {
requireNotStale();
this.identityKey = identityKey;
}
public String getIdentityKey() {
requireNotStale();
return identityKey;
}
public long getLastSeen() {
requireNotStale();
long lastSeen = 0;
for (Device device : devices) {
@@ -208,78 +274,139 @@ public class Account implements Principal {
}
public Optional<String> getCurrentProfileVersion() {
requireNotStale();
return Optional.ofNullable(currentProfileVersion);
}
public void setCurrentProfileVersion(String currentProfileVersion) {
requireNotStale();
this.currentProfileVersion = currentProfileVersion;
}
public String getProfileName() {
requireNotStale();
return name;
}
public void setProfileName(String name) {
requireNotStale();
this.name = name;
}
public String getAvatar() {
requireNotStale();
return avatar;
}
public void setAvatar(String avatar) {
requireNotStale();
this.avatar = avatar;
}
public void setPin(String pin) {
requireNotStale();
this.pin = pin;
}
public void setRegistrationLockFromAttributes(final AccountAttributes attributes) {
if (!Util.isEmpty(attributes.getPin())) {
setPin(attributes.getPin());
} else if (!Util.isEmpty(attributes.getRegistrationLock())) {
AuthenticationCredentials credentials = new AuthenticationCredentials(attributes.getRegistrationLock());
setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
} else {
setPin(null);
setRegistrationLock(null, null);
}
}
public void setRegistrationLock(String registrationLock, String registrationLockSalt) {
requireNotStale();
this.registrationLock = registrationLock;
this.registrationLockSalt = registrationLockSalt;
}
public StoredRegistrationLock getRegistrationLock() {
requireNotStale();
return new StoredRegistrationLock(Optional.ofNullable(registrationLock), Optional.ofNullable(registrationLockSalt), Optional.ofNullable(pin), getLastSeen());
}
public Optional<byte[]> getUnidentifiedAccessKey() {
requireNotStale();
return Optional.ofNullable(unidentifiedAccessKey);
}
public void setUnidentifiedAccessKey(byte[] unidentifiedAccessKey) {
requireNotStale();
this.unidentifiedAccessKey = unidentifiedAccessKey;
}
public boolean isUnrestrictedUnidentifiedAccess() {
requireNotStale();
return unrestrictedUnidentifiedAccess;
}
public void setUnrestrictedUnidentifiedAccess(boolean unrestrictedUnidentifiedAccess) {
requireNotStale();
this.unrestrictedUnidentifiedAccess = unrestrictedUnidentifiedAccess;
}
public boolean isFor(AmbiguousIdentifier identifier) {
requireNotStale();
if (identifier.hasUuid()) return identifier.getUuid().equals(uuid);
else if (identifier.hasNumber()) return identifier.getNumber().equals(number);
else throw new AssertionError();
}
public boolean isDiscoverableByPhoneNumber() {
requireNotStale();
return this.discoverableByPhoneNumber;
}
public void setDiscoverableByPhoneNumber(final boolean discoverableByPhoneNumber) {
requireNotStale();
this.discoverableByPhoneNumber = discoverableByPhoneNumber;
}
public int getDynamoDbMigrationVersion() {
return dynamoDbMigrationVersion;
public int getVersion() {
requireNotStale();
return version;
}
public void setDynamoDbMigrationVersion(int dynamoDbMigrationVersion) {
this.dynamoDbMigrationVersion = dynamoDbMigrationVersion;
public void setVersion(int version) {
requireNotStale();
this.version = version;
}
public void markStale() {
stale = true;
}
private void requireNotStale() {
assert !stale;
//noinspection ConstantConditions
if (stale) {
logger.error("Accessor called on stale account", new RuntimeException());
}
}
// Principal implementation

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public class AccountCrawlChunk {
private final List<Account> accounts;
@Nullable
private final UUID lastUuid;
public AccountCrawlChunk(final List<Account> accounts, @Nullable final UUID lastUuid) {
this.accounts = accounts;
this.lastUuid = lastUuid;
}
public List<Account> getAccounts() {
return accounts;
}
public Optional<UUID> getLastUuid() {
return Optional.ofNullable(lastUuid);
}
}

View File

@@ -4,22 +4,21 @@
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import io.dropwizard.lifecycle.Managed;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AccountDatabaseCrawler implements Managed, Runnable {
@@ -38,6 +37,8 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
private final AccountDatabaseCrawlerCache cache;
private final List<AccountDatabaseCrawlerListener> listeners;
private final DynamicConfigurationManager dynamicConfigurationManager;
private AtomicBoolean running = new AtomicBoolean(false);
private boolean finished;
@@ -45,7 +46,8 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
AccountDatabaseCrawlerCache cache,
List<AccountDatabaseCrawlerListener> listeners,
int chunkSize,
long chunkIntervalMs)
long chunkIntervalMs,
DynamicConfigurationManager dynamicConfigurationManager)
{
this.accounts = accounts;
this.chunkSize = chunkSize;
@@ -53,6 +55,8 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
this.workerId = UUID.randomUUID().toString();
this.cache = cache;
this.listeners = listeners;
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
@Override
@@ -93,14 +97,15 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
@VisibleForTesting
public boolean doPeriodicWork() {
if (cache.claimActiveWork(workerId, WORKER_TTL_MS)) {
try {
long startTimeMs = System.currentTimeMillis();
final long startTimeMs = System.currentTimeMillis();
processChunk();
if (cache.isAccelerated()) {
return true;
}
long endTimeMs = System.currentTimeMillis();
long sleepIntervalMs = chunkIntervalMs - (endTimeMs - startTimeMs);
final long endTimeMs = System.currentTimeMillis();
final long sleepIntervalMs = chunkIntervalMs - (endTimeMs - startTimeMs);
if (sleepIntervalMs > 0) sleepWhileRunning(sleepIntervalMs);
} finally {
cache.releaseActiveWork(workerId);
@@ -110,42 +115,67 @@ public class AccountDatabaseCrawler implements Managed, Runnable {
}
private void processChunk() {
Optional<UUID> fromUuid = cache.getLastUuid();
final boolean useDynamo = dynamicConfigurationManager.getConfiguration()
.getAccountsDynamoDbMigrationConfiguration()
.isDynamoCrawlerEnabled();
if (!fromUuid.isPresent()) {
listeners.stream().filter(listener -> listener instanceof DirectoryReconciler)
.forEach(reconciler -> ((DirectoryReconciler) reconciler).setUseV3Endpoints(useDynamo));
final Optional<UUID> fromUuid = getLastUuid(useDynamo);
if (fromUuid.isEmpty()) {
listeners.forEach(AccountDatabaseCrawlerListener::onCrawlStart);
}
List<Account> chunkAccounts = readChunk(fromUuid, chunkSize);
final AccountCrawlChunk chunkAccounts = readChunk(fromUuid, chunkSize, useDynamo);
if (chunkAccounts.isEmpty()) {
if (chunkAccounts.getAccounts().isEmpty()) {
logger.info("Finished crawl");
listeners.forEach(listener -> listener.onCrawlEnd(fromUuid));
cache.setLastUuid(Optional.empty());
cacheLastUuid(Optional.empty(), useDynamo);
cache.setAccelerated(false);
} else {
try {
for (AccountDatabaseCrawlerListener listener : listeners) {
listener.timeAndProcessCrawlChunk(fromUuid, chunkAccounts);
listener.timeAndProcessCrawlChunk(fromUuid, chunkAccounts.getAccounts());
}
cache.setLastUuid(Optional.of(chunkAccounts.get(chunkAccounts.size() - 1).getUuid()));
cacheLastUuid(chunkAccounts.getLastUuid(), useDynamo);
} catch (AccountDatabaseCrawlerRestartException e) {
cache.setLastUuid(Optional.empty());
cacheLastUuid(Optional.empty(), useDynamo);
cache.setAccelerated(false);
}
}
}
private List<Account> readChunk(Optional<UUID> fromUuid, int chunkSize) {
private AccountCrawlChunk readChunk(Optional<UUID> fromUuid, int chunkSize, boolean useDynamo) {
try (Timer.Context timer = readChunkTimer.time()) {
if (fromUuid.isPresent()) {
return accounts.getAllFrom(fromUuid.get(), chunkSize);
return useDynamo
? accounts.getAllFromDynamo(fromUuid.get(), chunkSize)
: accounts.getAllFrom(fromUuid.get(), chunkSize);
}
return accounts.getAllFrom(chunkSize);
return useDynamo
? accounts.getAllFromDynamo(chunkSize)
: accounts.getAllFrom(chunkSize);
}
}
private Optional<UUID> getLastUuid(final boolean useDynamo) {
if (useDynamo) {
return cache.getLastUuidDynamo();
} else {
return cache.getLastUuid();
}
}
private void cacheLastUuid(final Optional<UUID> lastUuid, final boolean useDynamo) {
if (useDynamo) {
cache.setLastUuidDynamo(lastUuid);
} else {
cache.setLastUuid(lastUuid);
}
}

View File

@@ -21,6 +21,8 @@ public class AccountDatabaseCrawlerCache {
private static final String LAST_UUID_KEY = "account_database_crawler_cache_last_uuid";
private static final String ACCELERATE_KEY = "account_database_crawler_cache_accelerate";
private static final String LAST_UUID_DYNAMO_KEY = "account_database_crawler_cache_last_uuid_dynamo";
private static final long LAST_NUMBER_TTL_MS = 86400_000L;
private final FaultTolerantRedisCluster cacheCluster;
@@ -66,4 +68,19 @@ public class AccountDatabaseCrawlerCache {
}
}
public Optional<UUID> getLastUuidDynamo() {
final String lastUuidString = cacheCluster.withCluster(connection -> connection.sync().get(LAST_UUID_DYNAMO_KEY));
if (lastUuidString == null) return Optional.empty();
else return Optional.of(UUID.fromString(lastUuidString));
}
public void setLastUuidDynamo(Optional<UUID> lastUuid) {
if (lastUuid.isPresent()) {
cacheCluster.useCluster(connection -> connection.sync().psetex(LAST_UUID_DYNAMO_KEY, LAST_NUMBER_TTL_MS, lastUuid.get().toString()));
} else {
cacheCluster.useCluster(connection -> connection.sync().del(LAST_UUID_DYNAMO_KEY));
}
}
}

View File

@@ -7,7 +7,7 @@ public interface AccountStore {
boolean create(Account account);
void update(Account account);
void update(Account account) throws ContestedOptimisticLockException;
Optional<Account> get(String number);

View File

@@ -9,9 +9,11 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.codahale.metrics.Timer.Context;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.jdbi.v3.core.transaction.TransactionIsolationLevel;
@@ -21,10 +23,11 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
public class Accounts implements AccountStore {
public static final String ID = "id";
public static final String UID = "uuid";
public static final String ID = "id";
public static final String UID = "uuid";
public static final String NUMBER = "number";
public static final String DATA = "data";
public static final String DATA = "data";
public static final String VERSION = "version";
private static final ObjectMapper mapper = SystemMapper.getMapper();
@@ -49,15 +52,19 @@ public class Accounts implements AccountStore {
public boolean create(Account account) {
return database.with(jdbi -> jdbi.inTransaction(TransactionIsolationLevel.SERIALIZABLE, handle -> {
try (Timer.Context ignored = createTimer.time()) {
UUID uuid = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET data = EXCLUDED.data RETURNING uuid")
.bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.mapTo(UUID.class)
.findOnly();
final Map<String, Object> resultMap = handle.createQuery("INSERT INTO accounts (" + NUMBER + ", " + UID + ", " + DATA + ") VALUES (:number, :uuid, CAST(:data AS json)) ON CONFLICT(number) DO UPDATE SET " + DATA + " = EXCLUDED.data, " + VERSION + " = accounts.version + 1 RETURNING uuid, version")
.bind("number", account.getNumber())
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.mapToMap()
.findOnly();
final UUID uuid = (UUID) resultMap.get(UID);
final int version = (int) resultMap.get(VERSION);
boolean isNew = uuid.equals(account.getUuid());
account.setUuid(uuid);
account.setVersion(version);
return isNew;
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
@@ -66,13 +73,23 @@ public class Accounts implements AccountStore {
}
@Override
public void update(Account account) {
public void update(Account account) throws ContestedOptimisticLockException {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = updateTimer.time()) {
handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json) WHERE " + UID + " = :uuid")
final int newVersion = account.getVersion() + 1;
int rowsModified = handle.createUpdate("UPDATE accounts SET " + DATA + " = CAST(:data AS json), " + VERSION + " = :newVersion WHERE " + UID + " = :uuid AND " + VERSION + " = :version")
.bind("uuid", account.getUuid())
.bind("data", mapper.writeValueAsString(account))
.bind("version", account.getVersion())
.bind("newVersion", newVersion)
.execute();
if (rowsModified == 0) {
throw new ContestedOptimisticLockException();
}
account.setVersion(newVersion);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
@@ -103,20 +120,21 @@ public class Accounts implements AccountStore {
}));
}
public List<Account> getAllFrom(UUID from, int length) {
return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context ignored = getAllFromOffsetTimer.time()) {
public AccountCrawlChunk getAllFrom(UUID from, int length) {
final List<Account> accounts = database.with(jdbi -> jdbi.withHandle(handle -> {
try (Context ignored = getAllFromOffsetTimer.time()) {
return handle.createQuery("SELECT * FROM accounts WHERE " + UID + " > :from ORDER BY " + UID + " LIMIT :limit")
.bind("from", from)
.bind("limit", length)
.mapTo(Account.class)
.list();
.bind("from", from)
.bind("limit", length)
.mapTo(Account.class)
.list();
}
}));
return buildChunkForAccounts(accounts);
}
public List<Account> getAllFrom(int length) {
return database.with(jdbi -> jdbi.withHandle(handle -> {
public AccountCrawlChunk getAllFrom(int length) {
final List<Account> accounts = database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context ignored = getAllFromTimer.time()) {
return handle.createQuery("SELECT * FROM accounts ORDER BY " + UID + " LIMIT :limit")
.bind("limit", length)
@@ -124,6 +142,12 @@ public class Accounts implements AccountStore {
.list();
}
}));
return buildChunkForAccounts(accounts);
}
private AccountCrawlChunk buildChunkForAccounts(final List<Account> accounts) {
return new AccountCrawlChunk(accounts, accounts.isEmpty() ? null : accounts.get(accounts.size() - 1).getUuid());
}
@Override

View File

@@ -1,3 +1,7 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
@@ -26,15 +30,20 @@ import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.CancellationReason;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
import software.amazon.awssdk.services.dynamodb.model.Delete;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.Put;
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
import software.amazon.awssdk.services.dynamodb.model.ReturnValuesOnConditionCheckFailure;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException;
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemResponse;
public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountStore {
@@ -44,8 +53,8 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
static final String ATTR_ACCOUNT_E164 = "P";
// account, serialized to JSON
static final String ATTR_ACCOUNT_DATA = "D";
static final String ATTR_MIGRATION_VERSION = "V";
// internal version for optimistic locking
static final String ATTR_VERSION = "V";
private final DynamoDbClient client;
private final DynamoDbAsyncClient asyncClient;
@@ -62,6 +71,8 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
private static final Timer UPDATE_TIMER = Metrics.timer(name(AccountsDynamoDb.class, "update"));
private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(AccountsDynamoDb.class, "getByNumber"));
private static final Timer GET_BY_UUID_TIMER = Metrics.timer(name(AccountsDynamoDb.class, "getByUuid"));
private static final Timer GET_ALL_FROM_START_TIMER = Metrics.timer(name(AccountsDynamoDb.class, "getAllFrom"));
private static final Timer GET_ALL_FROM_OFFSET_TIMER = Metrics.timer(name(AccountsDynamoDb.class, "getAllFromOffset"));
private static final Timer DELETE_TIMER = Metrics.timer(name(AccountsDynamoDb.class, "delete"));
private final Logger logger = LoggerFactory.getLogger(AccountsDynamoDb.class);
@@ -115,11 +126,19 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
ByteBuffer actualAccountUuid = phoneNumberConstraintCancellationReason.item().get(KEY_ACCOUNT_UUID).b().asByteBuffer();
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final int version = get(account.getUuid()).get().getVersion();
account.setVersion(version);
update(account);
return false;
}
if ("TransactionConflict".equals(accountCancellationReason.code())) {
// this should only happen during concurrent update()s for an account migration
throw new ContestedOptimisticLockException();
}
// this shouldnt happen
throw new RuntimeException("could not create account: " + extractCancellationReasonCodes(e));
}
@@ -139,7 +158,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(uuid),
ATTR_ACCOUNT_E164, AttributeValues.fromString(account.getNumber()),
ATTR_ACCOUNT_DATA, AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
ATTR_MIGRATION_VERSION, AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))
ATTR_VERSION, AttributeValues.fromInt(account.getVersion())))
.build())
.build();
}
@@ -165,28 +184,44 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
}
@Override
public void update(Account account) {
public void update(Account account) throws ContestedOptimisticLockException {
UPDATE_TIMER.record(() -> {
UpdateItemRequest updateItemRequest;
try {
updateItemRequest = UpdateItemRequest.builder()
.tableName(accountsTableName)
.key(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(account.getUuid())))
.updateExpression("SET #data = :data, #version = :version")
.conditionExpression("attribute_exists(#number)")
.updateExpression("SET #data = :data ADD #version :version_increment")
.conditionExpression("attribute_exists(#number) AND #version = :version")
.expressionAttributeNames(Map.of("#number", ATTR_ACCOUNT_E164,
"#data", ATTR_ACCOUNT_DATA,
"#version", ATTR_MIGRATION_VERSION))
"#version", ATTR_VERSION))
.expressionAttributeValues(Map.of(
":data", AttributeValues.fromByteArray(SystemMapper.getMapper().writeValueAsBytes(account)),
":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion())))
":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1)))
.returnValues(ReturnValue.UPDATED_NEW)
.build();
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
client.updateItem(updateItemRequest);
try {
UpdateItemResponse response = client.updateItem(updateItemRequest);
account.setVersion(AttributeValues.getInt(response.attributes(), "V", account.getVersion() + 1));
} catch (final TransactionConflictException e) {
throw new ContestedOptimisticLockException();
} catch (final ConditionalCheckFailedException e) {
// the exception doesnt give details about which condition failed,
// but we can infer it was an optimistic locking failure if the UUID is known
throw get(account.getUuid()).isPresent() ? new ContestedOptimisticLockException() : e;
}
});
}
@@ -230,6 +265,33 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
});
}
public AccountCrawlChunk getAllFrom(final UUID from, final int maxCount, final int pageSize) {
final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder()
.limit(pageSize)
.exclusiveStartKey(Map.of(KEY_ACCOUNT_UUID, AttributeValues.fromUUID(from)));
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_OFFSET_TIMER);
}
public AccountCrawlChunk getAllFromStart(final int maxCount, final int pageSize) {
final ScanRequest.Builder scanRequestBuilder = ScanRequest.builder()
.limit(pageSize);
return scanForChunk(scanRequestBuilder, maxCount, GET_ALL_FROM_START_TIMER);
}
private AccountCrawlChunk scanForChunk(final ScanRequest.Builder scanRequestBuilder, final int maxCount, final Timer timer) {
scanRequestBuilder.tableName(accountsTableName);
final List<Account> accounts = timer.record(() -> scan(scanRequestBuilder.build(), maxCount)
.stream()
.map(AccountsDynamoDb::fromItem)
.collect(Collectors.toList()));
return new AccountCrawlChunk(accounts, accounts.size() > 0 ? accounts.get(accounts.size() - 1).getUuid() : null);
}
private void delete(UUID uuid, boolean saveInDeletedAccountsTable) {
if (saveInDeletedAccountsTable) {
@@ -309,9 +371,9 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
.conditionExpression("attribute_not_exists(#uuid) OR (attribute_exists(#uuid) AND #version < :version)")
.expressionAttributeNames(Map.of(
"#uuid", KEY_ACCOUNT_UUID,
"#version", ATTR_MIGRATION_VERSION))
"#version", ATTR_VERSION))
.expressionAttributeValues(Map.of(
":version", AttributeValues.fromInt(account.getDynamoDbMigrationVersion()))));
":version", AttributeValues.fromInt(account.getVersion()))));
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, accountPut).build();
@@ -332,7 +394,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
return;
}
try {
migrationRetryAccounts.put(account.getUuid());
migrationRetryAccounts.put(account.getUuid());
} catch (final Exception e) {
logger.error("Could not store account {}", account.getUuid());
}
@@ -361,6 +423,7 @@ public class AccountsDynamoDb extends AbstractDynamoDbStore implements AccountSt
Account account = SystemMapper.getMapper().readValue(item.get(ATTR_ACCOUNT_DATA).b().asByteArray(), Account.class);
account.setNumber(item.get(ATTR_ACCOUNT_E164).s());
account.setUuid(UUIDUtil.fromByteBuffer(item.get(KEY_ACCOUNT_UUID).b().asByteBuffer()));
account.setVersion(Integer.parseInt(item.get(ATTR_VERSION).n()));
return account;

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
@@ -17,9 +18,9 @@ import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
@@ -27,12 +28,17 @@ import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import net.logstash.logback.argument.StructuredArguments;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
@@ -41,7 +47,6 @@ import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException;
public class AccountsManager {
@@ -52,11 +57,16 @@ public class AccountsManager {
private static final Timer getByUuidTimer = metricRegistry.timer(name(AccountsManager.class, "getByUuid" ));
private static final Timer deleteTimer = metricRegistry.timer(name(AccountsManager.class, "delete"));
// TODO Remove this meter when external dependencies have been resolved
// Note that this is deliberately namespaced to `AccountController` for metric continuity.
private static final Meter newUserMeter = metricRegistry.meter(name(AccountController.class, "brand_new_user"));
private static final Timer redisSetTimer = metricRegistry.timer(name(AccountsManager.class, "redisSet" ));
private static final Timer redisNumberGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisNumberGet"));
private static final Timer redisUuidGetTimer = metricRegistry.timer(name(AccountsManager.class, "redisUuidGet" ));
private static final Timer redisDeleteTimer = metricRegistry.timer(name(AccountsManager.class, "redisDelete" ));
private static final String CREATE_COUNTER_NAME = name(AccountsManager.class, "createCounter");
private static final String DELETE_COUNTER_NAME = name(AccountsManager.class, "deleteCounter");
private static final String DELETE_ERROR_COUNTER_NAME = name(AccountsManager.class, "deleteError");
private static final String COUNTRY_CODE_TAG_NAME = "country";
@@ -71,11 +81,13 @@ public class AccountsManager {
private final Accounts accounts;
private final AccountsDynamoDb accountsDynamoDb;
private final FaultTolerantRedisCluster cacheCluster;
private final DeletedAccountsManager deletedAccountsManager;
private final DirectoryQueue directoryQueue;
private final KeysDynamoDb keysDynamoDb;
private final MessagesManager messagesManager;
private final UsernamesManager usernamesManager;
private final ProfilesManager profilesManager;
private final StoredVerificationCodeManager pendingAccounts;
private final SecureStorageClient secureStorageClient;
private final SecureBackupClient secureBackupClient;
private final ObjectMapper mapper;
@@ -97,33 +109,65 @@ public class AccountsManager {
}
}
public AccountsManager(Accounts accounts, AccountsDynamoDb accountsDynamoDb, FaultTolerantRedisCluster cacheCluster, final DirectoryQueue directoryQueue,
public AccountsManager(Accounts accounts, AccountsDynamoDb accountsDynamoDb, FaultTolerantRedisCluster cacheCluster,
final DeletedAccountsManager deletedAccountsManager,
final DirectoryQueue directoryQueue,
final KeysDynamoDb keysDynamoDb, final MessagesManager messagesManager, final UsernamesManager usernamesManager,
final ProfilesManager profilesManager, final SecureStorageClient secureStorageClient,
final ProfilesManager profilesManager,
final StoredVerificationCodeManager pendingAccounts,
final SecureStorageClient secureStorageClient,
final SecureBackupClient secureBackupClient,
final ExperimentEnrollmentManager experimentEnrollmentManager, final DynamicConfigurationManager dynamicConfigurationManager) {
final ExperimentEnrollmentManager experimentEnrollmentManager,
final DynamicConfigurationManager dynamicConfigurationManager) {
this.accounts = accounts;
this.accountsDynamoDb = accountsDynamoDb;
this.cacheCluster = cacheCluster;
this.deletedAccountsManager = deletedAccountsManager;
this.directoryQueue = directoryQueue;
this.keysDynamoDb = keysDynamoDb;
this.messagesManager = messagesManager;
this.usernamesManager = usernamesManager;
this.profilesManager = profilesManager;
this.pendingAccounts = pendingAccounts;
this.secureStorageClient = secureStorageClient;
this.secureBackupClient = secureBackupClient;
this.mapper = SystemMapper.getMapper();
this.migrationComparisonMapper = mapper.copy();
migrationComparisonMapper.addMixIn(Account.class, AccountComparisonMixin.class);
migrationComparisonMapper.addMixIn(Device.class, DeviceComparisonMixin.class);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.experimentEnrollmentManager = experimentEnrollmentManager;
}
public boolean create(Account account) {
public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes) {
try (Timer.Context ignored = createTimer.time()) {
Optional<Account> maybeExistingAccount = get(number);
Device device = new Device();
device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
Account account = new Account();
account.setNumber(number);
account.setUuid(UUID.randomUUID());
account.addDevice(device);
account.setRegistrationLockFromAttributes(accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(accountAttributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(accountAttributes.isDiscoverableByPhoneNumber());
final UUID originalUuid = account.getUuid();
boolean freshUser = databaseCreate(account);
@@ -161,29 +205,120 @@ public class AccountsManager {
redisSet(account);
return freshUser;
final Tags tags;
if (freshUser) {
tags = Tags.of("type", "new");
} else {
tags = Tags.of("type", "reregister");
}
Metrics.counter(CREATE_COUNTER_NAME, tags).increment();
if (!account.isDiscoverableByPhoneNumber()) {
// The newly-created account has explicitly opted out of discoverability
directoryQueue.deleteAccount(account);
}
maybeExistingAccount.ifPresent(definitelyExistingAccount -> {
messagesManager.clear(definitelyExistingAccount.getUuid());
keysDynamoDb.delete(definitelyExistingAccount);
});
pendingAccounts.remove(number);
return account;
}
}
public void update(Account account) {
public Account update(Account account, Consumer<Account> updater) {
final boolean wasDiscoverableBeforeUpdate = directoryQueue.isDiscoverable(account);
final Account updatedAccount;
try (Timer.Context ignored = updateTimer.time()) {
account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1);
redisSet(account);
databaseUpdate(account);
updater.accept(account);
{
// optimistically increment version
final int originalVersion = account.getVersion();
account.setVersion(originalVersion + 1);
redisSet(account);
account.setVersion(originalVersion);
}
final UUID uuid = account.getUuid();
updatedAccount = updateWithRetries(account, updater, this::databaseUpdate, () -> databaseGet(uuid).get());
if (dynamoWriteEnabled()) {
runSafelyAndRecordMetrics(() -> {
try {
dynamoUpdate(account);
} catch (final ConditionalCheckFailedException e) {
dynamoCreate(account);
final Optional<Account> dynamoAccount = dynamoGet(uuid);
if (dynamoAccount.isPresent()) {
updater.accept(dynamoAccount.get());
Account dynamoUpdatedAccount = updateWithRetries(dynamoAccount.get(),
updater,
this::dynamoUpdate,
() -> dynamoGet(uuid).get());
return Optional.of(dynamoUpdatedAccount);
}
return true;
}, Optional.of(account.getUuid()), true,
(databaseSuccess, dynamoSuccess) -> Optional.empty(), // both values are always true
return Optional.empty();
}, Optional.of(uuid), Optional.of(updatedAccount),
this::compareAccounts,
"update");
}
// set the cache again, so that all updates are coalesced
redisSet(updatedAccount);
}
final boolean isDiscoverableAfterUpdate = directoryQueue.isDiscoverable(updatedAccount);
if (wasDiscoverableBeforeUpdate != isDiscoverableAfterUpdate) {
directoryQueue.refreshAccount(updatedAccount);
}
return updatedAccount;
}
private Account updateWithRetries(Account account, Consumer<Account> updater, Consumer<Account> persister, Supplier<Account> retriever) {
final int maxTries = 10;
int tries = 0;
while (tries < maxTries) {
try {
persister.accept(account);
final Account updatedAccount;
try {
updatedAccount = mapper.readValue(mapper.writeValueAsBytes(account), Account.class);
updatedAccount.setUuid(account.getUuid());
} catch (final IOException e) {
// this should really, truly, never happen
throw new IllegalArgumentException(e);
}
account.markStale();
return updatedAccount;
} catch (final ContestedOptimisticLockException e) {
tries++;
account = retriever.get();
updater.accept(account);
}
}
throw new OptimisticLockRetryLimitExceededException();
}
public Account updateDevice(Account account, long deviceId, Consumer<Device> deviceUpdater) {
return update(account, a -> a.getDevice(deviceId).ifPresent(deviceUpdater));
}
public Optional<Account> get(AmbiguousIdentifier identifier) {
@@ -229,15 +364,27 @@ public class AccountsManager {
}
public List<Account> getAllFrom(int length) {
public AccountCrawlChunk getAllFrom(int length) {
return accounts.getAllFrom(length);
}
public List<Account> getAllFrom(UUID uuid, int length) {
public AccountCrawlChunk getAllFrom(UUID uuid, int length) {
return accounts.getAllFrom(uuid, length);
}
public void delete(final Account account, final DeletionReason deletionReason) {
public AccountCrawlChunk getAllFromDynamo(int length) {
final int maxPageSize = dynamicConfigurationManager.getConfiguration().getAccountsDynamoDbMigrationConfiguration()
.getDynamoCrawlerScanPageSize();
return accountsDynamoDb.getAllFromStart(length, maxPageSize);
}
public AccountCrawlChunk getAllFromDynamo(UUID uuid, int length) {
final int maxPageSize = dynamicConfigurationManager.getConfiguration().getAccountsDynamoDbMigrationConfiguration()
.getDynamoCrawlerScanPageSize();
return accountsDynamoDb.getAllFrom(uuid, length, maxPageSize);
}
public void delete(final Account account, final DeletionReason deletionReason) throws InterruptedException {
try (final Timer.Context ignored = deleteTimer.time()) {
final CompletableFuture<Void> deleteStorageServiceDataFuture = secureStorageClient.deleteStoredData(account.getUuid());
final CompletableFuture<Void> deleteBackupServiceDataFuture = secureBackupClient.deleteBackups(account.getUuid());
@@ -263,7 +410,9 @@ public class AccountsManager {
}
}
} catch (final Exception e) {
deletedAccountsManager.put(account.getUuid(), account.getNumber());
} catch (final RuntimeException | InterruptedException e) {
logger.warn("Failed to delete account", e);
Metrics.counter(DELETE_ERROR_COUNTER_NAME,
@@ -427,6 +576,10 @@ public class AccountsManager {
return Optional.of("number");
}
if (databaseAccount.getVersion() != dynamoAccount.getVersion()) {
return Optional.of("version");
}
if (!Objects.equals(databaseAccount.getIdentityKey(), dynamoAccount.getIdentityKey())) {
return Optional.of("identityKey");
}
@@ -548,13 +701,6 @@ public class AccountsManager {
.collect(Collectors.joining(" -> "));
}
private static abstract class AccountComparisonMixin extends Account {
@JsonIgnore
private int dynamoDbMigrationVersion;
}
private static abstract class DeviceComparisonMixin extends Device {
@JsonIgnore

View File

@@ -0,0 +1,16 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class ChunkProcessingFailedException extends Exception {
public ChunkProcessingFailedException(String message) {
super(message);
}
public ChunkProcessingFailedException(Exception cause) {
super(cause);
}
}

View File

@@ -0,0 +1,13 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class ContestedOptimisticLockException extends RuntimeException {
public ContestedOptimisticLockException() {
super(null, null, true, false);
}
}

View File

@@ -0,0 +1,132 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.BatchGetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.BatchGetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.KeysAndAttributes;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
public class DeletedAccounts extends AbstractDynamoDbStore {
// e164, primary key
static final String KEY_ACCOUNT_E164 = "P";
static final String ATTR_ACCOUNT_UUID = "U";
static final String ATTR_EXPIRES = "E";
static final String ATTR_NEEDS_CDS_RECONCILIATION = "R";
static final Duration TIME_TO_LIVE = Duration.ofDays(30);
// Note that this limit is imposed by DynamoDB itself; going above 100 will result in errors
static final int GET_BATCH_SIZE = 100;
private final String tableName;
private final String needsReconciliationIndexName;
public DeletedAccounts(final DynamoDbClient dynamoDb, final String tableName, final String needsReconciliationIndexName) {
super(dynamoDb);
this.tableName = tableName;
this.needsReconciliationIndexName = needsReconciliationIndexName;
}
void put(UUID uuid, String e164) {
db().putItem(PutItemRequest.builder()
.tableName(tableName)
.item(Map.of(
KEY_ACCOUNT_E164, AttributeValues.fromString(e164),
ATTR_ACCOUNT_UUID, AttributeValues.fromUUID(uuid),
ATTR_EXPIRES, AttributeValues.fromLong(Instant.now().plus(TIME_TO_LIVE).getEpochSecond()),
ATTR_NEEDS_CDS_RECONCILIATION, AttributeValues.fromInt(1)))
.build());
}
List<Pair<UUID, String>> listAccountsToReconcile(final int max) {
final ScanRequest scanRequest = ScanRequest.builder()
.tableName(tableName)
.indexName(needsReconciliationIndexName)
.limit(max)
.build();
return scan(scanRequest, max)
.stream()
.map(item -> new Pair<>(
AttributeValues.getUUID(item, ATTR_ACCOUNT_UUID, null),
AttributeValues.getString(item, KEY_ACCOUNT_E164, null)))
.collect(Collectors.toList());
}
Set<String> getAccountsNeedingReconciliation(final Collection<String> e164s) {
final Queue<Map<String, AttributeValue>> pendingKeys = e164s.stream()
.map(e164 -> Map.of(KEY_ACCOUNT_E164, AttributeValues.fromString(e164)))
.collect(Collectors.toCollection(() -> new ArrayDeque<>(e164s.size())));
final Set<String> accountsNeedingReconciliation = new HashSet<>(e164s.size());
final List<Map<String, AttributeValue>> batchKeys = new ArrayList<>(GET_BATCH_SIZE);
while (!pendingKeys.isEmpty()) {
batchKeys.clear();
for (int i = 0; i < GET_BATCH_SIZE && !pendingKeys.isEmpty(); i++) {
batchKeys.add(pendingKeys.remove());
}
final BatchGetItemResponse response = db().batchGetItem(BatchGetItemRequest.builder()
.requestItems(Map.of(tableName, KeysAndAttributes.builder()
.consistentRead(true)
.keys(batchKeys)
.build()))
.build());
response.responses().getOrDefault(tableName, Collections.emptyList()).stream()
.filter(attributes -> AttributeValues.getInt(attributes, ATTR_NEEDS_CDS_RECONCILIATION, 0) == 1)
.map(attributes -> AttributeValues.getString(attributes, KEY_ACCOUNT_E164, null))
.forEach(accountsNeedingReconciliation::add);
if (response.hasUnprocessedKeys() && response.unprocessedKeys().containsKey(tableName)) {
pendingKeys.addAll(response.unprocessedKeys().get(tableName).keys());
}
}
return accountsNeedingReconciliation;
}
void markReconciled(final Collection<String> phoneNumbersReconciled) {
phoneNumbersReconciled.forEach(number -> db().updateItem(
UpdateItemRequest.builder()
.tableName(tableName)
.key(Map.of(
KEY_ACCOUNT_E164, AttributeValues.fromString(number)
))
.updateExpression("REMOVE #needs_reconciliation")
.expressionAttributeNames(Map.of(
"#needs_reconciliation", ATTR_NEEDS_CDS_RECONCILIATION
))
.build()
));
}
}

View File

@@ -0,0 +1,72 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest.User;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationResponse;
public class DeletedAccountsDirectoryReconciler {
private final Logger logger = LoggerFactory.getLogger(DeletedAccountsDirectoryReconciler.class);
private final DirectoryReconciliationClient directoryReconciliationClient;
private final Timer deleteTimer;
private final Counter errorCounter;
public DeletedAccountsDirectoryReconciler(
final String replicationName,
final DirectoryReconciliationClient directoryReconciliationClient) {
this.directoryReconciliationClient = directoryReconciliationClient;
deleteTimer = Timer.builder(name(DeletedAccountsDirectoryReconciler.class, "delete"))
.tag("replicationName", replicationName)
.register(Metrics.globalRegistry);
errorCounter = Counter.builder(name(DeletedAccountsDirectoryReconciler.class, "error"))
.tag("replicationName", replicationName)
.register(Metrics.globalRegistry);
}
public void onCrawlChunk(final List<User> deletedUsers) throws ChunkProcessingFailedException {
try {
deleteTimer.recordCallable(() -> {
try {
final DirectoryReconciliationResponse response = directoryReconciliationClient.delete(new DirectoryReconciliationRequest(null, null, deletedUsers));
if (response.getStatus() != DirectoryReconciliationResponse.Status.OK) {
errorCounter.increment();
throw new ChunkProcessingFailedException("Response status: " + response.getStatus());
}
} catch (final Exception e) {
errorCounter.increment();
throw new ChunkProcessingFailedException(e);
}
return null;
});
} catch (final ChunkProcessingFailedException e) {
throw e;
} catch (final Exception e) {
logger.warn("Unexpected exception", e);
throw new RuntimeException(e);
}
}
}

View File

@@ -0,0 +1,120 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.amazonaws.services.dynamodbv2.AcquireLockOptions;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBLockClient;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBLockClientOptions;
import com.amazonaws.services.dynamodbv2.LockItem;
import com.amazonaws.services.dynamodbv2.ReleaseLockOptions;
import com.amazonaws.services.dynamodbv2.model.LockCurrentlyUnavailableException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.Pair;
public class DeletedAccountsManager {
private final DeletedAccounts deletedAccounts;
private final AmazonDynamoDBLockClient lockClient;
private static final Logger log = LoggerFactory.getLogger(DeletedAccountsManager.class);
@FunctionalInterface
public interface DeletedAccountReconciliationConsumer {
/**
* Reconcile a list of deleted account records.
*
* @param deletedAccounts the account records to reconcile
* @return a list of account records that were successfully reconciled; accounts that were not successfully
* reconciled may be retried later
* @throws ChunkProcessingFailedException in the event of an error while processing the batch of account records
*/
Collection<String> reconcile(List<Pair<UUID, String>> deletedAccounts) throws ChunkProcessingFailedException;
}
public DeletedAccountsManager(final DeletedAccounts deletedAccounts, final AmazonDynamoDB lockDynamoDb, final String lockTableName) {
this.deletedAccounts = deletedAccounts;
lockClient = new AmazonDynamoDBLockClient(
AmazonDynamoDBLockClientOptions.builder(lockDynamoDb, lockTableName)
.withPartitionKeyName(DeletedAccounts.KEY_ACCOUNT_E164)
.withLeaseDuration(15L)
.withHeartbeatPeriod(2L)
.withTimeUnit(TimeUnit.SECONDS)
.withCreateHeartbeatBackgroundThread(true)
.build());
}
public void put(final UUID uuid, final String e164) throws InterruptedException {
withLock(e164, () -> deletedAccounts.put(uuid, e164));
}
private void withLock(final String e164, final Runnable task) throws InterruptedException {
final LockItem lockItem = lockClient.acquireLock(AcquireLockOptions.builder(e164)
.withAcquireReleasedLocksConsistently(true)
.build());
try {
task.run();
} finally {
lockClient.releaseLock(ReleaseLockOptions.builder(lockItem)
.withBestEffort(true)
.build());
}
}
public void lockAndReconcileAccounts(final int max, final DeletedAccountReconciliationConsumer consumer) throws ChunkProcessingFailedException {
final List<LockItem> lockItems = new ArrayList<>();
final List<Pair<UUID, String>> reconciliationCandidates = deletedAccounts.listAccountsToReconcile(max).stream()
.filter(pair -> {
boolean lockAcquired = false;
try {
lockItems.add(lockClient.acquireLock(AcquireLockOptions.builder(pair.second())
.withAcquireReleasedLocksConsistently(true)
.withShouldSkipBlockingWait(true)
.build()));
lockAcquired = true;
} catch (final InterruptedException e) {
log.warn("Interrupted while acquiring lock for reconciliation", e);
} catch (final LockCurrentlyUnavailableException ignored) {
}
return lockAcquired;
})
.collect(Collectors.toList());
assert lockItems.size() == reconciliationCandidates.size();
// A deleted account's status may have changed in the time between getting a list of candidates and acquiring a lock
// on the candidate records. Now that we hold the lock, check which of the candidates still need to be reconciled.
final Set<String> numbersNeedingReconciliationAfterLock =
deletedAccounts.getAccountsNeedingReconciliation(reconciliationCandidates.stream()
.map(Pair::second)
.collect(Collectors.toList()));
final List<Pair<UUID, String>> accountsToReconcile = reconciliationCandidates.stream()
.filter(candidate -> numbersNeedingReconciliationAfterLock.contains(candidate.second()))
.collect(Collectors.toList());
try {
deletedAccounts.markReconciled(consumer.reconcile(accountsToReconcile));
} finally {
lockItems.forEach(lockItem -> lockClient.releaseLock(ReleaseLockOptions.builder(lockItem).withBestEffort(true).build()));
}
}
}

View File

@@ -0,0 +1,66 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest.User;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Pair;
public class DeletedAccountsTableCrawler extends ManagedPeriodicWork {
private static final Duration WORKER_TTL = Duration.ofMinutes(2);
private static final Duration RUN_INTERVAL = Duration.ofMinutes(15);
private static final String ACTIVE_WORKER_KEY = "deleted_accounts_crawler_cache_active_worker";
private static final int MAX_BATCH_SIZE = 5_000;
private static final String BATCH_SIZE_DISTRIBUTION_NAME = name(DeletedAccountsTableCrawler.class, "batchSize");
private final DeletedAccountsManager deletedAccountsManager;
private final List<DeletedAccountsDirectoryReconciler> reconcilers;
public DeletedAccountsTableCrawler(
final DeletedAccountsManager deletedAccountsManager,
final List<DeletedAccountsDirectoryReconciler> reconcilers,
final FaultTolerantRedisCluster cluster,
final ScheduledExecutorService executorService) throws IOException {
super(new ManagedPeriodicWorkLock(ACTIVE_WORKER_KEY, cluster), WORKER_TTL, RUN_INTERVAL, executorService);
this.deletedAccountsManager = deletedAccountsManager;
this.reconcilers = reconcilers;
}
@Override
public void doPeriodicWork() throws Exception {
deletedAccountsManager.lockAndReconcileAccounts(MAX_BATCH_SIZE, deletedAccounts -> {
final List<User> deletedUsers = deletedAccounts.stream()
.map(pair -> new User(pair.first(), pair.second()))
.collect(Collectors.toList());
for (DeletedAccountsDirectoryReconciler reconciler : reconcilers) {
reconciler.onCrawlChunk(deletedUsers);
}
final List<String> reconciledPhoneNumbers = deletedAccounts.stream()
.map(Pair::second)
.collect(Collectors.toList());
Metrics.summary(BATCH_SIZE_DISTRIBUTION_NAME).record(reconciledPhoneNumbers.size());
return reconciledPhoneNumbers;
});
}
}

View File

@@ -276,10 +276,13 @@ public class Device {
@JsonProperty
private boolean senderKey;
@JsonProperty
private boolean announcementGroup;
public DeviceCapabilities() {}
public DeviceCapabilities(boolean gv2, final boolean gv2_2, final boolean gv2_3, boolean storage, boolean transfer,
boolean gv1Migration, final boolean senderKey) {
boolean gv1Migration, final boolean senderKey, final boolean announcementGroup) {
this.gv2 = gv2;
this.gv2_2 = gv2_2;
this.gv2_3 = gv2_3;
@@ -287,6 +290,7 @@ public class Device {
this.transfer = transfer;
this.gv1Migration = gv1Migration;
this.senderKey = senderKey;
this.announcementGroup = announcementGroup;
}
public boolean isGv2() {
@@ -316,5 +320,9 @@ public class Device {
public boolean isSenderKey() {
return senderKey;
}
public boolean isAnnouncementGroup() {
return announcementGroup;
}
}
}

View File

@@ -4,24 +4,23 @@
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationResponse;
import org.whispersystems.textsecuregcm.util.Constants;
import javax.ws.rs.ProcessingException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
import javax.ws.rs.ProcessingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationRequest;
import org.whispersystems.textsecuregcm.entities.DirectoryReconciliationResponse;
import org.whispersystems.textsecuregcm.util.Constants;
public class DirectoryReconciler extends AccountDatabaseCrawlerListener {
@@ -32,6 +31,8 @@ public class DirectoryReconciler extends AccountDatabaseCrawlerListener {
private final Timer sendChunkTimer;
private final Meter sendChunkErrorMeter;
private boolean useV3Endpoints;
public DirectoryReconciler(String name, DirectoryReconciliationClient reconciliationClient) {
this.reconciliationClient = reconciliationClient;
sendChunkTimer = metricRegistry.timer(name(DirectoryReconciler.class, name, "sendChunk"));
@@ -45,6 +46,10 @@ public class DirectoryReconciler extends AccountDatabaseCrawlerListener {
public void onCrawlEnd(Optional<UUID> fromUuid) {
DirectoryReconciliationRequest request = new DirectoryReconciliationRequest(fromUuid.orElse(null), null, Collections.emptyList());
sendChunk(request);
if (useV3Endpoints) {
reconciliationClient.complete();
}
}
@Override
@@ -76,7 +81,12 @@ public class DirectoryReconciler extends AccountDatabaseCrawlerListener {
private DirectoryReconciliationResponse sendChunk(DirectoryReconciliationRequest request) {
try (Timer.Context timer = sendChunkTimer.time()) {
DirectoryReconciliationResponse response = reconciliationClient.sendChunk(request);
DirectoryReconciliationResponse response;
if (useV3Endpoints) {
response = reconciliationClient.sendChunkV3(request);
} else {
response = reconciliationClient.sendChunk(request);
}
if (response.getStatus() != DirectoryReconciliationResponse.Status.OK) {
sendChunkErrorMeter.mark();
logger.warn("reconciliation error: " + response.getStatus());
@@ -89,4 +99,7 @@ public class DirectoryReconciler extends AccountDatabaseCrawlerListener {
}
}
public void setUseV3Endpoints(final boolean useV3Endpoints) {
this.useV3Endpoints = useV3Endpoints;
}
}

View File

@@ -4,7 +4,16 @@
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.SharedMetricRegistries;
import java.security.KeyStore;
import java.security.cert.CertificateException;
import javax.net.ssl.SSLContext;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import org.glassfish.jersey.SslConfigurator;
import org.glassfish.jersey.client.authentication.HttpAuthenticationFeature;
import org.whispersystems.textsecuregcm.configuration.DirectoryServerConfiguration;
@@ -14,16 +23,6 @@ import org.whispersystems.textsecuregcm.util.CertificateExpirationGauge;
import org.whispersystems.textsecuregcm.util.CertificateUtil;
import org.whispersystems.textsecuregcm.util.Constants;
import javax.net.ssl.SSLContext;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import java.security.KeyStore;
import java.security.cert.CertificateException;
import static com.codahale.metrics.MetricRegistry.name;
public class DirectoryReconciliationClient {
private final String replicationUrl;
@@ -47,6 +46,27 @@ public class DirectoryReconciliationClient {
.put(Entity.json(request), DirectoryReconciliationResponse.class);
}
public DirectoryReconciliationResponse sendChunkV3(DirectoryReconciliationRequest request) {
return client.target(replicationUrl)
.path("/v3/directory/exists")
.request(MediaType.APPLICATION_JSON_TYPE)
.put(Entity.json(request), DirectoryReconciliationResponse.class);
}
public DirectoryReconciliationResponse delete(DirectoryReconciliationRequest request) {
return client.target(replicationUrl)
.path("/v3/directory/deletes")
.request(MediaType.APPLICATION_JSON_TYPE)
.put(Entity.json(request), DirectoryReconciliationResponse.class);
}
public DirectoryReconciliationResponse complete() {
return client.target(replicationUrl)
.path("/v3/directory/complete")
.request(MediaType.APPLICATION_JSON_TYPE)
.post(null, DirectoryReconciliationResponse.class);
}
private static Client initializeClient(DirectoryServerConfiguration directoryServerConfiguration)
throws CertificateException
{

View File

@@ -1,31 +1,29 @@
package org.whispersystems.textsecuregcm.storage;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.services.appconfig.AmazonAppConfig;
import com.amazonaws.services.appconfig.AmazonAppConfigClient;
import com.amazonaws.services.appconfig.model.GetConfigurationRequest;
import com.amazonaws.services.appconfig.model.GetConfigurationResult;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.google.common.annotations.VisibleForTesting;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import javax.validation.ConstraintViolation;
import javax.validation.Validation;
import javax.validation.Validator;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.ConstraintViolation;
import javax.validation.Validation;
import javax.validation.Validator;
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.services.appconfig.AppConfigClient;
import software.amazon.awssdk.services.appconfig.model.GetConfigurationRequest;
import software.amazon.awssdk.services.appconfig.model.GetConfigurationResponse;
public class DynamicConfigurationManager {
@@ -33,11 +31,11 @@ public class DynamicConfigurationManager {
private final String environment;
private final String configurationName;
private final String clientId;
private final AmazonAppConfig appConfigClient;
private final AppConfigClient appConfigClient;
private final AtomicReference<DynamicConfiguration> configuration = new AtomicReference<>();
private GetConfigurationResult lastConfigResult;
private GetConfigurationResponse lastConfigResult;
private boolean initialized = false;
@@ -50,15 +48,20 @@ public class DynamicConfigurationManager {
private static final Logger logger = LoggerFactory.getLogger(DynamicConfigurationManager.class);
public DynamicConfigurationManager(String application, String environment, String configurationName) {
this(AmazonAppConfigClient.builder()
.withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(10000).withRequestTimeout(10000))
.withCredentials(InstanceProfileCredentialsProvider.getInstance())
.build(),
application, environment, configurationName, UUID.randomUUID().toString());
this(AppConfigClient.builder()
.overrideConfiguration(ClientOverrideConfiguration.builder()
.apiCallTimeout(Duration.ofMillis(10000))
.apiCallAttemptTimeout(Duration.ofMillis(10000)).build())
/* To specify specific credential provider:
https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials.html
*/
.build(),
application, environment, configurationName, UUID.randomUUID().toString());
}
@VisibleForTesting
public DynamicConfigurationManager(AmazonAppConfig appConfigClient, String application, String environment, String configurationName, String clientId) {
public DynamicConfigurationManager(AppConfigClient appConfigClient, String application, String environment,
String configurationName, String clientId) {
this.appConfigClient = appConfigClient;
this.application = application;
this.environment = environment;
@@ -99,21 +102,24 @@ public class DynamicConfigurationManager {
}
private Optional<DynamicConfiguration> retrieveDynamicConfiguration() throws JsonProcessingException {
final String previousVersion = lastConfigResult != null ? lastConfigResult.getConfigurationVersion() : null;
final String previousVersion = lastConfigResult != null ? lastConfigResult.configurationVersion() : null;
lastConfigResult = appConfigClient.getConfiguration(new GetConfigurationRequest().withApplication(application)
.withEnvironment(environment)
.withConfiguration(configurationName)
.withClientId(clientId)
.withClientConfigurationVersion(previousVersion));
lastConfigResult = appConfigClient.getConfiguration(GetConfigurationRequest.builder()
.application(application)
.environment(environment)
.configuration(configurationName)
.clientId(clientId)
.clientConfigurationVersion(previousVersion)
.build());
final Optional<DynamicConfiguration> maybeDynamicConfiguration;
if (!StringUtils.equals(lastConfigResult.getConfigurationVersion(), previousVersion)) {
logger.info("Received new config version: {}", lastConfigResult.getConfigurationVersion());
if (!StringUtils.equals(lastConfigResult.configurationVersion(), previousVersion)) {
logger.info("Received new config version: {}", lastConfigResult.configurationVersion());
maybeDynamicConfiguration =
parseConfiguration(StandardCharsets.UTF_8.decode(lastConfigResult.getContent().asReadOnlyBuffer()).toString());
parseConfiguration(
StandardCharsets.UTF_8.decode(lastConfigResult.content().asByteBuffer().asReadOnlyBuffer()).toString());
} else {
// No change since last version
maybeDynamicConfiguration = Optional.empty();
@@ -123,7 +129,8 @@ public class DynamicConfigurationManager {
}
@VisibleForTesting
public static Optional<DynamicConfiguration> parseConfiguration(final String configurationYaml) throws JsonProcessingException {
public static Optional<DynamicConfiguration> parseConfiguration(final String configurationYaml)
throws JsonProcessingException {
final DynamicConfiguration configuration = OBJECT_MAPPER.readValue(configurationYaml, DynamicConfiguration.class);
final Set<ConstraintViolation<DynamicConfiguration>> violations = VALIDATOR.validate(configuration);

View File

@@ -168,8 +168,7 @@ public class KeysDynamoDb extends AbstractDynamoDbStore {
});
}
@VisibleForTesting
void delete(final Account account, final long deviceId) {
public void delete(final Account account, final long deviceId) {
DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> {
final QueryRequest queryRequest = QueryRequest.builder()
.tableName(tableName)

View File

@@ -0,0 +1,118 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Metrics;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.Util;
public abstract class ManagedPeriodicWork implements Managed {
private final Logger logger = LoggerFactory.getLogger(getClass());
private static final String FUTURE_DONE_GAUGE_NAME = "futureDone";
private final ManagedPeriodicWorkLock lock;
private final Duration workerTtl;
private final Duration runInterval;
private final String workerId;
private final ScheduledExecutorService executorService;
private Duration sleepDurationAfterUnexpectedException = Duration.ofSeconds(10);
@Nullable
private ScheduledFuture<?> scheduledFuture;
private AtomicReference<CompletableFuture<Void>> activeExecutionFuture = new AtomicReference<>(CompletableFuture.completedFuture(null));
public ManagedPeriodicWork(final ManagedPeriodicWorkLock lock, final Duration workerTtl, final Duration runInterval, final ScheduledExecutorService scheduledExecutorService) {
this.lock = lock;
this.workerTtl = workerTtl;
this.runInterval = runInterval;
this.workerId = UUID.randomUUID().toString();
this.executorService = scheduledExecutorService;
}
abstract protected void doPeriodicWork() throws Exception;
@Override
public synchronized void start() throws Exception {
if (scheduledFuture != null) {
return;
}
scheduledFuture = executorService.scheduleAtFixedRate(() -> {
try {
execute();
} catch (final Exception e) {
logger.warn("Error in execution", e);
// wait a bit, in case the error is caused by external instability
Util.sleep(sleepDurationAfterUnexpectedException.toMillis());
}
}, 0, runInterval.getSeconds(), TimeUnit.SECONDS);
Metrics.gauge(name(getClass(), FUTURE_DONE_GAUGE_NAME), scheduledFuture, future -> future.isDone() ? 1 : 0);
}
@Override
public synchronized void stop() throws Exception {
if (scheduledFuture != null) {
scheduledFuture.cancel(false);
try {
activeExecutionFuture.get().join();
} catch (final Exception e) {
logger.warn("error while awaiting final execution", e);
}
}
}
public void setSleepDurationAfterUnexpectedException(final Duration sleepDurationAfterUnexpectedException) {
this.sleepDurationAfterUnexpectedException = sleepDurationAfterUnexpectedException;
}
private void execute() {
if (lock.claimActiveWork(workerId, workerTtl)) {
try {
activeExecutionFuture.set(new CompletableFuture<>());
logger.info("Starting execution");
doPeriodicWork();
logger.info("Execution complete");
} catch (final Exception e) {
logger.warn("Periodic work failed", e);
// wait a bit, in case the error is caused by external instability
Util.sleep(sleepDurationAfterUnexpectedException.toMillis());
} finally {
try {
lock.releaseActiveWork(workerId);
} finally {
activeExecutionFuture.get().complete(null);
}
}
}
}
}

View File

@@ -0,0 +1,35 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.SetArgs;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
public class ManagedPeriodicWorkLock {
private final String activeWorkerKey;
private final FaultTolerantRedisCluster cacheCluster;
private final ClusterLuaScript unlockClusterScript;
public ManagedPeriodicWorkLock(final String activeWorkerKey, final FaultTolerantRedisCluster cacheCluster) throws IOException {
this.activeWorkerKey = activeWorkerKey;
this.cacheCluster = cacheCluster;
this.unlockClusterScript = ClusterLuaScript.fromResource(cacheCluster, "lua/periodic_worker/unlock.lua", ScriptOutputType.INTEGER);
}
public boolean claimActiveWork(String workerId, Duration ttl) {
return "OK".equals(cacheCluster.withCluster(connection -> connection.sync().set(activeWorkerKey, workerId, SetArgs.Builder.nx().px(ttl.toMillis()))));
}
public void releaseActiveWork(String workerId) {
unlockClusterScript.execute(List.of(activeWorkerKey), List.of(workerId));
}
}

View File

@@ -5,12 +5,15 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.DeleteRequest;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
import software.amazon.awssdk.services.dynamodb.model.WriteRequest;
public class MigrationRetryAccounts extends AbstractDynamoDbStore {
@@ -58,4 +61,16 @@ public class MigrationRetryAccounts extends AbstractDynamoDbStore {
return Map.of(KEY_UUID, AttributeValues.fromUUID(uuid));
}
public void delete(final List<UUID> uuidsToDelete) {
writeInBatches(uuidsToDelete, (uuids -> {
final List<WriteRequest> deletes = uuids.stream()
.map(uuid -> WriteRequest.builder().deleteRequest(
DeleteRequest.builder().key(Map.of(KEY_UUID, AttributeValues.fromUUID(uuid))).build()).build())
.collect(Collectors.toList());
executeTableWriteItemsUntilComplete(Map.of(tableName, deletes));
}));
}
}

View File

@@ -0,0 +1,88 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class MigrationRetryAccountsTableCrawler extends ManagedPeriodicWork {
private static final Logger logger = LoggerFactory.getLogger(MigrationRetryAccountsTableCrawler.class);
private static final Duration WORKER_TTL = Duration.ofMinutes(2);
private static final Duration RUN_INTERVAL = Duration.ofMinutes(15);
private static final String ACTIVE_WORKER_KEY = "migration_retry_accounts_crawler_cache_active_worker";
private static final int MAX_BATCH_SIZE = 5_000;
private static final Counter MIGRATED_COUNTER = Metrics.counter(name(MigrationRetryAccountsTableCrawler.class, "migrated"));
private static final Counter ERROR_COUNTER = Metrics.counter(name(MigrationRetryAccountsTableCrawler.class, "error"));
private static final Counter TOTAL_COUNTER = Metrics.counter(name(MigrationRetryAccountsTableCrawler.class, "total"));
private final MigrationRetryAccounts retryAccounts;
private final AccountsManager accountsManager;
private final AccountsDynamoDb accountsDynamoDb;
public MigrationRetryAccountsTableCrawler(
final MigrationRetryAccounts retryAccounts,
final AccountsManager accountsManager,
final AccountsDynamoDb accountsDynamoDb,
final FaultTolerantRedisCluster cluster,
final ScheduledExecutorService executorService) throws IOException {
super(new ManagedPeriodicWorkLock(ACTIVE_WORKER_KEY, cluster), WORKER_TTL, RUN_INTERVAL, executorService);
this.retryAccounts = retryAccounts;
this.accountsManager = accountsManager;
this.accountsDynamoDb = accountsDynamoDb;
}
@Override
public void doPeriodicWork() {
final List<UUID> uuids = this.retryAccounts.getUuids(MAX_BATCH_SIZE);
final List<UUID> processedUuids = new ArrayList<>(uuids.size());
try {
for (UUID uuid : uuids) {
try {
final Optional<Account> maybeDynamoAccount = accountsDynamoDb.get(uuid);
if (maybeDynamoAccount.isEmpty()) {
accountsManager.get(uuid).ifPresent(account -> {
accountsDynamoDb.migrate(account);
MIGRATED_COUNTER.increment();
});
}
processedUuids.add(uuid);
TOTAL_COUNTER.increment();
} catch (final Exception e) {
ERROR_COUNTER.increment();
logger.warn("Failed to migrate account");
}
}
} finally {
this.retryAccounts.delete(processedUuids);
}
}
}

View File

@@ -0,0 +1,10 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class OptimisticLockRetryLimitExceededException extends RuntimeException {
}

View File

@@ -1,86 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import java.util.Optional;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper;
import org.whispersystems.textsecuregcm.util.Constants;
public class PendingAccounts {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer insertTimer = metricRegistry.timer(name(PendingAccounts.class, "insert" ));
private final Timer getCodeForNumberTimer = metricRegistry.timer(name(PendingAccounts.class, "getCodeForNumber"));
private final Timer removeTimer = metricRegistry.timer(name(PendingAccounts.class, "remove" ));
private final Timer vacuumTimer = metricRegistry.timer(name(PendingAccounts.class, "vacuum" ));
private final FaultTolerantDatabase database;
public PendingAccounts(FaultTolerantDatabase database) {
this.database = database;
this.database.getDatabase().registerRowMapper(new StoredVerificationCodeRowMapper());
}
@VisibleForTesting
public void insert (String number, String verificationCode, long timestamp, String pushCode) {
insert(number, verificationCode, timestamp, pushCode, null);
}
public void insert(String number, String verificationCode, long timestamp, String pushCode, String twilioVerificationSid) {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = insertTimer.time()) {
handle.createUpdate("INSERT INTO pending_accounts (number, verification_code, timestamp, push_code, twilio_verification_sid) " +
"VALUES (:number, :verification_code, :timestamp, :push_code, :twilio_verification_sid) " +
"ON CONFLICT(number) DO UPDATE " +
"SET verification_code = EXCLUDED.verification_code, timestamp = EXCLUDED.timestamp, push_code = EXCLUDED.push_code, twilio_verification_sid = EXCLUDED.twilio_verification_sid")
.bind("verification_code", verificationCode)
.bind("timestamp", timestamp)
.bind("number", number)
.bind("push_code", pushCode)
.bind("twilio_verification_sid", twilioVerificationSid)
.execute();
}
}));
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return database.with(jdbi ->jdbi.withHandle(handle -> {
try (Timer.Context ignored = getCodeForNumberTimer.time()) {
return handle.createQuery("SELECT verification_code, timestamp, push_code, twilio_verification_sid FROM pending_accounts WHERE number = :number")
.bind("number", number)
.mapTo(StoredVerificationCode.class)
.findFirst();
}
}));
}
public void remove(String number) {
database.use(jdbi-> jdbi.useHandle(handle -> {
try (Timer.Context ignored = removeTimer.time()) {
handle.createUpdate("DELETE FROM pending_accounts WHERE number = :number")
.bind("number", number)
.execute();
}
}));
}
public void vacuum() {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context ignored = vacuumTimer.time()) {
handle.execute("VACUUM pending_accounts");
}
}));
}
}

View File

@@ -1,82 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import java.util.Optional;
public class PendingAccountsManager {
private final Logger logger = LoggerFactory.getLogger(PendingAccountsManager.class);
private static final String CACHE_PREFIX = "pending_account2::";
private final PendingAccounts pendingAccounts;
private final FaultTolerantRedisCluster cacheCluster;
private final ObjectMapper mapper;
public PendingAccountsManager(PendingAccounts pendingAccounts, FaultTolerantRedisCluster cacheCluster)
{
this.pendingAccounts = pendingAccounts;
this.cacheCluster = cacheCluster;
this.mapper = SystemMapper.getMapper();
}
public void store(String number, StoredVerificationCode code) {
memcacheSet(number, code);
pendingAccounts.insert(number, code.getCode(), code.getTimestamp(), code.getPushCode(),
code.getTwilioVerificationSid().orElse(null));
}
public void remove(String number) {
memcacheDelete(number);
pendingAccounts.remove(number);
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
Optional<StoredVerificationCode> code = memcacheGet(number);
if (!code.isPresent()) {
code = pendingAccounts.getCodeForNumber(number);
code.ifPresent(storedVerificationCode -> memcacheSet(number, storedVerificationCode));
}
return code;
}
private void memcacheSet(String number, StoredVerificationCode code) {
try {
final String verificationCodeJson = mapper.writeValueAsString(code);
cacheCluster.useCluster(connection -> connection.sync().set(CACHE_PREFIX + number, verificationCodeJson));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private Optional<StoredVerificationCode> memcacheGet(String number) {
try {
final String json = cacheCluster.withCluster(connection -> connection.sync().get(CACHE_PREFIX + number));
if (json == null) return Optional.empty();
else return Optional.of(mapper.readValue(json, StoredVerificationCode.class));
} catch (IOException e) {
logger.warn("Error deserializing value...", e);
return Optional.empty();
}
}
private void memcacheDelete(String number) {
cacheCluster.useCluster(connection -> connection.sync().del(CACHE_PREFIX + number));
}
}

View File

@@ -1,65 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.storage.mappers.StoredVerificationCodeRowMapper;
import org.whispersystems.textsecuregcm.util.Constants;
public class PendingDevices {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Timer insertTimer = metricRegistry.timer(name(PendingDevices.class, "insert" ));
private final Timer getCodeForNumberTimer = metricRegistry.timer(name(PendingDevices.class, "getcodeForNumber"));
private final Timer removeTimer = metricRegistry.timer(name(PendingDevices.class, "remove" ));
private final FaultTolerantDatabase database;
public PendingDevices(FaultTolerantDatabase database) {
this.database = database;
this.database.getDatabase().registerRowMapper(new StoredVerificationCodeRowMapper());
}
public void insert(String number, String verificationCode, long timestamp) {
database.use(jdbi ->jdbi.useHandle(handle -> {
try (Timer.Context timer = insertTimer.time()) {
handle.createUpdate("WITH upsert AS (UPDATE pending_devices SET verification_code = :verification_code, timestamp = :timestamp WHERE number = :number RETURNING *) " +
"INSERT INTO pending_devices (number, verification_code, timestamp) SELECT :number, :verification_code, :timestamp WHERE NOT EXISTS (SELECT * FROM upsert)")
.bind("number", number)
.bind("verification_code", verificationCode)
.bind("timestamp", timestamp)
.execute();
}
}));
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return database.with(jdbi -> jdbi.withHandle(handle -> {
try (Timer.Context timer = getCodeForNumberTimer.time()) {
return handle.createQuery("SELECT verification_code, timestamp, NULL as push_code, NULL as twilio_verification_sid FROM pending_devices WHERE number = :number")
.bind("number", number)
.mapTo(StoredVerificationCode.class)
.findFirst();
}
}));
}
public void remove(String number) {
database.use(jdbi -> jdbi.useHandle(handle -> {
try (Timer.Context timer = removeTimer.time()) {
handle.createUpdate("DELETE FROM pending_devices WHERE number = :number")
.bind("number", number)
.execute();
}
}));
}
}

View File

@@ -1,81 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import java.util.Optional;
public class PendingDevicesManager {
private final Logger logger = LoggerFactory.getLogger(PendingDevicesManager.class);
private static final String CACHE_PREFIX = "pending_devices2::";
private final PendingDevices pendingDevices;
private final FaultTolerantRedisCluster cacheCluster;
private final ObjectMapper mapper;
public PendingDevicesManager(PendingDevices pendingDevices, FaultTolerantRedisCluster cacheCluster) {
this.pendingDevices = pendingDevices;
this.cacheCluster = cacheCluster;
this.mapper = SystemMapper.getMapper();
}
public void store(String number, StoredVerificationCode code) {
memcacheSet(number, code);
pendingDevices.insert(number, code.getCode(), code.getTimestamp());
}
public void remove(String number) {
memcacheDelete(number);
pendingDevices.remove(number);
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
Optional<StoredVerificationCode> code = memcacheGet(number);
if (!code.isPresent()) {
code = pendingDevices.getCodeForNumber(number);
code.ifPresent(storedVerificationCode -> memcacheSet(number, storedVerificationCode));
}
return code;
}
private void memcacheSet(String number, StoredVerificationCode code) {
try {
final String verificationCodeJson = mapper.writeValueAsString(code);
cacheCluster.useCluster(connection -> connection.sync().set(CACHE_PREFIX + number, verificationCodeJson));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private Optional<StoredVerificationCode> memcacheGet(String number) {
try {
final String json = cacheCluster.withCluster(connection -> connection.sync().get(CACHE_PREFIX + number));
if (json == null) return Optional.empty();
else return Optional.of(mapper.readValue(json, StoredVerificationCode.class));
} catch (IOException e) {
logger.warn("Could not parse pending device stored verification json");
return Optional.empty();
}
}
private void memcacheDelete(String number) {
cacheCluster.useCluster(connection -> connection.sync().del(CACHE_PREFIX + number));
}
}

View File

@@ -5,20 +5,20 @@
package org.whispersystems.textsecuregcm.storage;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
@@ -27,11 +27,9 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
private final Meter recovered = metricRegistry.meter(name(getClass(), "unregistered", "recovered"));
private final AccountsManager accountsManager;
private final DirectoryQueue directoryQueue;
public PushFeedbackProcessor(AccountsManager accountsManager, DirectoryQueue directoryQueue) {
public PushFeedbackProcessor(AccountsManager accountsManager) {
this.accountsManager = accountsManager;
this.directoryQueue = directoryQueue;
}
@Override
@@ -42,47 +40,58 @@ public class PushFeedbackProcessor extends AccountDatabaseCrawlerListener {
@Override
protected void onCrawlChunk(Optional<UUID> fromUuid, List<Account> chunkAccounts) {
final List<Account> directoryUpdateAccounts = new ArrayList<>();
for (Account account : chunkAccounts) {
boolean update = false;
for (Device device : account.getDevices()) {
if (device.getUninstalledFeedbackTimestamp() != 0 &&
device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis())
{
if (device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis()) {
if (!Util.isEmpty(device.getApnId())) {
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
} else if (!Util.isEmpty(device.getGcmId())) {
device.setUserAgent("OWA");
}
device.setGcmId(null);
device.setApnId(null);
device.setVoipApnId(null);
device.setFetchesMessages(false);
final Set<Device> devices = account.getDevices();
for (Device device : devices) {
if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) {
expired.mark();
} else {
device.setUninstalledFeedbackTimestamp(0);
recovered.mark();
}
update = true;
}
}
if (update) {
accountsManager.update(account);
directoryUpdateAccounts.add(account);
// fetch a new version, since the chunk is shared and implicitly read-only
accountsManager.get(account.getUuid()).ifPresent(accountToUpdate -> {
accountsManager.update(accountToUpdate, a -> {
for (Device device : a.getDevices()) {
if (deviceNeedsUpdate(device)) {
if (deviceExpired(device)) {
if (!Util.isEmpty(device.getApnId())) {
if (device.getId() == 1) {
device.setUserAgent("OWI");
} else {
device.setUserAgent("OWP");
}
} else if (!Util.isEmpty(device.getGcmId())) {
device.setUserAgent("OWA");
}
device.setGcmId(null);
device.setApnId(null);
device.setVoipApnId(null);
device.setFetchesMessages(false);
} else {
device.setUninstalledFeedbackTimestamp(0);
}
}
}
});
});
}
}
}
if (!directoryUpdateAccounts.isEmpty()) {
directoryQueue.refreshRegisteredUsers(directoryUpdateAccounts);
}
private boolean deviceNeedsUpdate(final Device device) {
return device.getUninstalledFeedbackTimestamp() != 0 &&
device.getUninstalledFeedbackTimestamp() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis();
}
private boolean deviceExpired(final Device device) {
return device.getLastSeen() + TimeUnit.DAYS.toMillis(2) <= Util.todayInMillis();
}
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
public class StoredVerificationCodeManager {
private final VerificationCodeStore verificationCodeStore;
public StoredVerificationCodeManager(final VerificationCodeStore verificationCodeStore) {
this.verificationCodeStore = verificationCodeStore;
}
public void store(String number, StoredVerificationCode code) {
verificationCodeStore.insert(number, code);
}
public void remove(String number) {
verificationCodeStore.remove(number);
}
public Optional<StoredVerificationCode> getCodeForNumber(String number) {
return verificationCodeStore.findForNumber(number);
}
}

View File

@@ -0,0 +1,103 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemRequest;
import software.amazon.awssdk.services.dynamodb.model.GetItemResponse;
import software.amazon.awssdk.services.dynamodb.model.PutItemRequest;
import java.time.Instant;
import java.util.Map;
import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name;
public class VerificationCodeStore {
private final DynamoDbClient dynamoDbClient;
private final String tableName;
private final Timer insertTimer;
private final Timer getTimer;
private final Timer removeTimer;
@VisibleForTesting
static final String KEY_E164 = "P";
private static final String ATTR_STORED_CODE = "C";
private static final String ATTR_TTL = "E";
private static final Logger log = LoggerFactory.getLogger(VerificationCodeStore.class);
public VerificationCodeStore(final DynamoDbClient dynamoDbClient, final String tableName) {
this.dynamoDbClient = dynamoDbClient;
this.tableName = tableName;
this.insertTimer = Metrics.timer(name(getClass(), "insert"), "table", tableName);
this.getTimer = Metrics.timer(name(getClass(), "get"), "table", tableName);
this.removeTimer = Metrics.timer(name(getClass(), "remove"), "table", tableName);
}
public void insert(final String number, final StoredVerificationCode verificationCode) {
insertTimer.record(() -> {
try {
dynamoDbClient.putItem(PutItemRequest.builder()
.tableName(tableName)
.item(Map.of(
KEY_E164, AttributeValues.fromString(number),
ATTR_STORED_CODE, AttributeValues.fromString(SystemMapper.getMapper().writeValueAsString(verificationCode)),
ATTR_TTL, AttributeValues.fromLong(getExpirationTimestamp(verificationCode))))
.build());
} catch (final JsonProcessingException e) {
// This should never happen when writing directly to a string except in cases of serious misconfiguration, which
// would be caught by tests.
throw new AssertionError(e);
}
});
}
private long getExpirationTimestamp(final StoredVerificationCode storedVerificationCode) {
return Instant.ofEpochMilli(storedVerificationCode.getTimestamp()).plus(StoredVerificationCode.EXPIRATION).getEpochSecond();
}
public Optional<StoredVerificationCode> findForNumber(final String number) {
return getTimer.record(() -> {
final GetItemResponse response = dynamoDbClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.consistentRead(true)
.key(Map.of(KEY_E164, AttributeValues.fromString(number)))
.build());
try {
return response.hasItem()
? Optional.of(SystemMapper.getMapper().readValue(response.item().get(ATTR_STORED_CODE).s(), StoredVerificationCode.class))
: Optional.empty();
} catch (final JsonProcessingException e) {
log.error("Failed to parse stored verification code", e);
return Optional.empty();
}
});
}
public void remove(final String number) {
removeTimer.record(() -> {
dynamoDbClient.deleteItem(DeleteItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_E164, AttributeValues.fromString(number)))
.build());
});
}
}

View File

@@ -27,6 +27,7 @@ public class AccountRowMapper implements RowMapper<Account> {
Account account = mapper.readValue(resultSet.getString(Accounts.DATA), Account.class);
account.setNumber(resultSet.getString(Accounts.NUMBER));
account.setUuid(UUID.fromString(resultSet.getString(Accounts.UID)));
account.setVersion(resultSet.getInt(Accounts.VERSION));
return account;
} catch (IOException e) {
throw new SQLException(e);

View File

@@ -1,24 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage.mappers;
import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import java.sql.ResultSet;
import java.sql.SQLException;
public class StoredVerificationCodeRowMapper implements RowMapper<StoredVerificationCode> {
@Override
public StoredVerificationCode map(ResultSet resultSet, StatementContext ctx) throws SQLException {
return new StoredVerificationCode(resultSet.getString("verification_code"),
resultSet.getLong("timestamp"),
resultSet.getString("push_code"),
resultSet.getString("twilio_verification_sid"));
}
}

View File

@@ -11,19 +11,59 @@ import com.codahale.metrics.MetricRegistry;
import static com.codahale.metrics.MetricRegistry.name;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.github.resilience4j.retry.Retry;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
public class CircuitBreakerUtil {
private static final String CIRCUIT_BREAKER_CALL_COUNTER_NAME = name(CircuitBreakerUtil.class, "breaker", "call");
private static final String CIRCUIT_BREAKER_STATE_GAUGE_NAME = name(CircuitBreakerUtil.class, "breaker", "state");
private static final String RETRY_CALL_COUNTER_NAME = name(CircuitBreakerUtil.class, "retry", "call");
private static final String NAME_TAG_NAME = "name";
private static final String OUTCOME_TAG_NAME = "outcome";
public static void registerMetrics(MetricRegistry metricRegistry, CircuitBreaker circuitBreaker, Class<?> clazz) {
Meter successMeter = metricRegistry.meter(name(clazz, circuitBreaker.getName(), "success" ));
Meter failureMeter = metricRegistry.meter(name(clazz, circuitBreaker.getName(), "failure" ));
Meter unpermittedMeter = metricRegistry.meter(name(clazz, circuitBreaker.getName(), "unpermitted"));
final String breakerName = clazz.getSimpleName() + "/" + circuitBreaker.getName();
final Counter successCounter = Metrics.counter(CIRCUIT_BREAKER_CALL_COUNTER_NAME,
NAME_TAG_NAME, breakerName,
OUTCOME_TAG_NAME, "success");
final Counter failureCounter = Metrics.counter(CIRCUIT_BREAKER_CALL_COUNTER_NAME,
NAME_TAG_NAME, breakerName,
OUTCOME_TAG_NAME, "failure");
final Counter unpermittedCounter = Metrics.counter(CIRCUIT_BREAKER_CALL_COUNTER_NAME,
NAME_TAG_NAME, breakerName,
OUTCOME_TAG_NAME, "unpermitted");
circuitBreaker.getEventPublisher().onSuccess(event -> {
successMeter.mark();
successCounter.increment();
});
circuitBreaker.getEventPublisher().onError(event -> {
failureMeter.mark();
failureCounter.increment();
});
circuitBreaker.getEventPublisher().onCallNotPermitted(event -> {
unpermittedMeter.mark();
unpermittedCounter.increment();
});
metricRegistry.gauge(name(clazz, circuitBreaker.getName(), "state"), () -> ()-> circuitBreaker.getState().getOrder());
circuitBreaker.getEventPublisher().onSuccess(event -> successMeter.mark());
circuitBreaker.getEventPublisher().onError(event -> failureMeter.mark());
circuitBreaker.getEventPublisher().onCallNotPermitted(event -> unpermittedMeter.mark());
Metrics.gauge(CIRCUIT_BREAKER_STATE_GAUGE_NAME,
Tags.of(Tag.of(NAME_TAG_NAME, circuitBreaker.getName())),
circuitBreaker, breaker -> breaker.getState().getOrder());
}
public static void registerMetrics(MetricRegistry metricRegistry, Retry retry, Class<?> clazz) {
@@ -32,10 +72,43 @@ public class CircuitBreakerUtil {
Meter errorMeter = metricRegistry.meter(name(clazz, retry.getName(), "error" ));
Meter ignoredErrorMeter = metricRegistry.meter(name(clazz, retry.getName(), "ignored_error"));
retry.getEventPublisher().onSuccess(event -> successMeter.mark());
retry.getEventPublisher().onRetry(event -> retryMeter.mark());
retry.getEventPublisher().onError(event -> errorMeter.mark());
retry.getEventPublisher().onIgnoredError(event -> ignoredErrorMeter.mark());
final String retryName = clazz.getSimpleName() + "/" + retry.getName();
final Counter successCounter = Metrics.counter(RETRY_CALL_COUNTER_NAME,
NAME_TAG_NAME, retryName,
OUTCOME_TAG_NAME, "success");
final Counter retryCounter = Metrics.counter(RETRY_CALL_COUNTER_NAME,
NAME_TAG_NAME, retryName,
OUTCOME_TAG_NAME, "retry");
final Counter errorCounter = Metrics.counter(RETRY_CALL_COUNTER_NAME,
NAME_TAG_NAME, retryName,
OUTCOME_TAG_NAME, "error");
final Counter ignoredErrorCounter = Metrics.counter(RETRY_CALL_COUNTER_NAME,
NAME_TAG_NAME, retryName,
OUTCOME_TAG_NAME, "ignored_error");
retry.getEventPublisher().onSuccess(event -> {
successMeter.mark();
successCounter.increment();
});
retry.getEventPublisher().onRetry(event -> {
retryMeter.mark();
retryCounter.increment();
});
retry.getEventPublisher().onError(event -> {
errorMeter.mark();
errorCounter.increment();
});
retry.getEventPublisher().onIgnoredError(event -> {
ignoredErrorMeter.mark();
ignoredErrorCounter.increment();
});
}
}

View File

@@ -0,0 +1,26 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Locale;
public class HostnameUtil {
private static final Logger log = LoggerFactory.getLogger(HostnameUtil.class);
public static String getLocalHostname() {
try {
return InetAddress.getLocalHost().getHostName().toLowerCase(Locale.US);
} catch (final UnknownHostException e) {
log.warn("Failed to get hostname", e);
return "unknown";
}
}
}

View File

@@ -0,0 +1,62 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util.logging;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.jersey.errors.LoggingExceptionMapper;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.Context;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class LoggingUnhandledExceptionMapper extends LoggingExceptionMapper<Throwable> {
@Context
private HttpServletRequest request;
@Context
private ExtendedUriInfo uriInfo;
public LoggingUnhandledExceptionMapper() {
super();
}
@VisibleForTesting
LoggingUnhandledExceptionMapper(final Logger logger) {
super(logger);
}
@Override
protected String formatLogMessage(final long id, final Throwable exception) {
String requestMethod = "unknown method";
String userAgent = "missing";
String requestPath = "/{unknown path}";
try {
// request and uriInfo shouldnt be `null`, but it is technically possible
requestMethod = request.getMethod();
requestPath = UriInfoUtil.getPathTemplate(uriInfo);
userAgent = request.getHeader("user-agent");
// streamline the user-agent if it is recognized
final UserAgent ua = UserAgentUtil.parseUserAgentString(userAgent);
userAgent = String.format("%s %s", ua.getPlatform(), ua.getVersion());
} catch (final UnrecognizedUserAgentException ignored) {
} catch (final Exception e) {
logger.warn("Unexpected exception getting request details", e);
}
return String.format("%s at %s %s (%s)",
super.formatLogMessage(id, exception),
requestMethod,
requestPath,
userAgent) ;
}
}

View File

@@ -0,0 +1,21 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util.logging;
import org.glassfish.jersey.server.ExtendedUriInfo;
public class UriInfoUtil {
public static String getPathTemplate(final ExtendedUriInfo uriInfo) {
final StringBuilder pathBuilder = new StringBuilder();
for (int i = uriInfo.getMatchedTemplates().size() - 1; i >= 0; i--) {
pathBuilder.append(uriInfo.getMatchedTemplates().get(i).getTemplate());
}
return pathBuilder.toString();
}
}

View File

@@ -68,6 +68,11 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
context.getClient(),
retrySchedulingExecutor);
// TODO Remove once PIN-based reglocks have been deprecated
if (account.getRegistrationLock().requiresClientRegistrationLock() && account.getRegistrationLock().hasDeprecatedPin()) {
log.info("User-Agent with deprecated PIN-based registration lock: {}", context.getClient().getUserAgent());
}
openWebsocketCounter.inc();
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device));

View File

@@ -9,6 +9,8 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder;
import com.fasterxml.jackson.databind.DeserializationFeature;
import io.dropwizard.Application;
import io.dropwizard.cli.EnvironmentCommand;
@@ -39,6 +41,8 @@ import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsDynamoDb;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.AccountsManager.DeletionReason;
import org.whispersystems.textsecuregcm.storage.DeletedAccounts;
import org.whispersystems.textsecuregcm.storage.DeletedAccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
@@ -52,8 +56,10 @@ import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageDynamoDb;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.Usernames;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.storage.VerificationCodeStore;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
@@ -112,6 +118,8 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
DynamoDbAsyncClient accountsDynamoDbAsyncClient = DynamoDbFromConfig.asyncClient(configuration.getAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create(),
accountsDynamoDbMigrationThreadPool);
DynamoDbClient deletedAccountsDynamoDbClient = DynamoDbFromConfig.client(configuration.getDeletedAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", configuration.getCacheClusterConfiguration(), redisClusterClientResources);
@@ -131,9 +139,20 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
DynamoDbClient migrationRetryAccountsDynamoDb = DynamoDbFromConfig.client(configuration.getMigrationRetryAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
DynamoDbClient pendingAccountsDynamoDbClient = DynamoDbFromConfig.client(configuration.getPendingAccountsDynamoDbConfiguration(),
software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider.create());
AmazonDynamoDB deletedAccountsLockDynamoDbClient = AmazonDynamoDBClientBuilder.standard()
.withRegion(configuration.getDeletedAccountsLockDynamoDbConfiguration().getRegion())
.withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) configuration.getDeletedAccountsLockDynamoDbConfiguration().getClientExecutionTimeout().toMillis()))
.withRequestTimeout((int) configuration.getDeletedAccountsLockDynamoDbConfiguration().getClientRequestTimeout().toMillis()))
.withCredentials(InstanceProfileCredentialsProvider.getInstance())
.build();
DeletedAccounts deletedAccounts = new DeletedAccounts(deletedAccountsDynamoDbClient, configuration.getDeletedAccountsDynamoDbConfiguration().getTableName(), configuration.getDeletedAccountsDynamoDbConfiguration().getNeedsReconciliationIndexName());
MigrationDeletedAccounts migrationDeletedAccounts = new MigrationDeletedAccounts(migrationDeletedAccountsDynamoDb, configuration.getMigrationDeletedAccountsDynamoDbConfiguration().getTableName());
MigrationRetryAccounts migrationRetryAccounts = new MigrationRetryAccounts(migrationRetryAccountsDynamoDb, configuration.getMigrationRetryAccountsDynamoDbConfiguration().getTableName());
VerificationCodeStore pendingAccounts = new VerificationCodeStore(pendingAccountsDynamoDbClient, configuration.getPendingAccountsDynamoDbConfiguration().getTableName());
Accounts accounts = new Accounts(accountDatabase);
AccountsDynamoDb accountsDynamoDb = new AccountsDynamoDb(accountsDynamoDbClient, accountsDynamoDbAsyncClient, accountsDynamoDbMigrationThreadPool, configuration.getAccountsDynamoDbConfiguration().getTableName(), configuration.getAccountsDynamoDbConfiguration().getPhoneNumberTableName(), migrationDeletedAccounts, migrationRetryAccounts);
@@ -155,7 +174,9 @@ public class DeleteUserCommand extends EnvironmentCommand<WhisperServerConfigura
ReportMessageDynamoDb reportMessageDynamoDb = new ReportMessageDynamoDb(reportMessagesDynamoDb, configuration.getReportMessageDynamoDbConfiguration().getTableName());
ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, Metrics.globalRegistry);
MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, pushLatencyManager, reportMessageManager);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
DeletedAccountsManager deletedAccountsManager = new DeletedAccountsManager(deletedAccounts, deletedAccountsLockDynamoDbClient, configuration.getDeletedAccountsLockDynamoDbConfiguration().getTableName());
StoredVerificationCodeManager pendingAccountsManager = new StoredVerificationCodeManager(pendingAccounts);
AccountsManager accountsManager = new AccountsManager(accounts, accountsDynamoDb, cacheCluster, deletedAccountsManager, directoryQueue, keysDynamoDb, messagesManager, usernamesManager, profilesManager, pendingAccountsManager, secureStorageClient, secureBackupClient, experimentEnrollmentManager, dynamicConfigurationManager);
for (String user: users) {
Optional<Account> account = accountsManager.get(user);

View File

@@ -13,7 +13,6 @@ import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.configuration.DatabaseConfiguration;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.FaultTolerantDatabase;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import io.dropwizard.cli.ConfiguredCommand;
import io.dropwizard.setup.Bootstrap;
@@ -38,14 +37,10 @@ public class VacuumCommand extends ConfiguredCommand<WhisperServerConfiguration>
FaultTolerantDatabase accountDatabase = new FaultTolerantDatabase("account_database_vacuum", accountJdbi, accountDbConfig.getCircuitBreakerConfiguration());
Accounts accounts = new Accounts(accountDatabase);
PendingAccounts pendingAccounts = new PendingAccounts(accountDatabase);
logger.info("Vacuuming accounts...");
accounts.vacuum();
logger.info("Vacuuming pending_accounts...");
pendingAccounts.vacuum();
Thread.sleep(3000);
System.exit(0);
}

View File

@@ -1 +1,2 @@
org.whispersystems.textsecuregcm.metrics.JsonMetricsReporterFactory
org.whispersystems.textsecuregcm.metrics.SignalDatadogReporterFactory

View File

@@ -375,4 +375,10 @@
</addColumn>
</changeSet>
<changeSet id="25" author="chris">
<addColumn tableName="accounts">
<column name="version" type="int" defaultValue="0"/>
</addColumn>
</changeSet>
</databaseChangeLog>

View File

@@ -0,0 +1,8 @@
-- keys: lock_key
-- argv: lock_value
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end

View File

@@ -0,0 +1,30 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
class StoredVerificationCodeTest {
@ParameterizedTest
@MethodSource
void isValid(final StoredVerificationCode storedVerificationCode, final String code, final boolean expectValid) {
assertEquals(expectValid, storedVerificationCode.isValid(code));
}
private static Stream<Arguments> isValid() {
return Stream.of(
Arguments.of(new StoredVerificationCode("code", System.currentTimeMillis(), null, null), "code", true),
Arguments.of(new StoredVerificationCode("code", System.currentTimeMillis(), null, null), "incorrect", false),
Arguments.of(new StoredVerificationCode("", System.currentTimeMillis(), null, null), "", false)
);
}
}

View File

@@ -5,6 +5,16 @@
package org.whispersystems.textsecuregcm.metrics;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.InvalidProtocolBufferException;
import com.vdurmont.semver4j.Semver;
@@ -13,8 +23,18 @@ import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
@@ -25,9 +45,11 @@ import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.uri.UriTemplate;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
@@ -38,33 +60,7 @@ import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFa
import org.whispersystems.websocket.messages.protobuf.SubProtocol;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyIterable;
import static org.mockito.ArgumentMatchers.anyVararg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(JUnitParamsRunner.class)
public class MetricsRequestEventListenerTest {
class MetricsRequestEventListenerTest {
private MeterRegistry meterRegistry;
private Counter counter;
@@ -72,8 +68,8 @@ public class MetricsRequestEventListenerTest {
private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP;
@Before
public void setup() {
@BeforeEach
void setup() {
meterRegistry = mock(MeterRegistry.class);
counter = mock(Counter.class);
listener = new MetricsRequestEventListener(TRAFFIC_SOURCE, meterRegistry);
@@ -81,7 +77,7 @@ public class MetricsRequestEventListenerTest {
@Test
@SuppressWarnings("unchecked")
public void testOnEvent() {
void testOnEvent() {
final String path = "/test";
final int statusCode = 200;
@@ -125,19 +121,7 @@ public class MetricsRequestEventListenerTest {
}
@Test
public void testGetPathTemplate() {
final UriTemplate firstComponent = new UriTemplate("/first");
final UriTemplate secondComponent = new UriTemplate("/second");
final UriTemplate thirdComponent = new UriTemplate("/{param}/{moreDifferentParam}");
final ExtendedUriInfo uriInfo = mock(ExtendedUriInfo.class);
when(uriInfo.getMatchedTemplates()).thenReturn(Arrays.asList(thirdComponent, secondComponent, firstComponent));
assertEquals("/first/second/{param}/{moreDifferentParam}", MetricsRequestEventListener.getPathTemplate(uriInfo));
}
@Test
public void testActualRouteMessageSuccess() throws InvalidProtocolBufferException {
void testActualRouteMessageSuccess() throws InvalidProtocolBufferException {
MetricsApplicationEventListener applicationEventListener = mock(MetricsApplicationEventListener.class);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
@@ -197,7 +181,7 @@ public class MetricsRequestEventListenerTest {
}
@Test
public void testActualRouteMessageSuccessNoUserAgent() throws InvalidProtocolBufferException {
void testActualRouteMessageSuccessNoUserAgent() throws InvalidProtocolBufferException {
MetricsApplicationEventListener applicationEventListener = mock(MetricsApplicationEventListener.class);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
@@ -254,9 +238,9 @@ public class MetricsRequestEventListenerTest {
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized")));
}
@Test
@Parameters(method = "argumentsForTestRecordDesktopOperatingSystem")
public void testRecordDesktopOperatingSystem(final UserAgent userAgent, final String expectedOperatingSystem) {
@ParameterizedTest
@MethodSource
void testRecordDesktopOperatingSystem(final UserAgent userAgent, final String expectedOperatingSystem) {
when(meterRegistry.counter(eq(MetricsRequestEventListener.DESKTOP_REQUEST_COUNTER_NAME), (String)any())).thenReturn(counter);
listener.recordDesktopOperatingSystem(userAgent);
@@ -271,20 +255,20 @@ public class MetricsRequestEventListenerTest {
}
}
private static Object argumentsForTestRecordDesktopOperatingSystem() {
return new Object[] {
new Object[] { new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux"), "linux" },
new Object[] { new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "macOS"), "macos" },
new Object[] { new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Windows"), "windows" },
new Object[] { new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3")), null },
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25"), null },
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)"), null },
};
private static Stream<Arguments> testRecordDesktopOperatingSystem() {
return Stream.of(
Arguments.of( new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux"), "linux" ),
Arguments.of( new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "macOS"), "macos" ),
Arguments.of( new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Windows"), "windows" ),
Arguments.of( new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3")), null ),
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25"), null ),
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)"), null )
);
}
@Test
@Parameters(method = "argumentsForTestRecordAndroidSdkVersion")
public void testRecordAndroidSdkVersion(final UserAgent userAgent, final String expectedSdkVersion) {
@ParameterizedTest
@MethodSource
void testRecordAndroidSdkVersion(final UserAgent userAgent, final String expectedSdkVersion) {
when(meterRegistry.counter(eq(MetricsRequestEventListener.ANDROID_REQUEST_COUNTER_NAME), (String)any())).thenReturn(counter);
listener.recordAndroidSdkVersion(userAgent);
@@ -299,21 +283,21 @@ public class MetricsRequestEventListenerTest {
}
}
private static Object argumentsForTestRecordAndroidSdkVersion() {
return new Object[] {
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/1"), null },
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25"), "25" },
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/700000"), null },
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/"), null },
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), null), null },
new Object[] { new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux"), null },
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)"), null }
};
private static Stream<Arguments> testRecordAndroidSdkVersion() {
return Stream.of(
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/1"), null ),
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25"), "25" ),
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/700000"), null ),
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/"), null ),
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), null), null ),
Arguments.of( new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux"), null ),
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)"), null )
);
}
@Test
@Parameters(method = "argumentsForTestRecordIosVersion")
public void testRecordIosVersion(final UserAgent userAgent, final String expectedIosVersion) {
@ParameterizedTest
@MethodSource
void testRecordIosVersion(final UserAgent userAgent, final String expectedIosVersion) {
when(meterRegistry.counter(eq(MetricsRequestEventListener.IOS_REQUEST_COUNTER_NAME), (String)any())).thenReturn(counter);
listener.recordIosVersion(userAgent);
@@ -328,16 +312,16 @@ public class MetricsRequestEventListenerTest {
}
}
private static Object argumentsForTestRecordIosVersion() {
return new Object[] {
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/14.2"), "14.2" },
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)"), "12.2" },
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0")), null },
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/bogus"), null },
new Object[] { new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS bogus; Scale/3.00)"), null },
new Object[] { new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25"), null },
new Object[] { new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux"), null }
};
private static Stream<Arguments> testRecordIosVersion() {
return Stream.of(
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/14.2"), "14.2" ),
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS 12.2; Scale/3.00)"), "12.2" ),
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0")), null ),
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "iOS/bogus"), null ),
Arguments.of( new UserAgent(ClientPlatform.IOS, new Semver("3.9.0"), "(iPhone; iOS bogus; Scale/3.00)"), null ),
Arguments.of( new UserAgent(ClientPlatform.ANDROID, new Semver("4.68.3"), "Android/25"), null ),
Arguments.of( new UserAgent(ClientPlatform.DESKTOP, new Semver("1.2.3"), "Linux"), null )
);
}
private static SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor<ByteBuffer> responseCaptor) throws InvalidProtocolBufferException {

View File

@@ -5,120 +5,123 @@
package org.whispersystems.textsecuregcm.sqs;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.Account;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(JUnitParamsRunner.class)
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.Account;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
public class DirectoryQueueTest {
@Test
@Parameters(method = "argumentsForTestRefreshRegisteredUser")
public void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) {
final AmazonSQS sqs = mock(AmazonSQS.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
private SqsAsyncClient sqsAsyncClient;
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(accountEnabled);
when(account.isDiscoverableByPhoneNumber()).thenReturn(accountDiscoverableByPhoneNumber);
@BeforeEach
void setUp() {
sqsAsyncClient = mock(SqsAsyncClient.class);
directoryQueue.refreshRegisteredUser(account);
when(sqsAsyncClient.sendMessage(any(SendMessageRequest.class)))
.thenReturn(CompletableFuture.completedFuture(SendMessageResponse.builder().build()));
}
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs).sendMessageBatch(requestCaptor.capture());
@ParameterizedTest
@MethodSource("argumentsForTestRefreshRegisteredUser")
void testRefreshRegisteredUser(final boolean accountEnabled, final boolean accountDiscoverableByPhoneNumber, final String expectedAction) {
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqsAsyncClient);
assertEquals(1, requestCaptor.getValue().getEntries().size());
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(accountEnabled);
when(account.isDiscoverableByPhoneNumber()).thenReturn(accountDiscoverableByPhoneNumber);
final Map<String, MessageAttributeValue> messageAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(expectedAction), messageAttributes.get("action"));
directoryQueue.refreshAccount(account);
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
verify(sqsAsyncClient).sendMessage(requestCaptor.capture());
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue(expectedAction).build(),
requestCaptor.getValue().messageAttributes().get("action"));
}
@SuppressWarnings("unused")
private static Stream<Arguments> argumentsForTestRefreshRegisteredUser() {
return Stream.of(
Arguments.of(true, true, "add"),
Arguments.of(true, false, "delete"),
Arguments.of(false, true, "delete"),
Arguments.of(false, false, "delete"));
}
@Test
void testSendMessageMultipleQueues() {
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqsAsyncClient);
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(true);
when(account.isDiscoverableByPhoneNumber()).thenReturn(true);
directoryQueue.refreshAccount(account);
final ArgumentCaptor<SendMessageRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
verify(sqsAsyncClient, times(2)).sendMessage(requestCaptor.capture());
for (final SendMessageRequest sendMessageRequest : requestCaptor.getAllValues()) {
assertEquals(MessageAttributeValue.builder().dataType("String").stringValue("add").build(),
sendMessageRequest.messageAttributes().get("action"));
}
}
@Test
public void testRefreshBatch() {
final AmazonSQS sqs = mock(AmazonSQS.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqs);
@Test
void testStop() {
final CompletableFuture<SendMessageResponse> sendMessageFuture = new CompletableFuture<>();
when(sqsAsyncClient.sendMessage(any(SendMessageRequest.class))).thenReturn(sendMessageFuture);
final Account discoverableAccount = mock(Account.class);
when(discoverableAccount.getNumber()).thenReturn("+18005556543");
when(discoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
when(discoverableAccount.isEnabled()).thenReturn(true);
when(discoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(true);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://test"), sqsAsyncClient);
final Account undiscoverableAccount = mock(Account.class);
when(undiscoverableAccount.getNumber()).thenReturn("+18005550987");
when(undiscoverableAccount.getUuid()).thenReturn(UUID.randomUUID());
when(undiscoverableAccount.isEnabled()).thenReturn(true);
when(undiscoverableAccount.isDiscoverableByPhoneNumber()).thenReturn(false);
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(true);
when(account.isDiscoverableByPhoneNumber()).thenReturn(true);
directoryQueue.refreshRegisteredUsers(List.of(discoverableAccount, undiscoverableAccount));
directoryQueue.refreshAccount(account);
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs).sendMessageBatch(requestCaptor.capture());
final CompletableFuture<Boolean> stopFuture = CompletableFuture.supplyAsync(() -> {
try {
directoryQueue.stop();
return true;
} catch (final Exception e) {
return false;
}
});
assertEquals(2, requestCaptor.getValue().getEntries().size());
assertThrows(TimeoutException.class, () -> stopFuture.get(1, TimeUnit.SECONDS),
"Directory queue should not finish shutting down until all outstanding requests are resolved");
final Map<String, MessageAttributeValue> discoverableAccountAttributes = requestCaptor.getValue().getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getNumber()), discoverableAccountAttributes.get("id"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(discoverableAccount.getUuid().toString()), discoverableAccountAttributes.get("uuid"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), discoverableAccountAttributes.get("action"));
final Map<String, MessageAttributeValue> undiscoverableAccountAttributes = requestCaptor.getValue().getEntries().get(1).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getNumber()), undiscoverableAccountAttributes.get("id"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue(undiscoverableAccount.getUuid().toString()), undiscoverableAccountAttributes.get("uuid"));
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("delete"), undiscoverableAccountAttributes.get("action"));
}
@Test
public void testSendMessageMultipleQueues() {
final AmazonSQS sqs = mock(AmazonSQS.class);
final DirectoryQueue directoryQueue = new DirectoryQueue(List.of("sqs://first", "sqs://second"), sqs);
final Account account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005556543");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.isEnabled()).thenReturn(true);
when(account.isDiscoverableByPhoneNumber()).thenReturn(true);
directoryQueue.refreshRegisteredUser(account);
final ArgumentCaptor<SendMessageBatchRequest> requestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
verify(sqs, times(2)).sendMessageBatch(requestCaptor.capture());
for (final SendMessageBatchRequest sendMessageBatchRequest : requestCaptor.getAllValues()) {
assertEquals(1, requestCaptor.getValue().getEntries().size());
final Map<String, MessageAttributeValue> messageAttributes = sendMessageBatchRequest.getEntries().get(0).getMessageAttributes();
assertEquals(new MessageAttributeValue().withDataType("String").withStringValue("add"), messageAttributes.get("action"));
}
}
@SuppressWarnings("unused")
private Object argumentsForTestRefreshRegisteredUser() {
return new Object[] {
new Object[] { true, true, "add" },
new Object[] { true, false, "delete" },
new Object[] { false, true, "delete" },
new Object[] { false, false, "delete" }
};
}
sendMessageFuture.complete(SendMessageResponse.builder().build());
assertTrue(stopFuture.join());
}
}

View File

@@ -5,22 +5,25 @@
package org.whispersystems.textsecuregcm.storage;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static org.junit.Assert.*;
import static org.junit.Assert.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterTest {
private static final UUID FIRST_UUID = UUID.fromString("82339e80-81cd-48e2-9ed2-ccd5dd262ad9");
@@ -32,6 +35,8 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
private AccountsManager accountsManager;
private AccountDatabaseCrawlerListener listener;
private DynamicConfigurationManager dynamicConfigurationManager;
private AccountDatabaseCrawler accountDatabaseCrawler;
private static final int CHUNK_SIZE = 1;
@@ -47,16 +52,22 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
accountsManager = mock(AccountsManager.class);
listener = mock(AccountDatabaseCrawlerListener.class);
dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
when(firstAccount.getUuid()).thenReturn(FIRST_UUID);
when(secondAccount.getUuid()).thenReturn(SECOND_UUID);
when(accountsManager.getAllFrom(CHUNK_SIZE)).thenReturn(List.of(firstAccount));
when(accountsManager.getAllFrom(FIRST_UUID, CHUNK_SIZE))
.thenReturn(List.of(secondAccount))
.thenReturn(Collections.emptyList());
when(accountsManager.getAllFrom(CHUNK_SIZE)).thenReturn(new AccountCrawlChunk(List.of(firstAccount), FIRST_UUID));
when(accountsManager.getAllFrom(any(UUID.class), eq(CHUNK_SIZE)))
.thenReturn(new AccountCrawlChunk(List.of(secondAccount), SECOND_UUID))
.thenReturn(new AccountCrawlChunk(Collections.emptyList(), null));
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getAccountsDynamoDbMigrationConfiguration()).thenReturn(mock(DynamicAccountsDynamoDbMigrationConfiguration.class));
final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(getRedisCluster());
accountDatabaseCrawler = new AccountDatabaseCrawler(accountsManager, crawlerCache, List.of(listener), CHUNK_SIZE, CHUNK_INTERVAL_MS);
accountDatabaseCrawler = new AccountDatabaseCrawler(accountsManager, crawlerCache, List.of(listener), CHUNK_SIZE, CHUNK_INTERVAL_MS, dynamicConfigurationManager);
}
@Test

View File

@@ -1,21 +1,24 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.fasterxml.uuid.UUIDComparator;
import io.github.resilience4j.circuitbreaker.CallNotPermittedException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
@@ -47,6 +50,7 @@ import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
import software.amazon.awssdk.services.dynamodb.model.ScanResponse;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest;
import software.amazon.awssdk.services.dynamodb.model.TransactionConflictException;
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
class AccountsDynamoDbTest {
@@ -54,7 +58,7 @@ class AccountsDynamoDbTest {
private static final String ACCOUNTS_TABLE_NAME = "accounts_test";
private static final String NUMBERS_TABLE_NAME = "numbers_test";
private static final String MIGRATION_DELETED_ACCOUNTS_TABLE_NAME = "migration_deleted_accounts_test";
private static final String MIGRATION_RETRY_ACCOUNTS_TABLE_NAME = "miration_retry_accounts_test";
private static final String MIGRATION_RETRY_ACCOUNTS_TABLE_NAME = "migration_retry_accounts_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
@@ -208,6 +212,10 @@ class AccountsDynamoDbTest {
verifyStoredState("+14151112222", account.getUuid(), account);
account.setProfileName("name");
accountsDynamoDb.update(account);
UUID secondUuid = UUID.randomUUID();
device = generateDevice(1);
@@ -249,13 +257,93 @@ class AccountsDynamoDbTest {
assertThatThrownBy(() -> accountsDynamoDb.update(unknownAccount)).isInstanceOfAny(ConditionalCheckFailedException.class);
account.setDynamoDbMigrationVersion(5);
account.setProfileName("name");
accountsDynamoDb.update(account);
assertThat(account.getVersion()).isEqualTo(2);
verifyStoredState("+14151112222", account.getUuid(), account);
account.setVersion(1);
assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
account.setVersion(2);
account.setProfileName("name2");
accountsDynamoDb.update(account);
verifyStoredState("+14151112222", account.getUuid(), account);
}
@Test
void testUpdateWithMockTransactionConflictException() {
final DynamoDbClient dynamoDbClient = mock(DynamoDbClient.class);
accountsDynamoDb = new AccountsDynamoDb(dynamoDbClient, mock(DynamoDbAsyncClient.class),
new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()),
dynamoDbExtension.getTableName(), NUMBERS_TABLE_NAME, mock(MigrationDeletedAccounts.class),
mock(MigrationRetryAccounts.class));
when(dynamoDbClient.updateItem(any(UpdateItemRequest.class)))
.thenThrow(TransactionConflictException.class);
Device device = generateDevice (1 );
Account account = generateAccount("+14151112222", UUID.randomUUID(), Collections.singleton(device));
assertThatThrownBy(() -> accountsDynamoDb.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class);
}
@Test
void testRetrieveFrom() {
List<Account> users = new ArrayList<>();
for (int i = 1; i <= 100; i++) {
Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID());
users.add(account);
accountsDynamoDb.create(account);
}
users.sort((account, t1) -> UUIDComparator.staticCompare(account.getUuid(), t1.getUuid()));
AccountCrawlChunk retrieved = accountsDynamoDb.getAllFromStart(10, 1);
assertThat(retrieved.getAccounts().size()).isEqualTo(10);
for (int i = 0; i < retrieved.getAccounts().size(); i++) {
final Account retrievedAccount = retrieved.getAccounts().get(i);
final Account expectedAccount = users.stream()
.filter(account -> account.getUuid().equals(retrievedAccount.getUuid()))
.findAny()
.orElseThrow();
verifyStoredState(expectedAccount.getNumber(), expectedAccount.getUuid(), retrievedAccount, expectedAccount);
users.remove(expectedAccount);
}
for (int j = 0; j < 9; j++) {
retrieved = accountsDynamoDb.getAllFrom(retrieved.getLastUuid().orElseThrow(), 10, 1);
assertThat(retrieved.getAccounts().size()).isEqualTo(10);
for (int i = 0; i < retrieved.getAccounts().size(); i++) {
final Account retrievedAccount = retrieved.getAccounts().get(i);
final Account expectedAccount = users.stream()
.filter(account -> account.getUuid().equals(retrievedAccount.getUuid()))
.findAny()
.orElseThrow();
verifyStoredState(expectedAccount.getNumber(), expectedAccount.getUuid(), retrievedAccount, expectedAccount);
users.remove(expectedAccount);
}
}
assertThat(users).isEmpty();
}
@Test
void testDelete() {
final Device deletedDevice = generateDevice (1);
@@ -411,7 +499,7 @@ class AccountsDynamoDbTest {
assertThat(migrated).isFalse();
verifyStoredState("+14151112222", firstUuid, account);
account.setDynamoDbMigrationVersion(account.getDynamoDbMigrationVersion() + 1);
account.setVersion(account.getVersion() + 1);
migrated = accountsDynamoDb.migrate(account).get();
@@ -423,7 +511,7 @@ class AccountsDynamoDbTest {
SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(), "testSignature-" + random.nextInt());
return new Device(id, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(), "testSalt-" + random.nextInt(),
"testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(), random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(), "testUserAgent-" + random.nextInt() , 0, new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
false));
random.nextBoolean(), random.nextBoolean()));
}
private Account generateAccount(String number, UUID uuid) {
@@ -452,8 +540,8 @@ class AccountsDynamoDbTest {
String data = new String(get.item().get(AccountsDynamoDb.ATTR_ACCOUNT_DATA).b().asByteArray(), StandardCharsets.UTF_8);
assertThat(data).isNotEmpty();
assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_MIGRATION_VERSION, -1))
.isEqualTo(expecting.getDynamoDbMigrationVersion());
assertThat(AttributeValues.getInt(get.item(), AccountsDynamoDb.ATTR_VERSION, -1))
.isEqualTo(expecting.getVersion());
Account result = AccountsDynamoDb.fromItem(get.item());
verifyStoredState(number, uuid, result, expecting);
@@ -466,6 +554,7 @@ class AccountsDynamoDbTest {
assertThat(result.getNumber()).isEqualTo(number);
assertThat(result.getLastSeen()).isEqualTo(expecting.getLastSeen());
assertThat(result.getUuid()).isEqualTo(uuid);
assertThat(result.getVersion()).isEqualTo(expecting.getVersion());
assertThat(Arrays.equals(result.getUnidentifiedAccessKey().get(), expecting.getUnidentifiedAccessKey().get())).isTrue();
for (Device expectingDevice : expecting.getDevices()) {

View File

@@ -0,0 +1,275 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.opentable.db.postgres.embedded.LiquibasePreparer;
import com.opentable.db.postgres.junit5.EmbeddedPostgresExtension;
import com.opentable.db.postgres.junit5.PreparedDbExtension;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicAccountsDynamoDbMigrationConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.securebackup.SecureBackupClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.CreateTableRequest;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
class AccountsManagerConcurrentModificationIntegrationTest {
@RegisterExtension
static PreparedDbExtension db = EmbeddedPostgresExtension.preparedDatabase(LiquibasePreparer.forClasspathLocation("accountsdb.xml"));
private static final String ACCOUNTS_TABLE_NAME = "accounts_test";
private static final String NUMBERS_TABLE_NAME = "numbers_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
.tableName(ACCOUNTS_TABLE_NAME)
.hashKey(AccountsDynamoDb.KEY_ACCOUNT_UUID)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(AccountsDynamoDb.KEY_ACCOUNT_UUID)
.attributeType(ScalarAttributeType.B)
.build())
.build();
private Accounts accounts;
private AccountsDynamoDb accountsDynamoDb;
private AccountsManager accountsManager;
private RedisAdvancedClusterCommands<String, String> commands;
private Executor mutationExecutor = new ThreadPoolExecutor(20, 20, 5, TimeUnit.SECONDS, new LinkedBlockingDeque<>(20));
@BeforeEach
void setup() {
{
CreateTableRequest createNumbersTableRequest = CreateTableRequest.builder()
.tableName(NUMBERS_TABLE_NAME)
.keySchema(KeySchemaElement.builder()
.attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164)
.keyType(KeyType.HASH)
.build())
.attributeDefinitions(AttributeDefinition.builder()
.attributeName(AccountsDynamoDb.ATTR_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S)
.build())
.provisionedThroughput(DynamoDbExtension.DEFAULT_PROVISIONED_THROUGHPUT)
.build();
dynamoDbExtension.getDynamoDbClient().createTable(createNumbersTableRequest);
}
accountsDynamoDb = new AccountsDynamoDb(
dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getDynamoDbAsyncClient(),
new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingDeque<>()),
dynamoDbExtension.getTableName(),
NUMBERS_TABLE_NAME,
mock(MigrationDeletedAccounts.class),
mock(MigrationRetryAccounts.class));
{
final CircuitBreakerConfiguration circuitBreakerConfiguration = new CircuitBreakerConfiguration();
circuitBreakerConfiguration.setIgnoredExceptions(List.of("org.whispersystems.textsecuregcm.storage.ContestedOptimisticLockException"));
FaultTolerantDatabase faultTolerantDatabase = new FaultTolerantDatabase("accountsTest",
Jdbi.create(db.getTestDatabase()),
circuitBreakerConfiguration);
accounts = new Accounts(faultTolerantDatabase);
}
{
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
final DynamicAccountsDynamoDbMigrationConfiguration config = dynamicConfiguration
.getAccountsDynamoDbMigrationConfiguration();
config.setDeleteEnabled(true);
config.setReadEnabled(true);
config.setWriteEnabled(true);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), anyString())).thenReturn(true);
commands = mock(RedisAdvancedClusterCommands.class);
accountsManager = new AccountsManager(
accounts,
accountsDynamoDb,
RedisClusterHelper.buildMockRedisCluster(commands),
mock(DeletedAccountsManager.class),
mock(DirectoryQueue.class),
mock(KeysDynamoDb.class),
mock(MessagesManager.class),
mock(UsernamesManager.class),
mock(ProfilesManager.class),
mock(StoredVerificationCodeManager.class),
mock(SecureStorageClient.class),
mock(SecureBackupClient.class),
experimentEnrollmentManager,
dynamicConfigurationManager);
}
}
@Test
void testConcurrentUpdate() throws IOException {
final UUID uuid;
{
final Account account = accountsManager.update(
accountsManager.create("+14155551212", "password", null, new AccountAttributes()),
a -> {
a.setUnidentifiedAccessKey(new byte[16]);
final Random random = new Random();
final SignedPreKey signedPreKey = new SignedPreKey(random.nextInt(), "testPublicKey-" + random.nextInt(),
"testSignature-" + random.nextInt());
a.removeDevice(1);
a.addDevice(new Device(1, "testName-" + random.nextInt(), "testAuthToken-" + random.nextInt(),
"testSalt-" + random.nextInt(),
"testGcmId-" + random.nextInt(), "testApnId-" + random.nextInt(), "testVoipApnId-" + random.nextInt(),
random.nextBoolean(), random.nextInt(), signedPreKey, random.nextInt(), random.nextInt(),
"testUserAgent-" + random.nextInt(), 0,
new Device.DeviceCapabilities(random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean(), random.nextBoolean(),
random.nextBoolean(), random.nextBoolean())));
});
uuid = account.getUuid();
}
final String profileName = "name";
final String avatar = "avatar";
final boolean discoverableByPhoneNumber = false;
final String currentProfileVersion = "cpv";
final String identityKey = "ikey";
final byte[] unidentifiedAccessKey = new byte[]{1};
final String pin = "1234";
final String registrationLock = "reglock";
final AuthenticationCredentials credentials = new AuthenticationCredentials(registrationLock);
final boolean unrestrictedUnidentifiedAccess = true;
final long lastSeen = Instant.now().getEpochSecond();
CompletableFuture.allOf(
modifyAccount(uuid, account -> account.setProfileName(profileName)),
modifyAccount(uuid, account -> account.setAvatar(avatar)),
modifyAccount(uuid, account -> account.setDiscoverableByPhoneNumber(discoverableByPhoneNumber)),
modifyAccount(uuid, account -> account.setCurrentProfileVersion(currentProfileVersion)),
modifyAccount(uuid, account -> account.setIdentityKey(identityKey)),
modifyAccount(uuid, account -> account.setUnidentifiedAccessKey(unidentifiedAccessKey)),
modifyAccount(uuid, account -> account.setPin(pin)),
modifyAccount(uuid, account -> account.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt())),
modifyAccount(uuid, account -> account.setUnrestrictedUnidentifiedAccess(unrestrictedUnidentifiedAccess)),
modifyDevice(uuid, Device.MASTER_ID, device-> device.setLastSeen(lastSeen)),
modifyDevice(uuid, Device.MASTER_ID, device-> device.setName("deviceName"))
).join();
final Account managerAccount = accountsManager.get(uuid).get();
final Account dbAccount = accounts.get(uuid).get();
final Account dynamoAccount = accountsDynamoDb.get(uuid).get();
final Account redisAccount = getLastAccountFromRedisMock(commands);
Stream.of(
new Pair<>("manager", managerAccount),
new Pair<>("db", dbAccount),
new Pair<>("dynamo", dynamoAccount),
new Pair<>("redis", redisAccount)
).forEach(pair ->
verifyAccount(pair.first(), pair.second(), profileName, avatar, discoverableByPhoneNumber,
currentProfileVersion, identityKey, unidentifiedAccessKey, pin, registrationLock,
unrestrictedUnidentifiedAccess, lastSeen)
);
}
private Account getLastAccountFromRedisMock(RedisAdvancedClusterCommands<String, String> commands) throws IOException {
ArgumentCaptor<String> redisSetArgumentCapture = ArgumentCaptor.forClass(String.class);
verify(commands, atLeast(20)).set(anyString(), redisSetArgumentCapture.capture());
return JsonHelpers.fromJson(redisSetArgumentCapture.getValue(), Account.class);
}
private void verifyAccount(final String name, final Account account, final String profileName, final String avatar, final boolean discoverableByPhoneNumber, final String currentProfileVersion, final String identityKey, final byte[] unidentifiedAccessKey, final String pin, final String clientRegistrationLock, final boolean unrestrictedUnidentifiedAcces, final long lastSeen) {
assertAll(name,
() -> assertEquals(profileName, account.getProfileName()),
() -> assertEquals(avatar, account.getAvatar()),
() -> assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()),
() -> assertEquals(currentProfileVersion, account.getCurrentProfileVersion().get()),
() -> assertEquals(identityKey, account.getIdentityKey()),
() -> assertArrayEquals(unidentifiedAccessKey, account.getUnidentifiedAccessKey().get()),
() -> assertTrue(account.getRegistrationLock().verify(clientRegistrationLock, pin)),
() -> assertEquals(unrestrictedUnidentifiedAcces, account.isUnrestrictedUnidentifiedAccess())
);
}
private CompletableFuture<?> modifyAccount(final UUID uuid, final Consumer<Account> accountMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.get(uuid).get();
accountsManager.update(account, accountMutation);
}, mutationExecutor);
}
private CompletableFuture<?> modifyDevice(final UUID uuid, final long deviceId, final Consumer<Device> deviceMutation) {
return CompletableFuture.runAsync(() -> {
final Account account = accountsManager.get(uuid).get();
accountsManager.updateDevice(account, deviceId, deviceMutation);
}, mutationExecutor);
}
}

View File

@@ -0,0 +1,142 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.api.function.Executable;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.Projection;
import software.amazon.awssdk.services.dynamodb.model.ProjectionType;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import java.lang.Thread.State;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
class DeletedAccountsManagerTest {
private static final String NEEDS_RECONCILIATION_INDEX_NAME = "needs_reconciliation_test";
@RegisterExtension
static final DynamoDbExtension DELETED_ACCOUNTS_DYNAMODB_EXTENSION = DynamoDbExtension.builder()
.tableName("deleted_accounts_test")
.hashKey(DeletedAccounts.KEY_ACCOUNT_E164)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(DeletedAccounts.KEY_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S).build())
.attributeDefinition(AttributeDefinition.builder()
.attributeName(DeletedAccounts.ATTR_NEEDS_CDS_RECONCILIATION)
.attributeType(ScalarAttributeType.N)
.build())
.globalSecondaryIndex(GlobalSecondaryIndex.builder()
.indexName(NEEDS_RECONCILIATION_INDEX_NAME)
.keySchema(KeySchemaElement.builder().attributeName(DeletedAccounts.KEY_ACCOUNT_E164).keyType(KeyType.HASH).build(),
KeySchemaElement.builder().attributeName(DeletedAccounts.ATTR_NEEDS_CDS_RECONCILIATION).keyType(KeyType.RANGE).build())
.projection(Projection.builder().projectionType(ProjectionType.INCLUDE).nonKeyAttributes(DeletedAccounts.ATTR_ACCOUNT_UUID).build())
.provisionedThroughput(ProvisionedThroughput.builder().readCapacityUnits(10L).writeCapacityUnits(10L).build())
.build())
.build();
@RegisterExtension
static DynamoDbExtension DELETED_ACCOUNTS_LOCK_DYNAMODB_EXTENSION = DynamoDbExtension.builder()
.tableName("deleted_accounts_lock_test")
.hashKey(DeletedAccounts.KEY_ACCOUNT_E164)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(DeletedAccounts.KEY_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S).build())
.build();
private DeletedAccountsManager deletedAccountsManager;
@BeforeEach
void setUp() {
final DeletedAccounts deletedAccounts = new DeletedAccounts(DELETED_ACCOUNTS_DYNAMODB_EXTENSION.getDynamoDbClient(),
DELETED_ACCOUNTS_DYNAMODB_EXTENSION.getTableName(),
NEEDS_RECONCILIATION_INDEX_NAME);
deletedAccountsManager = new DeletedAccountsManager(deletedAccounts,
DELETED_ACCOUNTS_LOCK_DYNAMODB_EXTENSION.getLegacyDynamoClient(),
DELETED_ACCOUNTS_LOCK_DYNAMODB_EXTENSION.getTableName());
}
@Test
void testReconciliationLockContention() throws ChunkProcessingFailedException, InterruptedException {
final UUID[] uuids = new UUID[3];
final String[] e164s = new String[uuids.length];
for (int i = 0; i < uuids.length; i++) {
uuids[i] = UUID.randomUUID();
e164s[i] = String.format("+1800555%04d", i);
}
final Map<String, UUID> expectedReconciledAccounts = new HashMap<>();
for (int i = 0; i < uuids.length; i++) {
deletedAccountsManager.put(uuids[i], e164s[i]);
expectedReconciledAccounts.put(e164s[i], uuids[i]);
}
final UUID replacedUUID = UUID.randomUUID();
final Map<String, UUID> reconciledAccounts = new HashMap<>();
final Thread putThread = new Thread(() -> {
try {
deletedAccountsManager.put(replacedUUID, e164s[0]);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
},
getClass().getSimpleName() + "-put");
final Thread reconcileThread = new Thread(() -> {
try {
deletedAccountsManager.lockAndReconcileAccounts(uuids.length, deletedAccounts -> {
// We hold the lock for the first account, so a thread trying to operate on that first count should block
// waiting for the lock.
putThread.start();
// Make sure the other thread really does actually block at some point
while (putThread.getState() != State.TIMED_WAITING) {
Thread.yield();
}
deletedAccounts.forEach(pair -> reconciledAccounts.put(pair.second(), pair.first()));
return reconciledAccounts.keySet();
});
} catch (ChunkProcessingFailedException e) {
throw new AssertionError(e);
}
}, getClass().getSimpleName() + "-reconcile");
reconcileThread.start();
assertDoesNotThrow((Executable) reconcileThread::join);
assertDoesNotThrow((Executable) putThread::join);
assertEquals(expectedReconciledAccounts, reconciledAccounts);
// The "put" thread should have completed after the reconciliation thread wrapped up. We can verify that's true by
// reconciling again; the updated account (and only that account) should appear in the "needs reconciliation" list.
deletedAccountsManager.lockAndReconcileAccounts(uuids.length, deletedAccounts -> {
assertEquals(1, deletedAccounts.size());
assertEquals(replacedUUID, deletedAccounts.get(0).first());
assertEquals(e164s[0], deletedAccounts.get(0).second());
return List.of(deletedAccounts.get(0).second());
});
}
}

View File

@@ -0,0 +1,130 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.GlobalSecondaryIndex;
import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement;
import software.amazon.awssdk.services.dynamodb.model.KeyType;
import software.amazon.awssdk.services.dynamodb.model.Projection;
import software.amazon.awssdk.services.dynamodb.model.ProjectionType;
import software.amazon.awssdk.services.dynamodb.model.ProvisionedThroughput;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
class DeletedAccountsTest {
private static final String NEEDS_RECONCILIATION_INDEX_NAME = "needs_reconciliation_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
.tableName("deleted_accounts_test")
.hashKey(DeletedAccounts.KEY_ACCOUNT_E164)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(DeletedAccounts.KEY_ACCOUNT_E164)
.attributeType(ScalarAttributeType.S).build())
.attributeDefinition(AttributeDefinition.builder()
.attributeName(DeletedAccounts.ATTR_NEEDS_CDS_RECONCILIATION)
.attributeType(ScalarAttributeType.N)
.build())
.globalSecondaryIndex(GlobalSecondaryIndex.builder()
.indexName(NEEDS_RECONCILIATION_INDEX_NAME)
.keySchema(KeySchemaElement.builder().attributeName(DeletedAccounts.KEY_ACCOUNT_E164).keyType(KeyType.HASH).build(),
KeySchemaElement.builder().attributeName(DeletedAccounts.ATTR_NEEDS_CDS_RECONCILIATION).keyType(KeyType.RANGE).build())
.projection(Projection.builder().projectionType(ProjectionType.INCLUDE).nonKeyAttributes(DeletedAccounts.ATTR_ACCOUNT_UUID).build())
.provisionedThroughput(ProvisionedThroughput.builder().readCapacityUnits(10L).writeCapacityUnits(10L).build())
.build())
.build();
private DeletedAccounts deletedAccounts;
@BeforeEach
void setUp() {
deletedAccounts = new DeletedAccounts(dynamoDbExtension.getDynamoDbClient(),
dynamoDbExtension.getTableName(),
NEEDS_RECONCILIATION_INDEX_NAME);
}
@Test
void testPutList() {
UUID firstUuid = UUID.randomUUID();
UUID secondUuid = UUID.randomUUID();
UUID thirdUuid = UUID.randomUUID();
String firstNumber = "+14152221234";
String secondNumber = "+14152225678";
String thirdNumber = "+14159998765";
assertTrue(deletedAccounts.listAccountsToReconcile(1).isEmpty());
deletedAccounts.put(firstUuid, firstNumber);
deletedAccounts.put(secondUuid, secondNumber);
deletedAccounts.put(thirdUuid, thirdNumber);
assertEquals(1, deletedAccounts.listAccountsToReconcile(1).size());
assertTrue(deletedAccounts.listAccountsToReconcile(10).containsAll(
List.of(
new Pair<>(firstUuid, firstNumber),
new Pair<>(secondUuid, secondNumber))));
deletedAccounts.markReconciled(List.of(firstNumber, secondNumber));
assertEquals(List.of(new Pair<>(thirdUuid, thirdNumber)), deletedAccounts.listAccountsToReconcile(10));
deletedAccounts.markReconciled(List.of(thirdNumber));
assertTrue(deletedAccounts.listAccountsToReconcile(1).isEmpty());
}
@Test
void testGetAccountsNeedingReconciliation() {
final UUID firstUuid = UUID.randomUUID();
final UUID secondUuid = UUID.randomUUID();
final String firstNumber = "+14152221234";
final String secondNumber = "+14152225678";
final String thirdNumber = "+14159998765";
assertEquals(Collections.emptySet(),
deletedAccounts.getAccountsNeedingReconciliation(List.of(firstNumber, secondNumber, thirdNumber)));
deletedAccounts.put(firstUuid, firstNumber);
deletedAccounts.put(secondUuid, secondNumber);
assertEquals(Set.of(firstNumber, secondNumber),
deletedAccounts.getAccountsNeedingReconciliation(List.of(firstNumber, secondNumber, thirdNumber)));
}
@Test
void testGetAccountsNeedingReconciliationLargeBatch() {
final int itemCount = (DeletedAccounts.GET_BATCH_SIZE * 3) + 1;
final Set<String> expectedAccountsNeedingReconciliation = new HashSet<>(itemCount);
for (int i = 0; i < itemCount; i++) {
final String e164 = String.format("+18000555%04d", i);
deletedAccounts.put(UUID.randomUUID(), e164);
expectedAccountsNeedingReconciliation.add(e164);
}
final Set<String> accountsNeedingReconciliation =
deletedAccounts.getAccountsNeedingReconciliation(expectedAccountsNeedingReconciliation);
assertEquals(expectedAccountsNeedingReconciliation, accountsNeedingReconciliation);
}
}

View File

@@ -70,7 +70,7 @@ public class DeviceTest {
@Parameters(method = "argumentsForTestIsGroupsV2Supported")
public void testIsGroupsV2Supported(final boolean master, final String apnId, final boolean gv2Capability, final boolean gv2_2Capability, final boolean gv2_3Capability, final boolean expectGv2Supported) {
final Device.DeviceCapabilities capabilities = new Device.DeviceCapabilities(gv2Capability, gv2_2Capability, gv2_3Capability, false, false, false,
false);
false, false);
final Device device = new Device(master ? 1 : 2, "test", "auth-token", "salt",
null, apnId, null, false, 1, null, 0, 0, "user-agent", 0, capabilities);

View File

@@ -1,42 +1,40 @@
package org.whispersystems.textsecuregcm.storage;
import com.amazonaws.services.appconfig.AmazonAppConfig;
import com.amazonaws.services.appconfig.model.GetConfigurationRequest;
import com.amazonaws.services.appconfig.model.GetConfigurationResult;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import java.nio.ByteBuffer;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.appconfig.AppConfigClient;
import software.amazon.awssdk.services.appconfig.model.GetConfigurationRequest;
import software.amazon.awssdk.services.appconfig.model.GetConfigurationResponse;
public class DynamicConfigurationManagerTest {
private DynamicConfigurationManager dynamicConfigurationManager;
private AmazonAppConfig appConfig;
private AppConfigClient appConfig;
@Before
public void setup() {
this.appConfig = mock(AmazonAppConfig.class);
this.appConfig = mock(AppConfigClient.class);
this.dynamicConfigurationManager = new DynamicConfigurationManager(appConfig, "foo", "bar", "baz", "poof");
}
@Test
public void testGetConfig() {
ArgumentCaptor<GetConfigurationRequest> captor = ArgumentCaptor.forClass(GetConfigurationRequest.class);
when(appConfig.getConfiguration(captor.capture())).thenReturn(new GetConfigurationResult().withContent(ByteBuffer.wrap("test: true".getBytes()))
.withConfigurationVersion("1"));
when(appConfig.getConfiguration(captor.capture())).thenReturn(
GetConfigurationResponse.builder().content(SdkBytes.fromByteArray("test: true".getBytes())).configurationVersion("1").build());
dynamicConfigurationManager.start();
assertThat(captor.getValue().getApplication()).isEqualTo("foo");
assertThat(captor.getValue().getEnvironment()).isEqualTo("bar");
assertThat(captor.getValue().getConfiguration()).isEqualTo("baz");
assertThat(captor.getValue().getClientId()).isEqualTo("poof");
assertThat(captor.getValue().application()).isEqualTo("foo");
assertThat(captor.getValue().environment()).isEqualTo("bar");
assertThat(captor.getValue().configuration()).isEqualTo("baz");
assertThat(captor.getValue().clientId()).isEqualTo("poof");
assertThat(dynamicConfigurationManager.getConfiguration()).isNotNull();
}

View File

@@ -1,6 +1,13 @@
package org.whispersystems.textsecuregcm.storage;
import com.almworks.sqlite4java.SQLite;
import com.amazonaws.ClientConfiguration;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder;
import com.amazonaws.services.dynamodbv2.local.main.ServerRunner;
import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer;
import java.net.ServerSocket;
@@ -46,6 +53,7 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
private DynamoDbClient dynamoDB2;
private DynamoDbAsyncClient dynamoAsyncDB2;
private AmazonDynamoDB legacyDynamoClient;
private DynamoDbExtension(String tableName, String hashKey, String rangeKey, List<AttributeDefinition> attributeDefinitions, List<GlobalSecondaryIndex> globalSecondaryIndexes, long readCapacityUnits,
long writeCapacityUnits) {
@@ -137,6 +145,11 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
.credentialsProvider(StaticCredentialsProvider.create(
AwsBasicCredentials.create("accessKey", "secretKey")))
.build();
legacyDynamoClient = AmazonDynamoDBClientBuilder.standard()
.withEndpointConfiguration(
new AwsClientBuilder.EndpointConfiguration("http://localhost:" + port, "local-test-region"))
.withCredentials(new AWSStaticCredentialsProvider(new BasicAWSCredentials("accessKey", "secretKey")))
.build();
}
static class DynamoDbExtensionBuilder {
@@ -194,6 +207,10 @@ public class DynamoDbExtension implements BeforeEachCallback, AfterEachCallback
return dynamoAsyncDB2;
}
public AmazonDynamoDB getLegacyDynamoClient() {
return legacyDynamoClient;
}
public String getTableName() {
return tableName;
}

View File

@@ -0,0 +1,169 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.time.Duration;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.Util;
class ManagedPeriodicWorkTest {
private ScheduledExecutorService scheduledExecutorService;
private ManagedPeriodicWorkLock lock;
private TestWork testWork;
@BeforeEach
void setup() {
scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
lock = mock(ManagedPeriodicWorkLock.class);
testWork = new TestWork(lock, Duration.ofMinutes(5), Duration.ofMinutes(5),
scheduledExecutorService);
}
@AfterEach
void teardown() throws Exception {
scheduledExecutorService.shutdown();
assertTrue(scheduledExecutorService.awaitTermination(5, TimeUnit.SECONDS));
}
@Test
void test() throws Exception {
when(lock.claimActiveWork(any(), any())).thenReturn(true);
testWork.start();
synchronized (testWork) {
Util.wait(testWork);
}
testWork.stop();
verify(lock, atLeastOnce()).claimActiveWork(anyString(), any(Duration.class));
verify(lock, atLeastOnce()).releaseActiveWork(anyString());
assertTrue(1 <= testWork.getCount());
}
@Test
void testSlowWorkShutdown() throws Exception {
when(lock.claimActiveWork(any(), any())).thenReturn(true);
testWork.setWorkSleepDuration(Duration.ofSeconds(1));
testWork.start();
synchronized (testWork) {
Util.wait(testWork);
}
long startMillis = System.currentTimeMillis();
testWork.stop();
long runMillis = System.currentTimeMillis() - startMillis;
assertTrue(runMillis > 500);
verify(lock, atLeastOnce()).claimActiveWork(anyString(), any(Duration.class));
verify(lock, atLeastOnce()).releaseActiveWork(anyString());
assertTrue(1 <= testWork.getCount());
}
@Test
void testWorkExceptionReleasesLock() throws Exception {
when(lock.claimActiveWork(any(), any())).thenReturn(true);
testWork = new ExceptionalTestWork(lock, Duration.ofMinutes(5), Duration.ofMinutes(5), scheduledExecutorService);
testWork.setSleepDurationAfterUnexpectedException(Duration.ZERO);
testWork.start();
synchronized (testWork) {
Util.wait(testWork);
}
testWork.stop();
verify(lock, atLeastOnce()).claimActiveWork(anyString(), any(Duration.class));
verify(lock, atLeastOnce()).releaseActiveWork(anyString());
assertEquals(0, testWork.getCount());
}
private static class TestWork extends ManagedPeriodicWork {
private final AtomicInteger workCounter = new AtomicInteger();
private Duration workSleepDuration = Duration.ZERO;
public TestWork(final ManagedPeriodicWorkLock lock, final Duration workerTtl, final Duration runInterval,
final ScheduledExecutorService scheduledExecutorService) {
super(lock, workerTtl, runInterval, scheduledExecutorService);
}
@Override
protected void doPeriodicWork() throws Exception {
notifyStarted();
if (!workSleepDuration.isZero()) {
Util.sleep(workSleepDuration.toMillis());
}
workCounter.incrementAndGet();
}
synchronized void notifyStarted() {
notifyAll();
}
int getCount() {
return workCounter.get();
}
void setWorkSleepDuration(final Duration workSleepDuration) {
this.workSleepDuration = workSleepDuration;
}
}
private static class ExceptionalTestWork extends TestWork {
public ExceptionalTestWork(final ManagedPeriodicWorkLock lock, final Duration workerTtl, final Duration runInterval,
final ScheduledExecutorService scheduledExecutorService) {
super(lock, workerTtl, runInterval, scheduledExecutorService);
}
@Override
protected void doPeriodicWork() throws Exception {
notifyStarted();
throw new RuntimeException();
}
}
}

View File

@@ -0,0 +1,62 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
class StoredVerificationCodeManagerTest {
private VerificationCodeStore verificationCodeStore;
private StoredVerificationCodeManager storedVerificationCodeManager;
@BeforeEach
void setUp() {
verificationCodeStore = mock(VerificationCodeStore.class);
storedVerificationCodeManager = new StoredVerificationCodeManager(verificationCodeStore);
}
@Test
void store() {
final String number = "+18005551234";
final StoredVerificationCode code = mock(StoredVerificationCode.class);
storedVerificationCodeManager.store(number, code);
verify(verificationCodeStore).insert(number, code);
}
@Test
void remove() {
final String number = "+18005551234";
storedVerificationCodeManager.remove(number);
verify(verificationCodeStore).remove(number);
}
@Test
void getCodeForNumber() {
final String number = "+18005551234";
when(verificationCodeStore.findForNumber(number)).thenReturn(Optional.empty());
assertEquals(Optional.empty(), storedVerificationCodeManager.getCodeForNumber(number));
final StoredVerificationCode storedVerificationCode = mock(StoredVerificationCode.class);
when(verificationCodeStore.findForNumber(number)).thenReturn(Optional.of(storedVerificationCode));
assertEquals(Optional.of(storedVerificationCode), storedVerificationCodeManager.getCodeForNumber(number));
}
}

View File

@@ -0,0 +1,91 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition;
import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType;
import java.util.Objects;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
class VerificationCodeStoreTest {
private VerificationCodeStore verificationCodeStore;
private static final String TABLE_NAME = "verification_code_test";
private static final String PHONE_NUMBER = "+14151112222";
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = DynamoDbExtension.builder()
.tableName(TABLE_NAME)
.hashKey(VerificationCodeStore.KEY_E164)
.attributeDefinition(AttributeDefinition.builder()
.attributeName(VerificationCodeStore.KEY_E164)
.attributeType(ScalarAttributeType.S)
.build())
.build();
@BeforeEach
void setUp() {
verificationCodeStore = new VerificationCodeStore(DYNAMO_DB_EXTENSION.getDynamoDbClient(), TABLE_NAME);
}
@Test
void testStoreAndFind() {
assertEquals(Optional.empty(), verificationCodeStore.findForNumber(PHONE_NUMBER));
final StoredVerificationCode originalCode = new StoredVerificationCode("1234", 1111, "abcd", "0987");
final StoredVerificationCode secondCode = new StoredVerificationCode("5678", 2222, "efgh", "7890");
verificationCodeStore.insert(PHONE_NUMBER, originalCode);
{
final Optional<StoredVerificationCode> maybeCode = verificationCodeStore.findForNumber(PHONE_NUMBER);
assertTrue(maybeCode.isPresent());
assertTrue(storedVerificationCodesAreEqual(originalCode, maybeCode.get()));
}
verificationCodeStore.insert(PHONE_NUMBER, secondCode);
{
final Optional<StoredVerificationCode> maybeCode = verificationCodeStore.findForNumber(PHONE_NUMBER);
assertTrue(maybeCode.isPresent());
assertTrue(storedVerificationCodesAreEqual(secondCode, maybeCode.get()));
}
}
@Test
void testRemove() {
assertEquals(Optional.empty(), verificationCodeStore.findForNumber(PHONE_NUMBER));
verificationCodeStore.insert(PHONE_NUMBER, new StoredVerificationCode("1234", 1111, "abcd", "0987"));
assertTrue(verificationCodeStore.findForNumber(PHONE_NUMBER).isPresent());
verificationCodeStore.remove(PHONE_NUMBER);
assertFalse(verificationCodeStore.findForNumber(PHONE_NUMBER).isPresent());
}
private static boolean storedVerificationCodesAreEqual(final StoredVerificationCode first, final StoredVerificationCode second) {
if (first == null && second == null) {
return true;
} else if (first == null || second == null) {
return false;
}
return Objects.equals(first.getCode(), second.getCode()) &&
first.getTimestamp() == second.getTimestamp() &&
Objects.equals(first.getPushCode(), second.getPushCode()) &&
Objects.equals(first.getTwilioVerificationSid(), second.getTwilioVerificationSid());
}
}

View File

@@ -5,104 +5,129 @@
package org.whispersystems.textsecuregcm.tests.auth;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.time.Clock;
import java.time.Instant;
import java.util.Random;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class BaseAccountAuthenticatorTest {
import java.time.Clock;
import java.time.Instant;
import java.util.Random;
import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L;
private final long currentTime = today + 68_000_000L;
class BaseAccountAuthenticatorTest {
private AccountsManager accountsManager;
private BaseAccountAuthenticator baseAccountAuthenticator;
private Clock clock;
private Account acct1;
private Account acct2;
private Account oldAccount;
private final Random random = new Random(867_5309L);
private final long today = 1590451200000L;
private final long yesterday = today - 86_400_000L;
private final long oldTime = yesterday - 86_400_000L;
private final long currentTime = today + 68_000_000L;
@Before
public void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
private AccountsManager accountsManager;
private BaseAccountAuthenticator baseAccountAuthenticator;
private Clock clock;
private Account acct1;
private Account acct2;
private Account oldAccount;
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
}
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
clock = mock(Clock.class);
baseAccountAuthenticator = new BaseAccountAuthenticator(accountsManager, clock);
@Test
public void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));
acct1 = new Account("+14088675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
acct2 = new Account("+14098675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, yesterday, 0, null, 0, null)), null);
oldAccount = new Account("+14108675309", AuthHelper.getRandomUUID(random), Set.of(new Device(1, null, null, null,
null, null, null, false, 0, null, oldTime, 0, null, 0, null)), null);
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
AccountsHelper.setupMockUpdate(accountsManager);
}
verify(accountsManager, never()).update(acct1);
verify(accountsManager).update(acct2);
@Test
void testUpdateLastSeenMiddleOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(currentTime));
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
@Test
public void testUpdateLastSeenStartOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
final Account updatedAcct1 = baseAccountAuthenticator.updateLastSeen(acct1, device1);
final Account updatedAcct2 = baseAccountAuthenticator.updateLastSeen(acct2, device2);
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager).updateDevice(eq(acct2), anyLong(), any());
verify(accountsManager, never()).update(acct1);
verify(accountsManager, never()).update(acct2);
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(today);
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(yesterday);
}
assertThat(acct1).isSameAs(updatedAcct1);
assertThat(acct2).isNotSameAs(updatedAcct2);
}
@Test
public void testUpdateLastSeenEndOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1));
@Test
void testUpdateLastSeenStartOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
baseAccountAuthenticator.updateLastSeen(acct1, acct1.getDevices().stream().findFirst().get());
baseAccountAuthenticator.updateLastSeen(acct2, acct2.getDevices().stream().findFirst().get());
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
verify(accountsManager).update(acct1);
verify(accountsManager).update(acct2);
final Account updatedAcct1 = baseAccountAuthenticator.updateLastSeen(acct1, device1);
final Account updatedAcct2 = baseAccountAuthenticator.updateLastSeen(acct2, device2);
assertThat(acct1.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
assertThat(acct2.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
verify(accountsManager, never()).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager, never()).updateDevice(eq(acct2), anyLong(), any());
@Test
public void testNeverWriteYesterday() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
assertThat(device1.getLastSeen()).isEqualTo(yesterday);
assertThat(device2.getLastSeen()).isEqualTo(yesterday);
baseAccountAuthenticator.updateLastSeen(oldAccount, oldAccount.getDevices().stream().findFirst().get());
assertThat(acct1).isSameAs(updatedAcct1);
assertThat(acct2).isSameAs(updatedAcct2);
}
verify(accountsManager).update(oldAccount);
@Test
void testUpdateLastSeenEndOfDay() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today + 86_400_000L - 1));
assertThat(oldAccount.getDevices().stream().findFirst().get().getLastSeen()).isEqualTo(today);
}
final Device device1 = acct1.getDevices().stream().findFirst().get();
final Device device2 = acct2.getDevices().stream().findFirst().get();
final Account updatedAcct1 = baseAccountAuthenticator.updateLastSeen(acct1, device1);
final Account updatedAcct2 = baseAccountAuthenticator.updateLastSeen(acct2, device2);
verify(accountsManager).updateDevice(eq(acct1), anyLong(), any());
verify(accountsManager).updateDevice(eq(acct2), anyLong(), any());
assertThat(device1.getLastSeen()).isEqualTo(today);
assertThat(device2.getLastSeen()).isEqualTo(today);
assertThat(updatedAcct1).isNotSameAs(acct1);
assertThat(updatedAcct2).isNotSameAs(acct2);
}
@Test
void testNeverWriteYesterday() {
when(clock.instant()).thenReturn(Instant.ofEpochMilli(today));
final Device device = oldAccount.getDevices().stream().findFirst().get();
baseAccountAuthenticator.updateLastSeen(oldAccount, device);
verify(accountsManager).updateDevice(eq(oldAccount), anyLong(), any());
assertThat(device.getLastSeen()).isEqualTo(today);
}
}

View File

@@ -10,7 +10,17 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@@ -41,7 +51,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
@@ -69,15 +79,14 @@ import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioVerifyExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRule;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.storage.UsernamesManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@@ -107,7 +116,7 @@ class AccountControllerTest {
private static final String VALID_CAPTCHA_TOKEN = "valid_token";
private static final String INVALID_CAPTCHA_TOKEN = "invalid_token";
private static PendingAccountsManager pendingAccountsManager = mock(PendingAccountsManager.class);
private static StoredVerificationCodeManager pendingAccountsManager = mock(StoredVerificationCodeManager.class);
private static AccountsManager accountsManager = mock(AccountsManager.class);
private static AbusiveHostRules abusiveHostRules = mock(AbusiveHostRules.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class);
@@ -118,8 +127,6 @@ class AccountControllerTest {
private static RateLimiter autoBlockLimiter = mock(RateLimiter.class);
private static RateLimiter usernameSetLimiter = mock(RateLimiter.class);
private static SmsSender smsSender = mock(SmsSender.class);
private static DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
private static MessagesManager storedMessages = mock(MessagesManager.class);
private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class);
private static Account senderPinAccount = mock(Account.class);
private static Account senderRegLockAccount = mock(Account.class);
@@ -151,8 +158,6 @@ class AccountControllerTest {
abusiveHostRules,
rateLimiters,
smsSender,
directoryQueue,
storedMessages,
dynamicConfigurationManager,
turnTokenGenerator,
new HashMap<>(),
@@ -171,6 +176,8 @@ class AccountControllerTest {
new SecureRandom().nextBytes(registration_lock_key);
AuthenticationCredentials registrationLockCredentials = new AuthenticationCredentials(Hex.toStringCondensed(registration_lock_key));
AccountsHelper.setupMockUpdate(accountsManager);
when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationDailyLimiter()).thenReturn(rateLimiter);
@@ -192,15 +199,15 @@ class AccountControllerTest {
when(senderRegLockAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
when(senderRegLockAccount.getUuid()).thenReturn(SENDER_REG_LOCK_UUID);
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis(), "1234-push")));
when(pendingAccountsManager.getCodeForNumber(SENDER_OLD)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("333333", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_REG_LOCK)).thenReturn(Optional.of(new StoredVerificationCode("666666", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OVER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("444444", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OVER_PREFIX)).thenReturn(Optional.of(new StoredVerificationCode("777777", System.currentTimeMillis(), "1234-push")));
when(pendingAccountsManager.getCodeForNumber(SENDER_PREAUTH)).thenReturn(Optional.of(new StoredVerificationCode("555555", System.currentTimeMillis(), "validchallenge")));
when(pendingAccountsManager.getCodeForNumber(SENDER_HAS_STORAGE)).thenReturn(Optional.of(new StoredVerificationCode("666666", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_TRANSFER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis(), null)));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis(), "1234-push", null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OLD)).thenReturn(Optional.empty());
when(pendingAccountsManager.getCodeForNumber(SENDER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("333333", System.currentTimeMillis(), null, null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_REG_LOCK)).thenReturn(Optional.of(new StoredVerificationCode("666666", System.currentTimeMillis(), null, null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OVER_PIN)).thenReturn(Optional.of(new StoredVerificationCode("444444", System.currentTimeMillis(), null, null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_OVER_PREFIX)).thenReturn(Optional.of(new StoredVerificationCode("777777", System.currentTimeMillis(), "1234-push", null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_PREAUTH)).thenReturn(Optional.of(new StoredVerificationCode("555555", System.currentTimeMillis(), "validchallenge", null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_HAS_STORAGE)).thenReturn(Optional.of(new StoredVerificationCode("666666", System.currentTimeMillis(), null, null)));
when(pendingAccountsManager.getCodeForNumber(SENDER_TRANSFER)).thenReturn(Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis(), null, null)));
when(accountsManager.get(eq(SENDER_PIN))).thenReturn(Optional.of(senderPinAccount));
when(accountsManager.get(eq(SENDER_REG_LOCK))).thenReturn(Optional.of(senderRegLockAccount));
@@ -211,6 +218,14 @@ class AccountControllerTest {
when(accountsManager.get(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage));
when(accountsManager.get(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer));
when(accountsManager.create(any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn(invocation.getArgument(0, String.class));
return account;
});
when(usernamesManager.put(eq(AuthHelper.VALID_UUID), eq("n00bkiller"))).thenReturn(true);
when(usernamesManager.put(eq(AuthHelper.VALID_UUID), eq("takenusername"))).thenReturn(false);
@@ -254,8 +269,6 @@ class AccountControllerTest {
autoBlockLimiter,
usernameSetLimiter,
smsSender,
directoryQueue,
storedMessages,
turnTokenGenerator,
senderPinAccount,
senderRegLockAccount,
@@ -884,7 +897,7 @@ class AccountControllerTest {
final String number = "+12345678901";
final String challenge = "challenge";
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(new StoredVerificationCode("123456", System.currentTimeMillis(), challenge)));
when(pendingAccountsManager.getCodeForNumber(number)).thenReturn(Optional.of(new StoredVerificationCode("123456", System.currentTimeMillis(), challenge, null)));
Response response =
resources.getJerseyTest()
@@ -914,23 +927,13 @@ class AccountControllerTest {
Optional.of(new StoredVerificationCode("1234", System.currentTimeMillis(), "1234-push", "VerificationSid")));;
}
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(), MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
assertThat(result.isStorageCapable()).isFalse();
final ArgumentCaptor<Account> accountArgumentCaptor = ArgumentCaptor.forClass(Account.class);
verify(accountsManager, times(1)).create(accountArgumentCaptor.capture());
verify(directoryQueue, times(1)).refreshRegisteredUser(argThat(account -> SENDER.equals(account.getNumber())));
assertThat(accountArgumentCaptor.getValue().isDiscoverableByPhoneNumber()).isTrue();
verify(accountsManager).create(eq(SENDER), eq("bar"), any(), any());
if (enrolledInVerifyExperiment) {
verify(smsSender).reportVerificationSucceeded("VerificationSid");
@@ -938,52 +941,14 @@ class AccountControllerTest {
}
@Test
void testVerifyCodeUndiscoverable() throws Exception {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, false, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
assertThat(result.isStorageCapable()).isFalse();
final ArgumentCaptor<Account> accountArgumentCaptor = ArgumentCaptor.forClass(Account.class);
verify(accountsManager, times(1)).create(accountArgumentCaptor.capture());
verify(directoryQueue, times(1)).refreshRegisteredUser(argThat(account -> SENDER.equals(account.getNumber())));
assertThat(accountArgumentCaptor.getValue().isDiscoverableByPhoneNumber()).isFalse();
}
@Test
void testVerifySupportsStorage() throws Exception {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_HAS_STORAGE, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
assertThat(result.isStorageCapable()).isTrue();
verify(accountsManager, times(1)).create(isA(Account.class));
verify(directoryQueue, times(1)).refreshRegisteredUser(argThat(account -> SENDER_HAS_STORAGE.equals(account.getNumber())));
}
@Test
void testVerifyCodeOld() throws Exception {
void testVerifyCodeOld() {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OLD, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OLD, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
@@ -991,14 +956,14 @@ class AccountControllerTest {
}
@Test
void testVerifyBadCode() throws Exception {
void testVerifyBadCode() {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1111"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "1111"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403);
@@ -1009,11 +974,11 @@ class AccountControllerTest {
void testVerifyPin() throws Exception {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "333333"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, "31337", null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
.target(String.format("/v1/accounts/code/%s", "333333"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, "31337", null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
@@ -1024,11 +989,11 @@ class AccountControllerTest {
void testVerifyRegistrationLock() throws Exception {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, Hex.toStringCondensed(registration_lock_key), true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, Hex.toStringCondensed(registration_lock_key), true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
@@ -1037,36 +1002,24 @@ class AccountControllerTest {
@Test
void testVerifyRegistrationLockSetsRegistrationLockOnNewAccount() throws Exception {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, Hex.toStringCondensed(registration_lock_key), true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, Hex.toStringCondensed(registration_lock_key), true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
verify(pinLimiter).validate(eq(SENDER_REG_LOCK));
verify(accountsManager).create(argThat(new ArgumentMatcher<>() {
@Override
public boolean matches(final Account account) {
final StoredRegistrationLock regLock = account.getRegistrationLock();
return regLock.requiresClientRegistrationLock() && regLock.verify(Hex.toStringCondensed(registration_lock_key), null);
}
@Override
public String toString() {
return "Account that has registration lock set";
}
}));
verify(accountsManager).create(eq(SENDER_REG_LOCK), eq("bar"), any(), argThat(
attributes -> Hex.toStringCondensed(registration_lock_key).equals(attributes.getRegistrationLock())));
}
@Test
void testVerifyRegistrationLockOld() throws Exception {
void testVerifyRegistrationLockOld() {
StoredRegistrationLock lock = senderRegLockAccount.getRegistrationLock();
try {
@@ -1074,11 +1027,11 @@ class AccountControllerTest {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
@@ -1092,11 +1045,11 @@ class AccountControllerTest {
void testVerifyWrongPin() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "333333"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, "31338", null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "333333"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, "31338", null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
@@ -1107,12 +1060,12 @@ class AccountControllerTest {
void testVerifyWrongRegistrationLock() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null,
Hex.toStringCondensed(new byte[32]), null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null,
Hex.toStringCondensed(new byte[32]), null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
@@ -1123,11 +1076,11 @@ class AccountControllerTest {
void testVerifyNoPin() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "333333"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "333333"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
@@ -1141,11 +1094,11 @@ class AccountControllerTest {
void testVerifyNoRegistrationLock() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "666666"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_REG_LOCK, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(423);
@@ -1164,11 +1117,11 @@ class AccountControllerTest {
void testVerifyLimitPin() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "444444"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OVER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, "31337", null, true, null),
MediaType.APPLICATION_JSON_TYPE));
.target(String.format("/v1/accounts/code/%s", "444444"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OVER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, "31337", null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(413);
@@ -1182,11 +1135,11 @@ class AccountControllerTest {
AccountCreationResult result =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "444444"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OVER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
.target(String.format("/v1/accounts/code/%s", "444444"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_OVER_PIN, "bar"))
.put(Entity.entity(new AccountAttributes(false, 3333, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE), AccountCreationResult.class);
assertThat(result.getUuid()).isNotNull();
@@ -1200,13 +1153,13 @@ class AccountControllerTest {
when(senderTransfer.isTransferSupported()).thenReturn(true);
final Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.queryParam("transfer", true)
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.queryParam("transfer", true)
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
}
@@ -1216,13 +1169,13 @@ class AccountControllerTest {
when(senderTransfer.isTransferSupported()).thenReturn(false);
final Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.queryParam("transfer", true)
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.queryParam("transfer", true)
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
}
@@ -1232,12 +1185,12 @@ class AccountControllerTest {
when(senderTransfer.isTransferSupported()).thenReturn(true);
final Response response =
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
resources.getJerseyTest()
.target(String.format("/v1/accounts/code/%s", "1234"))
.request()
.header("Authorization", AuthHelper.getAuthHeader(SENDER_TRANSFER, "bar"))
.put(Entity.entity(new AccountAttributes(false, 2222, null, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
}
@@ -1352,8 +1305,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("c00lz0rz"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
}
@Test
@@ -1368,8 +1320,7 @@ class AccountControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.DISABLED_DEVICE, times(1)).setGcmId(eq("z000"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
}
@Test
@@ -1385,8 +1336,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("second"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
}
@Test
@@ -1402,8 +1352,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("first"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(null);
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
}
@Test
@@ -1419,8 +1368,7 @@ class AccountControllerTest {
verify(AuthHelper.DISABLED_DEVICE, times(1)).setApnId(eq("third"));
verify(AuthHelper.DISABLED_DEVICE, times(1)).setVoipApnId(eq("fourth"));
verify(accountsManager, times(1)).update(eq(AuthHelper.DISABLED_ACCOUNT));
verify(directoryQueue, never()).refreshRegisteredUser(any(Account.class));
verify(accountsManager, times(1)).updateDevice(eq(AuthHelper.DISABLED_ACCOUNT), anyLong(), any());
}
@ParameterizedTest
@@ -1531,7 +1479,6 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, true, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, never()).refreshRegisteredUser(any());
}
@Test
@@ -1544,7 +1491,6 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, true, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.UNDISCOVERABLE_ACCOUNT);
}
@Test
@@ -1557,11 +1503,10 @@ class AccountControllerTest {
.put(Entity.json(new AccountAttributes(false, 2222, null, null, null, false, null)));
assertThat(response.getStatus()).isEqualTo(204);
verify(directoryQueue, times(1)).refreshRegisteredUser(AuthHelper.VALID_ACCOUNT);
}
@Test
void testDeleteAccount() {
void testDeleteAccount() throws InterruptedException {
Response response =
resources.getJerseyTest()
.target("/v1/accounts/me")
@@ -1573,6 +1518,21 @@ class AccountControllerTest {
verify(accountsManager).delete(AuthHelper.VALID_ACCOUNT, AccountsManager.DeletionReason.USER_REQUEST);
}
@Test
void testDeleteAccountInterrupted() throws InterruptedException {
doThrow(InterruptedException.class).when(accountsManager).delete(any(), any());
Response response =
resources.getJerseyTest()
.target("/v1/accounts/me")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.delete();
assertThat(response.getStatus()).isEqualTo(500);
verify(accountsManager).delete(AuthHelper.VALID_ACCOUNT, AccountsManager.DeletionReason.USER_REQUEST);
}
@ParameterizedTest
@MethodSource
void testSignupCaptcha(final String message, final boolean enforced, final Set<String> countryCodes, final int expectedResponseStatusCode) {

View File

@@ -11,7 +11,8 @@ import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
@@ -30,8 +31,8 @@ import org.assertj.core.api.Assertions;
import org.assertj.core.api.Condition;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV1;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
@@ -46,7 +47,8 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class AttachmentControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class AttachmentControllerTest {
private static RateLimiters rateLimiters = mock(RateLimiters.class );
private static RateLimiter rateLimiter = mock(RateLimiter.class );
@@ -71,12 +73,11 @@ public class AttachmentControllerTest {
}
}
@ClassRule
public static final ResourceTestRule resources;
private static final ResourceExtension resources;
static {
try {
resources = ResourceTestRule.builder()
resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
@@ -91,7 +92,7 @@ public class AttachmentControllerTest {
}
@Test
public void testV3Form() {
void testV3Form() {
AttachmentDescriptorV3 descriptor = resources.getJerseyTest()
.target("/v3/attachments/form/upload")
.request()
@@ -147,7 +148,7 @@ public class AttachmentControllerTest {
}
@Test
public void testV3FormDisabled() {
void testV3FormDisabled() {
Response response = resources.getJerseyTest()
.target("/v3/attachments/form/upload")
.request()
@@ -158,7 +159,7 @@ public class AttachmentControllerTest {
}
@Test
public void testV2Form() throws IOException {
void testV2Form() throws IOException {
AttachmentDescriptorV2 descriptor = resources.getJerseyTest()
.target("/v2/attachments/form/upload")
.request()
@@ -186,7 +187,7 @@ public class AttachmentControllerTest {
}
@Test
public void testV2FormDisabled() {
void testV2FormDisabled() {
Response response = resources.getJerseyTest()
.target("/v2/attachments/form/upload")
.request()
@@ -198,7 +199,7 @@ public class AttachmentControllerTest {
@Test
public void testAcceleratedPut() {
void testAcceleratedPut() {
AttachmentDescriptorV1 descriptor = resources.getJerseyTest()
.target("/v1/attachments/")
.request()
@@ -211,7 +212,7 @@ public class AttachmentControllerTest {
}
@Test
public void testUnacceleratedPut() {
void testUnacceleratedPut() {
AttachmentDescriptorV1 descriptor = resources.getJerseyTest()
.target("/v1/attachments/")
.request()
@@ -224,7 +225,7 @@ public class AttachmentControllerTest {
}
@Test
public void testAcceleratedGet() throws MalformedURLException {
void testAcceleratedGet() throws MalformedURLException {
AttachmentUri uri = resources.getJerseyTest()
.target("/v1/attachments/1234")
.request()
@@ -235,7 +236,7 @@ public class AttachmentControllerTest {
}
@Test
public void testUnacceleratedGet() throws MalformedURLException {
void testUnacceleratedGet() throws MalformedURLException {
AttachmentUri uri = resources.getJerseyTest()
.target("/v1/attachments/1234")
.request()

View File

@@ -5,21 +5,22 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import static junit.framework.TestCase.assertTrue;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException;
import java.util.Arrays;
import java.util.Base64;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.signal.zkgroup.InvalidInputException;
import org.signal.zkgroup.ServerSecretParams;
import org.signal.zkgroup.VerificationFailedException;
@@ -41,7 +42,8 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
public class CertificateControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class CertificateControllerTest {
private static final String caPublicKey = "BWh+UOhT1hD8bkb+MFRvb6tVqhoG8YYGCzOd7mgjo8cV";
private static final String caPrivateKey = "EO3Mnf0kfVlVnwSaqPoQnAxhnnGL1JTdXqktCKEe9Eo=";
@@ -63,17 +65,16 @@ public class CertificateControllerTest {
}
@ClassRule
public static final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true))
.build();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true))
.build();
@Test
public void testValidCertificate() throws Exception {
void testValidCertificate() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
@@ -94,11 +95,11 @@ public class CertificateControllerTest {
assertEquals(certificate.getSenderDevice(), 1L);
assertTrue(certificate.hasSenderUuid());
assertEquals(AuthHelper.VALID_UUID.toString(), certificate.getSenderUuid());
assertTrue(Arrays.equals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY)));
assertArrayEquals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
}
@Test
public void testValidCertificateWithUuid() throws Exception {
void testValidCertificateWithUuid() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.queryParam("includeUuid", "true")
@@ -119,11 +120,11 @@ public class CertificateControllerTest {
assertEquals(certificate.getSender(), AuthHelper.VALID_NUMBER);
assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertTrue(Arrays.equals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY)));
assertArrayEquals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
}
@Test
public void testValidCertificateWithUuidNoE164() throws Exception {
void testValidCertificateWithUuidNoE164() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.queryParam("includeUuid", "true")
@@ -145,11 +146,11 @@ public class CertificateControllerTest {
assertTrue(StringUtils.isBlank(certificate.getSender()));
assertEquals(certificate.getSenderDevice(), 1L);
assertEquals(certificate.getSenderUuid(), AuthHelper.VALID_UUID.toString());
assertTrue(Arrays.equals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY)));
assertArrayEquals(certificate.getIdentityKey().toByteArray(), Base64.getDecoder().decode(AuthHelper.VALID_IDENTITY));
}
@Test
public void testBadAuthentication() throws Exception {
void testBadAuthentication() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
@@ -161,7 +162,7 @@ public class CertificateControllerTest {
@Test
public void testNoAuthentication() throws Exception {
void testNoAuthentication() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
@@ -172,7 +173,7 @@ public class CertificateControllerTest {
@Test
public void testUnidentifiedAuthentication() throws Exception {
void testUnidentifiedAuthentication() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
@@ -183,7 +184,7 @@ public class CertificateControllerTest {
}
@Test
public void testDisabledAuthentication() throws Exception {
void testDisabledAuthentication() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/certificate/delivery")
.request()
@@ -194,7 +195,7 @@ public class CertificateControllerTest {
}
@Test
public void testGetSingleAuthCredential() throws InvalidInputException, VerificationFailedException {
void testGetSingleAuthCredential() throws InvalidInputException, VerificationFailedException {
GroupCredentials credentials = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + Util.currentDaysSinceEpoch())
.request()
@@ -209,7 +210,7 @@ public class CertificateControllerTest {
}
@Test
public void testGetWeekLongAuthCredentials() throws InvalidInputException, VerificationFailedException {
void testGetWeekLongAuthCredentials() throws InvalidInputException, VerificationFailedException {
GroupCredentials credentials = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + (Util.currentDaysSinceEpoch() + 7))
.request()
@@ -227,7 +228,7 @@ public class CertificateControllerTest {
}
@Test
public void testTooManyDaysOut() throws InvalidInputException {
void testTooManyDaysOut() throws InvalidInputException {
Response response = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + (Util.currentDaysSinceEpoch() + 8))
.request()
@@ -238,7 +239,7 @@ public class CertificateControllerTest {
}
@Test
public void testBackwardsInTime() throws InvalidInputException {
void testBackwardsInTime() throws InvalidInputException {
Response response = resources.getJerseyTest()
.target("/v1/certificate/group/" + (Util.currentDaysSinceEpoch() - 1) + "/" + (Util.currentDaysSinceEpoch() + 7))
.request()
@@ -249,7 +250,7 @@ public class CertificateControllerTest {
}
@Test
public void testBadAuth() throws InvalidInputException {
void testBadAuth() throws InvalidInputException {
Response response = resources.getJerseyTest()
.target("/v1/certificate/group/" + Util.currentDaysSinceEpoch() + "/" + (Util.currentDaysSinceEpoch() + 7))
.request()

View File

@@ -5,31 +5,38 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import javax.ws.rs.Path;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.controllers.DeviceController;
@@ -38,28 +45,29 @@ import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@RunWith(JUnitParamsRunner.class)
public class DeviceControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class DeviceControllerTest {
@Path("/v1/devices")
static class DumbVerificationDeviceController extends DeviceController {
public DumbVerificationDeviceController(PendingDevicesManager pendingDevices,
public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts,
MessagesManager messages,
DirectoryQueue cdsSender,
KeysDynamoDb keys,
RateLimiters rateLimiters,
Map<String, Integer> deviceConfiguration)
{
super(pendingDevices, accounts, messages, cdsSender, rateLimiters, deviceConfiguration);
super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration);
}
@Override
@@ -68,37 +76,34 @@ public class DeviceControllerTest {
}
}
private PendingDevicesManager pendingDevicesManager = mock(PendingDevicesManager.class);
private AccountsManager accountsManager = mock(AccountsManager.class );
private MessagesManager messagesManager = mock(MessagesManager.class);
private DirectoryQueue directoryQueue = mock(DirectoryQueue.class);
private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class );
private Account account = mock(Account.class );
private Account maxedAccount = mock(Account.class);
private Device masterDevice = mock(Device.class);
private static StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class);
private static AccountsManager accountsManager = mock(AccountsManager.class );
private static MessagesManager messagesManager = mock(MessagesManager.class);
private static KeysDynamoDb keys = mock(KeysDynamoDb.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class );
private static RateLimiter rateLimiter = mock(RateLimiter.class );
private static Account account = mock(Account.class );
private static Account maxedAccount = mock(Account.class);
private static Device masterDevice = mock(Device.class);
private Map<String, Integer> deviceConfiguration = new HashMap<String, Integer>() {{
private static Map<String, Integer> deviceConfiguration = new HashMap<>();
}};
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
directoryQueue,
rateLimiters,
deviceConfiguration))
.build();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
keys,
rateLimiters,
deviceConfiguration))
.build();
@Before
public void setup() throws Exception {
@BeforeEach
void setup() throws Exception {
when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter);
@@ -116,15 +121,33 @@ public class DeviceControllerTest {
when(account.isGroupsV2Supported()).thenReturn(true);
when(account.isGv1MigrationSupported()).thenReturn(true);
when(account.isSenderKeySupported()).thenReturn(true);
when(account.isAnnouncementGroupSupported()).thenReturn(true);
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(new StoredVerificationCode("1112223", System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(31), null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty());
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
AccountsHelper.setupMockUpdate(accountsManager);
}
@AfterEach
void teardown() {
reset(
pendingDevicesManager,
accountsManager,
messagesManager,
keys,
rateLimiters,
rateLimiter,
account,
maxedAccount,
masterDevice
);
}
@Test
public void validDeviceRegisterTest() throws Exception {
void validDeviceRegisterTest() throws Exception {
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
@@ -149,7 +172,7 @@ public class DeviceControllerTest {
}
@Test
public void disabledDeviceRegisterTest() throws Exception {
void disabledDeviceRegisterTest() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
@@ -160,7 +183,7 @@ public class DeviceControllerTest {
}
@Test
public void invalidDeviceRegisterTest() throws Exception {
void invalidDeviceRegisterTest() throws Exception {
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
@@ -182,7 +205,7 @@ public class DeviceControllerTest {
}
@Test
public void oldDeviceRegisterTest() throws Exception {
void oldDeviceRegisterTest() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/devices/1112223")
.request()
@@ -196,7 +219,7 @@ public class DeviceControllerTest {
}
@Test
public void maxDevicesTest() throws Exception {
void maxDevicesTest() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
@@ -208,7 +231,7 @@ public class DeviceControllerTest {
}
@Test
public void longNameTest() throws Exception {
void longNameTest() throws Exception {
Response response = resources.getJerseyTest()
.target("/v1/devices/5678901")
.request()
@@ -220,10 +243,10 @@ public class DeviceControllerTest {
verifyNoMoreInteractions(messagesManager);
}
@Test
@Parameters(method = "argumentsForDeviceDowngradeCapabilitiesTest")
public void deviceDowngradeCapabilitiesTest(final String userAgent, final boolean gv2, final boolean gv2_2, final boolean gv2_3, final int expectedStatus) throws Exception {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(gv2, gv2_2, gv2_3, true, false, true, true);
@ParameterizedTest
@MethodSource
void deviceDowngradeCapabilitiesTest(final String userAgent, final boolean gv2, final boolean gv2_2, final boolean gv2_3, final int expectedStatus) throws Exception {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(gv2, gv2_2, gv2_3, true, false, true, true, true);
AccountAttributes accountAttributes = new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
Response response = resources.getJerseyTest()
.target("/v1/devices/5678901")
@@ -239,31 +262,31 @@ public class DeviceControllerTest {
}
}
private static Object argumentsForDeviceDowngradeCapabilitiesTest() {
return new Object[] {
// User-Agent gv2 gv2-2 gv2-3 expected
new Object[] { "Signal-Android/4.68.3 Android/25", false, false, false, 409 },
new Object[] { "Signal-Android/4.68.3 Android/25", true, false, false, 409 },
new Object[] { "Signal-Android/4.68.3 Android/25", false, true, false, 409 },
new Object[] { "Signal-Android/4.68.3 Android/25", false, false, true, 200 },
new Object[] { "Signal-iOS/3.9.0", false, false, false, 409 },
new Object[] { "Signal-iOS/3.9.0", true, false, false, 409 },
new Object[] { "Signal-iOS/3.9.0", false, true, false, 200 },
new Object[] { "Signal-iOS/3.9.0", false, false, true, 200 },
new Object[] { "Signal-Desktop/1.32.0-beta.3", false, false, false, 409 },
new Object[] { "Signal-Desktop/1.32.0-beta.3", true, false, false, 409 },
new Object[] { "Signal-Desktop/1.32.0-beta.3", false, true, false, 409 },
new Object[] { "Signal-Desktop/1.32.0-beta.3", false, false, true, 200 },
new Object[] { "Old client with unparsable UA", false, false, false, 409 },
new Object[] { "Old client with unparsable UA", true, false, false, 409 },
new Object[] { "Old client with unparsable UA", false, true, false, 409 },
new Object[] { "Old client with unparsable UA", false, false, true, 409 }
};
private static Stream<Arguments> deviceDowngradeCapabilitiesTest() {
return Stream.of(
// User-Agent gv2 gv2-2 gv2-3 expected
Arguments.of( "Signal-Android/4.68.3 Android/25", false, false, false, 409 ),
Arguments.of( "Signal-Android/4.68.3 Android/25", true, false, false, 409 ),
Arguments.of( "Signal-Android/4.68.3 Android/25", false, true, false, 409 ),
Arguments.of( "Signal-Android/4.68.3 Android/25", false, false, true, 200 ),
Arguments.of( "Signal-iOS/3.9.0", false, false, false, 409 ),
Arguments.of( "Signal-iOS/3.9.0", true, false, false, 409 ),
Arguments.of( "Signal-iOS/3.9.0", false, true, false, 200 ),
Arguments.of( "Signal-iOS/3.9.0", false, false, true, 200 ),
Arguments.of( "Signal-Desktop/1.32.0-beta.3", false, false, false, 409 ),
Arguments.of( "Signal-Desktop/1.32.0-beta.3", true, false, false, 409 ),
Arguments.of( "Signal-Desktop/1.32.0-beta.3", false, true, false, 409 ),
Arguments.of( "Signal-Desktop/1.32.0-beta.3", false, false, true, 200 ),
Arguments.of( "Old client with unparsable UA", false, false, false, 409 ),
Arguments.of( "Old client with unparsable UA", true, false, false, 409 ),
Arguments.of( "Old client with unparsable UA", false, true, false, 409 ),
Arguments.of( "Old client with unparsable UA", false, false, true, 409 )
);
}
@Test
public void deviceDowngradeGv1MigrationTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true, false, false, true);
void deviceDowngradeGv1MigrationTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true, false, false, true, true);
AccountAttributes accountAttributes = new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
Response response = resources.getJerseyTest()
.target("/v1/devices/5678901")
@@ -274,7 +297,7 @@ public class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(409);
deviceCapabilities = new DeviceCapabilities(true, true, true, true, false, true, true);
deviceCapabilities = new DeviceCapabilities(true, true, true, true, false, true, true, true);
accountAttributes = new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
response = resources.getJerseyTest()
.target("/v1/devices/5678901")
@@ -287,8 +310,8 @@ public class DeviceControllerTest {
}
@Test
public void deviceDowngradeSenderKeyTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true, true, true, false);
void deviceDowngradeSenderKeyTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true, true, true, false, true);
AccountAttributes accountAttributes =
new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
Response response = resources
@@ -300,7 +323,7 @@ public class DeviceControllerTest {
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
deviceCapabilities = new DeviceCapabilities(true, true, true, true, true, true, true);
deviceCapabilities = new DeviceCapabilities(true, true, true, true, true, true, true, true);
accountAttributes = new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
response = resources
.getJerseyTest()
@@ -311,4 +334,57 @@ public class DeviceControllerTest {
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
}
@Test
void deviceDowngradeAnnouncementGroupTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true, true, true, true, false);
AccountAttributes accountAttributes =
new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
Response response = resources
.getJerseyTest()
.target("/v1/devices/5678901")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.42.8675309 Android/30")
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
deviceCapabilities = new DeviceCapabilities(true, true, true, true, true, true, true, true);
accountAttributes = new AccountAttributes(false, 1234, null, null, null, true, deviceCapabilities);
response = resources
.getJerseyTest()
.target("/v1/devices/5678901")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.42.8675309 Android/30")
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
}
@Test
void deviceRemovalClearsMessagesAndKeys() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
final long deviceId = 2;
final Response response = resources
.getJerseyTest()
.target("/v1/devices/" + deviceId)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.42.8675309 Android/30")
.delete();
assertThat(response.getStatus()).isEqualTo(204);
verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId);
verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any());
verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId);
// The account instance may have changed as part of a call to `AccountManager#update`
verify(keys).delete(argThat(account -> account.getUuid().equals(AuthHelper.VALID_UUID)), eq(deviceId));
}
}

View File

@@ -5,13 +5,24 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.Collections;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status.Family;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
@@ -19,37 +30,26 @@ import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status.Family;
import java.util.Collections;
@ExtendWith(DropwizardExtensionsSupport.class)
class DirectoryControllerTest {
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
private static final ExternalServiceCredentialGenerator directoryCredentialsGenerator = mock(ExternalServiceCredentialGenerator.class);
private static final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
public class DirectoryControllerTest {
private final ExternalServiceCredentialGenerator directoryCredentialsGenerator = mock(ExternalServiceCredentialGenerator.class);
private final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DirectoryController(directoryCredentialsGenerator))
.build();
@Before
public void setup() {
@BeforeEach
void setup() {
when(directoryCredentialsGenerator.generateFor(eq(AuthHelper.VALID_NUMBER))).thenReturn(validCredentials);
}
@Test
public void testFeedbackOk() {
void testFeedbackOk() {
Response response =
resources.getJerseyTest()
.target("/v1/directory/feedback-v3/ok")
@@ -60,7 +60,7 @@ public class DirectoryControllerTest {
}
@Test
public void testGetAuthToken() {
void testGetAuthToken() {
ExternalServiceCredentials token =
resources.getJerseyTest()
.target("/v1/directory/auth")
@@ -72,7 +72,7 @@ public class DirectoryControllerTest {
}
@Test
public void testDisabledGetAuthToken() {
void testDisabledGetAuthToken() {
Response response =
resources.getJerseyTest()
.target("/v1/directory/auth")
@@ -84,7 +84,7 @@ public class DirectoryControllerTest {
@Test
public void testContactIntersection() {
void testContactIntersection() {
Response response =
resources.getJerseyTest()
.target("/v1/directory/tokens/")

View File

@@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
@@ -15,6 +16,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.AccountsHelper.eqUuid;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
@@ -55,12 +57,12 @@ import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -89,7 +91,6 @@ class KeysControllerTest {
private final static KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
private final static DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final static PreKeyRateLimiter preKeyRateLimiter = mock(PreKeyRateLimiter.class );
private final static RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class );
private final static Account existsAccount = mock(Account.class );
@@ -104,7 +105,7 @@ class KeysControllerTest {
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager))
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager))
.build();
@BeforeEach
@@ -114,13 +115,15 @@ class KeysControllerTest {
final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
Set<Device> allDevices = new HashSet<Device>() {{
Set<Device> allDevices = new HashSet<>() {{
add(sampleDevice);
add(sampleDevice2);
add(sampleDevice3);
add(sampleDevice4);
}};
AccountsHelper.setupMockUpdate(accounts);
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
@@ -142,7 +145,7 @@ class KeysControllerTest {
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice(22L)).thenReturn(Optional.<Device>empty());
when(existsAccount.getDevice(22L)).thenReturn(Optional.empty());
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isEnabled()).thenReturn(true);
when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey");
@@ -179,7 +182,6 @@ class KeysControllerTest {
reset(
keysDynamoDb,
accounts,
directoryQueue,
preKeyRateLimiter,
existsAccount,
rateLimiters,
@@ -256,7 +258,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
}
@Test
@@ -271,7 +273,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyLong(), any());
}
@@ -578,7 +580,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keysDynamoDb).store(eq(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
verify(keysDynamoDb).store(eqUuid(AuthHelper.VALID_ACCOUNT), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -587,7 +589,7 @@ class KeysControllerTest {
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.VALID_ACCOUNT);
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any());
}
@Test
@@ -612,7 +614,7 @@ class KeysControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keysDynamoDb).store(eq(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
verify(keysDynamoDb).store(eqUuid(AuthHelper.DISABLED_ACCOUNT), eq(1L), listCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue();
assertThat(capturedList.size()).isEqualTo(1);
@@ -621,7 +623,7 @@ class KeysControllerTest {
verify(AuthHelper.DISABLED_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.DISABLED_ACCOUNT);
verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any());
}
@Test

View File

@@ -70,6 +70,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
@@ -157,19 +158,19 @@ class MessageControllerTest {
Set<Device> singleDeviceList = new HashSet<Device>() {{
add(new Device(1, null, "foo", "bar",
"isgcm", null, null, false, 111, new SignedPreKey(333, "baz", "boop"), System.currentTimeMillis(), System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(true, false, false, true, true, false,
false)));
false, false)));
}};
Set<Device> multiDeviceList = new HashSet<Device>() {{
add(new Device(1, null, "foo", "bar",
"isgcm", null, null, false, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis(), System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(true, false, false, true, false, false,
false)));
false, false)));
add(new Device(2, null, "foo", "bar",
"isgcm", null, null, false, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis(), System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(true, false, false, true, false, false,
false)));
false, false)));
add(new Device(3, null, "foo", "bar",
"isgcm", null, null, false, 444, null, System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31), System.currentTimeMillis(), "Test", 0, new Device.DeviceCapabilities(false, false, false, false, false, false,
false)));
false, false)));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, singleDeviceList, "1234".getBytes());
@@ -246,6 +247,19 @@ class MessageControllerTest {
assertTrue(captor.getValue().hasSourceDevice());
}
@Test
void testNullMessageInList() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", SINGLE_DEVICE_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_null_message_in_list.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertThat("Bad request", response.getStatus(), is(equalTo(422)));
}
@Test
void testInternationalUnsealedSenderFromRateLimitedHost() throws Exception {
final String senderHost = "10.0.0.1";
@@ -484,6 +498,7 @@ class MessageControllerTest {
OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList);
when(AuthHelper.VALID_ACCOUNT.getRegistrationLock()).thenReturn(mock(StoredRegistrationLock.class));
OutgoingMessageEntityList response =
resources.getJerseyTest().target("/v1/messages/")

View File

@@ -5,11 +5,23 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
@@ -20,42 +32,30 @@ import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import javax.ws.rs.core.Response;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class PaymentsControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class PaymentsControllerTest {
private static final ExternalServiceCredentialGenerator paymentsCredentialGenerator = mock(ExternalServiceCredentialGenerator.class);
private static final CurrencyConversionManager currencyManager = mock(CurrencyConversionManager.class);
private final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
@ClassRule
public static final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator))
.build();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator))
.build();
@Before
public void setup() {
@BeforeEach
void setup() {
when(paymentsCredentialGenerator.generateFor(eq(AuthHelper.VALID_UUID.toString()))).thenReturn(validCredentials);
when(currencyManager.getCurrencyConversions()).thenReturn(Optional.of(new CurrencyConversionEntityList(List.of(new CurrencyConversionEntity("FOO", Map.of("USD", 2.35, "EUR", 1.89)), new CurrencyConversionEntity("BAR", Map.of("USD", 1.50, "EUR", 0.98))), System.currentTimeMillis())));
}
@Test
public void testGetAuthToken() {
void testGetAuthToken() {
ExternalServiceCredentials token =
resources.getJerseyTest()
.target("/v1/payments/auth")
@@ -68,7 +68,7 @@ public class PaymentsControllerTest {
}
@Test
public void testInvalidAuthGetAuthToken() {
void testInvalidAuthGetAuthToken() {
Response response =
resources.getJerseyTest()
.target("/v1/payments/auth")
@@ -80,7 +80,7 @@ public class PaymentsControllerTest {
}
@Test
public void testDisabledGetAuthToken() {
void testDisabledGetAuthToken() {
Response response =
resources.getJerseyTest()
.target("/v1/payments/auth")
@@ -91,7 +91,7 @@ public class PaymentsControllerTest {
}
@Test
public void testGetCurrencyConversions() {
void testGetCurrencyConversions() {
CurrencyConversionEntityList conversions =
resources.getJerseyTest()
.target("/v1/payments/conversions")

Some files were not shown because too many files have changed in this diff Show More