Compare commits

...

24 Commits
v0.16 ... v0.24

Author SHA1 Message Date
Moxie Marlinspike
8a2131416d Bump version to 0.24
// FREEBIE
2014-11-27 16:24:27 -08:00
Moxie Marlinspike
2525304215 Account for websocket-resources changes.
// FREEBIE
2014-11-15 09:48:09 -08:00
Moxie Marlinspike
fdb35d4f77 Switch to WebSocket-Resources
// FREEBIE
2014-11-14 17:59:50 -08:00
Moxie Marlinspike
222c7ea641 Support for signature token based account verification. 2014-11-13 14:56:24 -08:00
Moxie Marlinspike
8f2722263f Bump version to 0.23 2014-11-04 19:33:07 -08:00
Moxie Marlinspike
fd662e3401 Add vacuum command.
// FREEBIE
2014-11-04 19:32:35 -08:00
Moxie Marlinspike
bc65461ecb Bump version to 0.22 2014-10-01 15:03:25 -07:00
Moxie Marlinspike
30017371df Reconnect even when Smack thinks it doesn't need to. 2014-10-01 14:07:12 -07:00
Moxie Marlinspike
b944b86bf8 Bump version to 0.21
// FREEBIE
2014-07-30 11:45:45 -07:00
Moxie Marlinspike
6ba8352fa6 Update sample config to include GCM senderId
// FREEBIE
2014-07-30 11:38:23 -07:00
Moxie Marlinspike
aadf76692e Bump version to 0.20
// FREEBIE
2014-07-30 11:36:54 -07:00
Moxie Marlinspike
c9a1386a55 Fix for PubSub channel.
1) Create channels based on numbers rather than DB row ids.

2) Ensure that stored messages are cleared at reregistration
   time.
2014-07-26 20:41:25 -07:00
Moxie Marlinspike
4eb88a3e02 Server side support for delivery receipts. 2014-07-25 15:48:34 -07:00
Moxie Marlinspike
160c0bfe14 Switch from Java serialization to JSON for memcache storage. 2014-07-23 18:02:35 -07:00
Moxie Marlinspike
4cd098af1d Switch to GCM CCS and add support for APN feedback processing. 2014-07-23 18:00:49 -07:00
Moxie Marlinspike
362abd618f Bump version to 0.19
// FREEBIE
2014-07-21 01:20:57 -07:00
Moxie Marlinspike
69de9f6684 Fix stored message retrieval.
// FREEBIE
2014-07-21 01:20:14 -07:00
Moxie Marlinspike
2aa379bf21 Bumping version to 0.18
// FREEBIE
2014-07-17 11:05:38 -07:00
Moxie Marlinspike
820a2f1a63 Break FederationController into V1 and V2 2014-07-16 17:24:01 -07:00
Moxie Marlinspike
6fac7614f5 Allow device to query their currently stored signed prekey. 2014-07-16 14:44:00 -07:00
Moxie Marlinspike
b724ea8d3b Renamed 'device key' to 'signed prekey'. 2014-07-11 10:37:19 -07:00
Moxie Marlinspike
06f80c320d Introduce V2 API for PreKey updates and requests.
1) A /v2/keys controller.

2) Separate wire protocol PreKey POJOs from database PreKey
   objects.

3) Separate wire protocol PreKey submission and response POJOs.

4) Introduce a new update/response JSON format for /v2/keys.
2014-07-10 18:06:45 -07:00
Moxie Marlinspike
d9de015eab Bump version to 0.17 2014-07-10 17:45:11 -07:00
Moxie Marlinspike
dd36c861ba Pipeline directory update redis flow for a 10x speedup. 2014-07-10 17:31:39 -07:00
73 changed files with 4307 additions and 1809 deletions

View File

@@ -15,6 +15,7 @@ nexmo:
number: number:
gcm: gcm:
senderId:
apiKey: apiKey:
# Optional. Only if iOS clients are supported. # Optional. Only if iOS clients are supported.

14
pom.xml
View File

@@ -9,7 +9,7 @@
<groupId>org.whispersystems.textsecure</groupId> <groupId>org.whispersystems.textsecure</groupId>
<artifactId>TextSecureServer</artifactId> <artifactId>TextSecureServer</artifactId>
<version>0.16</version> <version>0.24</version>
<properties> <properties>
<dropwizard.version>0.7.0</dropwizard.version> <dropwizard.version>0.7.0</dropwizard.version>
@@ -104,7 +104,7 @@
<dependency> <dependency>
<groupId>com.google.protobuf</groupId> <groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId> <artifactId>protobuf-java</artifactId>
<version>2.4.1</version> <version>2.5.0</version>
</dependency> </dependency>
<dependency> <dependency>
@@ -125,6 +125,16 @@
<artifactId>postgresql</artifactId> <artifactId>postgresql</artifactId>
<version>9.1-901.jdbc4</version> <version>9.1-901.jdbc4</version>
</dependency> </dependency>
<dependency>
<groupId>org.igniterealtime.smack</groupId>
<artifactId>smack-tcp</artifactId>
<version>4.0.0</version>
</dependency>
<dependency>
<groupId>org.whispersystems.websocket</groupId>
<artifactId>websocket-resources</artifactId>
<version>0.1-SNAPSHOT</version>
</dependency>
</dependencies> </dependencies>

View File

@@ -20,6 +20,15 @@ option java_package = "org.whispersystems.textsecuregcm.entities";
option java_outer_classname = "MessageProtos"; option java_outer_classname = "MessageProtos";
message OutgoingMessageSignal { message OutgoingMessageSignal {
enum Type {
UNKNOWN = 0;
CIPHERTEXT = 1;
KEY_EXCHANGE = 2;
PREKEY_BUNDLE = 3;
PLAINTEXT = 4;
RECEIPT = 5;
}
optional uint32 type = 1; optional uint32 type = 1;
optional string source = 2; optional string source = 2;
optional uint32 sourceDevice = 7; optional uint32 sourceDevice = 7;

View File

@@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.configuration.MemcacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.MetricsConfiguration; import org.whispersystems.textsecuregcm.configuration.MetricsConfiguration;
import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration; import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedPhoneConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisConfiguration;
import org.whispersystems.textsecuregcm.configuration.S3Configuration; import org.whispersystems.textsecuregcm.configuration.S3Configuration;
import org.whispersystems.textsecuregcm.configuration.TwilioConfiguration; import org.whispersystems.textsecuregcm.configuration.TwilioConfiguration;
@@ -47,6 +48,7 @@ public class WhisperServerConfiguration extends Configuration {
private NexmoConfiguration nexmo; private NexmoConfiguration nexmo;
@NotNull @NotNull
@Valid
@JsonProperty @JsonProperty
private GcmConfiguration gcm; private GcmConfiguration gcm;
@@ -94,6 +96,9 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty @JsonProperty
private WebsocketConfiguration websocket = new WebsocketConfiguration(); private WebsocketConfiguration websocket = new WebsocketConfiguration();
@JsonProperty
private RedPhoneConfiguration redphone = new RedPhoneConfiguration();
public WebsocketConfiguration getWebsocketConfiguration() { public WebsocketConfiguration getWebsocketConfiguration() {
return websocket; return websocket;
} }
@@ -145,4 +150,8 @@ public class WhisperServerConfiguration extends Configuration {
public MetricsConfiguration getMetricsConfiguration() { public MetricsConfiguration getMetricsConfiguration() {
return viz; return viz;
} }
public RedPhoneConfiguration getRedphoneConfiguration() {
return redphone;
}
} }

View File

@@ -32,9 +32,12 @@ import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.AttachmentController; import org.whispersystems.textsecuregcm.controllers.AttachmentController;
import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.controllers.DirectoryController; import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.controllers.FederationController; import org.whispersystems.textsecuregcm.controllers.FederationControllerV1;
import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.FederationControllerV2;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV1;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV2;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ReceiptController;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.FederatedPeer; import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@@ -49,7 +52,11 @@ import org.whispersystems.textsecuregcm.providers.MemcacheHealthCheck;
import org.whispersystems.textsecuregcm.providers.MemcachedClientFactory; import org.whispersystems.textsecuregcm.providers.MemcachedClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory; import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.providers.RedisHealthCheck; import org.whispersystems.textsecuregcm.providers.RedisHealthCheck;
import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.sms.NexmoSmsSender; import org.whispersystems.textsecuregcm.sms.NexmoSmsSender;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
@@ -66,8 +73,12 @@ import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.UrlSigner; import org.whispersystems.textsecuregcm.util.UrlSigner;
import org.whispersystems.textsecuregcm.websocket.WebsocketControllerFactory; import org.whispersystems.textsecuregcm.websocket.ConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.DirectoryCommand; import org.whispersystems.textsecuregcm.workers.DirectoryCommand;
import org.whispersystems.textsecuregcm.workers.VacuumCommand;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
import javax.servlet.DispatcherType; import javax.servlet.DispatcherType;
import javax.servlet.FilterRegistration; import javax.servlet.FilterRegistration;
@@ -95,6 +106,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
@Override @Override
public void initialize(Bootstrap<WhisperServerConfiguration> bootstrap) { public void initialize(Bootstrap<WhisperServerConfiguration> bootstrap) {
bootstrap.addCommand(new DirectoryCommand()); bootstrap.addCommand(new DirectoryCommand());
bootstrap.addCommand(new VacuumCommand());
bootstrap.addBundle(new MigrationsBundle<WhisperServerConfiguration>() { bootstrap.addBundle(new MigrationsBundle<WhisperServerConfiguration>() {
@Override @Override
public DataSourceFactory getDataSourceFactory(WhisperServerConfiguration configuration) { public DataSourceFactory getDataSourceFactory(WhisperServerConfiguration configuration) {
@@ -134,6 +146,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
StoredMessages storedMessages = new StoredMessages(redisClient); StoredMessages storedMessages = new StoredMessages(redisClient);
PubSubManager pubSubManager = new PubSubManager(redisClient); PubSubManager pubSubManager = new PubSubManager(redisClient);
APNSender apnSender = new APNSender(accountsManager, pubSubManager, storedMessages, memcachedClient,
config.getApnConfiguration().getCertificate(),
config.getApnConfiguration().getKey());
GCMSender gcmSender = new GCMSender(accountsManager,
config.getGcmConfiguration().getSenderId(),
config.getGcmConfiguration().getApiKey());
WebsocketSender websocketSender = new WebsocketSender(storedMessages, pubSubManager);
environment.lifecycle().manage(apnSender);
environment.lifecycle().manage(gcmSender);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager); AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient); RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient);
@@ -141,13 +166,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
Optional<NexmoSmsSender> nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration()); Optional<NexmoSmsSender> nexmoSmsSender = initializeNexmoSmsSender(config.getNexmoConfiguration());
SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational()); SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational());
UrlSigner urlSigner = new UrlSigner(config.getS3Configuration()); UrlSigner urlSigner = new UrlSigner(config.getS3Configuration());
PushSender pushSender = new PushSender(config.getGcmConfiguration(), PushSender pushSender = new PushSender(gcmSender, apnSender, websocketSender);
config.getApnConfiguration(), Optional<byte[]> authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey();
storedMessages, pubSubManager,
accountsManager);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner); AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysController keysController = new KeysController(rateLimiters, keys, accountsManager, federatedClientManager); KeysControllerV1 keysControllerV1 = new KeysControllerV1(rateLimiters, keys, accountsManager, federatedClientManager);
KeysControllerV2 keysControllerV2 = new KeysControllerV2(rateLimiters, keys, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager); MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
environment.jersey().register(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()), environment.jersey().register(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()),
@@ -155,19 +179,23 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
deviceAuthenticator, deviceAuthenticator,
Device.class, "WhisperServer")); Device.class, "WhisperServer"));
environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender)); environment.jersey().register(new AccountController(pendingAccountsManager, accountsManager, rateLimiters, smsSender, storedMessages, new TimeProvider(), authorizationKey));
environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters)); environment.jersey().register(new DeviceController(pendingDevicesManager, accountsManager, rateLimiters));
environment.jersey().register(new DirectoryController(rateLimiters, directory)); environment.jersey().register(new DirectoryController(rateLimiters, directory));
environment.jersey().register(new FederationController(accountsManager, attachmentController, keysController, messageController)); environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));
environment.jersey().register(new FederationControllerV2(accountsManager, attachmentController, messageController, keysControllerV2));
environment.jersey().register(new ReceiptController(accountsManager, federatedClientManager, pushSender));
environment.jersey().register(attachmentController); environment.jersey().register(attachmentController);
environment.jersey().register(keysController); environment.jersey().register(keysControllerV1);
environment.jersey().register(keysControllerV2);
environment.jersey().register(messageController); environment.jersey().register(messageController);
if (config.getWebsocketConfiguration().isEnabled()) { if (config.getWebsocketConfiguration().isEnabled()) {
WebsocketControllerFactory servlet = new WebsocketControllerFactory(deviceAuthenticator, WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment(environment);
pushSender, webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(deviceAuthenticator));
storedMessages, webSocketEnvironment.setConnectListener(new ConnectListener(accountsManager, pushSender, storedMessages, pubSubManager));
pubSubManager);
WebSocketResourceProviderFactory servlet = new WebSocketResourceProviderFactory(webSocketEnvironment);
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", servlet); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", servlet);
websocket.addMapping("/v1/websocket/*"); websocket.addMapping("/v1/websocket/*");

View File

@@ -0,0 +1,76 @@
package org.whispersystems.textsecuregcm.auth;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.Util;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.concurrent.TimeUnit;
public class AuthorizationToken {
private final Logger logger = LoggerFactory.getLogger(AuthorizationToken.class);
private final String token;
private final byte[] key;
public AuthorizationToken(String token, byte[] key) {
this.token = token;
this.key = key;
}
public boolean isValid(String number, long currentTimeMillis) {
String[] parts = token.split(":");
if (parts.length != 3) {
return false;
}
if (!number.equals(parts[0])) {
return false;
}
if (!isValidTime(parts[1], currentTimeMillis)) {
return false;
}
return isValidSignature(parts[0] + ":" + parts[1], parts[2]);
}
private boolean isValidTime(String timeString, long currentTimeMillis) {
try {
long tokenTime = Long.parseLong(timeString);
long ourTime = TimeUnit.MILLISECONDS.toSeconds(currentTimeMillis);
return TimeUnit.SECONDS.toHours(Math.abs(ourTime - tokenTime)) < 24;
} catch (NumberFormatException e) {
logger.warn("Number Format", e);
return false;
}
}
private boolean isValidSignature(String prefix, String suffix) {
try {
Mac hmac = Mac.getInstance("HmacSHA256");
hmac.init(new SecretKeySpec(key, "HmacSHA256"));
byte[] ourSuffix = Util.truncate(hmac.doFinal(prefix.getBytes()), 10);
byte[] theirSuffix = Hex.decodeHex(suffix.toCharArray());
return MessageDigest.isEqual(ourSuffix, theirSuffix);
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
} catch (DecoderException e) {
logger.warn("Authorizationtoken", e);
return false;
}
}
}

View File

@@ -19,8 +19,14 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty; import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
public class GcmConfiguration { public class GcmConfiguration {
@NotNull
@JsonProperty
private long senderId;
@NotEmpty @NotEmpty
@JsonProperty @JsonProperty
private String apiKey; private String apiKey;
@@ -28,4 +34,8 @@ public class GcmConfiguration {
public String getApiKey() { public String getApiKey() {
return apiKey; return apiKey;
} }
public long getSenderId() {
return senderId;
}
} }

View File

@@ -0,0 +1,20 @@
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Optional;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
public class RedPhoneConfiguration {
@JsonProperty
private String authKey;
public Optional<byte[]> getAuthorizationKey() throws DecoderException {
if (authKey == null || authKey.trim().length() == 0) {
return Optional.absent();
}
return Optional.of(Hex.decodeHex(authKey.toCharArray()));
}
}

View File

@@ -27,15 +27,19 @@ import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.auth.AuthorizationToken;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@@ -65,16 +69,25 @@ public class AccountController {
private final AccountsManager accounts; private final AccountsManager accounts;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final SmsSender smsSender; private final SmsSender smsSender;
private final StoredMessages storedMessages;
private final TimeProvider timeProvider;
private final Optional<byte[]> authorizationKey;
public AccountController(PendingAccountsManager pendingAccounts, public AccountController(PendingAccountsManager pendingAccounts,
AccountsManager accounts, AccountsManager accounts,
RateLimiters rateLimiters, RateLimiters rateLimiters,
SmsSender smsSenderFactory) SmsSender smsSenderFactory,
StoredMessages storedMessages,
TimeProvider timeProvider,
Optional<byte[]> authorizationKey)
{ {
this.pendingAccounts = pendingAccounts; this.pendingAccounts = pendingAccounts;
this.accounts = accounts; this.accounts = accounts;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.smsSender = smsSenderFactory; this.smsSender = smsSenderFactory;
this.storedMessages = storedMessages;
this.timeProvider = timeProvider;
this.authorizationKey = authorizationKey;
} }
@Timed @Timed
@@ -140,30 +153,46 @@ public class AccountController {
throw new WebApplicationException(Response.status(417).build()); throw new WebApplicationException(Response.status(417).build());
} }
Device device = new Device(); createAccount(number, password, accountAttributes);
device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
Account account = new Account();
account.setNumber(number);
account.setSupportsSms(accountAttributes.getSupportsSms());
account.addDevice(device);
accounts.create(account);
pendingAccounts.remove(number);
logger.debug("Stored device...");
} catch (InvalidAuthorizationHeaderException e) { } catch (InvalidAuthorizationHeaderException e) {
logger.info("Bad Authorization Header", e); logger.info("Bad Authorization Header", e);
throw new WebApplicationException(Response.status(401).build()); throw new WebApplicationException(Response.status(401).build());
} }
} }
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Path("/token/{verification_token}")
public void verifyToken(@PathParam("verification_token") String verificationToken,
@HeaderParam("Authorization") String authorizationHeader,
@Valid AccountAttributes accountAttributes)
throws RateLimitExceededException
{
try {
AuthorizationHeader header = AuthorizationHeader.fromFullHeader(authorizationHeader);
String number = header.getNumber();
String password = header.getPassword();
rateLimiters.getVerifyLimiter().validate(number);
if (!authorizationKey.isPresent()) {
logger.debug("Attempt to authorize with key but not configured...");
throw new WebApplicationException(Response.status(403).build());
}
AuthorizationToken token = new AuthorizationToken(verificationToken, authorizationKey.get());
if (!token.isValid(number, timeProvider.getCurrentTimeMillis())) {
throw new WebApplicationException(Response.status(403).build());
}
createAccount(number, password, accountAttributes);
} catch (InvalidAuthorizationHeaderException e) {
logger.info("Bad authorization header", e);
throw new WebApplicationException(Response.status(401).build());
}
}
@Timed @Timed
@PUT @PUT
@@ -214,6 +243,26 @@ public class AccountController {
encodedVerificationText)).build(); encodedVerificationText)).build();
} }
private void createAccount(String number, String password, AccountAttributes accountAttributes) {
Device device = new Device();
device.setId(Device.MASTER_ID);
device.setAuthenticationCredentials(new AuthenticationCredentials(password));
device.setSignalingKey(accountAttributes.getSignalingKey());
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
Account account = new Account();
account.setNumber(number);
account.setSupportsSms(accountAttributes.getSupportsSms());
account.addDevice(device);
accounts.create(account);
storedMessages.clear(new WebsocketAddress(number, Device.MASTER_ID));
pendingAccounts.remove(number);
logger.debug("Stored device...");
}
@VisibleForTesting protected VerificationCode generateVerificationCode() { @VisibleForTesting protected VerificationCode generateVerificationCode() {
try { try {
SecureRandom random = SecureRandom.getInstance("SHA1PRNG"); SecureRandom random = SecureRandom.getInstance("SHA1PRNG");

View File

@@ -1,168 +1,19 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.Valid;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.Auth;
@Path("/v1/federation")
public class FederationController { public class FederationController {
private final Logger logger = LoggerFactory.getLogger(FederationController.class); protected final AccountsManager accounts;
protected final AttachmentController attachmentController;
private static final int ACCOUNT_CHUNK_SIZE = 10000; protected final MessageController messageController;
private final AccountsManager accounts;
private final AttachmentController attachmentController;
private final KeysController keysController;
private final MessageController messageController;
public FederationController(AccountsManager accounts, public FederationController(AccountsManager accounts,
AttachmentController attachmentController, AttachmentController attachmentController,
KeysController keysController,
MessageController messageController) MessageController messageController)
{ {
this.accounts = accounts; this.accounts = accounts;
this.attachmentController = attachmentController; this.attachmentController = attachmentController;
this.keysController = keysController;
this.messageController = messageController; this.messageController = messageController;
} }
@Timed
@GET
@Path("/attachment/{attachmentId}")
@Produces(MediaType.APPLICATION_JSON)
public AttachmentUri getSignedAttachmentUri(@Auth FederatedPeer peer,
@PathParam("attachmentId") long attachmentId)
throws IOException
{
return attachmentController.redirectToAttachment(new NonLimitedAccount("Unknown", -1, peer.getName()),
attachmentId, Optional.<String>absent());
}
@Timed
@GET
@Path("/key/{number}")
@Produces(MediaType.APPLICATION_JSON)
public PreKey getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
throws IOException
{
try {
return keysController.get(new NonLimitedAccount("Unknown", -1, peer.getName()), number, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getKeys(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysController.getDeviceKey(new NonLimitedAccount("Unknown", -1, peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@PUT
@Path("/messages/{source}/{sourceDeviceId}/{destination}")
public void sendMessages(@Auth FederatedPeer peer,
@PathParam("source") String source,
@PathParam("sourceDeviceId") long sourceDeviceId,
@PathParam("destination") String destination,
@Valid IncomingMessageList messages)
throws IOException
{
try {
messages.setRelay(null);
messageController.sendMessage(new NonLimitedAccount(source, sourceDeviceId, peer.getName()), destination, messages);
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/user_count")
@Produces(MediaType.APPLICATION_JSON)
public AccountCount getUserCount(@Auth FederatedPeer peer) {
return new AccountCount((int)accounts.getCount());
}
@Timed
@GET
@Path("/user_tokens/{offset}")
@Produces(MediaType.APPLICATION_JSON)
public ClientContacts getUserTokens(@Auth FederatedPeer peer,
@PathParam("offset") int offset)
{
List<Account> accountList = accounts.getAll(offset, ACCOUNT_CHUNK_SIZE);
List<ClientContact> clientContacts = new LinkedList<>();
for (Account account : accountList) {
byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
if (!account.isActive()) {
clientContact.setInactive(true);
}
clientContacts.add(clientContact);
}
return new ClientContacts(clientContacts);
}
} }

View File

@@ -0,0 +1,164 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.AccountCount;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.Valid;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.Auth;
@Path("/v1/federation")
public class FederationControllerV1 extends FederationController {
private final Logger logger = LoggerFactory.getLogger(FederationControllerV1.class);
private static final int ACCOUNT_CHUNK_SIZE = 10000;
private final KeysControllerV1 keysControllerV1;
public FederationControllerV1(AccountsManager accounts,
AttachmentController attachmentController,
MessageController messageController,
KeysControllerV1 keysControllerV1)
{
super(accounts, attachmentController, messageController);
this.keysControllerV1 = keysControllerV1;
}
@Timed
@GET
@Path("/attachment/{attachmentId}")
@Produces(MediaType.APPLICATION_JSON)
public AttachmentUri getSignedAttachmentUri(@Auth FederatedPeer peer,
@PathParam("attachmentId") long attachmentId)
throws IOException
{
return attachmentController.redirectToAttachment(new NonLimitedAccount("Unknown", -1, peer.getName()),
attachmentId, Optional.<String>absent());
}
@Timed
@GET
@Path("/key/{number}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyV1> getKey(@Auth FederatedPeer peer,
@PathParam("number") String number)
throws IOException
{
try {
return keysControllerV1.get(new NonLimitedAccount("Unknown", -1, peer.getName()),
number, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponseV1> getKeysV1(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysControllerV1.getDeviceKey(new NonLimitedAccount("Unknown", -1, peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@PUT
@Path("/messages/{source}/{sourceDeviceId}/{destination}")
public void sendMessages(@Auth FederatedPeer peer,
@PathParam("source") String source,
@PathParam("sourceDeviceId") long sourceDeviceId,
@PathParam("destination") String destination,
@Valid IncomingMessageList messages)
throws IOException
{
try {
messages.setRelay(null);
messageController.sendMessage(new NonLimitedAccount(source, sourceDeviceId, peer.getName()), destination, messages);
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
@Timed
@GET
@Path("/user_count")
@Produces(MediaType.APPLICATION_JSON)
public AccountCount getUserCount(@Auth FederatedPeer peer) {
return new AccountCount((int)accounts.getCount());
}
@Timed
@GET
@Path("/user_tokens/{offset}")
@Produces(MediaType.APPLICATION_JSON)
public ClientContacts getUserTokens(@Auth FederatedPeer peer,
@PathParam("offset") int offset)
{
List<Account> accountList = accounts.getAll(offset, ACCOUNT_CHUNK_SIZE);
List<ClientContact> clientContacts = new LinkedList<>();
for (Account account : accountList) {
byte[] token = Util.getContactToken(account.getNumber());
ClientContact clientContact = new ClientContact(token, null, account.getSupportsSms());
if (!account.isActive()) {
clientContact.setInactive(true);
}
clientContacts.add(clientContact);
}
return new ClientContacts(clientContacts);
}
}

View File

@@ -0,0 +1,51 @@
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV2;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.federation.NonLimitedAccount;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import io.dropwizard.auth.Auth;
@Path("/v2/federation")
public class FederationControllerV2 extends FederationController {
private final Logger logger = LoggerFactory.getLogger(FederationControllerV2.class);
private final KeysControllerV2 keysControllerV2;
public FederationControllerV2(AccountsManager accounts, AttachmentController attachmentController, MessageController messageController, KeysControllerV2 keysControllerV2) {
super(accounts, attachmentController, messageController);
this.keysControllerV2 = keysControllerV2;
}
@Timed
@GET
@Path("/key/{number}/{device}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponseV2> getKeysV2(@Auth FederatedPeer peer,
@PathParam("number") String number,
@PathParam("device") String device)
throws IOException
{
try {
return keysControllerV2.getDeviceKeys(new NonLimitedAccount("Unknown", -1, peer.getName()),
number, device, Optional.<String>absent());
} catch (RateLimitExceededException e) {
logger.warn("Rate limiting on federated channel", e);
throw new IOException(e);
}
}
}

View File

@@ -1,5 +1,5 @@
/** /**
* Copyright (C) 2013 Open WhisperSystems * Copyright (C) 2014 Open Whisper Systems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by * it under the terms of the GNU Affero General Public License as published by
@@ -18,45 +18,30 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import org.slf4j.Logger; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyList;
import org.whispersystems.textsecuregcm.entities.PreKeyStatus;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
@Path("/v1/keys")
public class KeysController { public class KeysController {
private final Logger logger = LoggerFactory.getLogger(KeysController.class); protected final RateLimiters rateLimiters;
protected final Keys keys;
private final RateLimiters rateLimiters; protected final AccountsManager accounts;
private final Keys keys; protected final FederatedClientManager federatedClientManager;
private final AccountsManager accounts;
private final FederatedClientManager federatedClientManager;
public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts, public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts,
FederatedClientManager federatedClientManager) FederatedClientManager federatedClientManager)
@@ -67,119 +52,65 @@ public class KeysController {
this.federatedClientManager = federatedClientManager; this.federatedClientManager = federatedClientManager;
} }
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Account account, @Valid PreKeyList preKeys) {
Device device = account.getAuthenticatedDevice().get();
String identityKey = preKeys.getLastResortKey().getIdentityKey();
if (!identityKey.equals(account.getIdentityKey())) {
account.setIdentityKey(identityKey);
accounts.update(account);
}
keys.store(account.getNumber(), device.getId(), preKeys.getKeys(), preKeys.getLastResortKey());
}
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public PreKeyStatus getStatus(@Auth Account account) { public PreKeyCount getStatus(@Auth Account account) {
int count = keys.getCount(account.getNumber(), account.getAuthenticatedDevice().get().getId()); int count = keys.getCount(account.getNumber(), account.getAuthenticatedDevice().get().getId());
if (count > 0) { if (count > 0) {
count = count - 1; count = count - 1;
} }
return new PreKeyStatus(count); return new PreKeyCount(count);
} }
@Timed protected TargetKeys getLocalKeys(String number, String deviceIdSelector)
@GET throws NoSuchUserException
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public UnstructuredPreKeyList getDeviceKey(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{ {
try {
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
Optional<UnstructuredPreKeyList> results;
if (!relay.isPresent()) results = getLocalKeys(number, deviceId);
else results = federatedClientManager.getClient(relay.get()).getKeys(number, deviceId);
if (results.isPresent()) return results.get();
else throw new WebApplicationException(Response.status(404).build());
} catch (NoSuchPeerException e) {
throw new WebApplicationException(Response.status(404).build());
}
}
@Timed
@GET
@Path("/{number}")
@Produces(MediaType.APPLICATION_JSON)
public PreKey get(@Auth Account account,
@PathParam("number") String number,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
UnstructuredPreKeyList results = getDeviceKey(account, number, String.valueOf(Device.MASTER_ID), relay);
return results.getKeys().get(0);
}
private Optional<UnstructuredPreKeyList> getLocalKeys(String number, String deviceIdSelector) {
Optional<Account> destination = accounts.get(number); Optional<Account> destination = accounts.get(number);
if (!destination.isPresent() || !destination.get().isActive()) { if (!destination.isPresent() || !destination.get().isActive()) {
return Optional.absent(); throw new NoSuchUserException("Target account is inactive");
} }
try { try {
if (deviceIdSelector.equals("*")) { if (deviceIdSelector.equals("*")) {
Optional<UnstructuredPreKeyList> preKeys = keys.get(number); Optional<List<KeyRecord>> preKeys = keys.get(number);
return getActiveKeys(destination.get(), preKeys); return new TargetKeys(destination.get(), preKeys);
} }
long deviceId = Long.parseLong(deviceIdSelector); long deviceId = Long.parseLong(deviceIdSelector);
Optional<Device> targetDevice = destination.get().getDevice(deviceId); Optional<Device> targetDevice = destination.get().getDevice(deviceId);
if (!targetDevice.isPresent() || !targetDevice.get().isActive()) { if (!targetDevice.isPresent() || !targetDevice.get().isActive()) {
return Optional.absent(); throw new NoSuchUserException("Target device is inactive.");
} }
Optional<UnstructuredPreKeyList> preKeys = keys.get(number, deviceId); Optional<List<KeyRecord>> preKeys = keys.get(number, deviceId);
return getActiveKeys(destination.get(), preKeys); return new TargetKeys(destination.get(), preKeys);
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build()); throw new WebApplicationException(Response.status(422).build());
} }
} }
private Optional<UnstructuredPreKeyList> getActiveKeys(Account destination,
Optional<UnstructuredPreKeyList> preKeys)
{
if (!preKeys.isPresent()) return Optional.absent();
List<PreKey> filteredKeys = new LinkedList<>(); public static class TargetKeys {
private final Account destination;
private final Optional<List<KeyRecord>> keys;
for (PreKey preKey : preKeys.get().getKeys()) { public TargetKeys(Account destination, Optional<List<KeyRecord>> keys) {
Optional<Device> device = destination.getDevice(preKey.getDeviceId()); this.destination = destination;
this.keys = keys;
}
if (device.isPresent() && device.get().isActive()) { public Optional<List<KeyRecord>> getKeys() {
preKey.setRegistrationId(device.get().getRegistrationId()); return keys;
preKey.setIdentityKey(destination.getIdentityKey()); }
filteredKeys.add(preKey);
public Account getDestination() {
return destination;
} }
} }
if (filteredKeys.isEmpty()) return Optional.absent();
else return Optional.of(new UnstructuredPreKeyList(filteredKeys));
}
} }

View File

@@ -0,0 +1,136 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyStateV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.Auth;
@Path("/v1/keys")
public class KeysControllerV1 extends KeysController {
private final Logger logger = LoggerFactory.getLogger(KeysControllerV1.class);
public KeysControllerV1(RateLimiters rateLimiters, Keys keys, AccountsManager accounts,
FederatedClientManager federatedClientManager)
{
super(rateLimiters, keys, accounts, federatedClientManager);
}
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Account account, @Valid PreKeyStateV1 preKeys) {
Device device = account.getAuthenticatedDevice().get();
String identityKey = preKeys.getLastResortKey().getIdentityKey();
if (!identityKey.equals(account.getIdentityKey())) {
account.setIdentityKey(identityKey);
accounts.update(account);
}
keys.store(account.getNumber(), device.getId(), preKeys.getKeys(), preKeys.getLastResortKey());
}
@Timed
@GET
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponseV1> getDeviceKey(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
try {
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
if (relay.isPresent()) {
return federatedClientManager.getClient(relay.get()).getKeysV1(number, deviceId);
}
TargetKeys targetKeys = getLocalKeys(number, deviceId);
if (!targetKeys.getKeys().isPresent()) {
return Optional.absent();
}
List<PreKeyV1> preKeys = new LinkedList<>();
Account destination = targetKeys.getDestination();
for (KeyRecord record : targetKeys.getKeys().get()) {
Optional<Device> device = destination.getDevice(record.getDeviceId());
if (device.isPresent() && device.get().isActive()) {
preKeys.add(new PreKeyV1(record.getDeviceId(), record.getKeyId(),
record.getPublicKey(), destination.getIdentityKey(),
device.get().getRegistrationId()));
}
}
if (preKeys.isEmpty()) return Optional.absent();
else return Optional.of(new PreKeyResponseV1(preKeys));
} catch (NoSuchPeerException | NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
}
}
@Timed
@GET
@Path("/{number}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyV1> get(@Auth Account account,
@PathParam("number") String number,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
Optional<PreKeyResponseV1> results = getDeviceKey(account, number, String.valueOf(Device.MASTER_ID), relay);
if (results.isPresent()) return Optional.of(results.get().getKeys().get(0));
else return Optional.absent();
}
}

View File

@@ -0,0 +1,156 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseItemV2;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV2;
import org.whispersystems.textsecuregcm.entities.PreKeyStateV2;
import org.whispersystems.textsecuregcm.entities.PreKeyV2;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.Auth;
@Path("/v2/keys")
public class KeysControllerV2 extends KeysController {
public KeysControllerV2(RateLimiters rateLimiters, Keys keys, AccountsManager accounts,
FederatedClientManager federatedClientManager)
{
super(rateLimiters, keys, accounts, federatedClientManager);
}
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth Account account, @Valid PreKeyStateV2 preKeys) {
Device device = account.getAuthenticatedDevice().get();
boolean updateAccount = false;
if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) {
device.setSignedPreKey(preKeys.getSignedPreKey());
updateAccount = true;
}
if (!preKeys.getIdentityKey().equals(account.getIdentityKey())) {
account.setIdentityKey(preKeys.getIdentityKey());
updateAccount = true;
}
if (updateAccount) {
accounts.update(account);
}
keys.store(account.getNumber(), device.getId(), preKeys.getPreKeys(), preKeys.getLastResortKey());
}
@Timed
@GET
@Path("/{number}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponseV2> getDeviceKeys(@Auth Account account,
@PathParam("number") String number,
@PathParam("device_id") String deviceId,
@QueryParam("relay") Optional<String> relay)
throws RateLimitExceededException
{
try {
if (account.isRateLimited()) {
rateLimiters.getPreKeysLimiter().validate(account.getNumber() + "__" + number + "." + deviceId);
}
if (relay.isPresent()) {
return federatedClientManager.getClient(relay.get()).getKeysV2(number, deviceId);
}
TargetKeys targetKeys = getLocalKeys(number, deviceId);
Account destination = targetKeys.getDestination();
List<PreKeyResponseItemV2> devices = new LinkedList<>();
for (Device device : destination.getDevices()) {
if (device.isActive() && (deviceId.equals("*") || device.getId() == Long.parseLong(deviceId))) {
SignedPreKey signedPreKey = device.getSignedPreKey();
PreKeyV2 preKey = null;
if (targetKeys.getKeys().isPresent()) {
for (KeyRecord keyRecord : targetKeys.getKeys().get()) {
if (keyRecord.getDeviceId() == device.getId()) {
preKey = new PreKeyV2(keyRecord.getKeyId(), keyRecord.getPublicKey());
}
}
}
if (signedPreKey != null || preKey != null) {
devices.add(new PreKeyResponseItemV2(device.getId(), device.getRegistrationId(), signedPreKey, preKey));
}
}
}
if (devices.isEmpty()) return Optional.absent();
else return Optional.of(new PreKeyResponseV2(destination.getIdentityKey(), devices));
} catch (NoSuchPeerException | NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
}
}
@Timed
@PUT
@Path("/signed")
@Consumes(MediaType.APPLICATION_JSON)
public void setSignedKey(@Auth Account account, @Valid SignedPreKey signedPreKey) {
Device device = account.getAuthenticatedDevice().get();
device.setSignedPreKey(signedPreKey);
accounts.update(account);
}
@Timed
@GET
@Path("/signed")
@Produces(MediaType.APPLICATION_JSON)
public Optional<SignedPreKey> getSignedKey(@Auth Account account) {
Device device = account.getAuthenticatedDevice().get();
SignedPreKey signedPreKey = device.getSignedPreKey();
if (signedPreKey != null) return Optional.of(signedPreKey);
else return Optional.absent();
}
}

View File

@@ -142,7 +142,7 @@ public class MessageController {
Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId()); Optional<Device> destinationDevice = destination.getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) { if (destinationDevice.isPresent()) {
sendLocalMessage(source, destination, destinationDevice.get(), incomingMessage); sendLocalMessage(source, destination, destinationDevice.get(), messages.getTimestamp(), incomingMessage);
} }
} }
} }
@@ -150,6 +150,7 @@ public class MessageController {
private void sendLocalMessage(Account source, private void sendLocalMessage(Account source,
Account destinationAccount, Account destinationAccount,
Device destinationDevice, Device destinationDevice,
long timestamp,
IncomingMessage incomingMessage) IncomingMessage incomingMessage)
throws NoSuchUserException, IOException throws NoSuchUserException, IOException
{ {
@@ -159,7 +160,7 @@ public class MessageController {
messageBuilder.setType(incomingMessage.getType()) messageBuilder.setType(incomingMessage.getType())
.setSource(source.getNumber()) .setSource(source.getNumber())
.setTimestamp(System.currentTimeMillis()) .setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp)
.setSourceDevice((int)source.getAuthenticatedDevice().get().getId()); .setSourceDevice((int)source.getAuthenticatedDevice().get().getId());
if (messageBody.isPresent()) { if (messageBody.isPresent()) {

View File

@@ -0,0 +1,108 @@
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.List;
import io.dropwizard.auth.Auth;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
@Path("/v1/receipt")
public class ReceiptController {
private final AccountsManager accountManager;
private final PushSender pushSender;
private final FederatedClientManager federatedClientManager;
public ReceiptController(AccountsManager accountManager,
FederatedClientManager federatedClientManager,
PushSender pushSender)
{
this.accountManager = accountManager;
this.federatedClientManager = federatedClientManager;
this.pushSender = pushSender;
}
@Timed
@PUT
@Path("/{destination}/{messageId}")
public void sendDeliveryReceipt(@Auth Account source,
@PathParam("destination") String destination,
@PathParam("messageId") long messageId,
@QueryParam("relay") Optional<String> relay)
throws IOException
{
try {
if (relay.isPresent()) sendRelayedReceipt(source, destination, messageId, relay.get());
else sendDirectReceipt(source, destination, messageId);
} catch (NoSuchUserException | NotPushRegisteredException e) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
} catch (TransientPushFailureException e) {
throw new IOException(e);
}
}
private void sendRelayedReceipt(Account source, String destination, long messageId, String relay)
throws NoSuchUserException, IOException
{
try {
federatedClientManager.getClient(relay)
.sendDeliveryReceipt(source.getNumber(),
source.getAuthenticatedDevice().get().getId(),
destination, messageId);
} catch (NoSuchPeerException e) {
throw new NoSuchUserException(e);
}
}
private void sendDirectReceipt(Account source, String destination, long messageId)
throws NotPushRegisteredException, TransientPushFailureException, NoSuchUserException
{
Account destinationAccount = getDestinationAccount(destination);
List<Device> destinationDevices = destinationAccount.getDevices();
OutgoingMessageSignal.Builder message =
OutgoingMessageSignal.newBuilder()
.setSource(source.getNumber())
.setSourceDevice((int) source.getAuthenticatedDevice().get().getId())
.setTimestamp(messageId)
.setType(OutgoingMessageSignal.Type.RECEIPT_VALUE);
if (source.getRelay().isPresent()) {
message.setRelay(source.getRelay().get());
}
for (Device destinationDevice : destinationDevices) {
pushSender.sendMessage(destinationAccount, destinationDevice, message.build());
}
}
private Account getDestinationAccount(String destination)
throws NoSuchUserException
{
Optional<Account> account = accountManager.get(destination);
if (!account.isPresent()) {
throw new NoSuchUserException(destination);
}
return account.get();
}
}

View File

@@ -1,223 +0,0 @@
package org.whispersystems.textsecuregcm.controllers;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import org.eclipse.jetty.websocket.api.CloseStatus;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.IncomingWebsocketMessage;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubListener;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubMessage;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import org.whispersystems.textsecuregcm.websocket.WebsocketMessage;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import io.dropwizard.auth.AuthenticationException;
import io.dropwizard.auth.basic.BasicCredentials;
public class WebsocketController implements WebSocketListener, PubSubListener {
private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class);
private static final ObjectMapper mapper = new ObjectMapper();
private static final Map<Long, String> pendingMessages = new HashMap<>();
private final AccountAuthenticator accountAuthenticator;
private final PubSubManager pubSubManager;
private final StoredMessages storedMessages;
private final PushSender pushSender;
private WebsocketAddress address;
private Account account;
private Device device;
private Session session;
private long pendingMessageSequence;
public WebsocketController(AccountAuthenticator accountAuthenticator,
PushSender pushSender,
PubSubManager pubSubManager,
StoredMessages storedMessages)
{
this.accountAuthenticator = accountAuthenticator;
this.pushSender = pushSender;
this.pubSubManager = pubSubManager;
this.storedMessages = storedMessages;
}
@Override
public void onWebSocketConnect(Session session) {
try {
UpgradeRequest request = session.getUpgradeRequest();
Map<String, String[]> parameters = request.getParameterMap();
String[] usernames = parameters.get("login" );
String[] passwords = parameters.get("password");
if (usernames == null || usernames.length == 0 ||
passwords == null || passwords.length == 0)
{
session.close(new CloseStatus(4001, "Unauthorized"));
return;
}
BasicCredentials credentials = new BasicCredentials(usernames[0], passwords[0]);
Optional<Account> account = accountAuthenticator.authenticate(credentials);
if (!account.isPresent()) {
session.close(new CloseStatus(4001, "Unauthorized"));
return;
}
this.account = account.get();
this.device = account.get().getAuthenticatedDevice().get();
this.address = new WebsocketAddress(this.account.getId(), this.device.getId());
this.session = session;
this.session.setIdleTimeout(10 * 60 * 1000);
this.pubSubManager.subscribe(this.address, this);
handleQueryDatabase();
} catch (AuthenticationException e) {
try { session.close(1011, "Server Error");} catch (IOException e1) {}
} catch (IOException ioe) {
logger.info("Abrupt session close.");
}
}
@Override
public void onWebSocketText(String body) {
try {
IncomingWebsocketMessage incomingMessage = mapper.readValue(body, IncomingWebsocketMessage.class);
switch (incomingMessage.getType()) {
case IncomingWebsocketMessage.TYPE_ACKNOWLEDGE_MESSAGE:
handleMessageAck(body);
break;
default:
close(new CloseStatus(1008, "Unknown Type"));
}
} catch (IOException e) {
logger.debug("Parse", e);
close(new CloseStatus(1008, "Badly Formatted"));
}
}
@Override
public void onWebSocketClose(int i, String s) {
pubSubManager.unsubscribe(this.address, this);
List<String> remainingMessages = new LinkedList<>();
synchronized (pendingMessages) {
Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]);
Arrays.sort(pendingKeys);
for (long pendingKey : pendingKeys) {
remainingMessages.add(pendingMessages.get(pendingKey));
}
pendingMessages.clear();
}
for (String remainingMessage : remainingMessages) {
try {
pushSender.sendMessage(account, device, new EncryptedOutgoingMessage(remainingMessage));
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("onWebSocketClose", e);
storedMessages.insert(account.getId(), device.getId(), remainingMessage);
}
}
}
@Override
public void onPubSubMessage(PubSubMessage outgoingMessage) {
switch (outgoingMessage.getType()) {
case PubSubMessage.TYPE_DELIVER:
handleDeliverOutgoingMessage(outgoingMessage.getContents());
break;
case PubSubMessage.TYPE_QUERY_DB:
handleQueryDatabase();
break;
default:
logger.warn("Unknown pubsub message: " + outgoingMessage.getType());
}
}
private void handleDeliverOutgoingMessage(String message) {
try {
long messageSequence;
synchronized (pendingMessages) {
messageSequence = pendingMessageSequence++;
pendingMessages.put(messageSequence, message);
}
WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message);
session.getRemote().sendStringByFuture(mapper.writeValueAsString(websocketMessage));
} catch (IOException e) {
logger.debug("Response failed", e);
close(null);
}
}
private void handleMessageAck(String message) {
try {
AcknowledgeWebsocketMessage ack = mapper.readValue(message, AcknowledgeWebsocketMessage.class);
synchronized (pendingMessages) {
pendingMessages.remove(ack.getId());
}
} catch (IOException e) {
logger.warn("Mapping", e);
}
}
private void handleQueryDatabase() {
List<String> messages = storedMessages.getMessagesForDevice(account.getId(), device.getId());
for (String message : messages) {
handleDeliverOutgoingMessage(message);
}
}
@Override
public void onWebSocketBinary(byte[] bytes, int i, int i2) {
logger.info("Received binary message!");
}
@Override
public void onWebSocketError(Throwable throwable) {
logger.info("onWebSocketError", throwable);
}
private void close(CloseStatus closeStatus) {
try {
if (this.session != null) {
if (closeStatus != null) this.session.close(closeStatus);
else this.session.close();
}
} catch (IOException e) {
logger.info("close()", e);
}
}
}

View File

@@ -55,10 +55,6 @@ public class EncryptedOutgoingMessage {
this.serialized = Base64.encodeBytes(ciphertext); this.serialized = Base64.encodeBytes(ciphertext);
} }
public EncryptedOutgoingMessage(String serialized) {
this.serialized = serialized;
}
public String serialize() { public String serialize() {
return serialized; return serialized;
} }

View File

@@ -41,7 +41,7 @@ public class IncomingMessage {
private String relay; private String relay;
@JsonProperty @JsonProperty
private long timestamp; private long timestamp; // deprecated
public String getDestination() { public String getDestination() {

View File

@@ -32,6 +32,9 @@ public class IncomingMessageList {
@JsonProperty @JsonProperty
private String relay; private String relay;
@JsonProperty
private long timestamp;
public IncomingMessageList() {} public IncomingMessageList() {}
public List<IncomingMessage> getMessages() { public List<IncomingMessage> getMessages() {
@@ -45,4 +48,8 @@ public class IncomingMessageList {
public void setRelay(String relay) { public void setRelay(String relay) {
this.relay = relay; this.relay = relay;
} }
public long getTimestamp() {
return timestamp;
}
} }

View File

@@ -0,0 +1,60 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
public class PendingMessage {
@JsonProperty
private String sender;
@JsonProperty
private long messageId;
@JsonProperty
private String encryptedOutgoingMessage;
@JsonProperty
private boolean receipt;
public PendingMessage() {}
public PendingMessage(String sender, long messageId, boolean receipt, String encryptedOutgoingMessage) {
this.sender = sender;
this.messageId = messageId;
this.receipt = receipt;
this.encryptedOutgoingMessage = encryptedOutgoingMessage;
}
public String getEncryptedOutgoingMessage() {
return encryptedOutgoingMessage;
}
public long getMessageId() {
return messageId;
}
public String getSender() {
return sender;
}
public boolean isReceipt() {
return receipt;
}
@Override
public boolean equals(Object other) {
if (other == null || !(other instanceof PendingMessage)) return false;
PendingMessage that = (PendingMessage)other;
return
this.sender.equals(that.sender) &&
this.messageId == that.messageId &&
this.receipt == that.receipt &&
this.encryptedOutgoingMessage.equals(that.encryptedOutgoingMessage);
}
@Override
public int hashCode() {
return this.sender.hashCode() ^ (int)this.messageId ^ this.encryptedOutgoingMessage.hashCode() ^ (receipt ? 1 : 0);
}
}

View File

@@ -0,0 +1,8 @@
package org.whispersystems.textsecuregcm.entities;
public interface PreKeyBase {
public long getKeyId();
public String getPublicKey();
}

View File

@@ -3,16 +3,16 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
public class PreKeyStatus { public class PreKeyCount {
@JsonProperty @JsonProperty
private int count; private int count;
public PreKeyStatus(int count) { public PreKeyCount(int count) {
this.count = count; this.count = count;
} }
public PreKeyStatus() {} public PreKeyCount() {}
public int getCount() { public int getCount() {
return count; return count;

View File

@@ -0,0 +1,64 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
public class PreKeyResponseItemV2 {
@JsonProperty
private long deviceId;
@JsonProperty
private int registrationId;
@JsonProperty
private SignedPreKey signedPreKey;
@JsonProperty
private PreKeyV2 preKey;
public PreKeyResponseItemV2() {}
public PreKeyResponseItemV2(long deviceId, int registrationId, SignedPreKey signedPreKey, PreKeyV2 preKey) {
this.deviceId = deviceId;
this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
this.preKey = preKey;
}
@VisibleForTesting
public SignedPreKey getSignedPreKey() {
return signedPreKey;
}
@VisibleForTesting
public PreKeyV2 getPreKey() {
return preKey;
}
@VisibleForTesting
public int getRegistrationId() {
return registrationId;
}
@VisibleForTesting
public long getDeviceId() {
return deviceId;
}
}

View File

@@ -18,7 +18,6 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
@@ -26,36 +25,36 @@ import java.util.Iterator;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
public class UnstructuredPreKeyList { public class PreKeyResponseV1 {
@JsonProperty @JsonProperty
@NotNull @NotNull
@Valid @Valid
private List<PreKey> keys; private List<PreKeyV1> keys;
@VisibleForTesting @VisibleForTesting
public UnstructuredPreKeyList() {} public PreKeyResponseV1() {}
public UnstructuredPreKeyList(PreKey preKey) { public PreKeyResponseV1(PreKeyV1 preKey) {
this.keys = new LinkedList<PreKey>(); this.keys = new LinkedList<>();
this.keys.add(preKey); this.keys.add(preKey);
} }
public UnstructuredPreKeyList(List<PreKey> preKeys) { public PreKeyResponseV1(List<PreKeyV1> preKeys) {
this.keys = preKeys; this.keys = preKeys;
} }
public List<PreKey> getKeys() { public List<PreKeyV1> getKeys() {
return keys; return keys;
} }
@VisibleForTesting @VisibleForTesting
public boolean equals(Object o) { public boolean equals(Object o) {
if (!(o instanceof UnstructuredPreKeyList) || if (!(o instanceof PreKeyResponseV1) ||
((UnstructuredPreKeyList) o).keys.size() != keys.size()) ((PreKeyResponseV1) o).keys.size() != keys.size())
return false; return false;
Iterator<PreKey> otherKeys = ((UnstructuredPreKeyList) o).keys.iterator(); Iterator<PreKeyV1> otherKeys = ((PreKeyResponseV1) o).keys.iterator();
for (PreKey key : keys) { for (PreKeyV1 key : keys) {
if (!otherKeys.next().equals(key)) if (!otherKeys.next().equals(key))
return false; return false;
} }
@@ -64,7 +63,7 @@ public class UnstructuredPreKeyList {
public int hashCode() { public int hashCode() {
int ret = 0xFBA4C795 * keys.size(); int ret = 0xFBA4C795 * keys.size();
for (PreKey key : keys) for (PreKeyV1 key : keys)
ret ^= key.getPublicKey().hashCode(); ret ^= key.getPublicKey().hashCode();
return ret; return ret;
} }

View File

@@ -0,0 +1,48 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import java.util.List;
public class PreKeyResponseV2 {
@JsonProperty
private String identityKey;
@JsonProperty
private List<PreKeyResponseItemV2> devices;
public PreKeyResponseV2() {}
public PreKeyResponseV2(String identityKey, List<PreKeyResponseItemV2> devices) {
this.identityKey = identityKey;
this.devices = devices;
}
@VisibleForTesting
public String getIdentityKey() {
return identityKey;
}
@VisibleForTesting
public List<PreKeyResponseItemV2> getDevices() {
return devices;
}
}

View File

@@ -1,5 +1,5 @@
/** /**
* Copyright (C) 2013 Open WhisperSystems * Copyright (C) 2014 Open Whisper Systems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by * it under the terms of the GNU Affero General Public License as published by
@@ -18,39 +18,38 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import java.util.List; import java.util.List;
public class PreKeyList { public class PreKeyStateV1 {
@JsonProperty @JsonProperty
@NotNull @NotNull
@Valid @Valid
private PreKey lastResortKey; private PreKeyV1 lastResortKey;
@JsonProperty @JsonProperty
@NotNull @NotNull
@Valid @Valid
private List<PreKey> keys; private List<PreKeyV1> keys;
public List<PreKey> getKeys() { public List<PreKeyV1> getKeys() {
return keys; return keys;
} }
@VisibleForTesting @VisibleForTesting
public void setKeys(List<PreKey> keys) { public void setKeys(List<PreKeyV1> keys) {
this.keys = keys; this.keys = keys;
} }
public PreKey getLastResortKey() { public PreKeyV1 getLastResortKey() {
return lastResortKey; return lastResortKey;
} }
@VisibleForTesting @VisibleForTesting
public void setLastResortKey(PreKey lastResortKey) { public void setLastResortKey(PreKeyV1 lastResortKey) {
this.lastResortKey = lastResortKey; this.lastResortKey = lastResortKey;
} }
} }

View File

@@ -0,0 +1,76 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import java.util.List;
public class PreKeyStateV2 {
@JsonProperty
@NotNull
@Valid
private List<PreKeyV2> preKeys;
@JsonProperty
@NotNull
@Valid
private SignedPreKey signedPreKey;
@JsonProperty
@NotNull
@Valid
private PreKeyV2 lastResortKey;
@JsonProperty
@NotEmpty
private String identityKey;
public PreKeyStateV2() {}
@VisibleForTesting
public PreKeyStateV2(String identityKey, SignedPreKey signedPreKey,
List<PreKeyV2> keys, PreKeyV2 lastResortKey)
{
this.identityKey = identityKey;
this.signedPreKey = signedPreKey;
this.preKeys = keys;
this.lastResortKey = lastResortKey;
}
public List<PreKeyV2> getPreKeys() {
return preKeys;
}
public SignedPreKey getSignedPreKey() {
return signedPreKey;
}
public String getIdentityKey() {
return identityKey;
}
public PreKeyV2 getLastResortKey() {
return lastResortKey;
}
}

View File

@@ -1,5 +1,5 @@
/** /**
* Copyright (C) 2013 Open WhisperSystems * Copyright (C) 2014 Open Whisper Systems
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by * it under the terms of the GNU Affero General Public License as published by
@@ -17,23 +17,14 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.xml.bind.annotation.XmlTransient;
import java.io.Serializable;
@JsonInclude(JsonInclude.Include.NON_DEFAULT) @JsonInclude(JsonInclude.Include.NON_DEFAULT)
public class PreKey { public class PreKeyV1 implements PreKeyBase {
@JsonIgnore
private long id;
@JsonIgnore
private String number;
@JsonProperty @JsonProperty
private long deviceId; private long deviceId;
@@ -50,89 +41,43 @@ public class PreKey {
@NotNull @NotNull
private String identityKey; private String identityKey;
@JsonProperty
private boolean lastResort;
@JsonProperty @JsonProperty
private int registrationId; private int registrationId;
public PreKey() {} public PreKeyV1() {}
public PreKey(long id, String number, long deviceId, long keyId, public PreKeyV1(long deviceId, long keyId, String publicKey, String identityKey, int registrationId)
String publicKey, boolean lastResort)
{ {
this.id = id;
this.number = number;
this.deviceId = deviceId;
this.keyId = keyId;
this.publicKey = publicKey;
this.lastResort = lastResort;
}
@VisibleForTesting
public PreKey(long id, String number, long deviceId, long keyId,
String publicKey, String identityKey, boolean lastResort)
{
this.id = id;
this.number = number;
this.deviceId = deviceId; this.deviceId = deviceId;
this.keyId = keyId; this.keyId = keyId;
this.publicKey = publicKey; this.publicKey = publicKey;
this.identityKey = identityKey; this.identityKey = identityKey;
this.lastResort = lastResort; this.registrationId = registrationId;
} }
@XmlTransient @VisibleForTesting
public long getId() { public PreKeyV1(long deviceId, long keyId, String publicKey, String identityKey)
return id; {
} this.deviceId = deviceId;
this.keyId = keyId;
public void setId(long id) { this.publicKey = publicKey;
this.id = id; this.identityKey = identityKey;
}
@XmlTransient
public String getNumber() {
return number;
}
public void setNumber(String number) {
this.number = number;
} }
@Override
public String getPublicKey() { public String getPublicKey() {
return publicKey; return publicKey;
} }
public void setPublicKey(String publicKey) { @Override
this.publicKey = publicKey;
}
public long getKeyId() { public long getKeyId() {
return keyId; return keyId;
} }
public void setKeyId(long keyId) {
this.keyId = keyId;
}
public String getIdentityKey() { public String getIdentityKey() {
return identityKey; return identityKey;
} }
public void setIdentityKey(String identityKey) {
this.identityKey = identityKey;
}
@XmlTransient
public boolean isLastResort() {
return lastResort;
}
public void setLastResort(boolean lastResort) {
this.lastResort = lastResort;
}
public void setDeviceId(long deviceId) { public void setDeviceId(long deviceId) {
this.deviceId = deviceId; this.deviceId = deviceId;
} }

View File

@@ -0,0 +1,82 @@
package org.whispersystems.textsecuregcm.entities;
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
public class PreKeyV2 implements PreKeyBase {
@JsonProperty
@NotNull
private long keyId;
@JsonProperty
@NotEmpty
private String publicKey;
public PreKeyV2() {}
public PreKeyV2(long keyId, String publicKey)
{
this.keyId = keyId;
this.publicKey = publicKey;
}
@Override
public String getPublicKey() {
return publicKey;
}
public void setPublicKey(String publicKey) {
this.publicKey = publicKey;
}
@Override
public long getKeyId() {
return keyId;
}
public void setKeyId(long keyId) {
this.keyId = keyId;
}
@Override
public boolean equals(Object object) {
if (object == null || !(object instanceof PreKeyV2)) return false;
PreKeyV2 that = (PreKeyV2)object;
if (publicKey == null) {
return this.keyId == that.keyId && that.publicKey == null;
} else {
return this.keyId == that.keyId && this.publicKey.equals(that.publicKey);
}
}
@Override
public int hashCode() {
if (publicKey == null) {
return (int)this.keyId;
} else {
return ((int)this.keyId) ^ publicKey.hashCode();
}
}
}

View File

@@ -0,0 +1,46 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.hibernate.validator.constraints.NotEmpty;
import java.io.Serializable;
public class SignedPreKey extends PreKeyV2 {
@JsonProperty
@NotEmpty
private String signature;
public SignedPreKey() {}
public SignedPreKey(long keyId, String publicKey, String signature) {
super(keyId, publicKey);
this.signature = signature;
}
public String getSignature() {
return signature;
}
@Override
public boolean equals(Object object) {
if (object == null || !(object instanceof SignedPreKey)) return false;
SignedPreKey that = (SignedPreKey) object;
if (signature == null) {
return super.equals(object) && that.signature == null;
} else {
return super.equals(object) && this.signature.equals(that.signature);
}
}
@Override
public int hashCode() {
if (signature == null) {
return super.hashCode();
} else {
return super.hashCode() ^ signature.hashCode();
}
}
}

View File

@@ -36,7 +36,8 @@ import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.ClientContacts; import org.whispersystems.textsecuregcm.entities.ClientContacts;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV2;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
@@ -65,8 +66,10 @@ public class FederatedClient {
private static final String USER_COUNT_PATH = "/v1/federation/user_count"; private static final String USER_COUNT_PATH = "/v1/federation/user_count";
private static final String USER_TOKENS_PATH = "/v1/federation/user_tokens/%d"; private static final String USER_TOKENS_PATH = "/v1/federation/user_tokens/%d";
private static final String RELAY_MESSAGE_PATH = "/v1/federation/messages/%s/%d/%s"; private static final String RELAY_MESSAGE_PATH = "/v1/federation/messages/%s/%d/%s";
private static final String PREKEY_PATH_DEVICE = "/v1/federation/key/%s/%s"; private static final String PREKEY_PATH_DEVICE_V1 = "/v1/federation/key/%s/%s";
private static final String PREKEY_PATH_DEVICE_V2 = "/v2/federation/key/%s/%s";
private static final String ATTACHMENT_URI_PATH = "/v1/federation/attachment/%d"; private static final String ATTACHMENT_URI_PATH = "/v1/federation/attachment/%d";
private static final String RECEIPT_PATH = "/v1/receipt/%s/%d/%s/%d";
private final FederatedPeer peer; private final FederatedPeer peer;
private final Client client; private final Client client;
@@ -107,9 +110,9 @@ public class FederatedClient {
} }
} }
public Optional<UnstructuredPreKeyList> getKeys(String destination, String device) { public Optional<PreKeyResponseV1> getKeysV1(String destination, String device) {
try { try {
WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH_DEVICE, destination, device)); WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH_DEVICE_V1, destination, device));
ClientResponse response = resource.accept(MediaType.APPLICATION_JSON) ClientResponse response = resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader) .header("Authorization", authorizationHeader)
@@ -119,7 +122,7 @@ public class FederatedClient {
throw new WebApplicationException(clientResponseToResponse(response)); throw new WebApplicationException(clientResponseToResponse(response));
} }
return Optional.of(response.getEntity(UnstructuredPreKeyList.class)); return Optional.of(response.getEntity(PreKeyResponseV1.class));
} catch (UniformInterfaceException | ClientHandlerException e) { } catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("PreKey", e); logger.warn("PreKey", e);
@@ -127,6 +130,27 @@ public class FederatedClient {
} }
} }
public Optional<PreKeyResponseV2> getKeysV2(String destination, String device) {
try {
WebResource resource = client.resource(peer.getUrl()).path(String.format(PREKEY_PATH_DEVICE_V2, destination, device));
ClientResponse response = resource.accept(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.get(ClientResponse.class);
if (response.getStatus() < 200 || response.getStatus() >= 300) {
throw new WebApplicationException(clientResponseToResponse(response));
}
return Optional.of(response.getEntity(PreKeyResponseV2.class));
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("PreKey", e);
return Optional.absent();
}
}
public int getUserCount() { public int getUserCount() {
try { try {
WebResource resource = client.resource(peer.getUrl()).path(USER_COUNT_PATH); WebResource resource = client.resource(peer.getUrl()).path(USER_COUNT_PATH);
@@ -174,6 +198,25 @@ public class FederatedClient {
} }
} }
public void sendDeliveryReceipt(String source, long sourceDeviceId, String destination, long messageId)
throws IOException
{
try {
String path = String.format(RECEIPT_PATH, source, sourceDeviceId, destination, messageId);
WebResource resource = client.resource(peer.getUrl()).path(path);
ClientResponse response = resource.type(MediaType.APPLICATION_JSON)
.header("Authorization", authorizationHeader)
.put(ClientResponse.class);
if (response.getStatus() != 200 && response.getStatus() != 204) {
throw new WebApplicationException(clientResponseToResponse(response));
}
} catch (UniformInterfaceException | ClientHandlerException e) {
logger.warn("sendMessage", e);
throw new IOException(e);
}
}
private String getAuthorizationHeader(String federationName, FederatedPeer peer) { private String getAuthorizationHeader(String federationName, FederatedPeer peer) {
return "Basic " + Base64.encodeBytes((federationName + ":" + peer.getAuthenticationToken()).getBytes()); return "Basic " + Base64.encodeBytes((federationName + ":" + peer.getAuthenticationToken()).getBytes());
} }

View File

@@ -40,6 +40,6 @@ public class NonLimitedAccount extends Account {
@Override @Override
public Optional<Device> getAuthenticatedDevice() { public Optional<Device> getAuthenticatedDevice() {
return Optional.of(new Device(deviceId, null, null, null, null, null, false, 0)); return Optional.of(new Device(deviceId, null, null, null, null, null, false, 0, null));
} }
} }

View File

@@ -45,6 +45,8 @@ public class MemcacheHealthCheck extends HealthCheck {
return Result.unhealthy("Fetch failed"); return Result.unhealthy("Fetch failed");
} }
this.client.delete("HEALTH" + random);
return Result.healthy(); return Result.healthy();
} }

View File

@@ -0,0 +1,7 @@
package org.whispersystems.textsecuregcm.providers;
public class TimeProvider {
public long getCurrentTimeMillis() {
return System.currentTimeMillis();
}
}

View File

@@ -19,20 +19,24 @@ package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.notnoop.apns.APNS; import com.notnoop.apns.APNS;
import com.notnoop.apns.ApnsService; import com.notnoop.apns.ApnsService;
import com.notnoop.exceptions.NetworkIOException; import com.notnoop.exceptions.NetworkIOException;
import net.spy.memcached.MemcachedClient;
import org.bouncycastle.openssl.PEMReader; import org.bouncycastle.openssl.PEMReader;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubMessage; import org.whispersystems.textsecuregcm.storage.PubSubMessage;
import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
@@ -47,10 +51,16 @@ import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate; import java.security.cert.Certificate;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
public class APNSender { public class APNSender implements Managed {
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter websocketMeter = metricRegistry.meter(name(getClass(), "websocket")); private final Meter websocketMeter = metricRegistry.meter(name(getClass(), "websocket"));
@@ -60,39 +70,56 @@ public class APNSender {
private static final String MESSAGE_BODY = "m"; private static final String MESSAGE_BODY = "m";
private final Optional<ApnsService> apnService; private static final ObjectMapper mapper = SystemMapper.getMapper();
private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
private final AccountsManager accounts;
private final PubSubManager pubSubManager; private final PubSubManager pubSubManager;
private final StoredMessages storedMessages; private final StoredMessages storedMessages;
private final MemcachedClient memcachedClient;
public APNSender(PubSubManager pubSubManager, private final String apnCertificate;
private final String apnKey;
private Optional<ApnsService> apnService;
public APNSender(AccountsManager accounts,
PubSubManager pubSubManager,
StoredMessages storedMessages, StoredMessages storedMessages,
MemcachedClient memcachedClient,
String apnCertificate, String apnKey) String apnCertificate, String apnKey)
throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
{ {
this.accounts = accounts;
this.pubSubManager = pubSubManager; this.pubSubManager = pubSubManager;
this.storedMessages = storedMessages; this.storedMessages = storedMessages;
this.apnCertificate = apnCertificate;
if (!Util.isEmpty(apnCertificate) && !Util.isEmpty(apnKey)) { this.apnKey = apnKey;
byte[] keyStore = initializeKeyStore(apnCertificate, apnKey); this.memcachedClient = memcachedClient;
this.apnService = Optional.of(APNS.newService()
.withCert(new ByteArrayInputStream(keyStore), "insecure")
.withSandboxDestination().build());
} else {
this.apnService = Optional.absent();
}
} }
public void sendMessage(Account account, Device device, public void sendMessage(Account account, Device device,
String registrationId, EncryptedOutgoingMessage message) String registrationId, PendingMessage message)
throws TransientPushFailureException, NotPushRegisteredException throws TransientPushFailureException
{ {
if (pubSubManager.publish(new WebsocketAddress(account.getId(), device.getId()), try {
new PubSubMessage(PubSubMessage.TYPE_DELIVER, message.serialize()))) String serializedPendingMessage = mapper.writeValueAsString(message);
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
if (pubSubManager.publish(websocketAddress, new PubSubMessage(PubSubMessage.TYPE_DELIVER,
serializedPendingMessage)))
{ {
websocketMeter.mark(); websocketMeter.mark();
} else { } else {
storedMessages.insert(account.getId(), device.getId(), message.serialize()); memcacheSet(registrationId, account.getNumber());
sendPush(registrationId, message.serialize()); storedMessages.insert(websocketAddress, message);
if (!message.isReceipt()) {
sendPush(registrationId, serializedPendingMessage);
}
}
} catch (IOException e) {
throw new TransientPushFailureException(e);
} }
} }
@@ -143,4 +170,79 @@ public class APNSender {
return baos.toByteArray(); return baos.toByteArray();
} }
@Override
public void start() throws Exception {
if (!Util.isEmpty(apnCertificate) && !Util.isEmpty(apnKey)) {
byte[] keyStore = initializeKeyStore(apnCertificate, apnKey);
this.apnService = Optional.of(APNS.newService()
.withCert(new ByteArrayInputStream(keyStore), "insecure")
.asQueued()
.withSandboxDestination().build());
this.executor.scheduleAtFixedRate(new FeedbackRunnable(), 0, 1, TimeUnit.HOURS);
} else {
this.apnService = Optional.absent();
}
}
@Override
public void stop() throws Exception {
if (apnService.isPresent()) {
apnService.get().stop();
}
}
private void memcacheSet(String registrationId, String number) {
if (memcachedClient != null) {
memcachedClient.set("APN-" + registrationId, 60 * 60 * 24, number);
}
}
private Optional<String> memcacheGet(String registrationId) {
if (memcachedClient != null) {
return Optional.fromNullable((String)memcachedClient.get("APN-" + registrationId));
} else {
return Optional.absent();
}
}
private class FeedbackRunnable implements Runnable {
private void updateAccount(Account account, String registrationId) {
boolean needsUpdate = false;
for (Device device : account.getDevices()) {
if (registrationId.equals(device.getApnId())) {
needsUpdate = true;
device.setApnId(null);
}
}
if (needsUpdate) {
accounts.update(account);
}
}
@Override
public void run() {
if (apnService.isPresent()) {
Map<String, Date> inactiveDevices = apnService.get().getInactiveDevices();
for (String registrationId : inactiveDevices.keySet()) {
Optional<String> number = memcacheGet(registrationId);
if (number.isPresent()) {
Optional<Account> account = accounts.get(number.get());
if (account.isPresent()) {
updateAccount(account.get(), registrationId);
}
} else {
logger.warn("APN unregister event received for uncached ID: " + registrationId);
}
}
}
}
}
} }

View File

@@ -1,69 +1,424 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.google.android.gcm.server.Constants; import com.google.common.base.Optional;
import com.google.android.gcm.server.Message; import org.jivesoftware.smack.ConnectionConfiguration;
import com.google.android.gcm.server.Result; import org.jivesoftware.smack.ConnectionListener;
import com.google.android.gcm.server.Sender; import org.jivesoftware.smack.PacketListener;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.jivesoftware.smack.SmackException;
import org.jivesoftware.smack.XMPPConnection;
import org.jivesoftware.smack.XMPPException;
import org.jivesoftware.smack.filter.PacketTypeFilter;
import org.jivesoftware.smack.packet.DefaultPacketExtension;
import org.jivesoftware.smack.packet.Message;
import org.jivesoftware.smack.packet.Packet;
import org.jivesoftware.smack.packet.PacketExtension;
import org.jivesoftware.smack.provider.PacketExtensionProvider;
import org.jivesoftware.smack.provider.ProviderManager;
import org.jivesoftware.smack.tcp.XMPPTCPConnection;
import org.jivesoftware.smack.util.StringUtils;
import org.json.simple.JSONObject;
import org.json.simple.JSONValue;
import org.json.simple.parser.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Util;
import org.xmlpull.v1.XmlPullParser;
import javax.net.ssl.SSLSocketFactory;
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
public class GCMSender { public class GCMSender implements Managed, PacketListener {
private final Logger logger = LoggerFactory.getLogger(GCMSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(org.whispersystems.textsecuregcm.util.Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(org.whispersystems.textsecuregcm.util.Constants.METRICS_NAME);
private final Meter success = metricRegistry.meter(name(getClass(), "sent", "success")); private final Meter success = metricRegistry.meter(name(getClass(), "sent", "success"));
private final Meter failure = metricRegistry.meter(name(getClass(), "sent", "failure")); private final Meter failure = metricRegistry.meter(name(getClass(), "sent", "failure"));
private final Meter unregistered = metricRegistry.meter(name(getClass(), "sent", "unregistered"));
private final Sender sender; private static final String GCM_SERVER = "gcm.googleapis.com";
private static final int GCM_PORT = 5235;
public GCMSender(String apiKey) { private static final String GCM_ELEMENT_NAME = "gcm";
this.sender = new Sender(apiKey); private static final String GCM_NAMESPACE = "google:mobile:data";
private final Map<String, UnacknowledgedMessage> pendingMessages = new ConcurrentHashMap<>();
private final long senderId;
private final String apiKey;
private final AccountsManager accounts;
private XMPPTCPConnection connection;
public GCMSender(AccountsManager accounts, long senderId, String apiKey) {
this.accounts = accounts;
this.senderId = senderId;
this.apiKey = apiKey;
ProviderManager.addExtensionProvider(GCM_ELEMENT_NAME, GCM_NAMESPACE,
new GcmPacketExtensionProvider());
} }
public String sendMessage(String gcmRegistrationId, EncryptedOutgoingMessage outgoingMessage) public void sendMessage(String destinationNumber, long destinationDeviceId,
throws NotPushRegisteredException, TransientPushFailureException String registrationId, PendingMessage message)
{ {
String messageId = "m-" + UUID.randomUUID().toString();
UnacknowledgedMessage unacknowledgedMessage = new UnacknowledgedMessage(destinationNumber,
destinationDeviceId,
registrationId, message);
sendMessage(messageId, unacknowledgedMessage);
}
public void sendMessage(String messageId, UnacknowledgedMessage message) {
try { try {
Message gcmMessage = new Message.Builder().addData("type", "message") boolean isReceipt = message.getPendingMessage().isReceipt();
.addData("message", outgoingMessage.serialize())
.build();
Result result = sender.send(gcmMessage, gcmRegistrationId, 5); Map<String, String> dataObject = new HashMap<>();
dataObject.put("type", "message");
dataObject.put(isReceipt ? "receipt" : "message", message.getPendingMessage().getEncryptedOutgoingMessage());
if (result.getMessageId() != null) { Map<String, Object> messageObject = new HashMap<>();
messageObject.put("to", message.getRegistrationId());
messageObject.put("message_id", messageId);
messageObject.put("data", dataObject);
String json = JSONObject.toJSONString(messageObject);
pendingMessages.put(messageId, message);
connection.sendPacket(new GcmPacketExtension(json).toPacket());
} catch (SmackException.NotConnectedException e) {
logger.warn("GCMClient", "No connection", e);
}
}
@Override
public void start() throws Exception {
this.connection = connect(senderId, apiKey);
}
@Override
public void stop() throws Exception {
this.connection.disconnect();
}
@Override
public void processPacket(Packet packet) throws SmackException.NotConnectedException {
Message incomingMessage = (Message) packet;
GcmPacketExtension gcmPacket = (GcmPacketExtension) incomingMessage.getExtension(GCM_NAMESPACE);
String json = gcmPacket.getJson();
try {
Map<String, Object> jsonObject = (Map<String, Object>) JSONValue.parseWithException(json);
Object messageType = jsonObject.get("message_type");
if (messageType == null) {
handleUpstreamMessage(jsonObject);
return;
}
switch (messageType.toString()) {
case "ack" : handleAckReceipt(jsonObject); break;
case "nack" : handleNackReceipt(jsonObject); break;
case "receipt" : handleDeliveryReceipt(jsonObject); break;
case "control" : handleControlMessage(jsonObject); break;
default:
logger.warn("Received unknown GCM message: " + messageType.toString());
}
} catch (ParseException e) {
logger.warn("GCMClient", "Received unparsable message", e);
} catch (Exception e) {
logger.warn("GCMClient", "Failed to process packet", e);
}
}
private void handleControlMessage(Map<String, Object> message) {
String controlType = (String) message.get("control_type");
if ("CONNECTION_DRAINING".equals(controlType)) {
logger.warn("GCM Connection is draining! Initiating reconnect...");
reconnect();
} else {
logger.warn("Received unknown GCM control message: " + controlType);
}
}
private void handleDeliveryReceipt(Map<String, Object> message) {
logger.warn("Got delivery receipt!");
}
private void handleNackReceipt(Map<String, Object> message) {
String messageId = (String) message.get("message_id");
String errorCode = (String) message.get("error");
if (errorCode == null) {
logger.warn("Null GCM error code!");
if (messageId != null) {
pendingMessages.remove(messageId);
}
return;
}
switch (errorCode) {
case "BAD_REGISTRATION" : handleBadRegistration(message); break;
case "DEVICE_UNREGISTERED" : handleBadRegistration(message); break;
case "INTERNAL_SERVER_ERROR" : handleServerFailure(message); break;
case "INVALID_JSON" : handleClientFailure(message); break;
case "QUOTA_EXCEEDED" : handleClientFailure(message); break;
case "SERVICE_UNAVAILABLE" : handleServerFailure(message); break;
}
}
private void handleAckReceipt(Map<String, Object> message) {
success.mark(); success.mark();
return result.getCanonicalRegistrationId();
} else { String messageId = (String) message.get("message_id");
if (messageId != null) {
pendingMessages.remove(messageId);
}
}
private void handleUpstreamMessage(Map<String, Object> message)
throws SmackException.NotConnectedException
{
logger.warn("Got upstream message from GCM Server!");
for (String key : message.keySet()) {
logger.warn(key + " : " + message.get(key));
}
Map<String, Object> ack = new HashMap<>();
message.put("message_type", "ack");
message.put("to", message.get("from"));
message.put("message_id", message.get("message_id"));
String json = JSONValue.toJSONString(ack);
Packet request = new GcmPacketExtension(json).toPacket();
connection.sendPacket(request);
}
private void handleBadRegistration(Map<String, Object> message) {
unregistered.mark();
String messageId = (String) message.get("message_id");
if (messageId != null) {
UnacknowledgedMessage unacknowledgedMessage = pendingMessages.remove(messageId);
if (unacknowledgedMessage != null) {
Optional<Account> account = accounts.get(unacknowledgedMessage.getDestinationNumber());
if (account.isPresent()) {
Optional<Device> device = account.get().getDevice(unacknowledgedMessage.getDestinationDeviceId());
if (device.isPresent()) {
device.get().setGcmId(null);
accounts.update(account.get());
}
}
}
}
}
private void handleServerFailure(Map<String, Object> message) {
failure.mark(); failure.mark();
if (result.getErrorCodeName().equals(Constants.ERROR_NOT_REGISTERED)) {
throw new NotPushRegisteredException("Device no longer registered with GCM."); String messageId = (String)message.get("message_id");
} else {
throw new TransientPushFailureException("GCM Failed: " + result.getErrorCodeName()); if (messageId != null) {
UnacknowledgedMessage unacknowledgedMessage = pendingMessages.remove(messageId);
if (unacknowledgedMessage != null) {
sendMessage(messageId, unacknowledgedMessage);
} }
} }
} catch (IOException e) { }
throw new TransientPushFailureException(e);
private void handleClientFailure(Map<String, Object> message) {
failure.mark();
logger.warn("Unrecoverable error: " + message.get("error"));
String messageId = (String)message.get("message_id");
if (messageId != null) {
pendingMessages.remove(messageId);
}
}
private void reconnect() {
try {
this.connection.disconnect();
} catch (SmackException.NotConnectedException e) {
logger.warn("GCMClient", "Disconnect attempt", e);
}
while (true) {
try {
this.connection = connect(senderId, apiKey);
return;
} catch (XMPPException | IOException | SmackException e) {
logger.warn("GCMClient", "Reconnecting", e);
Util.sleep(1000);
}
}
}
private XMPPTCPConnection connect(long senderId, String apiKey)
throws XMPPException, IOException, SmackException
{
ConnectionConfiguration config = new ConnectionConfiguration(GCM_SERVER, GCM_PORT);
config.setSecurityMode(ConnectionConfiguration.SecurityMode.enabled);
config.setReconnectionAllowed(true);
config.setRosterLoadedAtLogin(false);
config.setSendPresence(false);
config.setSocketFactory(SSLSocketFactory.getDefault());
XMPPTCPConnection connection = new XMPPTCPConnection(config);
connection.connect();
connection.addConnectionListener(new LoggingConnectionListener());
connection.addPacketListener(this, new PacketTypeFilter(Message.class));
connection.login(senderId + "@gcm.googleapis.com", apiKey);
return connection;
}
private static class GcmPacketExtensionProvider implements PacketExtensionProvider {
@Override
public PacketExtension parseExtension(XmlPullParser xmlPullParser) throws Exception {
String json = xmlPullParser.nextText();
return new GcmPacketExtension(json);
}
}
private static final class GcmPacketExtension extends DefaultPacketExtension {
private final String json;
public GcmPacketExtension(String json) {
super(GCM_ELEMENT_NAME, GCM_NAMESPACE);
this.json = json;
}
public String getJson() {
return json;
}
@Override
public String toXML() {
return String.format("<%s xmlns=\"%s\">%s</%s>", GCM_ELEMENT_NAME, GCM_NAMESPACE,
StringUtils.escapeForXML(json), GCM_ELEMENT_NAME);
}
public Packet toPacket() {
Message message = new Message();
message.addExtension(this);
return message;
}
}
private class LoggingConnectionListener implements ConnectionListener {
@Override
public void connected(XMPPConnection xmppConnection) {
logger.warn("GCM XMPP Connected.");
}
@Override
public void authenticated(XMPPConnection xmppConnection) {
logger.warn("GCM XMPP Authenticated.");
}
@Override
public void reconnectionSuccessful() {
logger.warn("GCM XMPP Reconnecting..");
Iterator<Map.Entry<String, UnacknowledgedMessage>> iterator =
pendingMessages.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, UnacknowledgedMessage> entry = iterator.next();
iterator.remove();
sendMessage(entry.getKey(), entry.getValue());
}
}
@Override
public void reconnectionFailed(Exception e) {
logger.warn("GCM XMPP Reconnection failed!", e);
reconnect();
}
@Override
public void reconnectingIn(int seconds) {
logger.warn(String.format("GCM XMPP Reconnecting in %d secs", seconds));
}
@Override
public void connectionClosedOnError(Exception e) {
logger.warn("GCM XMPP Connection closed on error.");
}
@Override
public void connectionClosed() {
logger.warn("GCM XMPP Connection closed.");
reconnect();
}
}
private static class UnacknowledgedMessage {
private final String destinationNumber;
private final long destinationDeviceId;
private final String registrationId;
private final PendingMessage pendingMessage;
private UnacknowledgedMessage(String destinationNumber,
long destinationDeviceId,
String registrationId,
PendingMessage pendingMessage)
{
this.destinationNumber = destinationNumber;
this.destinationDeviceId = destinationDeviceId;
this.registrationId = registrationId;
this.pendingMessage = pendingMessage;
}
private String getRegistrationId() {
return registrationId;
}
private PendingMessage getPendingMessage() {
return pendingMessage;
}
public String getDestinationNumber() {
return destinationNumber;
}
public long getDestinationDeviceId() {
return destinationDeviceId;
} }
} }
} }

View File

@@ -18,106 +18,75 @@ package org.whispersystems.textsecuregcm.push;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.GcmConfiguration;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import java.io.IOException; import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
public class PushSender { public class PushSender {
private final Logger logger = LoggerFactory.getLogger(PushSender.class); private final Logger logger = LoggerFactory.getLogger(PushSender.class);
private final AccountsManager accounts;
private final GCMSender gcmSender; private final GCMSender gcmSender;
private final APNSender apnSender; private final APNSender apnSender;
private final WebsocketSender webSocketSender; private final WebsocketSender webSocketSender;
public PushSender(GcmConfiguration gcmConfiguration, public PushSender(GCMSender gcmClient,
ApnConfiguration apnConfiguration, APNSender apnSender,
StoredMessages storedMessages, WebsocketSender websocketSender)
PubSubManager pubSubManager,
AccountsManager accounts)
throws CertificateException, NoSuchAlgorithmException, KeyStoreException, IOException
{ {
this.accounts = accounts; this.gcmSender = gcmClient;
this.webSocketSender = new WebsocketSender(storedMessages, pubSubManager); this.apnSender = apnSender;
this.gcmSender = new GCMSender(gcmConfiguration.getApiKey()); this.webSocketSender = websocketSender;
this.apnSender = new APNSender(pubSubManager, storedMessages,
apnConfiguration.getCertificate(),
apnConfiguration.getKey());
} }
public void sendMessage(Account account, Device device, MessageProtos.OutgoingMessageSignal message) public void sendMessage(Account account, Device device, OutgoingMessageSignal message)
throws NotPushRegisteredException, TransientPushFailureException throws NotPushRegisteredException, TransientPushFailureException
{ {
try { try {
boolean isReceipt = message.getType() == OutgoingMessageSignal.Type.RECEIPT_VALUE;
String signalingKey = device.getSignalingKey(); String signalingKey = device.getSignalingKey();
EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, signalingKey); EncryptedOutgoingMessage encryptedMessage = new EncryptedOutgoingMessage(message, signalingKey);
PendingMessage pendingMessage = new PendingMessage(message.getSource(),
message.getTimestamp(),
isReceipt,
encryptedMessage.serialize());
sendMessage(account, device, encryptedMessage); sendMessage(account, device, pendingMessage);
} catch (CryptoEncodingException e) { } catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e); throw new NotPushRegisteredException(e);
} }
} }
public void sendMessage(Account account, Device device, EncryptedOutgoingMessage message) public void sendMessage(Account account, Device device, PendingMessage pendingMessage)
throws NotPushRegisteredException, TransientPushFailureException throws NotPushRegisteredException, TransientPushFailureException
{ {
if (device.getGcmId() != null) sendGcmMessage(account, device, message); if (device.getGcmId() != null) sendGcmMessage(account, device, pendingMessage);
else if (device.getApnId() != null) sendApnMessage(account, device, message); else if (device.getApnId() != null) sendApnMessage(account, device, pendingMessage);
else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, message); else if (device.getFetchesMessages()) sendWebSocketMessage(account, device, pendingMessage);
else throw new NotPushRegisteredException("No delivery possible!"); else throw new NotPushRegisteredException("No delivery possible!");
} }
private void sendGcmMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) private void sendGcmMessage(Account account, Device device, PendingMessage pendingMessage) {
throws NotPushRegisteredException, TransientPushFailureException String number = account.getNumber();
long deviceId = device.getId();
String registrationId = device.getGcmId();
gcmSender.sendMessage(number, deviceId, registrationId, pendingMessage);
}
private void sendApnMessage(Account account, Device device, PendingMessage outgoingMessage)
throws TransientPushFailureException
{ {
try {
String canonicalId = gcmSender.sendMessage(device.getGcmId(), outgoingMessage);
if (canonicalId != null) {
device.setGcmId(canonicalId);
accounts.update(account);
}
} catch (NotPushRegisteredException e) {
logger.debug("No Such User", e);
device.setGcmId(null);
accounts.update(account);
throw new NotPushRegisteredException(e);
}
}
private void sendApnMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws TransientPushFailureException, NotPushRegisteredException
{
try {
apnSender.sendMessage(account, device, device.getApnId(), outgoingMessage); apnSender.sendMessage(account, device, device.getApnId(), outgoingMessage);
} catch (NotPushRegisteredException e) {
device.setApnId(null);
accounts.update(account);
throw new NotPushRegisteredException(e);
}
} }
private void sendWebSocketMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) private void sendWebSocketMessage(Account account, Device device, PendingMessage outgoingMessage)
throws NotPushRegisteredException
{ {
try {
webSocketSender.sendMessage(account, device, outgoingMessage); webSocketSender.sendMessage(account, device, outgoingMessage);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
} }
} }

View File

@@ -19,26 +19,32 @@ package org.whispersystems.textsecuregcm.push;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import com.fasterxml.jackson.core.JsonProcessingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubMessage; import org.whispersystems.textsecuregcm.storage.PubSubMessage;
import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.List;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
public class WebsocketSender { public class WebsocketSender {
private static final Logger logger = LoggerFactory.getLogger(WebsocketSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter onlineMeter = metricRegistry.meter(name(getClass(), "online")); private final Meter onlineMeter = metricRegistry.meter(name(getClass(), "online"));
private final Meter offlineMeter = metricRegistry.meter(name(getClass(), "offline")); private final Meter offlineMeter = metricRegistry.meter(name(getClass(), "offline"));
private static final ObjectMapper mapper = SystemMapper.getMapper();
private final StoredMessages storedMessages; private final StoredMessages storedMessages;
private final PubSubManager pubSubManager; private final PubSubManager pubSubManager;
@@ -47,22 +53,21 @@ public class WebsocketSender {
this.pubSubManager = pubSubManager; this.pubSubManager = pubSubManager;
} }
public void sendMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) public void sendMessage(Account account, Device device, PendingMessage pendingMessage) {
throws CryptoEncodingException try {
{ String serialized = mapper.writeValueAsString(pendingMessage);
sendMessage(account, device, outgoingMessage.serialize()); WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
} PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serialized);
private void sendMessage(Account account, Device device, String serializedMessage) {
WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage);
if (pubSubManager.publish(address, pubSubMessage)) { if (pubSubManager.publish(address, pubSubMessage)) {
onlineMeter.mark(); onlineMeter.mark();
} else { } else {
offlineMeter.mark(); offlineMeter.mark();
storedMessages.insert(account.getId(), device.getId(), serializedMessage); storedMessages.insert(address, pendingMessage);
pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null)); pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
} }
} catch (JsonProcessingException e) {
logger.warn("WebsocketSender", "Unable to serialize json", e);
}
} }
} }

View File

@@ -22,16 +22,12 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import java.io.Serializable;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
public class Account implements Serializable { public class Account {
public static final int MEMCACHE_VERION = 3; public static final int MEMCACHE_VERION = 5;
@JsonIgnore
private long id;
@JsonProperty @JsonProperty
private String number; private String number;
@@ -57,14 +53,6 @@ public class Account implements Serializable {
this.devices = devices; this.devices = devices;
} }
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public Optional<Device> getAuthenticatedDevice() { public Optional<Device> getAuthenticatedDevice() {
return authenticatedDevice; return authenticatedDevice;
} }

View File

@@ -33,6 +33,7 @@ import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction; import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper; import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper; import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException; import java.io.IOException;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
@@ -51,12 +52,7 @@ public abstract class Accounts {
private static final String NUMBER = "number"; private static final String NUMBER = "number";
private static final String DATA = "data"; private static final String DATA = "data";
private static final ObjectMapper mapper = new ObjectMapper(); private static final ObjectMapper mapper = SystemMapper.getMapper();
static {
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
}
@SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))") @SqlUpdate("INSERT INTO accounts (" + NUMBER + ", " + DATA + ") VALUES (:number, CAST(:data AS json))")
@GetGeneratedKeys @GetGeneratedKeys
@@ -89,6 +85,9 @@ public abstract class Accounts {
return insertStep(account); return insertStep(account);
} }
@SqlUpdate("VACUUM accounts")
public abstract void vacuum();
public static class AccountMapper implements ResultSetMapper<Account> { public static class AccountMapper implements ResultSetMapper<Account> {
@Override @Override
public Account map(int i, ResultSet resultSet, StatementContext statementContext) public Account map(int i, ResultSet resultSet, StatementContext statementContext)
@@ -96,7 +95,7 @@ public abstract class Accounts {
{ {
try { try {
Account account = mapper.readValue(resultSet.getString(DATA), Account.class); Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
account.setId(resultSet.getLong(ID)); // account.setId(resultSet.getLong(ID));
return account; return account;
} catch (IOException e) { } catch (IOException e) {

View File

@@ -17,19 +17,30 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import net.spy.memcached.MemcachedClient; import net.spy.memcached.MemcachedClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
public class AccountsManager { public class AccountsManager {
private final Logger logger = LoggerFactory.getLogger(AccountsManager.class);
private final Accounts accounts; private final Accounts accounts;
private final MemcachedClient memcachedClient; private final MemcachedClient memcachedClient;
private final DirectoryManager directory; private final DirectoryManager directory;
private final ObjectMapper mapper;
public AccountsManager(Accounts accounts, public AccountsManager(Accounts accounts,
DirectoryManager directory, DirectoryManager directory,
@@ -38,6 +49,7 @@ public class AccountsManager {
this.accounts = accounts; this.accounts = accounts;
this.directory = directory; this.directory = directory;
this.memcachedClient = memcachedClient; this.memcachedClient = memcachedClient;
this.mapper = SystemMapper.getMapper();
} }
public long getCount() { public long getCount() {
@@ -54,40 +66,28 @@ public class AccountsManager {
public void create(Account account) { public void create(Account account) {
accounts.create(account); accounts.create(account);
memcacheSet(account.getNumber(), account);
if (memcachedClient != null) {
memcachedClient.set(getKey(account.getNumber()), 0, account);
}
updateDirectory(account); updateDirectory(account);
} }
public void update(Account account) { public void update(Account account) {
if (memcachedClient != null) { memcacheSet(account.getNumber(), account);
memcachedClient.set(getKey(account.getNumber()), 0, account);
}
accounts.update(account); accounts.update(account);
updateDirectory(account); updateDirectory(account);
} }
public Optional<Account> get(String number) { public Optional<Account> get(String number) {
Account account = null; Optional<Account> account = memcacheGet(number);
if (memcachedClient != null) { if (!account.isPresent()) {
account = (Account)memcachedClient.get(getKey(number)); account = Optional.fromNullable(accounts.get(number));
}
if (account == null) { if (account.isPresent()) {
account = accounts.get(number); memcacheSet(number, account.get());
if (account != null && memcachedClient != null) {
memcachedClient.set(getKey(number), 0, account);
} }
} }
if (account != null) return Optional.of(account); return account;
else return Optional.absent();
} }
public boolean isRelayListed(String number) { public boolean isRelayListed(String number) {
@@ -111,4 +111,30 @@ public class AccountsManager {
return Account.class.getSimpleName() + Account.MEMCACHE_VERION + number; return Account.class.getSimpleName() + Account.MEMCACHE_VERION + number;
} }
private void memcacheSet(String number, Account account) {
if (memcachedClient != null) {
try {
String json = mapper.writeValueAsString(account);
memcachedClient.set(getKey(number), 0, json);
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
}
private Optional<Account> memcacheGet(String number) {
if (memcachedClient == null) return Optional.absent();
try {
String json = (String)memcachedClient.get(getKey(number));
if (json != null) return Optional.of(mapper.readValue(json, Account.class));
else return Optional.absent();
} catch (IOException e) {
logger.warn("AccountsManager", "Deserialization error", e);
return Optional.absent();
}
}
} }

View File

@@ -19,11 +19,12 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import java.io.Serializable; import java.io.Serializable;
public class Device implements Serializable { public class Device {
public static final long MASTER_ID = 1; public static final long MASTER_ID = 1;
@@ -51,11 +52,15 @@ public class Device implements Serializable {
@JsonProperty @JsonProperty
private int registrationId; private int registrationId;
@JsonProperty
private SignedPreKey signedPreKey;
public Device() {} public Device() {}
public Device(long id, String authToken, String salt, public Device(long id, String authToken, String salt,
String signalingKey, String gcmId, String apnId, String signalingKey, String gcmId, String apnId,
boolean fetchesMessages, int registrationId) boolean fetchesMessages, int registrationId,
SignedPreKey signedPreKey)
{ {
this.id = id; this.id = id;
this.authToken = authToken; this.authToken = authToken;
@@ -65,6 +70,7 @@ public class Device implements Serializable {
this.apnId = apnId; this.apnId = apnId;
this.fetchesMessages = fetchesMessages; this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId; this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
} }
public String getApnId() { public String getApnId() {
@@ -131,4 +137,12 @@ public class Device implements Serializable {
public void setRegistrationId(int registrationId) { public void setRegistrationId(int registrationId) {
this.registrationId = registrationId; this.registrationId = registrationId;
} }
public SignedPreKey getSignedPreKey() {
return signedPreKey;
}
public void setSignedPreKey(SignedPreKey signedPreKey) {
this.signedPreKey = signedPreKey;
}
} }

View File

@@ -76,6 +76,11 @@ public class DirectoryManager {
pipeline.hset(DIRECTORY_KEY, contact.getToken(), new Gson().toJson(tokenValue).getBytes()); pipeline.hset(DIRECTORY_KEY, contact.getToken(), new Gson().toJson(tokenValue).getBytes());
} }
public PendingClientContact get(BatchOperationHandle handle, byte[] token) {
Pipeline pipeline = handle.pipeline;
return new PendingClientContact(token, pipeline.hget(DIRECTORY_KEY, token));
}
public Optional<ClientContact> get(byte[] token) { public Optional<ClientContact> get(byte[] token) {
Jedis jedis = redisPool.getResource(); Jedis jedis = redisPool.getResource();
@@ -162,4 +167,26 @@ public class DirectoryManager {
this.supportsSms = supportsSms; this.supportsSms = supportsSms;
} }
} }
public static class PendingClientContact {
private final byte[] token;
private final Response<byte[]> response;
PendingClientContact(byte[] token, Response<byte[]> response) {
this.token = token;
this.response = response;
}
public Optional<ClientContact> get() {
byte[] result = response.get();
if (result == null) {
return Optional.absent();
}
TokenValue tokenValue = new Gson().fromJson(new String(result), TokenValue.class);
return Optional.of(new ClientContact(token, tokenValue.relay, tokenValue.supportsSms));
}
}
} }

View File

@@ -0,0 +1,46 @@
package org.whispersystems.textsecuregcm.storage;
public class KeyRecord {
private long id;
private String number;
private long deviceId;
private long keyId;
private String publicKey;
private boolean lastResort;
public KeyRecord(long id, String number, long deviceId, long keyId,
String publicKey, boolean lastResort)
{
this.id = id;
this.number = number;
this.deviceId = deviceId;
this.keyId = keyId;
this.publicKey = publicKey;
this.lastResort = lastResort;
}
public long getId() {
return id;
}
public String getNumber() {
return number;
}
public long getDeviceId() {
return deviceId;
}
public long getKeyId() {
return keyId;
}
public String getPublicKey() {
return publicKey;
}
public boolean isLastResort() {
return lastResort;
}
}

View File

@@ -30,8 +30,7 @@ import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.Transaction; import org.skife.jdbi.v2.sqlobject.Transaction;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper; import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper; import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKeyBase;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType; import java.lang.annotation.ElementType;
@@ -40,6 +39,7 @@ import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.LinkedList;
import java.util.List; import java.util.List;
public abstract class Keys { public abstract class Keys {
@@ -52,67 +52,69 @@ public abstract class Keys {
@SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " + @SqlBatch("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " +
"(:number, :device_id, :key_id, :public_key, :last_resort)") "(:number, :device_id, :key_id, :public_key, :last_resort)")
abstract void append(@PreKeyBinder List<PreKey> preKeys); abstract void append(@PreKeyBinder List<KeyRecord> preKeys);
@SqlUpdate("INSERT INTO keys (number, device_id, key_id, public_key, last_resort) VALUES " +
"(:number, :device_id, :key_id, :public_key, :last_resort)")
abstract void append(@PreKeyBinder PreKey preKey);
@SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE") @SqlQuery("SELECT * FROM keys WHERE number = :number AND device_id = :device_id ORDER BY key_id ASC FOR UPDATE")
@Mapper(PreKeyMapper.class) @Mapper(PreKeyMapper.class)
abstract PreKey retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId); abstract KeyRecord retrieveFirst(@Bind("number") String number, @Bind("device_id") long deviceId);
@SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC") @SqlQuery("SELECT DISTINCT ON (number, device_id) * FROM keys WHERE number = :number ORDER BY number, device_id, key_id ASC")
@Mapper(PreKeyMapper.class) @Mapper(PreKeyMapper.class)
abstract List<PreKey> retrieveFirst(@Bind("number") String number); abstract List<KeyRecord> retrieveFirst(@Bind("number") String number);
@SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id") @SqlQuery("SELECT COUNT(*) FROM keys WHERE number = :number AND device_id = :device_id")
public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId); public abstract int getCount(@Bind("number") String number, @Bind("device_id") long deviceId);
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public void store(String number, long deviceId, List<PreKey> keys, PreKey lastResortKey) { public void store(String number, long deviceId, List<? extends PreKeyBase> keys, PreKeyBase lastResortKey) {
for (PreKey key : keys) { List<KeyRecord> records = new LinkedList<>();
key.setNumber(number);
key.setDeviceId(deviceId); for (PreKeyBase key : keys) {
records.add(new KeyRecord(0, number, deviceId, key.getKeyId(), key.getPublicKey(), false));
} }
lastResortKey.setNumber(number); records.add(new KeyRecord(0, number, deviceId, lastResortKey.getKeyId(),
lastResortKey.setDeviceId(deviceId); lastResortKey.getPublicKey(), true));
lastResortKey.setLastResort(true);
removeKeys(number, deviceId); removeKeys(number, deviceId);
append(keys); append(records);
append(lastResortKey);
} }
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<UnstructuredPreKeyList> get(String number, long deviceId) { public Optional<List<KeyRecord>> get(String number, long deviceId) {
PreKey preKey = retrieveFirst(number, deviceId); final KeyRecord record = retrieveFirst(number, deviceId);
if (preKey != null && !preKey.isLastResort()) { if (record != null && !record.isLastResort()) {
removeKey(preKey.getId()); removeKey(record.getId());
} else if (record == null) {
return Optional.absent();
} }
if (preKey != null) return Optional.of(new UnstructuredPreKeyList(preKey)); List<KeyRecord> results = new LinkedList<>();
else return Optional.absent(); results.add(record);
return Optional.of(results);
} }
@Transaction(TransactionIsolationLevel.SERIALIZABLE) @Transaction(TransactionIsolationLevel.SERIALIZABLE)
public Optional<UnstructuredPreKeyList> get(String number) { public Optional<List<KeyRecord>> get(String number) {
List<PreKey> preKeys = retrieveFirst(number); List<KeyRecord> preKeys = retrieveFirst(number);
if (preKeys != null) { if (preKeys != null) {
for (PreKey preKey : preKeys) { for (KeyRecord preKey : preKeys) {
if (!preKey.isLastResort()) { if (!preKey.isLastResort()) {
removeKey(preKey.getId()); removeKey(preKey.getId());
} }
} }
} }
if (preKeys != null) return Optional.of(new UnstructuredPreKeyList(preKeys)); if (preKeys != null) return Optional.of(preKeys);
else return Optional.absent(); else return Optional.absent();
} }
@SqlUpdate("VACUUM keys")
public abstract void vacuum();
@BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class) @BindingAnnotation(PreKeyBinder.PreKeyBinderFactory.class)
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER}) @Target({ElementType.PARAMETER})
@@ -120,16 +122,16 @@ public abstract class Keys {
public static class PreKeyBinderFactory implements BinderFactory { public static class PreKeyBinderFactory implements BinderFactory {
@Override @Override
public Binder build(Annotation annotation) { public Binder build(Annotation annotation) {
return new Binder<PreKeyBinder, PreKey>() { return new Binder<PreKeyBinder, KeyRecord>() {
@Override @Override
public void bind(SQLStatement<?> sql, PreKeyBinder accountBinder, PreKey preKey) public void bind(SQLStatement<?> sql, PreKeyBinder accountBinder, KeyRecord record)
{ {
sql.bind("id", preKey.getId()); sql.bind("id", record.getId());
sql.bind("number", preKey.getNumber()); sql.bind("number", record.getNumber());
sql.bind("device_id", preKey.getDeviceId()); sql.bind("device_id", record.getDeviceId());
sql.bind("key_id", preKey.getKeyId()); sql.bind("key_id", record.getKeyId());
sql.bind("public_key", preKey.getPublicKey()); sql.bind("public_key", record.getPublicKey());
sql.bind("last_resort", preKey.isLastResort() ? 1 : 0); sql.bind("last_resort", record.isLastResort() ? 1 : 0);
} }
}; };
} }
@@ -137,14 +139,14 @@ public abstract class Keys {
} }
public static class PreKeyMapper implements ResultSetMapper<PreKey> { public static class PreKeyMapper implements ResultSetMapper<KeyRecord> {
@Override @Override
public PreKey map(int i, ResultSet resultSet, StatementContext statementContext) public KeyRecord map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException throws SQLException
{ {
return new PreKey(resultSet.getLong("id"), resultSet.getString("number"), resultSet.getLong("device_id"), return new KeyRecord(resultSet.getLong("id"), resultSet.getString("number"),
resultSet.getLong("key_id"), resultSet.getString("public_key"), resultSet.getLong("device_id"), resultSet.getLong("key_id"),
resultSet.getInt("last_resort") == 1); resultSet.getString("public_key"), resultSet.getInt("last_resort") == 1);
} }
} }

View File

@@ -31,4 +31,7 @@ public interface PendingAccounts {
@SqlUpdate("DELETE FROM pending_accounts WHERE number = :number") @SqlUpdate("DELETE FROM pending_accounts WHERE number = :number")
void remove(@Bind("number") String number); void remove(@Bind("number") String number);
@SqlUpdate("VACUUM pending_accounts")
public void vacuum();
} }

View File

@@ -34,35 +34,46 @@ public class PendingAccountsManager {
} }
public void store(String number, String code) { public void store(String number, String code) {
if (memcachedClient != null) { memcacheSet(number, code);
memcachedClient.set(MEMCACHE_PREFIX + number, 0, code);
}
pendingAccounts.insert(number, code); pendingAccounts.insert(number, code);
} }
public void remove(String number) { public void remove(String number) {
if (memcachedClient != null) memcacheDelete(number);
memcachedClient.delete(MEMCACHE_PREFIX + number);
pendingAccounts.remove(number); pendingAccounts.remove(number);
} }
public Optional<String> getCodeForNumber(String number) { public Optional<String> getCodeForNumber(String number) {
String code = null; Optional<String> code = memcacheGet(number);
if (memcachedClient != null) { if (!code.isPresent()) {
code = (String)memcachedClient.get(MEMCACHE_PREFIX + number); code = Optional.fromNullable(pendingAccounts.getCodeForNumber(number));
if (code.isPresent()) {
memcacheSet(number, code.get());
}
} }
if (code == null) { return code;
code = pendingAccounts.getCodeForNumber(number); }
if (code != null && memcachedClient != null) { private void memcacheSet(String number, String code) {
if (memcachedClient != null) {
memcachedClient.set(MEMCACHE_PREFIX + number, 0, code); memcachedClient.set(MEMCACHE_PREFIX + number, 0, code);
} }
} }
if (code != null) return Optional.of(code); private Optional<String> memcacheGet(String number) {
else return Optional.absent(); if (memcachedClient != null) {
return Optional.fromNullable((String)memcachedClient.get(MEMCACHE_PREFIX + number));
} else {
return Optional.absent();
}
}
private void memcacheDelete(String number) {
if (memcachedClient != null) {
memcachedClient.delete(MEMCACHE_PREFIX + number);
}
} }
} }

View File

@@ -34,37 +34,47 @@ public class PendingDevicesManager {
} }
public void store(String number, String code) { public void store(String number, String code) {
if (memcachedClient != null) { memcacheSet(number, code);
memcachedClient.set(MEMCACHE_PREFIX + number, 0, code);
}
pendingDevices.insert(number, code); pendingDevices.insert(number, code);
} }
public void remove(String number) { public void remove(String number) {
if (memcachedClient != null) { memcacheDelete(number);
memcachedClient.delete(MEMCACHE_PREFIX + number);
}
pendingDevices.remove(number); pendingDevices.remove(number);
} }
public Optional<String> getCodeForNumber(String number) { public Optional<String> getCodeForNumber(String number) {
String code = null; Optional<String> code = memcacheGet(number);
if (memcachedClient != null) { if (!code.isPresent()) {
code = (String)memcachedClient.get(MEMCACHE_PREFIX + number); code = Optional.fromNullable(pendingDevices.getCodeForNumber(number));
if (code.isPresent()) {
memcacheSet(number, code.get());
}
} }
if (code == null) { return code;
code = pendingDevices.getCodeForNumber(number); }
if (code != null && memcachedClient != null) { private void memcacheSet(String number, String code) {
if (memcachedClient != null) {
memcachedClient.set(MEMCACHE_PREFIX + number, 0, code); memcachedClient.set(MEMCACHE_PREFIX + number, 0, code);
} }
} }
if (code != null) return Optional.of(code); private Optional<String> memcacheGet(String number) {
else return Optional.absent(); if (memcachedClient != null) {
return Optional.fromNullable((String)memcachedClient.get(MEMCACHE_PREFIX + number));
} else {
return Optional.absent();
} }
}
private void memcacheDelete(String number) {
if (memcachedClient != null) {
memcachedClient.delete(MEMCACHE_PREFIX + number);
}
}
} }

View File

@@ -4,6 +4,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException; import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
@@ -17,10 +18,12 @@ import redis.clients.jedis.JedisPubSub;
public class PubSubManager { public class PubSubManager {
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = SystemMapper.getMapper();
private final SubscriptionListener baseListener = new SubscriptionListener(); private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<WebsocketAddress, PubSubListener> listeners = new HashMap<>(); private final Map<String, PubSubListener> listeners = new HashMap<>();
private final JedisPool jedisPool; private final JedisPool jedisPool;
private boolean subscribed = false; private boolean subscribed = false;
@@ -32,25 +35,29 @@ public class PubSubManager {
} }
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) { public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
listeners.put(address, listener); listeners.put(address.serialize(), listener);
baseListener.subscribe(address.toString()); baseListener.subscribe(address.serialize());
} }
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) { public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
if (listeners.get(address) == listener) { if (listeners.get(address.serialize()) == listener) {
listeners.remove(address); listeners.remove(address.serialize());
baseListener.unsubscribe(address.toString()); baseListener.unsubscribe(address.serialize());
} }
} }
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) { public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize(), message);
}
private synchronized boolean publish(String channel, PubSubMessage message) {
try { try {
String serialized = mapper.writeValueAsString(message); String serialized = mapper.writeValueAsString(message);
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
return jedis.publish(address.toString(), serialized) != 0; return jedis.publish(channel, serialized) != 0;
} finally { } finally {
if (jedis != null) if (jedis != null)
jedisPool.returnResource(jedis); jedisPool.returnResource(jedis);
@@ -78,7 +85,7 @@ public class PubSubManager {
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString()); jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
logger.warn("**** Unsubscribed from holding channel!!! ******"); logger.warn("**** Unsubscribed from holding channel!!! ******");
} finally { } finally {
if (jedis != null) if (jedis != null)
@@ -94,7 +101,7 @@ public class PubSubManager {
for (;;) { for (;;) {
try { try {
Thread.sleep(20000); Thread.sleep(20000);
publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo")); publish(KEEPALIVE_CHANNEL, new PubSubMessage(0, "foo"));
} catch (InterruptedException e) { } catch (InterruptedException e) {
throw new AssertionError(e); throw new AssertionError(e);
} }
@@ -108,18 +115,15 @@ public class PubSubManager {
@Override @Override
public void onMessage(String channel, String message) { public void onMessage(String channel, String message) {
try { try {
WebsocketAddress address = new WebsocketAddress(channel);
PubSubListener listener; PubSubListener listener;
synchronized (PubSubManager.this) { synchronized (PubSubManager.this) {
listener = listeners.get(address); listener = listeners.get(channel);
} }
if (listener != null) { if (listener != null) {
listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class)); listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
} }
} catch (InvalidWebsocketAddressException e) {
logger.warn("Address", e);
} catch (IOException e) { } catch (IOException e) {
logger.warn("IOE", e); logger.warn("IOE", e);
} }
@@ -132,18 +136,12 @@ public class PubSubManager {
@Override @Override
public void onSubscribe(String channel, int count) { public void onSubscribe(String channel, int count) {
try { if (KEEPALIVE_CHANNEL.equals(channel)) {
WebsocketAddress address = new WebsocketAddress(channel);
if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
synchronized (PubSubManager.this) { synchronized (PubSubManager.this) {
subscribed = true; subscribed = true;
PubSubManager.this.notifyAll(); PubSubManager.this.notifyAll();
} }
} }
} catch (InvalidWebsocketAddressException e) {
logger.warn("Weird address", e);
}
} }
@Override @Override

View File

@@ -19,9 +19,17 @@ package org.whispersystems.textsecuregcm.storage;
import com.codahale.metrics.Histogram; import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.io.IOException;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@@ -31,9 +39,13 @@ import redis.clients.jedis.JedisPool;
public class StoredMessages { public class StoredMessages {
private static final Logger logger = LoggerFactory.getLogger(StoredMessages.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Histogram queueSizeHistogram = metricRegistry.histogram(name(getClass(), "queue_size")); private final Histogram queueSizeHistogram = metricRegistry.histogram(name(getClass(), "queue_size"));
private static final ObjectMapper mapper = SystemMapper.getMapper();
private static final String QUEUE_PREFIX = "msgs"; private static final String QUEUE_PREFIX = "msgs";
private final JedisPool jedisPool; private final JedisPool jedisPool;
@@ -42,34 +54,54 @@ public class StoredMessages {
this.jedisPool = jedisPool; this.jedisPool = jedisPool;
} }
public void insert(long accountId, long deviceId, String message) { public void clear(WebsocketAddress address) {
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
jedis.del(getKey(address));
long queueSize = jedis.lpush(getKey(accountId, deviceId), message);
queueSizeHistogram.update(queueSize);
if (queueSize > 1000) {
jedis.ltrim(getKey(accountId, deviceId), 0, 999);
}
} finally { } finally {
if (jedis != null) if (jedis != null)
jedisPool.returnResource(jedis); jedisPool.returnResource(jedis);
} }
} }
public List<String> getMessagesForDevice(long accountId, long deviceId) { public void insert(WebsocketAddress address, PendingMessage message) {
List<String> messages = new LinkedList<>(); Jedis jedis = null;
try {
jedis = jedisPool.getResource();
String serializedMessage = mapper.writeValueAsString(message);
long queueSize = jedis.lpush(getKey(address), serializedMessage);
queueSizeHistogram.update(queueSize);
if (queueSize > 1000) {
jedis.ltrim(getKey(address), 0, 999);
}
} catch (JsonProcessingException e) {
logger.warn("StoredMessages", "Unable to store correctly", e);
} finally {
if (jedis != null)
jedisPool.returnResource(jedis);
}
}
public List<PendingMessage> getMessagesForDevice(WebsocketAddress address) {
List<PendingMessage> messages = new LinkedList<>();
Jedis jedis = null; Jedis jedis = null;
try { try {
jedis = jedisPool.getResource(); jedis = jedisPool.getResource();
String message; String message;
while ((message = jedis.rpop(QUEUE_PREFIX + accountId + ":" + deviceId)) != null) { while ((message = jedis.rpop(getKey(address))) != null) {
messages.add(message); try {
messages.add(mapper.readValue(message, PendingMessage.class));
} catch (IOException e) {
logger.warn("StoredMessages", "Not a valid PendingMessage", e);
}
} }
return messages; return messages;
@@ -79,8 +111,8 @@ public class StoredMessages {
} }
} }
private String getKey(long accountId, long deviceId) { private String getKey(WebsocketAddress address) {
return QUEUE_PREFIX + ":" + accountId + ":" + deviceId; return QUEUE_PREFIX + ":" + address.serialize();
} }
} }

View File

@@ -1,41 +0,0 @@
/**
* Copyright (C) 2014 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.util;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
public class CORSHeaderFilter implements Filter {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
if (response instanceof HttpServletResponse) {
((HttpServletResponse) response).addHeader("Access-Control-Allow-Origin", "*");
((HttpServletResponse) response).addHeader("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE");
((HttpServletResponse) response).addHeader("Access-Control-Allow-Headers", "Authorization, Content-type");
}
chain.doFilter(request, response);
}
@Override public void init(FilterConfig filterConfig) throws ServletException { }
@Override public void destroy() { }
}

View File

@@ -0,0 +1,20 @@
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
public class SystemMapper {
private static final ObjectMapper mapper = new ObjectMapper();
static {
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
}
public static ObjectMapper getMapper() {
return mapper;
}
}

View File

@@ -83,4 +83,40 @@ public class Util {
return result; return result;
} }
public static byte[][] split(byte[] input, int firstLength, int secondLength) {
byte[][] parts = new byte[2][];
parts[0] = new byte[firstLength];
System.arraycopy(input, 0, parts[0], 0, firstLength);
parts[1] = new byte[secondLength];
System.arraycopy(input, firstLength, parts[1], 0, secondLength);
return parts;
}
public static byte[][] split(byte[] input, int firstLength, int secondLength, int thirdLength, int fourthLength) {
byte[][] parts = new byte[4][];
parts[0] = new byte[firstLength];
System.arraycopy(input, 0, parts[0], 0, firstLength);
parts[1] = new byte[secondLength];
System.arraycopy(input, firstLength, parts[1], 0, secondLength);
parts[2] = new byte[thirdLength];
System.arraycopy(input, firstLength + secondLength, parts[2], 0, thirdLength);
parts[3] = new byte[fourthLength];
System.arraycopy(input, firstLength + secondLength + thirdLength, parts[3], 0, fourthLength);
return parts;
}
public static void sleep(int i) {
try {
Thread.sleep(i);
} catch (InterruptedException ie) {}
}
} }

View File

@@ -0,0 +1,65 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
public class ConnectListener implements WebSocketConnectListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final AccountsManager accountsManager;
private final PushSender pushSender;
private final StoredMessages storedMessages;
private final PubSubManager pubSubManager;
public ConnectListener(AccountsManager accountsManager, PushSender pushSender,
StoredMessages storedMessages, PubSubManager pubSubManager)
{
this.accountsManager = accountsManager;
this.pushSender = pushSender;
this.storedMessages = storedMessages;
this.pubSubManager = pubSubManager;
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
Optional<Account> account = context.getAuthenticated(Account.class);
if (!account.isPresent()) {
logger.debug("WS Connection with no authentication...");
context.getClient().close(4001, "Authentication failed");
return;
}
Optional<Device> device = account.get().getAuthenticatedDevice();
if (!device.isPresent()) {
logger.debug("WS Connection with no authenticated device...");
context.getClient().close(4001, "Device authentication failed");
return;
}
final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
storedMessages, pubSubManager,
account.get(), device.get(),
context.getClient());
connection.onConnected();
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
connection.onConnectionLost();
}
});
}
}

View File

@@ -0,0 +1,43 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.common.base.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.Map;
import io.dropwizard.auth.basic.BasicCredentials;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Account> {
private final AccountAuthenticator accountAuthenticator;
public WebSocketAccountAuthenticator(AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
}
@Override
public Optional<Account> authenticate(UpgradeRequest request) throws AuthenticationException {
try {
Map<String, String[]> parameters = request.getParameterMap();
String[] usernames = parameters.get("login");
String[] passwords = parameters.get("password");
if (usernames == null || usernames.length == 0 ||
passwords == null || passwords.length == 0)
{
return Optional.absent();
}
BasicCredentials credentials = new BasicCredentials(usernames[0], passwords[0]);
return accountAuthenticator.authenticate(credentials);
} catch (io.dropwizard.auth.AuthenticationException e) {
throw new AuthenticationException(e);
}
}
}

View File

@@ -0,0 +1,160 @@
package org.whispersystems.textsecuregcm.websocket;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubListener;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubMessage;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.List;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
public class WebSocketConnection implements PubSubListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final ObjectMapper objectMapper = SystemMapper.getMapper();
private final AccountsManager accountsManager;
private final PushSender pushSender;
private final StoredMessages storedMessages;
private final PubSubManager pubSubManager;
private final Account account;
private final Device device;
private final WebsocketAddress address;
private final WebSocketClient client;
public WebSocketConnection(AccountsManager accountsManager,
PushSender pushSender,
StoredMessages storedMessages,
PubSubManager pubSubManager,
Account account,
Device device,
WebSocketClient client)
{
this.accountsManager = accountsManager;
this.pushSender = pushSender;
this.storedMessages = storedMessages;
this.pubSubManager = pubSubManager;
this.account = account;
this.device = device;
this.client = client;
this.address = new WebsocketAddress(account.getNumber(), device.getId());
}
public void onConnected() {
pubSubManager.subscribe(address, this);
processStoredMessages();
}
public void onConnectionLost() {
pubSubManager.unsubscribe(address, this);
}
@Override
public void onPubSubMessage(PubSubMessage message) {
try {
switch (message.getType()) {
case PubSubMessage.TYPE_QUERY_DB:
processStoredMessages();
break;
case PubSubMessage.TYPE_DELIVER:
PendingMessage pendingMessage = objectMapper.readValue(message.getContents(),
PendingMessage.class);
sendMessage(pendingMessage);
break;
default:
logger.warn("Unknown pubsub message: " + message.getType());
}
} catch (IOException e) {
logger.warn("Error deserializing PendingMessage", e);
}
}
private void sendMessage(final PendingMessage message) {
String content = message.getEncryptedOutgoingMessage();
Optional<byte[]> body = Optional.fromNullable(content.getBytes());
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/api/v1/message", body);
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(@Nullable WebSocketResponseMessage response) {
if (isSuccessResponse(response) && !message.isReceipt()) {
sendDeliveryReceiptFor(message);
} else if (!isSuccessResponse(response)) {
requeueMessage(message);
}
}
@Override
public void onFailure(@Nonnull Throwable throwable) {
requeueMessage(message);
}
private boolean isSuccessResponse(WebSocketResponseMessage response) {
return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
}
});
}
private void requeueMessage(PendingMessage message) {
try {
pushSender.sendMessage(account, device, message);
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("requeueMessage", e);
storedMessages.insert(address, message);
}
}
private void sendDeliveryReceiptFor(PendingMessage message) {
try {
Optional<Account> source = accountsManager.get(message.getSender());
if (!source.isPresent()) {
logger.warn("Source account disappeared? (%s)", message.getSender());
return;
}
OutgoingMessageSignal.Builder receipt =
OutgoingMessageSignal.newBuilder()
.setSource(account.getNumber())
.setSourceDevice((int) device.getId())
.setTimestamp(message.getMessageId())
.setType(OutgoingMessageSignal.Type.RECEIPT_VALUE);
for (Device device : source.get().getDevices()) {
pushSender.sendMessage(source.get(), device, receipt.build());
}
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("sendDeliveryReceiptFor", "Delivery receipet", e);
}
}
private void processStoredMessages() {
List<PendingMessage> messages = storedMessages.getMessagesForDevice(address);
for (PendingMessage message : messages) {
sendMessage(message);
}
}
}

View File

@@ -2,39 +2,20 @@ package org.whispersystems.textsecuregcm.websocket;
public class WebsocketAddress { public class WebsocketAddress {
private final long accountId; private final String number;
private final long deviceId; private final long deviceId;
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException { public WebsocketAddress(String number, long deviceId) {
try { this.number = number;
String[] parts = serialized.split(":");
if (parts == null || parts.length != 2) {
throw new InvalidWebsocketAddressException(serialized);
}
this.accountId = Long.parseLong(parts[0]);
this.deviceId = Long.parseLong(parts[1]);
} catch (NumberFormatException e) {
throw new InvalidWebsocketAddressException(e);
}
}
public WebsocketAddress(long accountId, long deviceId) {
this.accountId = accountId;
this.deviceId = deviceId; this.deviceId = deviceId;
} }
public long getAccountId() { public String serialize() {
return accountId; return number + ":" + deviceId;
}
public long getDeviceId() {
return deviceId;
} }
public String toString() { public String toString() {
return accountId + ":" + deviceId; return serialize();
} }
@Override @Override
@@ -45,13 +26,13 @@ public class WebsocketAddress {
WebsocketAddress that = (WebsocketAddress)other; WebsocketAddress that = (WebsocketAddress)other;
return return
this.accountId == that.accountId && this.number.equals(that.number) &&
this.deviceId == that.deviceId; this.deviceId == that.deviceId;
} }
@Override @Override
public int hashCode() { public int hashCode() {
return (int)accountId ^ (int)deviceId; return number.hashCode() ^ (int)deviceId;
} }
} }

View File

@@ -1,47 +0,0 @@
package org.whispersystems.textsecuregcm.websocket;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.controllers.WebsocketController;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
public class WebsocketControllerFactory extends WebSocketServlet implements WebSocketCreator {
private final Logger logger = LoggerFactory.getLogger(WebsocketControllerFactory.class);
private final PushSender pushSender;
private final StoredMessages storedMessages;
private final PubSubManager pubSubManager;
private final AccountAuthenticator accountAuthenticator;
public WebsocketControllerFactory(AccountAuthenticator accountAuthenticator,
PushSender pushSender,
StoredMessages storedMessages,
PubSubManager pubSubManager)
{
this.accountAuthenticator = accountAuthenticator;
this.pushSender = pushSender;
this.storedMessages = storedMessages;
this.pubSubManager = pubSubManager;
}
@Override
public void configure(WebSocketServletFactory factory) {
factory.setCreator(this);
}
@Override
public Object createWebSocket(UpgradeRequest upgradeRequest, UpgradeResponse upgradeResponse) {
return new WebsocketController(accountAuthenticator, pushSender, pubSubManager, storedMessages);
}
}

View File

@@ -1,18 +0,0 @@
package org.whispersystems.textsecuregcm.websocket;
import com.fasterxml.jackson.annotation.JsonProperty;
public class WebsocketMessage {
@JsonProperty
private long id;
@JsonProperty
private String message;
public WebsocketMessage(long id, String message) {
this.id = id;
this.message = message;
}
}

View File

@@ -27,12 +27,14 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle; import org.whispersystems.textsecuregcm.storage.DirectoryManager.BatchOperationHandle;
import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedList;
import java.util.List; import java.util.List;
import static org.whispersystems.textsecuregcm.storage.DirectoryManager.PendingClientContact;
public class DirectoryUpdater { public class DirectoryUpdater {
private final Logger logger = LoggerFactory.getLogger(DirectoryUpdater.class); private final Logger logger = LoggerFactory.getLogger(DirectoryUpdater.class);
@@ -82,13 +84,14 @@ public class DirectoryUpdater {
public void updateFromPeers() { public void updateFromPeers() {
logger.info("Updating peer directories."); logger.info("Updating peer directories.");
int contactsAdded = 0;
int contactsRemoved = 0;
List<FederatedClient> clients = federatedClientManager.getClients(); List<FederatedClient> clients = federatedClientManager.getClients();
for (FederatedClient client : clients) { for (FederatedClient client : clients) {
logger.info("Updating directory from peer: " + client.getPeerName()); logger.info("Updating directory from peer: " + client.getPeerName());
// BatchOperationHandle handle = directory.startBatchOperation();
try {
int userCount = client.getUserCount(); int userCount = client.getUserCount();
int retrieved = 0; int retrieved = 0;
@@ -96,40 +99,54 @@ public class DirectoryUpdater {
while (retrieved < userCount) { while (retrieved < userCount) {
logger.info("Retrieving remote tokens..."); logger.info("Retrieving remote tokens...");
List<ClientContact> clientContacts = client.getUserTokens(retrieved); List<ClientContact> remoteContacts = client.getUserTokens(retrieved);
List<PendingClientContact> localContacts = new LinkedList<>();
BatchOperationHandle handle = directory.startBatchOperation();
if (clientContacts == null) { if (remoteContacts == null) {
logger.info("Remote tokens empty, ending..."); logger.info("Remote tokens empty, ending...");
break; break;
} else { } else {
logger.info("Retrieved " + clientContacts.size() + " remote tokens..."); logger.info("Retrieved " + remoteContacts.size() + " remote tokens...");
} }
for (ClientContact clientContact : clientContacts) { for (ClientContact remoteContact : remoteContacts) {
clientContact.setRelay(client.getPeerName()); localContacts.add(directory.get(handle, remoteContact.getToken()));
}
Optional<ClientContact> existing = directory.get(clientContact.getToken()); directory.stopBatchOperation(handle);
if (!clientContact.isInactive() && (!existing.isPresent() || client.getPeerName().equals(existing.get().getRelay()))) { handle = directory.startBatchOperation();
// directory.add(handle, clientContact); Iterator<ClientContact> remoteContactIterator = remoteContacts.iterator();
directory.add(clientContact); Iterator<PendingClientContact> localContactIterator = localContacts.iterator();
while (remoteContactIterator.hasNext() && localContactIterator.hasNext()) {
ClientContact remoteContact = remoteContactIterator.next();
Optional<ClientContact> localContact = localContactIterator.next().get();
remoteContact.setRelay(client.getPeerName());
if (!remoteContact.isInactive() && (!localContact.isPresent() || client.getPeerName().equals(localContact.get().getRelay()))) {
contactsAdded++;
directory.add(handle, remoteContact);
} else { } else {
if (existing.isPresent() && client.getPeerName().equals(existing.get().getRelay())) { if (localContact.isPresent() && client.getPeerName().equals(localContact.get().getRelay())) {
directory.remove(clientContact.getToken()); contactsRemoved++;
directory.remove(handle, remoteContact.getToken());
} }
} }
} }
retrieved += clientContacts.size(); directory.stopBatchOperation(handle);
retrieved += remoteContacts.size();
logger.info("Processed: " + retrieved + " remote tokens."); logger.info("Processed: " + retrieved + " remote tokens.");
} }
logger.info("Update from peer complete."); logger.info("Update from peer complete.");
} finally {
// directory.stopBatchOperation(handle);
}
} }
logger.info("Update from peer directories complete."); logger.info("Update from peer directories complete.");
logger.info(String.format("Added %d and removed %d remove contacts.", contactsAdded, contactsRemoved));
} }
} }

View File

@@ -0,0 +1,59 @@
package org.whispersystems.textsecuregcm.workers;
import net.sourceforge.argparse4j.inf.Namespace;
import org.skife.jdbi.v2.DBI;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import io.dropwizard.cli.ConfiguredCommand;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.jdbi.ImmutableListContainerFactory;
import io.dropwizard.jdbi.ImmutableSetContainerFactory;
import io.dropwizard.jdbi.OptionalContainerFactory;
import io.dropwizard.jdbi.args.OptionalArgumentFactory;
import io.dropwizard.setup.Bootstrap;
public class VacuumCommand extends ConfiguredCommand<WhisperServerConfiguration> {
private final Logger logger = LoggerFactory.getLogger(DirectoryCommand.class);
public VacuumCommand() {
super("vacuum", "Vacuum Postgres Tables");
}
@Override
protected void run(Bootstrap<WhisperServerConfiguration> bootstrap,
Namespace namespace,
WhisperServerConfiguration config)
throws Exception
{
DataSourceFactory dbConfig = config.getDataSourceFactory();
DBI dbi = new DBI(dbConfig.getUrl(), dbConfig.getUser(), dbConfig.getPassword());
dbi.registerArgumentFactory(new OptionalArgumentFactory(dbConfig.getDriverClass()));
dbi.registerContainerFactory(new ImmutableListContainerFactory());
dbi.registerContainerFactory(new ImmutableSetContainerFactory());
dbi.registerContainerFactory(new OptionalContainerFactory());
Accounts accounts = dbi.onDemand(Accounts.class );
Keys keys = dbi.onDemand(Keys.class );
PendingAccounts pendingAccounts = dbi.onDemand(PendingAccounts.class);
logger.warn("Vacuuming accounts...");
accounts.vacuum();
logger.warn("Vacuuming pending_accounts...");
pendingAccounts.vacuum();
logger.warn("Vacuuming keys...");
keys.vacuum();
Thread.sleep(3000);
System.exit(0);
}
}

View File

@@ -2,6 +2,8 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.ClientResponse;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@@ -9,10 +11,12 @@ import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
@@ -31,6 +35,9 @@ public class AccountControllerTest {
private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class ); private RateLimiter rateLimiter = mock(RateLimiter.class );
private SmsSender smsSender = mock(SmsSender.class ); private SmsSender smsSender = mock(SmsSender.class );
private StoredMessages storedMessages = mock(StoredMessages.class );
private TimeProvider timeProvider = mock(TimeProvider.class );
private static byte[] authorizationKey = decodeHex("3a078586eea8971155f5c1ebd73c8c923cbec1c3ed22a54722e4e88321dc749f");
@Rule @Rule
public final ResourceTestRule resources = ResourceTestRule.builder() public final ResourceTestRule resources = ResourceTestRule.builder()
@@ -38,7 +45,10 @@ public class AccountControllerTest {
.addResource(new AccountController(pendingAccountsManager, .addResource(new AccountController(pendingAccountsManager,
accountsManager, accountsManager,
rateLimiters, rateLimiters,
smsSender)) smsSender,
storedMessages,
timeProvider,
Optional.of(authorizationKey)))
.build(); .build();
@@ -48,6 +58,8 @@ public class AccountControllerTest {
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter);
when(timeProvider.getCurrentTimeMillis()).thenReturn(System.currentTimeMillis());
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of("1234")); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of("1234"));
} }
@@ -90,4 +102,84 @@ public class AccountControllerTest {
verifyNoMoreInteractions(accountsManager); verifyNoMoreInteractions(accountsManager);
} }
@Test
public void testVerifyToken() throws Exception {
when(timeProvider.getCurrentTimeMillis()).thenReturn(1415917053106L);
String token = SENDER + ":1415906573:af4f046107c21721224a";
ClientResponse response =
resources.client().resource(String.format("/v1/accounts/token/%s", token))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new AccountAttributes("keykeykeykey", false, false, 4444))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
assertThat(response.getStatus()).isEqualTo(204);
verify(accountsManager, times(1)).create(isA(Account.class));
}
@Test
public void testVerifyBadToken() throws Exception {
when(timeProvider.getCurrentTimeMillis()).thenReturn(1415917053106L);
String token = SENDER + ":1415906574:af4f046107c21721224a";
ClientResponse response =
resources.client().resource(String.format("/v1/accounts/token/%s", token))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new AccountAttributes("keykeykeykey", false, false, 4444))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
assertThat(response.getStatus()).isEqualTo(403);
verifyNoMoreInteractions(accountsManager);
}
@Test
public void testVerifyWrongToken() throws Exception {
when(timeProvider.getCurrentTimeMillis()).thenReturn(1415917053106L);
String token = SENDER + ":1415906573:af4f046107c21721224a";
ClientResponse response =
resources.client().resource(String.format("/v1/accounts/token/%s", token))
.header("Authorization", AuthHelper.getAuthHeader("+14151111111", "bar"))
.entity(new AccountAttributes("keykeykeykey", false, false, 4444))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
assertThat(response.getStatus()).isEqualTo(403);
verifyNoMoreInteractions(accountsManager);
}
@Test
public void testVerifyExpiredToken() throws Exception {
when(timeProvider.getCurrentTimeMillis()).thenReturn(1416003757901L);
String token = SENDER + ":1415906573:af4f046107c21721224a";
ClientResponse response =
resources.client().resource(String.format("/v1/accounts/token/%s", token))
.header("Authorization", AuthHelper.getAuthHeader(SENDER, "bar"))
.entity(new AccountAttributes("keykeykeykey", false, false, 4444))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class);
assertThat(response.getStatus()).isEqualTo(403);
verifyNoMoreInteractions(accountsManager);
}
private static byte[] decodeHex(String hex) {
try {
return Hex.decodeHex(hex.toCharArray());
} catch (DecoderException e) {
throw new AssertionError(e);
}
}
} }

View File

@@ -4,13 +4,19 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.ClientResponse;
import org.hamcrest.CoreMatchers;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.FederationController; import org.whispersystems.textsecuregcm.controllers.FederationControllerV1;
import org.whispersystems.textsecuregcm.controllers.FederationControllerV2;
import org.whispersystems.textsecuregcm.controllers.KeysControllerV2;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList; import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseItemV2;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseV2;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@@ -44,16 +50,19 @@ public class FederatedControllerTest {
private RateLimiters rateLimiters = mock(RateLimiters.class ); private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class ); private RateLimiter rateLimiter = mock(RateLimiter.class );
private final SignedPreKey signedPreKey = new SignedPreKey(3333, "foo", "baar");
private final PreKeyResponseV2 preKeyResponseV2 = new PreKeyResponseV2("foo", new LinkedList<PreKeyResponseItemV2>());
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
private final MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager); private final MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
private final KeysControllerV2 keysControllerV2 = mock(KeysControllerV2.class);
@Rule @Rule
public final ResourceTestRule resources = ResourceTestRule.builder() public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthenticator()) .addProvider(AuthHelper.getAuthenticator())
.addResource(new FederationController(accountsManager, .addResource(new FederationControllerV1(accountsManager, null, messageController, null))
null, null, .addResource(new FederationControllerV2(accountsManager, null, messageController, keysControllerV2))
messageController))
.build(); .build();
@@ -61,12 +70,12 @@ public class FederatedControllerTest {
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
List<Device> singleDeviceList = new LinkedList<Device>() {{ List<Device> singleDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111)); add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111, null));
}}; }};
List<Device> multiDeviceList = new LinkedList<Device>() {{ List<Device> multiDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222)); add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222, null));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333)); add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333, null));
}}; }};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList); Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);
@@ -76,6 +85,10 @@ public class FederatedControllerTest {
when(accountsManager.get(eq(MULTI_DEVICE_RECIPIENT))).thenReturn(Optional.of(multiDeviceAccount)); when(accountsManager.get(eq(MULTI_DEVICE_RECIPIENT))).thenReturn(Optional.of(multiDeviceAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(keysControllerV2.getSignedKey(any(Account.class))).thenReturn(Optional.of(signedPreKey));
when(keysControllerV2.getDeviceKeys(any(Account.class), anyString(), anyString(), any(Optional.class)))
.thenReturn(Optional.of(preKeyResponseV2));
} }
@Test @Test
@@ -92,5 +105,14 @@ public class FederatedControllerTest {
verify(pushSender).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class)); verify(pushSender).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class));
} }
@Test
public void testSignedPreKeyV2() throws Exception {
PreKeyResponseV2 response =
resources.client().resource("/v2/federation/key/+14152223333/1")
.header("Authorization", AuthHelper.getAuthHeader("cyanogen", "foofoo"))
.get(PreKeyResponseV2.class);
assertThat("good response", response.getIdentityKey().equals(preKeyResponseV2.getIdentityKey()));
}
} }

View File

@@ -6,18 +6,22 @@ import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock; import org.whispersystems.textsecuregcm.controllers.KeysControllerV1;
import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.controllers.KeysControllerV2;
import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyList; import org.whispersystems.textsecuregcm.entities.PreKeyResponseV1;
import org.whispersystems.textsecuregcm.entities.PreKeyStatus; import org.whispersystems.textsecuregcm.entities.PreKeyResponseV2;
import org.whispersystems.textsecuregcm.entities.UnstructuredPreKeyList; import org.whispersystems.textsecuregcm.entities.PreKeyStateV1;
import org.whispersystems.textsecuregcm.entities.PreKeyStateV2;
import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV2;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeyRecord;
import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@@ -36,10 +40,18 @@ public class KeyControllerTest {
private static int SAMPLE_REGISTRATION_ID = 999; private static int SAMPLE_REGISTRATION_ID = 999;
private static int SAMPLE_REGISTRATION_ID2 = 1002; private static int SAMPLE_REGISTRATION_ID2 = 1002;
private static int SAMPLE_REGISTRATION_ID4 = 1555;
private final KeyRecord SAMPLE_KEY = new KeyRecord(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", false);
private final KeyRecord SAMPLE_KEY2 = new KeyRecord(2, EXISTS_NUMBER, 2, 5667, "test3", false );
private final KeyRecord SAMPLE_KEY3 = new KeyRecord(3, EXISTS_NUMBER, 3, 334, "test5", false );
private final KeyRecord SAMPLE_KEY4 = new KeyRecord(4, EXISTS_NUMBER, 4, 336, "test6", false );
private final SignedPreKey SAMPLE_SIGNED_KEY = new SignedPreKey(1111, "foofoo", "sig11");
private final SignedPreKey SAMPLE_SIGNED_KEY2 = new SignedPreKey(2222, "foobar", "sig22");
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey(3333, "barfoo", "sig33");
private final PreKey SAMPLE_KEY = new PreKey(1, EXISTS_NUMBER, Device.MASTER_ID, 1234, "test1", "test2", false);
private final PreKey SAMPLE_KEY2 = new PreKey(2, EXISTS_NUMBER, 2, 5667, "test3", "test4,", false );
private final PreKey SAMPLE_KEY3 = new PreKey(3, EXISTS_NUMBER, 3, 334, "test5", "test6", false );
private final Keys keys = mock(Keys.class ); private final Keys keys = mock(Keys.class );
private final AccountsManager accounts = mock(AccountsManager.class); private final AccountsManager accounts = mock(AccountsManager.class);
private final Account existsAccount = mock(Account.class ); private final Account existsAccount = mock(Account.class );
@@ -50,25 +62,47 @@ public class KeyControllerTest {
@Rule @Rule
public final ResourceTestRule resources = ResourceTestRule.builder() public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthenticator()) .addProvider(AuthHelper.getAuthenticator())
.addResource(new KeysController(rateLimiters, keys, accounts, null)) .addResource(new KeysControllerV1(rateLimiters, keys, accounts, null))
.addResource(new KeysControllerV2(rateLimiters, keys, accounts, null))
.build(); .build();
@Before @Before
public void setup() { public void setup() {
Device sampleDevice = mock(Device.class ); final Device sampleDevice = mock(Device.class);
Device sampleDevice2 = mock(Device.class); final Device sampleDevice2 = mock(Device.class);
Device sampleDevice3 = mock(Device.class); final Device sampleDevice3 = mock(Device.class);
final Device sampleDevice4 = mock(Device.class);
List<Device> allDevices = new LinkedList<Device>() {{
add(sampleDevice);
add(sampleDevice2);
add(sampleDevice3);
add(sampleDevice4);
}};
when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID); when(sampleDevice.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID);
when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice2.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2); when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.isActive()).thenReturn(true); when(sampleDevice.isActive()).thenReturn(true);
when(sampleDevice2.isActive()).thenReturn(true); when(sampleDevice2.isActive()).thenReturn(true);
when(sampleDevice3.isActive()).thenReturn(false); when(sampleDevice3.isActive()).thenReturn(false);
when(sampleDevice4.isActive()).thenReturn(true);
when(sampleDevice.getSignedPreKey()).thenReturn(SAMPLE_SIGNED_KEY);
when(sampleDevice2.getSignedPreKey()).thenReturn(SAMPLE_SIGNED_KEY2);
when(sampleDevice3.getSignedPreKey()).thenReturn(SAMPLE_SIGNED_KEY3);
when(sampleDevice4.getSignedPreKey()).thenReturn(null);
when(sampleDevice.getId()).thenReturn(1L);
when(sampleDevice2.getId()).thenReturn(2L);
when(sampleDevice3.getId()).thenReturn(3L);
when(sampleDevice4.getId()).thenReturn(4L);
when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice)); when(existsAccount.getDevice(1L)).thenReturn(Optional.of(sampleDevice));
when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2)); when(existsAccount.getDevice(2L)).thenReturn(Optional.of(sampleDevice2));
when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3)); when(existsAccount.getDevice(3L)).thenReturn(Optional.of(sampleDevice3));
when(existsAccount.getDevice(4L)).thenReturn(Optional.of(sampleDevice4));
when(existsAccount.getDevice(22L)).thenReturn(Optional.<Device>absent());
when(existsAccount.getDevices()).thenReturn(allDevices);
when(existsAccount.isActive()).thenReturn(true); when(existsAccount.isActive()).thenReturn(true);
when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey"); when(existsAccount.getIdentityKey()).thenReturn("existsidentitykey");
@@ -77,37 +111,31 @@ public class KeyControllerTest {
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenAnswer(new Answer<Optional<UnstructuredPreKeyList>>() { List<KeyRecord> singleDevice = new LinkedList<>();
@Override singleDevice.add(SAMPLE_KEY);
public Optional<UnstructuredPreKeyList> answer(InvocationOnMock invocationOnMock) throws Throwable { when(keys.get(eq(EXISTS_NUMBER), eq(1L))).thenReturn(Optional.of(singleDevice));
return Optional.of(new UnstructuredPreKeyList(cloneKey(SAMPLE_KEY)));
}
});
when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<UnstructuredPreKeyList>absent()); when(keys.get(eq(NOT_EXISTS_NUMBER), eq(1L))).thenReturn(Optional.<List<KeyRecord>>absent());
when(keys.get(EXISTS_NUMBER)).thenAnswer(new Answer<Optional<UnstructuredPreKeyList>>() { List<KeyRecord> multiDevice = new LinkedList<>();
@Override multiDevice.add(SAMPLE_KEY);
public Optional<UnstructuredPreKeyList> answer(InvocationOnMock invocationOnMock) throws Throwable { multiDevice.add(SAMPLE_KEY2);
List<PreKey> allKeys = new LinkedList<>(); multiDevice.add(SAMPLE_KEY3);
allKeys.add(cloneKey(SAMPLE_KEY)); multiDevice.add(SAMPLE_KEY4);
allKeys.add(cloneKey(SAMPLE_KEY2)); when(keys.get(EXISTS_NUMBER)).thenReturn(Optional.of(multiDevice));
allKeys.add(cloneKey(SAMPLE_KEY3));
return Optional.of(new UnstructuredPreKeyList(allKeys));
}
});
when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5); when(keys.getCount(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(5);
when(AuthHelper.VALID_DEVICE.getSignedPreKey()).thenReturn(new SignedPreKey(89898, "zoofarb", "sigvalid"));
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null); when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
} }
@Test @Test
public void validKeyStatusTest() throws Exception { public void validKeyStatusTestV1() throws Exception {
PreKeyStatus result = resources.client().resource("/v1/keys") PreKeyCount result = resources.client().resource("/v1/keys")
.header("Authorization", .header("Authorization",
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyStatus.class); .get(PreKeyCount.class);
assertThat(result.getCount() == 4); assertThat(result.getCount() == 4);
@@ -115,48 +143,145 @@ public class KeyControllerTest {
} }
@Test @Test
public void validLegacyRequestTest() throws Exception { public void validKeyStatusTestV2() throws Exception {
PreKey result = resources.client().resource(String.format("/v1/keys/%s", EXISTS_NUMBER)) PreKeyCount result = resources.client().resource("/v2/keys")
.header("Authorization",
AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyCount.class);
assertThat(result.getCount() == 4);
verify(keys).getCount(eq(AuthHelper.VALID_NUMBER), eq(1L));
}
@Test
public void getSignedPreKeyV2() throws Exception {
SignedPreKey result = resources.client().resource("/v2/keys/signed")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKey.class); .get(SignedPreKey.class);
assertThat(result.equals(SAMPLE_SIGNED_KEY));
}
@Test
public void putSignedPreKeyV2() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
ClientResponse response = resources.client().resource("/v2/keys/signed")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class, test);
assertThat(response.getStatus() == 204);
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test));
verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT));
}
@Test
public void validLegacyRequestTest() throws Exception {
PreKeyV1 result = resources.client().resource(String.format("/v1/keys/%s", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyV1.class);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);
verify(keys).get(eq(EXISTS_NUMBER), eq(1L)); verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys); verifyNoMoreInteractions(keys);
} }
@Test @Test
public void validMultiRequestTest() throws Exception { public void validSingleRequestTestV2() throws Exception {
UnstructuredPreKeyList results = resources.client().resource(String.format("/v1/keys/%s/*", EXISTS_NUMBER)) PreKeyResponseV2 result = resources.client().resource(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(UnstructuredPreKeyList.class); .get(PreKeyResponseV2.class);
assertThat(results.getKeys().size()).isEqualTo(2); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getDevices().size()).isEqualTo(1);
assertThat(result.getDevices().get(0).getPreKey().getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getDevices().get(0).getPreKey().getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getDevices().get(0).getSignedPreKey()).isEqualTo(existsAccount.getDevice(1).get().getSignedPreKey());
PreKey result = results.getKeys().get(0); verify(keys).get(eq(EXISTS_NUMBER), eq(1L));
verifyNoMoreInteractions(keys);
}
@Test
public void validMultiRequestTestV1() throws Exception {
PreKeyResponseV1 results = resources.client().resource(String.format("/v1/keys/%s/*", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponseV1.class);
assertThat(results.getKeys().size()).isEqualTo(3);
PreKeyV1 result = results.getKeys().get(0);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId()); assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey()); assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID); assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(result.getId() == 0);
assertThat(result.getNumber() == null);
result = results.getKeys().get(1); result = results.getKeys().get(1);
assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId()); assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey()); assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey()); assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID2); assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertThat(result.getId() == 0); result = results.getKeys().get(2);
assertThat(result.getNumber() == null); assertThat(result.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId());
assertThat(result.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey());
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
assertThat(result.getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID4);
verify(keys).get(eq(EXISTS_NUMBER));
verifyNoMoreInteractions(keys);
}
@Test
public void validMultiRequestTestV2() throws Exception {
PreKeyResponseV2 results = resources.client().resource(String.format("/v2/keys/%s/*", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(PreKeyResponseV2.class);
assertThat(results.getDevices().size()).isEqualTo(3);
assertThat(results.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey());
PreKeyV2 signedPreKey = results.getDevices().get(0).getSignedPreKey();
PreKeyV2 preKey = results.getDevices().get(0).getPreKey();
long registrationId = results.getDevices().get(0).getRegistrationId();
long deviceId = results.getDevices().get(0).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY.getKeyId());
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY.getPublicKey());
assertThat(deviceId).isEqualTo(1);
signedPreKey = results.getDevices().get(1).getSignedPreKey();
preKey = results.getDevices().get(1).getPreKey();
registrationId = results.getDevices().get(1).getRegistrationId();
deviceId = results.getDevices().get(1).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY2.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY2.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID2);
assertThat(signedPreKey.getKeyId()).isEqualTo(SAMPLE_SIGNED_KEY2.getKeyId());
assertThat(signedPreKey.getPublicKey()).isEqualTo(SAMPLE_SIGNED_KEY2.getPublicKey());
assertThat(deviceId).isEqualTo(2);
signedPreKey = results.getDevices().get(2).getSignedPreKey();
preKey = results.getDevices().get(2).getPreKey();
registrationId = results.getDevices().get(2).getRegistrationId();
deviceId = results.getDevices().get(2).getDeviceId();
assertThat(preKey.getKeyId()).isEqualTo(SAMPLE_KEY4.getKeyId());
assertThat(preKey.getPublicKey()).isEqualTo(SAMPLE_KEY4.getPublicKey());
assertThat(registrationId).isEqualTo(SAMPLE_REGISTRATION_ID4);
assertThat(signedPreKey).isNull();
assertThat(deviceId).isEqualTo(4);
verify(keys).get(eq(EXISTS_NUMBER)); verify(keys).get(eq(EXISTS_NUMBER));
verifyNoMoreInteractions(keys); verifyNoMoreInteractions(keys);
@@ -164,7 +289,7 @@ public class KeyControllerTest {
@Test @Test
public void invalidRequestTest() throws Exception { public void invalidRequestTestV1() throws Exception {
ClientResponse response = resources.client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER)) ClientResponse response = resources.client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(ClientResponse.class); .get(ClientResponse.class);
@@ -173,7 +298,25 @@ public class KeyControllerTest {
} }
@Test @Test
public void unauthorizedRequestTest() throws Exception { public void invalidRequestTestV2() throws Exception {
ClientResponse response = resources.client().resource(String.format("/v2/keys/%s", NOT_EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(ClientResponse.class);
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(404);
}
@Test
public void anotherInvalidRequestTestV2() throws Exception {
ClientResponse response = resources.client().resource(String.format("/v2/keys/%s/22", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.get(ClientResponse.class);
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(404);
}
@Test
public void unauthorizedRequestTestV1() throws Exception {
ClientResponse response = ClientResponse response =
resources.client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER)) resources.client().resource(String.format("/v1/keys/%s", NOT_EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD))
@@ -189,15 +332,31 @@ public class KeyControllerTest {
} }
@Test @Test
public void putKeysTest() throws Exception { public void unauthorizedRequestTestV2() throws Exception {
final PreKey newKey = new PreKey(0, null, 1L, 31337, "foobar", "foobarbaz", false); ClientResponse response =
final PreKey lastResortKey = new PreKey(0, null, 1L, 0xFFFFFF, "fooz", "foobarbaz", false); resources.client().resource(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.INVALID_PASSWORD))
.get(ClientResponse.class);
List<PreKey> preKeys = new LinkedList<PreKey>() {{ assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(401);
response =
resources.client().resource(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.get(ClientResponse.class);
assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(401);
}
@Test
public void putKeysTestV1() throws Exception {
final PreKeyV1 newKey = new PreKeyV1(1L, 31337, "foobar", "foobarbaz");
final PreKeyV1 lastResortKey = new PreKeyV1(1L, 0xFFFFFF, "fooz", "foobarbaz");
List<PreKeyV1> preKeys = new LinkedList<PreKeyV1>() {{
add(newKey); add(newKey);
}}; }};
PreKeyList preKeyList = new PreKeyList(); PreKeyStateV1 preKeyList = new PreKeyStateV1();
preKeyList.setKeys(preKeys); preKeyList.setKeys(preKeys);
preKeyList.setLastResortKey(lastResortKey); preKeyList.setLastResortKey(lastResortKey);
@@ -210,10 +369,10 @@ public class KeyControllerTest {
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(204); assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class ); ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class );
ArgumentCaptor<PreKey> lastResortCaptor = ArgumentCaptor.forClass(PreKey.class); ArgumentCaptor<PreKeyV1> lastResortCaptor = ArgumentCaptor.forClass(PreKeyV1.class);
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture(), lastResortCaptor.capture()); verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture(), lastResortCaptor.capture());
List<PreKey> capturedList = listCaptor.getValue(); List<PreKeyV1> capturedList = listCaptor.getValue();
assertThat(capturedList.size() == 1); assertThat(capturedList.size() == 1);
assertThat(capturedList.get(0).getIdentityKey().equals("foobarbaz")); assertThat(capturedList.get(0).getIdentityKey().equals("foobarbaz"));
assertThat(capturedList.get(0).getKeyId() == 31337); assertThat(capturedList.get(0).getKeyId() == 31337);
@@ -226,9 +385,39 @@ public class KeyControllerTest {
verify(accounts).update(AuthHelper.VALID_ACCOUNT); verify(accounts).update(AuthHelper.VALID_ACCOUNT);
} }
private PreKey cloneKey(PreKey source) { @Test
return new PreKey(source.getId(), source.getNumber(), source.getDeviceId(), source.getKeyId(), public void putKeysTestV2() throws Exception {
source.getPublicKey(), source.getIdentityKey(), source.isLastResort()); final PreKeyV2 preKey = new PreKeyV2(31337, "foobar");
final PreKeyV2 lastResortKey = new PreKeyV2(31339, "barbar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
List<PreKeyV2> preKeys = new LinkedList<PreKeyV2>() {{
add(preKey);
}};
PreKeyStateV2 preKeyState = new PreKeyStateV2(identityKey, signedPreKey, preKeys, lastResortKey);
ClientResponse response =
resources.client().resource("/v2/keys")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.type(MediaType.APPLICATION_JSON_TYPE)
.put(ClientResponse.class, preKeyState);
assertThat(response.getClientResponseStatus().getStatusCode()).isEqualTo(204);
ArgumentCaptor<List> listCaptor = ArgumentCaptor.forClass(List.class);
verify(keys).store(eq(AuthHelper.VALID_NUMBER), eq(1L), listCaptor.capture(), eq(lastResortKey));
List<PreKeyV2> capturedList = listCaptor.getValue();
assertThat(capturedList.size() == 1);
assertThat(capturedList.get(0).getKeyId() == 31337);
assertThat(capturedList.get(0).getPublicKey().equals("foobar"));
verify(AuthHelper.VALID_ACCOUNT).setIdentityKey(eq("barbar"));
verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey));
verify(accounts).update(AuthHelper.VALID_ACCOUNT);
} }
} }

View File

@@ -4,7 +4,6 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional; import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse; import com.sun.jersey.api.client.ClientResponse;
import org.junit.Before; import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
@@ -59,12 +58,12 @@ public class MessageControllerTest {
@Before @Before
public void setup() throws Exception { public void setup() throws Exception {
List<Device> singleDeviceList = new LinkedList<Device>() {{ List<Device> singleDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111)); add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111, null));
}}; }};
List<Device> multiDeviceList = new LinkedList<Device>() {{ List<Device> multiDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222)); add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222, null));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333)); add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333, null));
}}; }};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList); Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);

View File

@@ -0,0 +1,91 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ReceiptController;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.testing.junit.ResourceTestRule;
import static org.fest.assertions.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
public class ReceiptControllerTest {
private static final String SINGLE_DEVICE_RECIPIENT = "+14151111111";
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private final PushSender pushSender = mock(PushSender.class );
private final FederatedClientManager federatedClientManager = mock(FederatedClientManager.class);
private final AccountsManager accountsManager = mock(AccountsManager.class );
private final ObjectMapper mapper = new ObjectMapper();
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthenticator())
.addResource(new ReceiptController(accountsManager, federatedClientManager, pushSender))
.build();
@Before
public void setup() throws Exception {
List<Device> singleDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111, null));
}};
List<Device> multiDeviceList = new LinkedList<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222, null));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333, null));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);
Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, false, multiDeviceList);
when(accountsManager.get(eq(SINGLE_DEVICE_RECIPIENT))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.get(eq(MULTI_DEVICE_RECIPIENT))).thenReturn(Optional.of(multiDeviceAccount));
}
@Test
public synchronized void testSingleDeviceCurrent() throws Exception {
ClientResponse response =
resources.client().resource(String.format("/v1/receipt/%s/%d", SINGLE_DEVICE_RECIPIENT, 1234))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(ClientResponse.class);
assertThat(response.getStatus() == 204);
verify(pushSender, times(1)).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class));
}
@Test
public synchronized void testMultiDeviceCurrent() throws Exception {
ClientResponse response =
resources.client().resource(String.format("/v1/receipt/%s/%d", MULTI_DEVICE_RECIPIENT, 12345))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(ClientResponse.class);
assertThat(response.getStatus() == 204);
verify(pushSender, times(2)).sendMessage(any(Account.class), any(Device.class), any(MessageProtos.OutgoingMessageSignal.class));
}
}

View File

@@ -1,127 +0,0 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import org.eclipse.jetty.websocket.api.CloseStatus;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.controllers.WebsocketController;
import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import org.whispersystems.textsecuregcm.websocket.WebsocketControllerFactory;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.basic.BasicCredentials;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
public class WebsocketControllerTest {
private static final ObjectMapper mapper = new ObjectMapper();
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
@Test
public void testCredentials() throws Exception {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.<Account>absent());
when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
WebsocketController controller = new WebsocketController(accountAuthenticator, pushSender, pubSubManager, storedMessages);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
}});
controller.onWebSocketConnect(session);
verify(session, never()).close();
verify(session, never()).close(any(CloseStatus.class));
verify(session, never()).close(anyInt(), anyString());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
}});
controller.onWebSocketConnect(session);
verify(session).close(any(CloseStatus.class));
}
@Test
public void testOpen() throws Exception {
RemoteEndpoint remote = mock(RemoteEndpoint.class);
List<String> outgoingMessages = new LinkedList<String>() {{
add("first");
add("second");
add("third");
}};
when(device.getId()).thenReturn(2L);
when(account.getId()).thenReturn(31337L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(session.getRemote()).thenReturn(remote);
when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
}});
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
when(storedMessages.getMessagesForDevice(account.getId(), device.getId())).thenReturn(outgoingMessages);
WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, pushSender, storedMessages, pubSubManager);
WebsocketController controller = (WebsocketController) factory.createWebSocket(null, null);
controller.onWebSocketConnect(session);
verify(pubSubManager).subscribe(eq(new WebsocketAddress(31337L, 2L)), eq((controller)));
verify(remote, times(3)).sendStringByFuture(anyString());
controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));
controller.onWebSocketClose(1000, "Closed");
List<String> pending = new LinkedList<String>() {{
add("first");
add("third");
}};
verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(EncryptedOutgoingMessage.class));
}
}

View File

@@ -2,7 +2,8 @@ package org.whispersystems.textsecuregcm.tests.entities;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.entities.ClientContact; import org.whispersystems.textsecuregcm.entities.ClientContact;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKeyV1;
import org.whispersystems.textsecuregcm.entities.PreKeyV2;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
@@ -13,8 +14,8 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.*;
public class PreKeyTest { public class PreKeyTest {
@Test @Test
public void serializeToJSON() throws Exception { public void serializeToJSONV1() throws Exception {
PreKey preKey = new PreKey(1, "+14152222222", 1, 1234, "test", "identityTest", false); PreKeyV1 preKey = new PreKeyV1(1, 1234, "test", "identityTest");
preKey.setRegistrationId(987); preKey.setRegistrationId(987);
assertThat("Basic Contact Serialization works", assertThat("Basic Contact Serialization works",
@@ -23,7 +24,7 @@ public class PreKeyTest {
} }
@Test @Test
public void deserializeFromJSON() throws Exception { public void deserializeFromJSONV() throws Exception {
ClientContact contact = new ClientContact(Util.getContactToken("+14152222222"), ClientContact contact = new ClientContact(Util.getContactToken("+14152222222"),
"whisper", true); "whisper", true);
@@ -32,4 +33,13 @@ public class PreKeyTest {
is(contact)); is(contact));
} }
@Test
public void serializeToJSONV2() throws Exception {
PreKeyV2 preKey = new PreKeyV2(1234, "test");
assertThat("PreKeyV2 Serialization works",
asJson(preKey),
is(equalTo(jsonFixture("fixtures/prekey_v2.json"))));
}
} }

View File

@@ -0,0 +1,182 @@
package org.whispersystems.textsecuregcm.tests.websocket;
import com.google.common.base.Optional;
import com.google.common.util.concurrent.SettableFuture;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.websocket.ConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.basic.BasicCredentials;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
public class WebSocketConnectionTest {
// private static final ObjectMapper mapper = new ObjectMapper();
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
// private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
@Test
public void testCredentials() throws Exception {
StoredMessages storedMessages = mock(StoredMessages.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
ConnectListener connectListener = new ConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.<Account>absent());
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
// when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
//
// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
}});
Optional<Account> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(Account.class)).thenReturn(account);
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
//
// controller.onWebSocketConnect(session);
// verify(session, never()).close();
// verify(session, never()).close(any(CloseStatus.class));
// verify(session, never()).close(anyInt(), anyString());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
}});
account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(Account.class)).thenReturn(account);
WebSocketClient client = mock(WebSocketClient.class);
when(sessionContext.getClient()).thenReturn(client);
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
verify(client).close(eq(4001), anyString());
}
@Test
public void testOpen() throws Exception {
StoredMessages storedMessages = mock(StoredMessages.class);
List<PendingMessage> outgoingMessages = new LinkedList<PendingMessage>() {{
add(new PendingMessage("sender1", 1111, false, "first"));
add(new PendingMessage("sender1", 2222, false, "second"));
add(new PendingMessage("sender2", 3333, false, "third"));
}};
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
final Device sender1device = mock(Device.class);
List<Device> sender1devices = new LinkedList<Device>() {{
add(sender1device);
}};
Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices);
when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.get("sender2")).thenReturn(Optional.<Account>absent());
when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId())))
.thenReturn(outgoingMessages);
final List<SettableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class)))
.thenAnswer(new Answer<SettableFuture<WebSocketResponseMessage>>() {
@Override
public SettableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
SettableFuture<WebSocketResponseMessage> future = SettableFuture.create();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages,
pubSubManager, account, device, client);
connection.onConnected();
verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection)));
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
assertTrue(futures.size() == 3);
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(1).set(response);
futures.get(0).setException(new IOException());
futures.get(2).setException(new IOException());
List<PendingMessage> pending = new LinkedList<PendingMessage>() {{
add(new PendingMessage("sender1", 1111, false, "first"));
add(new PendingMessage("sender2", 3333, false, "third"));
}};
verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(PendingMessage.class));
verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(MessageProtos.OutgoingMessageSignal.class));
connection.onConnectionLost();
verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection));
}
}

View File

@@ -0,0 +1,4 @@
{
"keyId" : 1234,
"publicKey" : "test"
}