Compare commits

...

34 Commits
v0.35 ... v0.49

Author SHA1 Message Date
Moxie Marlinspike
c6810d7460 Bump version to 0.49
// FREEBIE
2015-05-13 14:31:38 -07:00
Moxie Marlinspike
4c1e7e7c2f Rate limit messages on source+destination rather than just src.
// FREEBIE
2015-04-24 16:25:59 -07:00
Moxie Marlinspike
931081752a Updated sample config
// FREEBIE

Closes #38
2015-04-21 19:45:31 -07:00
Moxie Marlinspike
424e98e67e Bump version to 0.48
// FREEBIE
2015-04-21 19:44:20 -07:00
Moxie Marlinspike
7cfa93f5f8 Tone down websocket logging for bad federated responses.
// FREEBIE
2015-04-21 19:44:02 -07:00
Moxie Marlinspike
fd8e8d1475 Catch WebApplicationException inside WebsocketConnection.
// FREEBIE
2015-04-16 11:33:16 -07:00
Moxie Marlinspike
7ed5eb22ec Additional WebsocketConnection test.
// FREEBIE
2015-04-16 10:45:24 -07:00
Moxie Marlinspike
fa1c275904 Bump version to 0.46
// FREEBIE
2015-04-15 16:44:22 -07:00
Moxie Marlinspike
558c72bbb7 Make pending messages indexable by sender and timestamp.
Rather than just timestamp.

// FREEBIE
2015-04-15 16:43:44 -07:00
Moxie Marlinspike
37976455bc Bump version to 0.45
// FREEBIE
2015-04-15 16:20:25 -07:00
Moxie Marlinspike
db6ee8f687 Make stored messages REST accessible.
Add REST endpoints for retrieving and acknowledging pending
messges, beyond the WebSocket.

// FREEBIE
2015-04-15 16:19:07 -07:00
Moxie Marlinspike
53e7ffa311 Bump version to 0.44 2015-03-30 10:31:35 -07:00
Moxie Marlinspike
e0f7ff325a Add voip push support in communication with push server.
// FREEBIE
2015-03-25 15:59:52 -07:00
Moxie Marlinspike
1fcd1e33c5 Log keepalives for unsubscribed channels.
// FREEBIE
2015-03-24 14:13:32 -07:00
Moxie Marlinspike
843b16c1f0 Correctly serialize provisioning addresses.
// FREEBIE
2015-03-24 12:10:59 -07:00
Moxie Marlinspike
a58f3f0fe3 Check subscription status on websocket keepalive.
// FREEBIE
2015-03-24 10:48:14 -07:00
Moxie Marlinspike
e69e395b25 Support for receiving "canonical id" update events from pushserver.
// FREEBIE
2015-03-24 10:47:45 -07:00
Moxie Marlinspike
456164fc24 Support registering a 'voip' APN ID.
// FREEBIE
2015-03-24 10:47:20 -07:00
Moxie Marlinspike
9b7f61a09d Bump version to 0.43 2015-03-20 07:57:00 -07:00
Moxie Marlinspike
2de9adb7ae Make dispatch subscription/unsubscription synchronized. 2015-03-19 14:37:10 -07:00
Moxie Marlinspike
c7e0cc1158 Use a custom redis pubsub implementation rather than Jedis.
// FREEBIE
2015-03-17 13:30:51 -07:00
Moxie Marlinspike
e79861c30a Add connection duration stats.
// FREEBIE
2015-03-15 10:21:24 -07:00
Moxie Marlinspike
407f596b61 Bump version to 0.41
// FREEBIE
2015-03-13 14:38:39 -07:00
Moxie Marlinspike
41d30fc8dc Resubscribe listeners when subscription link breaks.
// FREEBIE
2015-03-13 10:56:48 -07:00
Moxie Marlinspike
2e429c5b35 Loop voice verification prompt 3 times.
Fixes #18
Closes #32

// FREEBIE
2015-03-13 10:09:40 -07:00
Moxie Marlinspike
28afe3470b Upgrade jedis to 2.6.2
// FREEBIE
2015-03-13 10:06:44 -07:00
Moxie Marlinspike
f623b24196 Fix string formatting.
// FREEBIE
2015-03-12 16:29:49 -07:00
Moxie Marlinspike
2016c17894 Bump version to 0.40 2015-03-10 18:29:14 -07:00
Moxie Marlinspike
2d28077010 Make idle timeout 90s
// FREEBIE
2015-03-10 17:27:43 -07:00
Moxie Marlinspike
7011f3c3c7 Fix 500 on validation error 2015-03-10 16:00:20 -07:00
Moxie Marlinspike
1403dbd5dd Handle pubsub callbacks from a cached thread pool.
...implement some belt and suspenders dead letter handling.

...implement some belt and suspenders redis pubsub queue handling.

// FREEBIE
2015-03-10 12:45:05 -07:00
Moxie Marlinspike
6ef3845a34 Don't consider empty relays present on receipt delivery. 2015-03-09 15:41:50 -07:00
Moxie Marlinspike
de2f0914f0 Reenable websocket notifications for Android.
// FREEBIE
2015-03-09 09:39:09 -07:00
Moxie Marlinspike
080ae0985f Support URL-safe Base64 encoding for directory tokens.
// FREEBIE
2015-03-09 09:34:02 -07:00
51 changed files with 1961 additions and 504 deletions

View File

@@ -1,32 +1,16 @@
twilio:
accountId:
twilio: # Twilio SMS gateway configuration
accountId:
accountToken:
number:
localDomain: # The domain Twilio can call back to.
international: # Boolean specifying Twilio for international delivery
# Optional. If specified, Nexmo will be used for non-US SMS and
# voice verification if twilio.international is false. Otherwise,
# Nexmo, if specified, Nexmo will only be used as a fallback
# for failed Twilio deliveries.
nexmo:
apiKey:
apiSecret:
number:
push: # GCM/APN push server configuration
host:
port:
username:
password:
gcm:
senderId:
apiKey:
# Optional. Only if iOS clients are supported.
apn:
# In PEM format.
certificate:
# In PEM format.
key:
s3:
s3: # AWS S3 configuration
accessKey:
accessSecret:
@@ -35,13 +19,37 @@ s3:
# correct permissions.
attachmentsBucket:
memcache:
servers:
user:
password:
directory: # Redis server configuration for TS directory
url:
redis:
url:
cache: # Redis server configuration for general purpose caching
url:
websocket:
enabled: true
messageStore: # Postgres database configuration for message store
driverClass: org.postgresql.Driver
user:
password:
url:
database: # Postgres database configuration for account store
# the name of your JDBC driver
driverClass: org.postgresql.Driver
# the username
user:
# the password
password:
# the JDBC URL
url: jdbc:postgresql://somehost:somport/somedb
# any properties specific to your JDBC driver:
properties:
charSet: UTF-8
federation:
name:
@@ -52,24 +60,3 @@ federation:
authenticationToken: foo
certificate: in pem format
# Optional address of graphite server to report metrics
graphite:
host:
port:
database:
# the name of your JDBC driver
driverClass: org.postgresql.Driver
# the username
user:
# the password
password:
# the JDBC URL
url: jdbc:postgresql://somehost:somport/somedb
# any properties specific to your JDBC driver:
properties:
charSet: UTF-8

View File

@@ -9,7 +9,7 @@
<groupId>org.whispersystems.textsecure</groupId>
<artifactId>TextSecureServer</artifactId>
<version>0.35</version>
<version>0.49</version>
<properties>
<dropwizard.version>0.7.1</dropwizard.version>
@@ -106,7 +106,7 @@
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>2.6.1</version>
<version>2.6.2</version>
<type>jar</type>
<scope>compile</scope>
</dependency>
@@ -129,10 +129,9 @@
<dependency>
<groupId>org.whispersystems</groupId>
<artifactId>websocket-resources</artifactId>
<version>0.2.1</version>
<version>0.2.3</version>
</dependency>
</dependencies>
<dependencyManagement>

View File

@@ -0,0 +1,7 @@
package org.whispersystems.dispatch;
public interface DispatchChannel {
public void onDispatchMessage(String channel, byte[] message);
public void onDispatchSubscribed(String channel);
public void onDispatchUnsubscribed(String channel);
}

View File

@@ -0,0 +1,172 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
public class DispatchManager extends Thread {
private final Logger logger = LoggerFactory.getLogger(DispatchManager.class);
private final Executor executor = Executors.newCachedThreadPool();
private final Map<String, DispatchChannel> subscriptions = new ConcurrentHashMap<>();
private final Optional<DispatchChannel> deadLetterChannel;
private final RedisPubSubConnectionFactory redisPubSubConnectionFactory;
private PubSubConnection pubSubConnection;
private volatile boolean running;
public DispatchManager(RedisPubSubConnectionFactory redisPubSubConnectionFactory,
Optional<DispatchChannel> deadLetterChannel)
{
this.redisPubSubConnectionFactory = redisPubSubConnectionFactory;
this.deadLetterChannel = deadLetterChannel;
}
@Override
public void start() {
this.pubSubConnection = redisPubSubConnectionFactory.connect();
this.running = true;
super.start();
}
public void shutdown() {
this.running = false;
this.pubSubConnection.close();
}
public synchronized void subscribe(String name, DispatchChannel dispatchChannel) {
Optional<DispatchChannel> previous = Optional.fromNullable(subscriptions.get(name));
subscriptions.put(name, dispatchChannel);
try {
pubSubConnection.subscribe(name);
} catch (IOException e) {
logger.warn("Subscription error", e);
}
if (previous.isPresent()) {
dispatchUnsubscription(name, previous.get());
}
}
public synchronized void unsubscribe(String name, DispatchChannel channel) {
Optional<DispatchChannel> subscription = Optional.fromNullable(subscriptions.get(name));
if (subscription.isPresent() && subscription.get() == channel) {
subscriptions.remove(name);
try {
pubSubConnection.unsubscribe(name);
} catch (IOException e) {
logger.warn("Unsubscribe error", e);
}
dispatchUnsubscription(name, subscription.get());
}
}
public boolean hasSubscription(String name) {
return subscriptions.containsKey(name);
}
@Override
public void run() {
while (running) {
try {
PubSubReply reply = pubSubConnection.read();
switch (reply.getType()) {
case UNSUBSCRIBE: break;
case SUBSCRIBE: dispatchSubscribe(reply); break;
case MESSAGE: dispatchMessage(reply); break;
default: throw new AssertionError("Unknown pubsub reply type! " + reply.getType());
}
} catch (IOException e) {
logger.warn("***** PubSub Connection Error *****", e);
if (running) {
this.pubSubConnection.close();
this.pubSubConnection = redisPubSubConnectionFactory.connect();
resubscribeAll();
}
}
}
logger.warn("DispatchManager Shutting Down...");
}
private void dispatchSubscribe(final PubSubReply reply) {
Optional<DispatchChannel> subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
if (subscription.isPresent()) {
dispatchSubscription(reply.getChannel(), subscription.get());
} else {
logger.warn("Received subscribe event for non-existing channel: " + reply.getChannel());
}
}
private void dispatchMessage(PubSubReply reply) {
Optional<DispatchChannel> subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
if (subscription.isPresent()) {
dispatchMessage(reply.getChannel(), subscription.get(), reply.getContent().get());
} else if (deadLetterChannel.isPresent()) {
dispatchMessage(reply.getChannel(), deadLetterChannel.get(), reply.getContent().get());
} else {
logger.warn("Received message for non-existing channel, with no dead letter handler: " + reply.getChannel());
}
}
private void resubscribeAll() {
new Thread() {
@Override
public void run() {
synchronized (DispatchManager.this) {
try {
for (String name : subscriptions.keySet()) {
pubSubConnection.subscribe(name);
}
} catch (IOException e) {
logger.warn("***** RESUBSCRIPTION ERROR *****", e);
}
}
}
}.start();
}
private void dispatchMessage(final String name, final DispatchChannel channel, final byte[] message) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchMessage(name, message);
}
});
}
private void dispatchSubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchSubscribed(name);
}
});
}
private void dispatchUnsubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchUnsubscribed(name);
}
});
}
}

View File

@@ -0,0 +1,64 @@
package org.whispersystems.dispatch.io;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
public class RedisInputStream {
private static final byte CR = 0x0D;
private static final byte LF = 0x0A;
private final InputStream inputStream;
public RedisInputStream(InputStream inputStream) {
this.inputStream = inputStream;
}
public String readLine() throws IOException {
ByteArrayOutputStream boas = new ByteArrayOutputStream();
boolean foundCr = false;
while (true) {
int character = inputStream.read();
if (character == -1) {
throw new IOException("Stream closed!");
}
boas.write(character);
if (foundCr && character == LF) break;
else if (character == CR) foundCr = true;
else if (foundCr) foundCr = false;
}
byte[] data = boas.toByteArray();
return new String(data, 0, data.length-2);
}
public byte[] readFully(int size) throws IOException {
byte[] result = new byte[size];
int offset = 0;
int remaining = result.length;
while (remaining > 0) {
int read = inputStream.read(result, offset, remaining);
if (read < 0) {
throw new IOException("Stream closed!");
}
offset += read;
remaining -= read;
}
return result;
}
public void close() throws IOException {
inputStream.close();
}
}

View File

@@ -0,0 +1,9 @@
package org.whispersystems.dispatch.io;
import org.whispersystems.dispatch.redis.PubSubConnection;
public interface RedisPubSubConnectionFactory {
public PubSubConnection connect();
}

View File

@@ -0,0 +1,119 @@
package org.whispersystems.dispatch.redis;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisInputStream;
import org.whispersystems.dispatch.redis.protocol.ArrayReplyHeader;
import org.whispersystems.dispatch.redis.protocol.IntReply;
import org.whispersystems.dispatch.redis.protocol.StringReplyHeader;
import org.whispersystems.dispatch.util.Util;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
public class PubSubConnection {
private final Logger logger = LoggerFactory.getLogger(PubSubConnection.class);
private static final byte[] UNSUBSCRIBE_TYPE = {'u', 'n', 's', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
private static final byte[] SUBSCRIBE_TYPE = {'s', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
private static final byte[] MESSAGE_TYPE = {'m', 'e', 's', 's', 'a', 'g', 'e' };
private static final byte[] SUBSCRIBE_COMMAND = {'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' ' };
private static final byte[] UNSUBSCRIBE_COMMAND = {'U', 'N', 'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' '};
private static final byte[] CRLF = {'\r', '\n' };
private final OutputStream outputStream;
private final RedisInputStream inputStream;
private final Socket socket;
private final AtomicBoolean closed;
public PubSubConnection(Socket socket) throws IOException {
this.socket = socket;
this.outputStream = socket.getOutputStream();
this.inputStream = new RedisInputStream(new BufferedInputStream(socket.getInputStream()));
this.closed = new AtomicBoolean(false);
}
public void subscribe(String channelName) throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
byte[] command = Util.combine(SUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
outputStream.write(command);
}
public void unsubscribe(String channelName) throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
byte[] command = Util.combine(UNSUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
outputStream.write(command);
}
public PubSubReply read() throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
ArrayReplyHeader replyHeader = new ArrayReplyHeader(inputStream.readLine());
if (replyHeader.getElementCount() != 3) {
throw new IOException("Received array reply header with strange count: " + replyHeader.getElementCount());
}
StringReplyHeader replyTypeHeader = new StringReplyHeader(inputStream.readLine());
byte[] replyType = inputStream.readFully(replyTypeHeader.getStringLength());
inputStream.readLine();
if (Arrays.equals(SUBSCRIBE_TYPE, replyType)) return readSubscribeReply();
else if (Arrays.equals(UNSUBSCRIBE_TYPE, replyType)) return readUnsubscribeReply();
else if (Arrays.equals(MESSAGE_TYPE, replyType)) return readMessageReply();
else throw new IOException("Unknown reply type: " + new String(replyType));
}
public void close() {
try {
this.closed.set(true);
this.inputStream.close();
this.outputStream.close();
this.socket.close();
} catch (IOException e) {
logger.warn("Exception while closing", e);
}
}
private PubSubReply readMessageReply() throws IOException {
StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
inputStream.readLine();
StringReplyHeader messageHeader = new StringReplyHeader(inputStream.readLine());
byte[] message = inputStream.readFully(messageHeader.getStringLength());
inputStream.readLine();
return new PubSubReply(PubSubReply.Type.MESSAGE, new String(channelName), Optional.of(message));
}
private PubSubReply readUnsubscribeReply() throws IOException {
String channelName = readSubscriptionReply();
return new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, channelName, Optional.<byte[]>absent());
}
private PubSubReply readSubscribeReply() throws IOException {
String channelName = readSubscriptionReply();
return new PubSubReply(PubSubReply.Type.SUBSCRIBE, channelName, Optional.<byte[]>absent());
}
private String readSubscriptionReply() throws IOException {
StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
inputStream.readLine();
IntReply subscriptionCount = new IntReply(inputStream.readLine());
return new String(channelName);
}
}

View File

@@ -0,0 +1,35 @@
package org.whispersystems.dispatch.redis;
import com.google.common.base.Optional;
public class PubSubReply {
public enum Type {
MESSAGE,
SUBSCRIBE,
UNSUBSCRIBE
}
private final Type type;
private final String channel;
private final Optional<byte[]> content;
public PubSubReply(Type type, String channel, Optional<byte[]> content) {
this.type = type;
this.channel = channel;
this.content = content;
}
public Type getType() {
return type;
}
public String getChannel() {
return channel;
}
public Optional<byte[]> getContent() {
return content;
}
}

View File

@@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class ArrayReplyHeader {
private final int elementCount;
public ArrayReplyHeader(String header) throws IOException {
if (header == null || header.length() < 2 || header.charAt(0) != '*') {
throw new IOException("Invalid array reply header: " + header);
}
try {
this.elementCount = Integer.parseInt(header.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getElementCount() {
return elementCount;
}
}

View File

@@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class IntReply {
private final int value;
public IntReply(String reply) throws IOException {
if (reply == null || reply.length() < 2 || reply.charAt(0) != ':') {
throw new IOException("Invalid int reply: " + reply);
}
try {
this.value = Integer.parseInt(reply.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getValue() {
return value;
}
}

View File

@@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class StringReplyHeader {
private final int stringLength;
public StringReplyHeader(String header) throws IOException {
if (header == null || header.length() < 2 || header.charAt(0) != '$') {
throw new IOException("Invalid string reply header: " + header);
}
try {
this.stringLength = Integer.parseInt(header.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getStringLength() {
return stringLength;
}
}

View File

@@ -0,0 +1,36 @@
package org.whispersystems.dispatch.util;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class Util {
public static byte[] combine(byte[]... elements) {
try {
int sum = 0;
for (byte[] element : elements) {
sum += element.length;
}
ByteArrayOutputStream baos = new ByteArrayOutputStream(sum);
for (byte[] element : elements) {
baos.write(element);
}
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError(e);
}
}
public static void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
}

View File

@@ -24,6 +24,8 @@ import com.sun.jersey.api.client.Client;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
@@ -56,6 +58,7 @@ import org.whispersystems.textsecuregcm.providers.TimeProvider;
import org.whispersystems.textsecuregcm.push.FeedbackHandler;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.PushServiceClient;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebsocketSender;
import org.whispersystems.textsecuregcm.sms.NexmoSmsSender;
import org.whispersystems.textsecuregcm.sms.SmsSender;
@@ -75,6 +78,7 @@ import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.UrlSigner;
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
import org.whispersystems.textsecuregcm.websocket.DeadLetterHandler;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.DirectoryCommand;
@@ -146,10 +150,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
Keys keys = database.onDemand(Keys.class);
Messages messages = messagedb.onDemand(Messages.class);
JedisPool cacheClient = new RedisClientFactory(config.getCacheConfiguration().getUrl()).getRedisClientPool();
JedisPool directoryClient = new RedisClientFactory(config.getDirectoryConfiguration().getUrl()).getRedisClientPool();
Client httpClient = new JerseyClientBuilder(environment).using(config.getJerseyClientConfiguration())
.build(getName());
RedisClientFactory cacheClientFactory = new RedisClientFactory(config.getCacheConfiguration().getUrl());
JedisPool cacheClient = cacheClientFactory.getRedisClientPool();
JedisPool directoryClient = new RedisClientFactory(config.getDirectoryConfiguration().getUrl()).getRedisClientPool();
Client httpClient = new JerseyClientBuilder(environment).using(config.getJerseyClientConfiguration())
.build(getName());
DirectoryManager directory = new DirectoryManager(directoryClient);
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, cacheClient);
@@ -157,7 +162,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountsManager accountsManager = new AccountsManager(accounts, directory, cacheClient);
FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration());
MessagesManager messagesManager = new MessagesManager(messages);
PubSubManager pubSubManager = new PubSubManager(cacheClient);
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(messagesManager);
DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.<DispatchChannel>of(deadLetterHandler));
PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager);
PushServiceClient pushServiceClient = new PushServiceClient(httpClient, config.getPushConfiguration());
WebsocketSender websocketSender = new WebsocketSender(messagesManager, pubSubManager);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
@@ -168,15 +175,17 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
SmsSender smsSender = new SmsSender(twilioSmsSender, nexmoSmsSender, config.getTwilioConfiguration().isInternational());
UrlSigner urlSigner = new UrlSigner(config.getS3Configuration());
PushSender pushSender = new PushSender(pushServiceClient, websocketSender);
ReceiptSender receiptSender = new ReceiptSender(accountsManager, pushSender, federatedClientManager);
FeedbackHandler feedbackHandler = new FeedbackHandler(pushServiceClient, accountsManager);
Optional<byte[]> authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey();
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(feedbackHandler);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
KeysControllerV1 keysControllerV1 = new KeysControllerV1(rateLimiters, keys, accountsManager, federatedClientManager);
KeysControllerV2 keysControllerV2 = new KeysControllerV2(rateLimiters, keys, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager);
environment.jersey().register(new MultiBasicAuthProvider<>(new FederatedPeerAuthenticator(config.getFederationConfiguration()),
FederatedPeer.class,
@@ -188,7 +197,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new DirectoryController(rateLimiters, directory));
environment.jersey().register(new FederationControllerV1(accountsManager, attachmentController, messageController, keysControllerV1));
environment.jersey().register(new FederationControllerV2(accountsManager, attachmentController, messageController, keysControllerV2));
environment.jersey().register(new ReceiptController(accountsManager, federatedClientManager, pushSender));
environment.jersey().register(new ReceiptController(receiptSender));
environment.jersey().register(new ProvisioningController(rateLimiters, pushSender));
environment.jersey().register(attachmentController);
environment.jersey().register(keysControllerV1);
@@ -196,14 +205,14 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(messageController);
if (config.getWebsocketConfiguration().isEnabled()) {
WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment(environment, config);
WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment(environment, config, 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(deviceAuthenticator));
webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(accountsManager, pushSender, messagesManager, pubSubManager));
webSocketEnvironment.jersey().register(new KeepAliveController());
webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, messagesManager, pubSubManager));
webSocketEnvironment.jersey().register(new KeepAliveController(pubSubManager));
WebSocketEnvironment provisioningEnvironment = new WebSocketEnvironment(environment, config);
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
provisioningEnvironment.jersey().register(new KeepAliveController());
provisioningEnvironment.jersey().register(new KeepAliveController(pubSubManager));
WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory(webSocketEnvironment );
WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory(provisioningEnvironment);
@@ -261,5 +270,4 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
public static void main(String[] args) throws Exception {
new WhisperServerService().run(args);
}
}

View File

@@ -200,6 +200,7 @@ public class AccountController {
public void setGcmRegistrationId(@Auth Account account, @Valid GcmRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(null);
device.setVoipApnId(null);
device.setGcmId(registrationId.getGcmRegistrationId());
if (registrationId.isWebSocketChannel()) device.setFetchesMessages(true);
@@ -225,6 +226,7 @@ public class AccountController {
public void setApnRegistrationId(@Auth Account account, @Valid ApnRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get();
device.setApnId(registrationId.getApnRegistrationId());
device.setVoipApnId(registrationId.getVoipRegistrationId());
device.setGcmId(null);
device.setFetchesMessages(true);
accounts.update(account);

View File

@@ -74,7 +74,7 @@ public class DirectoryController {
rateLimiters.getContactsLimiter().validate(account.getNumber());
try {
Optional<ClientContact> contact = directory.get(Base64.decodeWithoutPadding(token));
Optional<ClientContact> contact = directory.get(decodeToken(token));
if (contact.isPresent()) return Response.ok().entity(contact.get()).build();
else return Response.status(404).build();
@@ -100,7 +100,7 @@ public class DirectoryController {
List<byte[]> tokens = new LinkedList<>();
for (String encodedContact : contacts.getContacts()) {
tokens.add(Base64.decodeWithoutPadding(encodedContact));
tokens.add(decodeToken(encodedContact));
}
List<ClientContact> intersection = directory.get(tokens);
@@ -110,4 +110,8 @@ public class DirectoryController {
throw new WebApplicationException(Response.status(400).build());
}
}
private byte[] decodeToken(String encoded) throws IOException {
return Base64.decodeWithoutPadding(encoded.replace('-', '+').replace('_', '/'));
}
}

View File

@@ -1,19 +1,47 @@
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
import io.dropwizard.auth.Auth;
@Path("/v1/keepalive")
public class KeepAliveController {
private final Logger logger = LoggerFactory.getLogger(KeepAliveController.class);
private final PubSubManager pubSubManager;
public KeepAliveController(PubSubManager pubSubManager) {
this.pubSubManager = pubSubManager;
}
@Timed
@GET
public Response getKeepAlive() {
public Response getKeepAlive(@Auth(required = false) Account account,
@WebSocketSession WebSocketSessionContext context)
{
if (account != null) {
WebsocketAddress address = new WebsocketAddress(account.getNumber(),
account.getAuthenticatedDevice().get().getId());
if (!pubSubManager.hasLocalSubscription(address)) {
logger.warn("***** No local subscription found for: " + address);
context.getClient().close(1000, "OK");
}
}
return Response.ok().build();
}

View File

@@ -26,6 +26,8 @@ import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.MessageResponse;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClient;
@@ -34,15 +36,19 @@ import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Util;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
@@ -66,17 +72,23 @@ public class MessageController {
private final RateLimiters rateLimiters;
private final PushSender pushSender;
private final ReceiptSender receiptSender;
private final FederatedClientManager federatedClientManager;
private final AccountsManager accountsManager;
private final MessagesManager messagesManager;
public MessageController(RateLimiters rateLimiters,
PushSender pushSender,
ReceiptSender receiptSender,
AccountsManager accountsManager,
MessagesManager messagesManager,
FederatedClientManager federatedClientManager)
{
this.rateLimiters = rateLimiters;
this.pushSender = pushSender;
this.receiptSender = receiptSender;
this.accountsManager = accountsManager;
this.messagesManager = messagesManager;
this.federatedClientManager = federatedClientManager;
}
@@ -90,7 +102,7 @@ public class MessageController {
@Valid IncomingMessageList messages)
throws IOException, RateLimitExceededException
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
rateLimiters.getMessagesLimiter().validate(source.getNumber() + "__" + destinationName);
try {
boolean isSyncMessage = source.getNumber().equals(destinationName);
@@ -137,6 +149,38 @@ public class MessageController {
}
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public OutgoingMessageEntityList getPendingMessages(@Auth Account account) {
return new OutgoingMessageEntityList(messagesManager.getMessagesForDevice(account.getNumber(),
account.getAuthenticatedDevice()
.get().getId()));
}
@Timed
@DELETE
@Path("/{source}/{timestamp}")
public void removePendingMessage(@Auth Account account,
@PathParam("source") String source,
@PathParam("timestamp") long timestamp)
throws IOException
{
try {
Optional<OutgoingMessageEntity> message = messagesManager.delete(account.getNumber(), source, timestamp);
if (message.isPresent() && message.get().getType() != OutgoingMessageSignal.Type.RECEIPT_VALUE) {
receiptSender.sendReceipt(account,
message.get().getSource(),
message.get().getTimestamp(),
Optional.fromNullable(message.get().getRelay()));
}
} catch (NoSuchUserException | NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("Sending delivery receipt", e);
}
}
private void sendLocalMessage(Account source,
String destinationName,
IncomingMessageList messages,

View File

@@ -15,6 +15,7 @@ import javax.ws.rs.Consumes;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
@@ -37,6 +38,7 @@ public class ProvisioningController {
@Path("/{destination}")
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void sendProvisioningMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid ProvisioningMessage message)
@@ -44,7 +46,7 @@ public class ProvisioningController {
{
rateLimiters.getMessagesLimiter().validate(source.getNumber());
if (!websocketSender.sendProvisioningMessage(new ProvisioningAddress(destinationName),
if (!websocketSender.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
Base64.decode(message.getBody())))
{
throw new WebApplicationException(Response.Status.NOT_FOUND);

View File

@@ -2,14 +2,10 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
@@ -18,25 +14,16 @@ import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Response;
import java.io.IOException;
import java.util.Set;
import io.dropwizard.auth.Auth;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
@Path("/v1/receipt")
public class ReceiptController {
private final AccountsManager accountManager;
private final PushSender pushSender;
private final FederatedClientManager federatedClientManager;
private final ReceiptSender receiptSender;
public ReceiptController(AccountsManager accountManager,
FederatedClientManager federatedClientManager,
PushSender pushSender)
{
this.accountManager = accountManager;
this.federatedClientManager = federatedClientManager;
this.pushSender = pushSender;
public ReceiptController(ReceiptSender receiptSender) {
this.receiptSender = receiptSender;
}
@Timed
@@ -49,8 +36,7 @@ public class ReceiptController {
throws IOException
{
try {
if (relay.isPresent()) sendRelayedReceipt(source, destination, messageId, relay.get());
else sendDirectReceipt(source, destination, messageId);
receiptSender.sendReceipt(source, destination, messageId, relay);
} catch (NoSuchUserException | NotPushRegisteredException e) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
} catch (TransientPushFailureException e) {
@@ -58,51 +44,4 @@ public class ReceiptController {
}
}
private void sendRelayedReceipt(Account source, String destination, long messageId, String relay)
throws NoSuchUserException, IOException
{
try {
federatedClientManager.getClient(relay)
.sendDeliveryReceipt(source.getNumber(),
source.getAuthenticatedDevice().get().getId(),
destination, messageId);
} catch (NoSuchPeerException e) {
throw new NoSuchUserException(e);
}
}
private void sendDirectReceipt(Account source, String destination, long messageId)
throws NotPushRegisteredException, TransientPushFailureException, NoSuchUserException
{
Account destinationAccount = getDestinationAccount(destination);
Set<Device> destinationDevices = destinationAccount.getDevices();
OutgoingMessageSignal.Builder message =
OutgoingMessageSignal.newBuilder()
.setSource(source.getNumber())
.setSourceDevice((int) source.getAuthenticatedDevice().get().getId())
.setTimestamp(messageId)
.setType(OutgoingMessageSignal.Type.RECEIPT_VALUE);
if (source.getRelay().isPresent()) {
message.setRelay(source.getRelay().get());
}
for (Device destinationDevice : destinationDevices) {
pushSender.sendMessage(destinationAccount, destinationDevice, message.build());
}
}
private Account getDestinationAccount(String destination)
throws NoSuchUserException
{
Optional<Account> account = accountManager.get(destination);
if (!account.isPresent()) {
throw new NoSuchUserException(destination);
}
return account.get();
}
}

View File

@@ -5,6 +5,7 @@ import com.google.common.annotations.VisibleForTesting;
import org.hibernate.validator.constraints.NotEmpty;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
public class ApnMessage {
@JsonProperty
@@ -23,12 +24,17 @@ public class ApnMessage {
@NotEmpty
private String message;
@JsonProperty
@NotNull
private boolean voip;
public ApnMessage() {}
public ApnMessage(String apnId, String number, int deviceId, String message) {
public ApnMessage(String apnId, String number, int deviceId, String message, boolean voip) {
this.apnId = apnId;
this.number = number;
this.deviceId = deviceId;
this.message = message;
this.voip = voip;
}
}

View File

@@ -25,7 +25,14 @@ public class ApnRegistrationId {
@NotEmpty
private String apnRegistrationId;
@JsonProperty
private String voipRegistrationId;
public String getApnRegistrationId() {
return apnRegistrationId;
}
public String getVoipRegistrationId() {
return voipRegistrationId;
}
}

View File

@@ -0,0 +1,70 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
public class OutgoingMessageEntity {
@JsonIgnore
private long id;
@JsonProperty
private int type;
@JsonProperty
private String relay;
@JsonProperty
private long timestamp;
@JsonProperty
private String source;
@JsonProperty
private int sourceDevice;
@JsonProperty
private byte[] message;
public OutgoingMessageEntity() {}
public OutgoingMessageEntity(long id, int type, String relay, long timestamp,
String source, int sourceDevice, byte[] message)
{
this.id = id;
this.type = type;
this.relay = relay;
this.timestamp = timestamp;
this.source = source;
this.sourceDevice = sourceDevice;
this.message = message;
}
public int getType() {
return type;
}
public String getRelay() {
return relay;
}
public long getTimestamp() {
return timestamp;
}
public String getSource() {
return source;
}
public int getSourceDevice() {
return sourceDevice;
}
public byte[] getMessage() {
return message;
}
public long getId() {
return id;
}
}

View File

@@ -0,0 +1,23 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import java.util.List;
public class OutgoingMessageEntityList {
@JsonProperty
private List<OutgoingMessageEntity> messages;
public OutgoingMessageEntityList() {}
public OutgoingMessageEntityList(List<OutgoingMessageEntity> messages) {
this.messages = messages;
}
@VisibleForTesting
public List<OutgoingMessageEntity> getMessages() {
return messages;
}
}

View File

@@ -11,6 +11,9 @@ public class UnregisteredEvent {
@NotEmpty
private String registrationId;
@JsonProperty
private String canonicalId;
@JsonProperty
@NotEmpty
private String number;
@@ -26,6 +29,10 @@ public class UnregisteredEvent {
return registrationId;
}
public String getCanonicalId() {
return canonicalId;
}
public String getNumber() {
return number;
}

View File

@@ -40,6 +40,6 @@ public class NonLimitedAccount extends Account {
@Override
public Optional<Device> getAuthenticatedDevice() {
return Optional.of(new Device(deviceId, null, null, null, null, null, false, 0, null, System.currentTimeMillis()));
return Optional.of(new Device(deviceId, null, null, null, null, null, null, false, 0, null, System.currentTimeMillis()));
}
}

View File

@@ -16,8 +16,14 @@
*/
package org.whispersystems.textsecuregcm.providers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
@@ -25,29 +31,40 @@ import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.Protocol;
public class RedisClientFactory {
public class RedisClientFactory implements RedisPubSubConnectionFactory {
private final Logger logger = LoggerFactory.getLogger(RedisClientFactory.class);
private final String host;
private final int port;
private final JedisPool jedisPool;
public RedisClientFactory(String url) throws URISyntaxException {
JedisPoolConfig poolConfig = new JedisPoolConfig();
poolConfig.setTestOnBorrow(true);
URI redisURI = new URI(url);
String redisHost = redisURI.getHost();
int redisPort = redisURI.getPort();
String redisPassword = null;
URI redisURI = new URI(url);
if (!Util.isEmpty(redisURI.getUserInfo())) {
redisPassword = redisURI.getUserInfo().split(":",2)[1];
}
this.jedisPool = new JedisPool(poolConfig, redisHost, redisPort,
Protocol.DEFAULT_TIMEOUT, redisPassword);
this.host = redisURI.getHost();
this.port = redisURI.getPort();
this.jedisPool = new JedisPool(poolConfig, host, port,
Protocol.DEFAULT_TIMEOUT, null);
}
public JedisPool getRedisClientPool() {
return jedisPool;
}
@Override
public PubSubConnection connect() {
while (true) {
try {
Socket socket = new Socket(host, port);
return new PubSubConnection(socket);
} catch (IOException e) {
logger.warn("Error connecting", e);
Util.sleep(200);
}
}
}
}

View File

@@ -76,7 +76,13 @@ public class FeedbackHandler implements Managed, Runnable {
event.getTimestamp() > device.get().getPushTimestamp())
{
logger.info("GCM Unregister Timestamp matches!");
device.get().setGcmId(null);
if (event.getCanonicalId() != null && !event.getCanonicalId().isEmpty()) {
logger.info("It's a canonical ID update...");
device.get().setGcmId(event.getCanonicalId());
} else {
device.get().setGcmId(null);
}
accountsManager.update(account.get());
}
}

View File

@@ -25,6 +25,7 @@ import org.whispersystems.textsecuregcm.entities.GcmMessage;
import org.whispersystems.textsecuregcm.push.WebsocketSender.DeliveryStatus;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Util;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
@@ -58,9 +59,8 @@ public class PushSender {
private void sendGcmMessage(Account account, Device device, OutgoingMessageSignal message)
throws TransientPushFailureException, NotPushRegisteredException
{
// if (device.getFetchesMessages()) sendNotificationGcmMessage(account, device, message);
// else sendPayloadGcmMessage(account, device, message);
sendPayloadGcmMessage(account, device, message);
if (device.getFetchesMessages()) sendNotificationGcmMessage(account, device, message);
else sendPayloadGcmMessage(account, device, message);
}
private void sendPayloadGcmMessage(Account account, Device device, OutgoingMessageSignal message)
@@ -91,10 +91,7 @@ public class PushSender {
(int)device.getId(), "", false, true);
pushServiceClient.send(gcmMessage);
} else {
logger.warn("Delivered!");
}
}
private void sendApnMessage(Account account, Device device, OutgoingMessageSignal outgoingMessage)
@@ -103,8 +100,12 @@ public class PushSender {
DeliveryStatus deliveryStatus = webSocketSender.sendMessage(account, device, outgoingMessage, WebsocketSender.Type.APN);
if (!deliveryStatus.isDelivered() && outgoingMessage.getType() != OutgoingMessageSignal.Type.RECEIPT_VALUE) {
ApnMessage apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), (int)device.getId(),
String.format(APN_PAYLOAD, deliveryStatus.getMessageQueueDepth()));
String apnId = Util.isEmpty(device.getVoipApnId()) ? device.getApnId() : device.getVoipApnId();
boolean isVoip = !Util.isEmpty(device.getVoipApnId());
ApnMessage apnMessage = new ApnMessage(apnId, account.getNumber(), (int)device.getId(),
String.format(APN_PAYLOAD, deliveryStatus.getMessageQueueDepth()),
isVoip);
pushServiceClient.send(apnMessage);
}
}

View File

@@ -0,0 +1,89 @@
package org.whispersystems.textsecuregcm.push;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.NoSuchPeerException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import java.io.IOException;
import java.util.Set;
public class ReceiptSender {
private final PushSender pushSender;
private final FederatedClientManager federatedClientManager;
private final AccountsManager accountManager;
public ReceiptSender(AccountsManager accountManager,
PushSender pushSender,
FederatedClientManager federatedClientManager)
{
this.federatedClientManager = federatedClientManager;
this.accountManager = accountManager;
this.pushSender = pushSender;
}
public void sendReceipt(Account source, String destination,
long messageId, Optional<String> relay)
throws IOException, NoSuchUserException,
NotPushRegisteredException, TransientPushFailureException
{
if (relay.isPresent() && !relay.get().isEmpty()) {
sendRelayedReceipt(source, destination, messageId, relay.get());
} else {
sendDirectReceipt(source, destination, messageId);
}
}
private void sendRelayedReceipt(Account source, String destination, long messageId, String relay)
throws NoSuchUserException, IOException
{
try {
federatedClientManager.getClient(relay)
.sendDeliveryReceipt(source.getNumber(),
source.getAuthenticatedDevice().get().getId(),
destination, messageId);
} catch (NoSuchPeerException e) {
throw new NoSuchUserException(e);
}
}
private void sendDirectReceipt(Account source, String destination, long messageId)
throws NotPushRegisteredException, TransientPushFailureException, NoSuchUserException
{
Account destinationAccount = getDestinationAccount(destination);
Set<Device> destinationDevices = destinationAccount.getDevices();
MessageProtos.OutgoingMessageSignal.Builder message =
MessageProtos.OutgoingMessageSignal.newBuilder()
.setSource(source.getNumber())
.setSourceDevice((int) source.getAuthenticatedDevice().get().getId())
.setTimestamp(messageId)
.setType(MessageProtos.OutgoingMessageSignal.Type.RECEIPT_VALUE);
if (source.getRelay().isPresent()) {
message.setRelay(source.getRelay().get());
}
for (Device destinationDevice : destinationDevices) {
pushSender.sendMessage(destinationAccount, destinationDevice, message.build());
}
}
private Account getDestinationAccount(String destination)
throws NoSuchUserException
{
Optional<Account> account = accountManager.get(destination);
if (!account.isPresent()) {
throw new NoSuchUserException(destination);
}
return account.get();
}
}

View File

@@ -40,7 +40,7 @@ public class TwilioSmsSender {
public static final String SAY_TWIML = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n" +
"<Response>\n" +
" <Say voice=\"woman\" language=\"en\">" + SmsSender.VOX_VERIFICATION_TEXT + "%s</Say>\n" +
" <Say voice=\"woman\" language=\"en\" loop=\"3\">" + SmsSender.VOX_VERIFICATION_TEXT + "%s.</Say>\n" +
"</Response>";
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);

View File

@@ -46,6 +46,9 @@ public class Device {
@JsonProperty
private String apnId;
@JsonProperty
private String voipApnId;
@JsonProperty
private long pushTimestamp;
@@ -65,8 +68,8 @@ public class Device {
public Device(long id, String authToken, String salt,
String signalingKey, String gcmId, String apnId,
boolean fetchesMessages, int registrationId,
SignedPreKey signedPreKey, long lastSeen)
String voipApnId, boolean fetchesMessages,
int registrationId, SignedPreKey signedPreKey, long lastSeen)
{
this.id = id;
this.authToken = authToken;
@@ -74,6 +77,7 @@ public class Device {
this.signalingKey = signalingKey;
this.gcmId = gcmId;
this.apnId = apnId;
this.voipApnId = voipApnId;
this.fetchesMessages = fetchesMessages;
this.registrationId = registrationId;
this.signedPreKey = signedPreKey;
@@ -92,6 +96,14 @@ public class Device {
}
}
public String getVoipApnId() {
return voipApnId;
}
public void setVoipApnId(String voipApnId) {
this.voipApnId = voipApnId;
}
public void setLastSeen(long lastSeen) {
this.lastSeen = lastSeen;
}

View File

@@ -1,6 +1,5 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import org.skife.jdbi.v2.SQLStatement;
import org.skife.jdbi.v2.StatementContext;
import org.skife.jdbi.v2.sqlobject.Bind;
@@ -12,7 +11,7 @@ import org.skife.jdbi.v2.sqlobject.SqlUpdate;
import org.skife.jdbi.v2.sqlobject.customizers.Mapper;
import org.skife.jdbi.v2.tweak.ResultSetMapper;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
@@ -44,11 +43,16 @@ public abstract class Messages {
@Mapper(MessageMapper.class)
@SqlQuery("SELECT * FROM messages WHERE " + DESTINATION + " = :destination AND " + DESTINATION_DEVICE + " = :destination_device ORDER BY " + TIMESTAMP + " ASC")
abstract List<Pair<Long, OutgoingMessageSignal>> load(@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice);
abstract List<OutgoingMessageEntity> load(@Bind("destination") String destination,
@Bind("destination_device") long destinationDevice);
@SqlUpdate("DELETE FROM messages WHERE " + ID + " = :id")
abstract void remove(@Bind("id") long id);
@Mapper(MessageMapper.class)
@SqlQuery("DELETE FROM messages WHERE " + ID + " IN (SELECT " + ID + " FROM messages WHERE " + DESTINATION + " = :destination AND " + SOURCE + " = :source AND " + TIMESTAMP + " = :timestamp ORDER BY " + ID + " LIMIT 1) RETURNING *")
abstract OutgoingMessageEntity remove(@Bind("destination") String destination, @Bind("source") String source, @Bind("timestamp") long timestamp);
@Mapper(MessageMapper.class)
@SqlUpdate("DELETE FROM messages WHERE " + ID + " = :id AND " + DESTINATION + " = :destination")
abstract void remove(@Bind("destination") String destination, @Bind("id") long id);
@SqlUpdate("DELETE FROM messages WHERE " + DESTINATION + " = :destination")
abstract void clear(@Bind("destination") String destination);
@@ -56,20 +60,18 @@ public abstract class Messages {
@SqlUpdate("VACUUM messages")
public abstract void vacuum();
public static class MessageMapper implements ResultSetMapper<Pair<Long, OutgoingMessageSignal>> {
public static class MessageMapper implements ResultSetMapper<OutgoingMessageEntity> {
@Override
public Pair<Long, OutgoingMessageSignal> map(int i, ResultSet resultSet, StatementContext statementContext)
public OutgoingMessageEntity map(int i, ResultSet resultSet, StatementContext statementContext)
throws SQLException
{
return new Pair<>(resultSet.getLong(ID),
OutgoingMessageSignal.newBuilder()
.setType(resultSet.getInt(TYPE))
.setRelay(resultSet.getString(RELAY))
.setTimestamp(resultSet.getLong(TIMESTAMP))
.setSource(resultSet.getString(SOURCE))
.setSourceDevice(resultSet.getInt(SOURCE_DEVICE))
.setMessage(ByteString.copyFrom(resultSet.getBytes(MESSAGE)))
.build());
return new OutgoingMessageEntity(resultSet.getLong(ID),
resultSet.getInt(TYPE),
resultSet.getString(RELAY),
resultSet.getLong(TIMESTAMP),
resultSet.getString(SOURCE),
resultSet.getInt(SOURCE_DEVICE),
resultSet.getBytes(MESSAGE));
}
}

View File

@@ -1,7 +1,9 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.base.Optional;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.util.Pair;
import java.util.List;
@@ -18,7 +20,7 @@ public class MessagesManager {
return this.messages.store(message, destination, destinationDevice) + 1;
}
public List<Pair<Long, OutgoingMessageSignal>> getMessagesForDevice(String destination, long destinationDevice) {
public List<OutgoingMessageEntity> getMessagesForDevice(String destination, long destinationDevice) {
return this.messages.load(destination, destinationDevice);
}
@@ -26,7 +28,11 @@ public class MessagesManager {
this.messages.clear(destination);
}
public void delete(long id) {
this.messages.remove(id);
public Optional<OutgoingMessageEntity> delete(String destination, String source, long timestamp) {
return Optional.fromNullable(this.messages.remove(destination, source, timestamp));
}
public void delete(String destination, long id) {
this.messages.remove(destination, id);
}
}

View File

@@ -1,9 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public interface PubSubListener {
public void onPubSubMessage(PubSubMessage outgoingMessage);
}

View File

@@ -1,129 +1,88 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import io.dropwizard.lifecycle.Managed;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import redis.clients.jedis.BinaryJedisPubSub;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
public class PubSubManager {
public class PubSubManager implements Managed {
private static final byte[] KEEPALIVE_CHANNEL = "KEEPALIVE".getBytes();
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<String, PubSubListener> listeners = new HashMap<>();
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final DispatchManager dispatchManager;
private final JedisPool jedisPool;
private final JedisPool jedisPool;
private boolean subscribed = false;
public PubSubManager(JedisPool jedisPool) {
this.jedisPool = jedisPool;
initializePubSubWorker();
waitForSubscription();
public PubSubManager(JedisPool jedisPool, DispatchManager dispatchManager) {
this.dispatchManager = dispatchManager;
this.jedisPool = jedisPool;
}
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
String serializedAddress = address.serialize();
@Override
public void start() throws Exception {
this.dispatchManager.start();
listeners.put(serializedAddress, listener);
baseListener.subscribe(serializedAddress.getBytes());
}
KeepaliveDispatchChannel keepaliveDispatchChannel = new KeepaliveDispatchChannel();
this.dispatchManager.subscribe(KEEPALIVE_CHANNEL, keepaliveDispatchChannel);
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
String serializedAddress = address.serialize();
if (listeners.get(serializedAddress) == listener) {
listeners.remove(serializedAddress);
baseListener.unsubscribe(serializedAddress.getBytes());
synchronized (this) {
while (!subscribed) wait(0);
}
new KeepaliveSender().start();
}
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
@Override
public void stop() throws Exception {
dispatchManager.shutdown();
}
public void subscribe(WebsocketAddress address, DispatchChannel channel) {
String serializedAddress = address.serialize();
dispatchManager.subscribe(serializedAddress, channel);
}
public void unsubscribe(WebsocketAddress address, DispatchChannel dispatchChannel) {
String serializedAddress = address.serialize();
dispatchManager.unsubscribe(serializedAddress, dispatchChannel);
}
public boolean hasLocalSubscription(WebsocketAddress address) {
return dispatchManager.hasSubscription(address.serialize());
}
public boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize().getBytes(), message);
}
private synchronized boolean publish(byte[] channel, PubSubMessage message) {
private boolean publish(byte[] channel, PubSubMessage message) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.publish(channel, message.toByteArray()) != 0;
}
}
private synchronized void waitForSubscription() {
try {
while (!subscribed) {
wait();
}
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
private void initializePubSubWorker() {
new Thread("PubSubListener") {
@Override
public void run() {
for (;;) {
try (Jedis jedis = jedisPool.getResource()) {
jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
logger.warn("**** Unsubscribed from holding channel!!! ******");
}
}
}
}.start();
new Thread("PubSubKeepAlive") {
@Override
public void run() {
for (;;) {
try {
Thread.sleep(20000);
publish(KEEPALIVE_CHANNEL, PubSubMessage.newBuilder()
.setType(PubSubMessage.Type.KEEPALIVE)
.build());
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
}
}.start();
}
private class SubscriptionListener extends BinaryJedisPubSub {
private class KeepaliveDispatchChannel implements DispatchChannel {
@Override
public void onMessage(byte[] channel, byte[] message) {
try {
PubSubListener listener;
synchronized (PubSubManager.this) {
listener = listeners.get(new String(channel));
}
if (listener != null) {
listener.onPubSubMessage(PubSubMessage.parseFrom(message));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Error parsing PubSub protobuf", e);
}
public void onDispatchMessage(String channel, byte[] message) {
// Good
}
@Override
public void onPMessage(byte[] s, byte[] s2, byte[] s3) {
logger.warn("Received PMessage!");
}
@Override
public void onSubscribe(byte[] channel, int count) {
if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) {
public void onDispatchSubscribed(String channel) {
if (KEEPALIVE_CHANNEL.equals(channel)) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
@@ -132,12 +91,24 @@ public class PubSubManager {
}
@Override
public void onUnsubscribe(byte[] s, int i) {}
public void onDispatchUnsubscribed(String channel) {
logger.warn("***** KEEPALIVE CHANNEL UNSUBSCRIBED *****");
}
}
private class KeepaliveSender extends Thread {
@Override
public void onPUnsubscribe(byte[] s, int i) {}
@Override
public void onPSubscribe(byte[] s, int i) {}
public void run() {
while (true) {
try {
Thread.sleep(20000);
publish(KEEPALIVE_CHANNEL.getBytes(), PubSubMessage.newBuilder()
.setType(PubSubMessage.Type.KEEPALIVE)
.build());
} catch (Throwable e) {
logger.warn("***** KEEPALIVE EXCEPTION ******", e);
}
}
}
}
}

View File

@@ -1,55 +1,66 @@
package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration"));
private final AccountsManager accountsManager;
private final PushSender pushSender;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final PubSubManager pubSubManager;
public AuthenticatedConnectListener(AccountsManager accountsManager, PushSender pushSender,
MessagesManager messagesManager, PubSubManager pubSubManager)
ReceiptSender receiptSender, MessagesManager messagesManager,
PubSubManager pubSubManager)
{
this.accountsManager = accountsManager;
this.pushSender = pushSender;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.pubSubManager = pubSubManager;
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
Account account = context.getAuthenticated(Account.class).get();
Device device = account.getAuthenticatedDevice().get();
final Account account = context.getAuthenticated(Account.class).get();
final Device device = account.getAuthenticatedDevice().get();
final long connectTime = System.currentTimeMillis();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
messagesManager, account, device,
context.getClient());
updateLastSeen(account, device);
closeExistingDeviceConnection(account, device);
final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
messagesManager, pubSubManager,
account, device,
context.getClient());
connection.onConnected();
pubSubManager.subscribe(address, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
connection.onConnectionLost();
pubSubManager.unsubscribe(address, connection);
durationHistogram.update(System.currentTimeMillis() - connectTime);
}
});
}
@@ -60,12 +71,5 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
accountsManager.update(account);
}
}
private void closeExistingDeviceConnection(Account account, Device device) {
pubSubManager.publish(new WebsocketAddress(account.getNumber(), device.getId()),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.CLOSE)
.build());
}
}

View File

@@ -0,0 +1,52 @@
package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public class DeadLetterHandler implements DispatchChannel {
private final Logger logger = LoggerFactory.getLogger(DeadLetterHandler.class);
private final MessagesManager messagesManager;
public DeadLetterHandler(MessagesManager messagesManager) {
this.messagesManager = messagesManager;
}
@Override
public void onDispatchMessage(String channel, byte[] data) {
try {
logger.warn("Handling dead letter to: " + channel);
WebsocketAddress address = new WebsocketAddress(channel);
PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data);
switch (pubSubMessage.getType().getNumber()) {
case PubSubMessage.Type.DELIVER_VALUE:
OutgoingMessageSignal message = OutgoingMessageSignal.parseFrom(pubSubMessage.getContent());
messagesManager.insert(address.getNumber(), address.getDeviceId(), message);
break;
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Bad pubsub message", e);
} catch (InvalidWebsocketAddressException e) {
logger.warn("Invalid websocket address", e);
}
}
@Override
public void onDispatchSubscribed(String channel) {
logger.warn("DeadLetterHandler subscription notice! " + channel);
}
@Override
public void onDispatchUnsubscribed(String channel) {
logger.warn("DeadLetterHandler unsubscribe notice! " + channel);
}
}

View File

@@ -7,15 +7,16 @@ import java.security.SecureRandom;
public class ProvisioningAddress extends WebsocketAddress {
private final String address;
public ProvisioningAddress(String address, int id) throws InvalidWebsocketAddressException {
super(address, id);
}
public ProvisioningAddress(String address) throws InvalidWebsocketAddressException {
super(address, 0);
this.address = address;
public ProvisioningAddress(String serialized) throws InvalidWebsocketAddressException {
super(serialized);
}
public String getAddress() {
return address;
return getNumber();
}
public static ProvisioningAddress generate() {
@@ -24,7 +25,7 @@ public class ProvisioningAddress extends WebsocketAddress {
SecureRandom.getInstance("SHA1PRNG").nextBytes(random);
return new ProvisioningAddress(Base64.encodeBytesWithoutPadding(random)
.replace('+', '-').replace('/', '_'));
.replace('+', '-').replace('/', '_'), 0);
} catch (NoSuchAlgorithmException | InvalidWebsocketAddressException e) {
throw new AssertionError(e);
}

View File

@@ -14,13 +14,15 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final ProvisioningConnection connection = new ProvisioningConnection(pubSubManager, context.getClient());
connection.onConnected();
final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
pubSubManager.subscribe(provisioningAddress, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
connection.onConnectionLost();
pubSubManager.unsubscribe(provisioningAddress, connection);
}
});
}

View File

@@ -4,58 +4,68 @@ import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
import org.whispersystems.textsecuregcm.storage.PubSubListener;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
public class ProvisioningConnection implements PubSubListener {
public class ProvisioningConnection implements DispatchChannel {
private final PubSubManager pubSubManager;
private final ProvisioningAddress provisioningAddress;
private final WebSocketClient client;
private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
public ProvisioningConnection(PubSubManager pubSubManager, WebSocketClient client) {
this.pubSubManager = pubSubManager;
this.client = client;
this.provisioningAddress = ProvisioningAddress.generate();
private final WebSocketClient client;
public ProvisioningConnection(WebSocketClient client) {
this.client = client;
}
@Override
public void onPubSubMessage(PubSubMessage outgoingMessage) {
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/v1/message", body);
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this);
client.close(1001, "All you get.");
}
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/v1/message", body);
@Override
public void onFailure(Throwable throwable) {
pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this);
client.close(1001, "That's all!");
}
});
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
client.close(1001, "All you get.");
}
@Override
public void onFailure(Throwable throwable) {
client.close(1001, "That's all!");
}
});
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Protobuf Error: ", e);
}
}
public void onConnected() {
this.pubSubManager.subscribe(provisioningAddress, this);
this.client.sendRequest("PUT", "/v1/address", Optional.of(ProvisioningUuid.newBuilder()
.setUuid(provisioningAddress.getAddress())
.build()
.toByteArray()));
@Override
public void onDispatchSubscribed(String channel) {
try {
ProvisioningAddress address = new ProvisioningAddress(channel);
this.client.sendRequest("PUT", "/v1/address", Optional.of(ProvisioningUuid.newBuilder()
.setUuid(address.getAddress())
.build()
.toByteArray()));
} catch (InvalidWebsocketAddressException e) {
logger.warn("Badly formatted address", e);
this.client.close(1001, "Server Error");
}
}
public void onConnectionLost() {
this.pubSubManager.unsubscribe(provisioningAddress, this);
this.client.close(1001, "Done");
@Override
public void onDispatchUnsubscribed(String channel) {
this.client.close(1001, "Closed");
}
}

View File

@@ -4,75 +4,66 @@ import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubListener;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.ws.rs.WebApplicationException;
import java.io.IOException;
import java.util.List;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public class WebSocketConnection implements PubSubListener {
public class WebSocketConnection implements DispatchChannel {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final AccountsManager accountsManager;
private final ReceiptSender receiptSender;
private final PushSender pushSender;
private final MessagesManager messagesManager;
private final PubSubManager pubSubManager;
private final Account account;
private final Device device;
private final WebsocketAddress address;
private final WebSocketClient client;
public WebSocketConnection(AccountsManager accountsManager,
PushSender pushSender,
public WebSocketConnection(PushSender pushSender,
ReceiptSender receiptSender,
MessagesManager messagesManager,
PubSubManager pubSubManager,
Account account,
Device device,
WebSocketClient client)
{
this.accountsManager = accountsManager;
this.pushSender = pushSender;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.pubSubManager = pubSubManager;
this.account = account;
this.device = device;
this.client = client;
this.address = new WebsocketAddress(account.getNumber(), device.getId());
}
public void onConnected() {
pubSubManager.subscribe(address, this);
processStoredMessages();
}
public void onConnectionLost() {
pubSubManager.unsubscribe(address, this);
}
@Override
public void onPubSubMessage(PubSubMessage pubSubMessage) {
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage pubSubMessage = PubSubMessage.parseFrom(message);
switch (pubSubMessage.getType().getNumber()) {
case PubSubMessage.Type.QUERY_DB_VALUE:
processStoredMessages();
@@ -80,10 +71,6 @@ public class WebSocketConnection implements PubSubListener {
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()), Optional.<Long>absent());
break;
case PubSubMessage.Type.CLOSE_VALUE:
client.close(1000, "OK");
pubSubManager.unsubscribe(address, this);
break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}
@@ -92,6 +79,15 @@ public class WebSocketConnection implements PubSubListener {
}
}
@Override
public void onDispatchUnsubscribed(String channel) {
client.close(1000, "OK");
}
public void onDispatchSubscribed(String channel) {
processStoredMessages();
}
private void sendMessage(final OutgoingMessageSignal message,
final Optional<Long> storedMessageId)
{
@@ -106,9 +102,9 @@ public class WebSocketConnection implements PubSubListener {
boolean isReceipt = message.getType() == OutgoingMessageSignal.Type.RECEIPT_VALUE;
if (isSuccessResponse(response)) {
if (storedMessageId.isPresent()) messagesManager.delete(storedMessageId.get());
if (storedMessageId.isPresent()) messagesManager.delete(account.getNumber(), storedMessageId.get());
if (!isReceipt) sendDeliveryReceiptFor(message);
} else if (!isSuccessResponse(response) & !storedMessageId.isPresent()) {
} else if (!isSuccessResponse(response) && !storedMessageId.isPresent()) {
requeueMessage(message);
}
}
@@ -138,34 +134,34 @@ public class WebSocketConnection implements PubSubListener {
private void sendDeliveryReceiptFor(OutgoingMessageSignal message) {
try {
Optional<Account> source = accountsManager.get(message.getSource());
if (!source.isPresent()) {
logger.warn("Source account disappeared? (%s)", message.getSource());
return;
}
OutgoingMessageSignal.Builder receipt =
OutgoingMessageSignal.newBuilder()
.setSource(account.getNumber())
.setSourceDevice((int) device.getId())
.setTimestamp(message.getTimestamp())
.setType(OutgoingMessageSignal.Type.RECEIPT_VALUE);
for (Device device : source.get().getDevices()) {
pushSender.sendMessage(source.get(), device, receipt.build());
}
} catch (NotPushRegisteredException | TransientPushFailureException e) {
logger.warn("sendDeliveryReceiptFor", "Delivery receipet", e);
receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp(),
message.hasRelay() ? Optional.of(message.getRelay()) :
Optional.<String>absent());
} catch (IOException | NoSuchUserException | TransientPushFailureException | NotPushRegisteredException e) {
logger.warn("sendDeliveryReceiptFor", e);
} catch (WebApplicationException e) {
logger.warn("Bad federated response for receipt: " + e.getResponse().getStatus());
}
}
private void processStoredMessages() {
List<Pair<Long, OutgoingMessageSignal>> messages = messagesManager.getMessagesForDevice(account.getNumber(),
device.getId());
List<OutgoingMessageEntity> messages = messagesManager.getMessagesForDevice(account.getNumber(), device.getId());
for (Pair<Long, OutgoingMessageSignal> message : messages) {
sendMessage(message.second(), Optional.of(message.first()));
for (OutgoingMessageEntity message : messages) {
OutgoingMessageSignal.Builder builder = OutgoingMessageSignal.newBuilder()
.setType(message.getType())
.setMessage(ByteString.copyFrom(message.getMessage()))
.setSourceDevice(message.getSourceDevice())
.setSource(message.getSource())
.setTimestamp(message.getTimestamp());
if (message.getRelay() != null && !message.getRelay().isEmpty()) {
builder.setRelay(message.getRelay());
}
sendMessage(builder.build(), Optional.of(message.getId()));
}
}
}

View File

@@ -10,6 +10,29 @@ public class WebsocketAddress {
this.deviceId = deviceId;
}
public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException {
try {
String[] parts = serialized.split(":", 2);
if (parts.length != 2) {
throw new InvalidWebsocketAddressException("Bad address: " + serialized);
}
this.number = parts[0];
this.deviceId = Long.parseLong(parts[1]);
} catch (NumberFormatException e) {
throw new InvalidWebsocketAddressException(e);
}
}
public String getNumber() {
return number;
}
public long getDeviceId() {
return deviceId;
}
public String serialize() {
return number + ":" + deviceId;
}

View File

@@ -0,0 +1,127 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExternalResource;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
public class DispatchManagerTest {
private PubSubConnection pubSubConnection;
private RedisPubSubConnectionFactory socketFactory;
private DispatchManager dispatchManager;
private PubSubReplyInputStream pubSubReplyInputStream;
@Rule
public ExternalResource resource = new ExternalResource() {
@Override
protected void before() throws Throwable {
pubSubConnection = mock(PubSubConnection.class );
socketFactory = mock(RedisPubSubConnectionFactory.class);
pubSubReplyInputStream = new PubSubReplyInputStream();
when(socketFactory.connect()).thenReturn(pubSubConnection);
when(pubSubConnection.read()).thenAnswer(new Answer<PubSubReply>() {
@Override
public PubSubReply answer(InvocationOnMock invocationOnMock) throws Throwable {
return pubSubReplyInputStream.read();
}
});
dispatchManager = new DispatchManager(socketFactory, Optional.<DispatchChannel>absent());
dispatchManager.start();
}
@Override
protected void after() {
}
};
@Test
public void testConnect() {
verify(socketFactory).connect();
}
@Test
public void testSubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
}
@Test
public void testSubscribeUnsubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
dispatchManager.unsubscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchUnsubscribed(eq("foo"));
}
@Test
public void testMessages() throws IOException {
DispatchChannel fooChannel = mock(DispatchChannel.class);
DispatchChannel barChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", fooChannel);
dispatchManager.subscribe("bar", barChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "bar", Optional.<byte[]>absent()));
verify(fooChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
verify(barChannel, timeout(1000)).onDispatchSubscribed(eq("bar"));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "foo", Optional.of("hello".getBytes())));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "bar", Optional.of("there".getBytes())));
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(fooChannel, timeout(1000)).onDispatchMessage(eq("foo"), captor.capture());
assertArrayEquals("hello".getBytes(), captor.getValue());
verify(barChannel, timeout(1000)).onDispatchMessage(eq("bar"), captor.capture());
assertArrayEquals("there".getBytes(), captor.getValue());
}
private static class PubSubReplyInputStream {
private final List<PubSubReply> pubSubReplyList = new LinkedList<>();
public synchronized PubSubReply read() {
try {
while (pubSubReplyList.isEmpty()) wait();
return pubSubReplyList.remove(0);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
public synchronized void write(PubSubReply pubSubReply) {
pubSubReplyList.add(pubSubReply);
notifyAll();
}
}
}

View File

@@ -0,0 +1,263 @@
package org.whispersystems.dispatch.redis;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.*;
public class PubSubConnectionTest {
private static final String REPLY = "*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"abcde\r\n" +
":1\r\n" +
"*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"fghij\r\n" +
":2\r\n" +
"*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"klmno\r\n" +
":2\r\n" +
"*3\r\n" +
"$7\r\n" +
"message\r\n" +
"$5\r\n" +
"abcde\r\n" +
"$10\r\n" +
"1234567890\r\n" +
"*3\r\n" +
"$7\r\n" +
"message\r\n" +
"$5\r\n" +
"klmno\r\n" +
"$10\r\n" +
"0987654321\r\n";
@Test
public void testSubscribe() throws IOException {
// ByteChannel byteChannel = mock(ByteChannel.class);
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
PubSubConnection connection = new PubSubConnection(socket);
connection.subscribe("foobar");
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(outputStream).write(captor.capture());
assertArrayEquals(captor.getValue(), "SUBSCRIBE foobar\r\n".getBytes());
}
@Test
public void testUnsubscribe() throws IOException {
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
PubSubConnection connection = new PubSubConnection(socket);
connection.unsubscribe("bazbar");
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(outputStream).write(captor.capture());
assertArrayEquals(captor.getValue(), "UNSUBSCRIBE bazbar\r\n".getBytes());
}
@Test
public void testTricklyResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new TrickleInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
@Test
public void testFullResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new FullInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
@Test
public void testRandomLengthResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new RandomInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
private InputStream mockInputStreamFor(final MockInputStream stub) throws IOException {
InputStream result = mock(InputStream.class);
when(result.read()).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
return stub.read();
}
});
when(result.read(any(byte[].class))).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
byte[] buffer = (byte[])invocationOnMock.getArguments()[0];
return stub.read(buffer, 0, buffer.length);
}
});
when(result.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
byte[] buffer = (byte[]) invocationOnMock.getArguments()[0];
int offset = (int) invocationOnMock.getArguments()[1];
int length = (int) invocationOnMock.getArguments()[2];
return stub.read(buffer, offset, length);
}
});
return result;
}
private void readResponses(PubSubConnection pubSubConnection) throws Exception {
PubSubReply reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "abcde");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "fghij");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "klmno");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
assertEquals(reply.getChannel(), "abcde");
assertArrayEquals(reply.getContent().get(), "1234567890".getBytes());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
assertEquals(reply.getChannel(), "klmno");
assertArrayEquals(reply.getContent().get(), "0987654321".getBytes());
}
private interface MockInputStream {
public int read();
public int read(byte[] input, int offset, int length);
}
private static class TrickleInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private TrickleInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
input[offset] = data[index++];
return 1;
}
}
private static class FullInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private FullInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
int amount = Math.min(data.length - index, length);
System.arraycopy(data, index, input, offset, amount);
index += length;
return amount;
}
}
private static class RandomInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private RandomInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
try {
int maxCopy = Math.min(data.length - index, length);
int randomCopy = SecureRandom.getInstance("SHA1PRNG").nextInt(maxCopy) + 1;
int copyAmount = Math.min(maxCopy, randomCopy);
System.arraycopy(data, index, input, offset, copyAmount);
index += copyAmount;
return copyAmount;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
}
}

View File

@@ -0,0 +1,51 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class ArrayReplyHeaderTest {
@Test(expected = IOException.class)
public void testNull() throws IOException {
new ArrayReplyHeader(null);
}
@Test(expected = IOException.class)
public void testBadPrefix() throws IOException {
new ArrayReplyHeader(":3");
}
@Test(expected = IOException.class)
public void testEmpty() throws IOException {
new ArrayReplyHeader("");
}
@Test(expected = IOException.class)
public void testTruncated() throws IOException {
new ArrayReplyHeader("*");
}
@Test(expected = IOException.class)
public void testBadNumber() throws IOException {
new ArrayReplyHeader("*ABC");
}
@Test
public void testValid() throws IOException {
assertEquals(4, new ArrayReplyHeader("*4").getElementCount());
}
}

View File

@@ -0,0 +1,36 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class IntReplyHeaderTest {
@Test(expected = IOException.class)
public void testNull() throws IOException {
new IntReply(null);
}
@Test(expected = IOException.class)
public void testEmpty() throws IOException {
new IntReply("");
}
@Test(expected = IOException.class)
public void testBadNumber() throws IOException {
new IntReply(":A");
}
@Test(expected = IOException.class)
public void testBadFormat() throws IOException {
new IntReply("*");
}
@Test
public void testValid() throws IOException {
assertEquals(23, new IntReply(":23").getValue());
}
}

View File

@@ -0,0 +1,47 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class StringReplyHeaderTest {
@Test
public void testNull() {
try {
new StringReplyHeader(null);
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testBadNumber() {
try {
new StringReplyHeader("$100A");
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testBadPrefix() {
try {
new StringReplyHeader("*");
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testValid() throws IOException {
assertEquals(1000, new StringReplyHeader("$1000").getStringLength());
}
}

View File

@@ -4,7 +4,6 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import org.hamcrest.CoreMatchers;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -21,15 +20,16 @@ import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import javax.ws.rs.core.MediaType;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import io.dropwizard.testing.junit.ResourceTestRule;
@@ -47,8 +47,10 @@ public class FederatedControllerTest {
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private PushSender pushSender = mock(PushSender.class );
private ReceiptSender receiptSender = mock(ReceiptSender.class);
private FederatedClientManager federatedClientManager = mock(FederatedClientManager.class);
private AccountsManager accountsManager = mock(AccountsManager.class );
private MessagesManager messagesManager = mock(MessagesManager.class);
private RateLimiters rateLimiters = mock(RateLimiters.class );
private RateLimiter rateLimiter = mock(RateLimiter.class );
@@ -57,7 +59,7 @@ public class FederatedControllerTest {
private final ObjectMapper mapper = new ObjectMapper();
private final MessageController messageController = new MessageController(rateLimiters, pushSender, accountsManager, federatedClientManager);
private final MessageController messageController = new MessageController(rateLimiters, pushSender, receiptSender, accountsManager, messagesManager, federatedClientManager);
private final KeysControllerV2 keysControllerV2 = mock(KeysControllerV2.class);
@Rule
@@ -72,12 +74,12 @@ public class FederatedControllerTest {
@Before
public void setup() throws Exception {
Set<Device> singleDeviceList = new HashSet<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111, null, System.currentTimeMillis()));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, null, false, 111, null, System.currentTimeMillis()));
}};
Set<Device> multiDeviceList = new HashSet<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222, null, System.currentTimeMillis()));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333, null, System.currentTimeMillis()));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, null, false, 222, null, System.currentTimeMillis()));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, null, false, 333, null, System.currentTimeMillis()));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);

View File

@@ -3,28 +3,34 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
import com.sun.jersey.api.client.ClientResponse;
import org.hamcrest.CoreMatchers;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import javax.ws.rs.core.MediaType;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
@@ -32,6 +38,7 @@ import io.dropwizard.testing.junit.ResourceTestRule;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
@@ -44,8 +51,10 @@ public class MessageControllerTest {
private static final String MULTI_DEVICE_RECIPIENT = "+14152222222";
private final PushSender pushSender = mock(PushSender.class );
private final ReceiptSender receiptSender = mock(ReceiptSender.class);
private final FederatedClientManager federatedClientManager = mock(FederatedClientManager.class);
private final AccountsManager accountsManager = mock(AccountsManager.class );
private final MessagesManager messagesManager = mock(MessagesManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class );
private final RateLimiter rateLimiter = mock(RateLimiter.class );
@@ -54,21 +63,21 @@ public class MessageControllerTest {
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthenticator())
.addResource(new MessageController(rateLimiters, pushSender, accountsManager,
federatedClientManager))
.addResource(new MessageController(rateLimiters, pushSender, receiptSender, accountsManager,
messagesManager, federatedClientManager))
.build();
@Before
public void setup() throws Exception {
Set<Device> singleDeviceList = new HashSet<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111, null, System.currentTimeMillis()));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, null, false, 111, null, System.currentTimeMillis()));
}};
Set<Device> multiDeviceList = new HashSet<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis()));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis()));
add(new Device(3, "foo", "bar", "baz", "isgcm", null, false, 444, null, System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, null, false, 222, new SignedPreKey(111, "foo", "bar"), System.currentTimeMillis()));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, null, false, 333, new SignedPreKey(222, "oof", "rab"), System.currentTimeMillis()));
add(new Device(3, "foo", "bar", "baz", "isgcm", null, null, false, 444, null, System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31)));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);
@@ -177,4 +186,75 @@ public class MessageControllerTest {
}
@Test
public synchronized void testGetMessages() throws Exception {
final long timestampOne = 313377;
final long timestampTwo = 313388;
List<OutgoingMessageEntity> messages = new LinkedList<OutgoingMessageEntity>() {{
add(new OutgoingMessageEntity(1L, MessageProtos.OutgoingMessageSignal.Type.CIPHERTEXT_VALUE, null, timestampOne, "+14152222222", 2, "hi there".getBytes()));
add(new OutgoingMessageEntity(2L, MessageProtos.OutgoingMessageSignal.Type.RECEIPT_VALUE, null, timestampTwo, "+14152222222", 2, null));
}};
when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(1L))).thenReturn(messages);
OutgoingMessageEntityList response =
resources.client().resource("/v1/messages/")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.accept(MediaType.APPLICATION_JSON_TYPE)
.get(OutgoingMessageEntityList.class);
assertEquals(response.getMessages().size(), 2);
assertEquals(response.getMessages().get(0).getId(), 0);
assertEquals(response.getMessages().get(1).getId(), 0);
assertEquals(response.getMessages().get(0).getTimestamp(), timestampOne);
assertEquals(response.getMessages().get(1).getTimestamp(), timestampTwo);
}
@Test
public synchronized void testDeleteMessages() throws Exception {
long timestamp = System.currentTimeMillis();
when(messagesManager.delete(AuthHelper.VALID_NUMBER, "+14152222222", 31337))
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L,
MessageProtos.OutgoingMessageSignal.Type.CIPHERTEXT_VALUE,
null, timestamp,
"+14152222222", 1, "hi".getBytes())));
when(messagesManager.delete(AuthHelper.VALID_NUMBER, "+14152222222", 31338))
.thenReturn(Optional.of(new OutgoingMessageEntity(31337L,
MessageProtos.OutgoingMessageSignal.Type.RECEIPT_VALUE,
null, System.currentTimeMillis(),
"+14152222222", 1, null)));
when(messagesManager.delete(AuthHelper.VALID_NUMBER, "+14152222222", 31339))
.thenReturn(Optional.<OutgoingMessageEntity>absent());
ClientResponse response = resources.client().resource(String.format("/v1/messages/%s/%d", "+14152222222", 31337))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.delete(ClientResponse.class);
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(any(Account.class), eq("+14152222222"), eq(timestamp), eq(Optional.<String>absent()));
response = resources.client().resource(String.format("/v1/messages/%s/%d", "+14152222222", 31338))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.delete(ClientResponse.class);
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verifyNoMoreInteractions(receiptSender);
response = resources.client().resource(String.format("/v1/messages/%s/%d", "+14152222222", 31339))
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.delete(ClientResponse.class);
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verifyNoMoreInteractions(receiptSender);
}
}

View File

@@ -6,28 +6,21 @@ import com.sun.jersey.api.client.ClientResponse;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.ReceiptController;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import io.dropwizard.testing.junit.ResourceTestRule;
import static org.fest.assertions.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.core.IsEqual.equalTo;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
@@ -40,23 +33,25 @@ public class ReceiptControllerTest {
private final FederatedClientManager federatedClientManager = mock(FederatedClientManager.class);
private final AccountsManager accountsManager = mock(AccountsManager.class );
private final ReceiptSender receiptSender = new ReceiptSender(accountsManager, pushSender, federatedClientManager);
private final ObjectMapper mapper = new ObjectMapper();
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
.addProvider(AuthHelper.getAuthenticator())
.addResource(new ReceiptController(accountsManager, federatedClientManager, pushSender))
.addResource(new ReceiptController(receiptSender))
.build();
@Before
public void setup() throws Exception {
Set<Device> singleDeviceList = new HashSet<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 111, null, System.currentTimeMillis()));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, null, false, 111, null, System.currentTimeMillis()));
}};
Set<Device> multiDeviceList = new HashSet<Device>() {{
add(new Device(1, "foo", "bar", "baz", "isgcm", null, false, 222, null, System.currentTimeMillis()));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, false, 333, null, System.currentTimeMillis()));
add(new Device(1, "foo", "bar", "baz", "isgcm", null, null, false, 222, null, System.currentTimeMillis()));
add(new Device(2, "foo", "bar", "baz", "isgcm", null, null, false, 333, null, System.currentTimeMillis()));
}};
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, false, singleDeviceList);

View File

@@ -9,7 +9,9 @@ import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -44,50 +46,26 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMe
public class WebSocketConnectionTest {
// private static final ObjectMapper mapper = new ObjectMapper();
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
// private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
@Test
public void testCloseExisting() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class );
WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
Account account = mock(Account.class );
Device device = mock(Device.class );
when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14157777777");
when(device.getId()).thenReturn(1L);
connectListener.onWebSocketConnect(sessionContext);
ArgumentCaptor<PubSubProtos.PubSubMessage> message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture());
assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE);
}
private static final ReceiptSender receiptSender = mock(ReceiptSender.class);
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, receiptSender, storedMessages, pubSubManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@@ -98,10 +76,6 @@ public class WebSocketConnectionTest {
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
// when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
//
// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
@@ -114,13 +88,6 @@ public class WebSocketConnectionTest {
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
//
// controller.onWebSocketConnect(session);
// verify(session, never()).close();
// verify(session, never()).close(any(CloseStatus.class));
// verify(session, never()).close(anyInt(), anyString());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
@@ -128,25 +95,16 @@ public class WebSocketConnectionTest {
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.isPresent());
// when(sessionContext.getAuthenticated(Account.class)).thenReturn(account);
//
// WebSocketClient client = mock(WebSocketClient.class);
// when(sessionContext.getClient()).thenReturn(client);
//
// connectListener.onWebSocketConnect(sessionContext);
//
// verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
// verify(client).close(eq(4001), anyString());
}
@Test
public void testOpen() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
List<Pair<Long, OutgoingMessageSignal>> outgoingMessages = new LinkedList<Pair<Long, OutgoingMessageSignal>> () {{
add(new Pair<>(1L, createMessage("sender1", 1111, false, "first")));
add(new Pair<>(2L, createMessage("sender1", 2222, false, "second")));
add(new Pair<>(3L, createMessage("sender2", 3333, false, "third")));
List<OutgoingMessageEntity> outgoingMessages = new LinkedList<OutgoingMessageEntity> () {{
add(createMessage(1L, "sender1", 1111, false, "first"));
add(createMessage(2L, "sender1", 2222, false, "second"));
add(createMessage(3L, "sender2", 3333, false, "third"));
}};
when(device.getId()).thenReturn(2L);
@@ -183,12 +141,11 @@ public class WebSocketConnectionTest {
}
});
WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages,
pubSubManager, account, device, client);
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
account, device, client);
connection.onConnected();
verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection)));
connection.onDispatchSubscribed(websocketAddress.serialize());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
assertTrue(futures.size() == 3);
@@ -200,26 +157,102 @@ public class WebSocketConnectionTest {
futures.get(0).setException(new IOException());
futures.get(2).setException(new IOException());
List<OutgoingMessageSignal> pending = new LinkedList<OutgoingMessageSignal>() {{
add(createMessage("sender1", 1111, false, "first"));
add(createMessage("sender2", 3333, false, "third"));
}};
verify(storedMessages, times(1)).delete(eq(account.getNumber()), eq(2L));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L), eq(Optional.<String>absent()));
// verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(OutgoingMessageSignal.class));
verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(OutgoingMessageSignal.class));
connection.onConnectionLost();
verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
}
private OutgoingMessageSignal createMessage(String sender, long timestamp, boolean receipt, String content) {
return OutgoingMessageSignal.newBuilder()
.setSource(sender)
.setSourceDevice(1)
.setType(receipt ? OutgoingMessageSignal.Type.RECEIPT_VALUE : OutgoingMessageSignal.Type.CIPHERTEXT_VALUE)
.setTimestamp(timestamp)
.setMessage(ByteString.copyFrom(content.getBytes()))
.build();
@Test
public void testOnlineSend() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
OutgoingMessageSignal firstMessage = OutgoingMessageSignal.newBuilder()
.setMessage(ByteString.copyFrom("first".getBytes()))
.setSource("sender1")
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(1)
.setType(OutgoingMessageSignal.Type.CIPHERTEXT_VALUE)
.build();
OutgoingMessageSignal secondMessage = OutgoingMessageSignal.newBuilder()
.setMessage(ByteString.copyFrom("second".getBytes()))
.setSource("sender2")
.setTimestamp(System.currentTimeMillis())
.setSourceDevice(2)
.setType(OutgoingMessageSignal.Type.CIPHERTEXT_VALUE)
.build();
List<OutgoingMessageEntity> pendingMessages = new LinkedList<>();
when(device.getId()).thenReturn(2L);
when(device.getSignalingKey()).thenReturn(Base64.encodeBytes(new byte[52]));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
final Device sender1device = mock(Device.class);
Set<Device> sender1devices = new HashSet<Device>() {{
add(sender1device);
}};
Account sender1 = mock(Account.class);
when(sender1.getDevices()).thenReturn(sender1devices);
when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.get("sender2")).thenReturn(Optional.<Account>absent());
when(storedMessages.getMessagesForDevice(account.getNumber(), device.getId()))
.thenReturn(pendingMessages);
final List<SettableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
final WebSocketClient client = mock(WebSocketClient.class);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class)))
.thenAnswer(new Answer<SettableFuture<WebSocketResponseMessage>>() {
@Override
public SettableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
SettableFuture<WebSocketResponseMessage> future = SettableFuture.create();
futures.add(future);
return future;
}
});
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
account, device, client);
connection.onDispatchSubscribed(websocketAddress.serialize());
connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
.setContent(ByteString.copyFrom(firstMessage.toByteArray()))
.build().toByteArray());
connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.DELIVER)
.setContent(ByteString.copyFrom(secondMessage.toByteArray()))
.build().toByteArray());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
assertEquals(futures.size(), 2);
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
futures.get(1).set(response);
futures.get(0).setException(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()), eq(Optional.<String>absent()));
verify(pushSender, times(1)).sendMessage(eq(account), eq(device), any(OutgoingMessageSignal.class));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
}
private OutgoingMessageEntity createMessage(long id, String sender, long timestamp, boolean receipt, String content) {
return new OutgoingMessageEntity(id, receipt ? OutgoingMessageSignal.Type.RECEIPT_VALUE : OutgoingMessageSignal.Type.CIPHERTEXT_VALUE,
null, timestamp, sender, 1, content.getBytes());
}
}