Compare commits

...

14 Commits

Author SHA1 Message Date
Erik Johnston
a36b38c3df Explicitly add Create event as auth event 2015-10-02 13:14:10 +01:00
Erik Johnston
0ef17169c4 Explicitly add Create event as auth event 2015-10-02 13:11:49 +01:00
Erik Johnston
9434ad729a Merge branch 'develop' of github.com:matrix-org/synapse into erikj/login_token 2015-10-01 09:21:50 +01:00
Erik Johnston
97b494a655 Don't enable client_addr by default 2015-09-28 18:46:50 +01:00
Erik Johnston
ce38f09ac3 Add missing file 2015-09-28 18:15:07 +01:00
Erik Johnston
6d0e02a140 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/login_token 2015-09-28 17:45:14 +01:00
Erik Johnston
64afabd0bf Move QR code to client API 2015-09-28 17:45:00 +01:00
Erik Johnston
24b8c58fb2 s/nonce/txn_id/ 2015-09-28 17:35:06 +01:00
Erik Johnston
448e525ed1 Unused import 2015-09-28 17:16:40 +01:00
Erik Johnston
3db9a4a26c s/nonce/txn_id/ 2015-09-28 16:43:35 +01:00
Erik Johnston
64abb765dd Needs to be dict, not string 2015-09-26 18:03:16 +01:00
Erik Johnston
1ca673a876 Support nonces 2015-09-26 17:38:40 +01:00
Erik Johnston
d01ef0c848 Return correct number of params 2015-09-25 11:21:34 +01:00
Erik Johnston
936cdac6aa Add support for logging in via token. Also add QR code to server up token. 2015-09-25 11:18:15 +01:00
6 changed files with 255 additions and 7 deletions

View File

@@ -117,6 +117,11 @@ class ServerConfig(Config):
self.content_addr = content_addr self.content_addr = content_addr
client_addr = config.get("client_addr")
if not client_addr:
client_addr = self.content_addr
self.client_addr = client_addr
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
@@ -140,6 +145,9 @@ class ServerConfig(Config):
# Whether to serve a web client from the HTTP/HTTPS root resource. # Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True web_client: True
# URL clients can use to talk to the server.
# client_addr: "https://%(server_name)s:%(bind_port)s"
# Set the soft limit on the number of file descriptors synapse can use # Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the # Zero is used to indicate synapse should set the soft limit to the
# hard limit. # hard limit.

View File

@@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@@ -33,6 +33,8 @@ import synapse.util.stringutils as stringutils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MACAROON_TYPE_LOGIN_TOKEN = "st_login"
class AuthHandler(BaseHandler): class AuthHandler(BaseHandler):
@@ -46,6 +48,22 @@ class AuthHandler(BaseHandler):
} }
self.sessions = {} self.sessions = {}
self._nonces = {}
self.clock.looping_call(self._prune_nonce, 60 * 1000)
def _prune_nonce(self):
now = self.clock.time_msec()
self._nonces = {
user_id: {
nonce: nonce_dict
for nonce, nonce_dict in user_dict.items()
if nonce_dict.get("expiry", 0) < now - 60 * 1000
}
for user_id, user_dict in self._nonces.items()
if user_dict
}
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
""" """
@@ -290,11 +308,105 @@ class AuthHandler(BaseHandler):
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash) self._check_password(user_id, password, password_hash)
res = yield self._issue_tokens(user_id)
defer.returnValue(res)
@defer.inlineCallbacks
def _issue_tokens(self, user_id):
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def do_short_term_token_login(self, token, user_id, txn_id):
macaroon_exact_caveats = [
"gen = 1",
"type = %s" % (MACAROON_TYPE_LOGIN_TOKEN,),
"user_id = %s" % (user_id,)
]
macaroon_general_caveats = [
self._verify_macaroon_expiry,
lambda c: self._verify_nonce(c, user_id, txn_id)
]
try:
macaroon = pymacaroons.Macaroon.deserialize(token)
v = pymacaroons.Verifier()
for exact_caveat in macaroon_exact_caveats:
v.satisfy_exact(exact_caveat)
for general_caveat in macaroon_general_caveats:
v.satisfy_general(general_caveat)
verified = v.verify(macaroon, self.hs.config.macaroon_secret_key)
if not verified:
raise LoginError(403, "Invalid token", errcode=Codes.FORBIDDEN)
user_id, access_token, refresh_token = yield self._issue_tokens(
user_id=user_id,
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
defer.returnValue(result)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError) as e:
logger.info("Invalid token: %s", e.message)
raise LoginError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _verify_macaroon_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
def _verify_nonce(self, caveat, user_id, txn_id):
prefix = "nonce = "
if not caveat.startswith(prefix):
return False
user_dict = self._nonces.get(user_id, {})
nonce = caveat[len(prefix):]
does_match = (
nonce in user_dict
and user_dict[nonce].get("txn_id", None) in (None, txn_id)
)
if does_match:
user_dict.setdefault(nonce, {})["txn_id"] = txn_id
return does_match
def make_short_term_token(self, user_id, nonce):
user_nonces = self._nonces.setdefault(user_id, {})
if user_nonces.get(nonce, {}).get("txn_id", None) is not None:
raise SynapseError(400, "nonce already used")
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = %s" % (MACAROON_TYPE_LOGIN_TOKEN,))
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
user_nonces[nonce] = {
"txn_id": None,
"expiry": expiry,
}
return macaroon.serialize()
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case

View File

@@ -150,7 +150,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids or e.type == EventTypes.Create
} }
event_infos.append({ event_infos.append({
"event": e, "event": e,
@@ -660,7 +660,7 @@ class FederationHandler(BaseHandler):
"event": e, "event": e,
"auth_events": { "auth_events": {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids or e.type == EventTypes.Create
} }
}) })
@@ -669,7 +669,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in event.auth_events] auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = { auth_events = {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids or e.type == EventTypes.Create
} }
_, event_stream_id, max_stream_id = yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -1166,7 +1166,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in remote_auth_chain (e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids or e.type == EventTypes.Create
} }
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
@@ -1284,6 +1284,7 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e (e.type, e.state_key): e
for e in result["auth_chain"] for e in result["auth_chain"]
if e.event_id in auth_ids if e.event_id in auth_ids
or event.type == EventTypes.Create
} }
ev.internal_metadata.outlier = True ev.internal_metadata.outlier = True

View File

@@ -15,7 +15,7 @@
from . import ( from . import (
room, events, register, login, profile, presence, initial_sync, directory, room, events, register, login, profile, presence, initial_sync, directory,
voip, admin, pusher, push_rule voip, admin, pusher, push_rule, login_qr
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@@ -42,3 +42,4 @@ class ClientV1RestResource(JsonResource):
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource)
login_qr.register_servlets(hs, client_resource)

View File

@@ -35,6 +35,7 @@ class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
TOKEN_TYPE = "m.login.token"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
@@ -42,7 +43,10 @@ class LoginRestServlet(ClientV1RestServlet):
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
def on_GET(self, request): def on_GET(self, request):
flows = [{"type": LoginRestServlet.PASS_TYPE}] flows = [
{"type": LoginRestServlet.PASS_TYPE},
{"type": LoginRestServlet.TOKEN_TYPE}
]
if self.saml2_enabled: if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE}) flows.append({"type": LoginRestServlet.SAML2_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
@@ -67,6 +71,15 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
auth_handler = self.handlers.auth_handler
token = login_submission["token"]
user_id = login_submission["user"]
txn_id = login_submission["txn_id"]
result = yield auth_handler.do_short_term_token_login(
token, user_id, txn_id
)
defer.returnValue((200, result))
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@@ -100,6 +113,15 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
def _verify_macaroon_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
class LoginFallbackRestServlet(ClientV1RestServlet): class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$") PATTERN = client_path_pattern("/login/fallback$")

View File

@@ -0,0 +1,104 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, threads
from synapse.api.errors import CodeMessageException
from synapse.util.stringutils import random_string
from base import ClientV1RestServlet, client_path_pattern
import simplejson
import logging
from unpaddedbase64 import encode_base64
from hashlib import sha256
from OpenSSL import crypto
logger = logging.getLogger(__name__)
class LoginQRResource(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/make_qr/(?P<nonce>[^/]*)$")
def __init__(self, hs):
super(LoginQRResource, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
self.config = hs.get_config()
@defer.inlineCallbacks
def on_GET(self, request, nonce):
try:
auth_user, _ = yield self.auth.get_user_by_req(request)
if not nonce:
nonce = random_string(10)
image = yield self.make_short_term_qr_code(
auth_user.to_string(), nonce
)
request.setHeader(b"Content-Type", b"image/png")
image.save(request)
request.finish()
except CodeMessageException as e:
logger.info("Returning: %s", e)
request.setResponseCode(e.code)
request.write("%s: %s" % (e.code, e.message))
request.finish()
except Exception:
logger.exception("Exception while generating token")
request.setResponseCode(500)
request.write("Internal server error")
request.finish()
@defer.inlineCallbacks
def make_short_term_qr_code(self, user_id, nonce):
h = self.handlers.auth_handler
token = h.make_short_term_token(user_id, nonce)
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.config.tls_certificate
)
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
def gen():
import qrcode
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=5,
)
qr.add_data(simplejson.dumps({
"user_id": user_id,
"token": token,
"homeserver_url": self.config.client_addr,
"fingerprints": [{
"hash_type": "SHA256",
"bytes": encode_base64(sha256_fingerprint),
}],
}))
qr.make(fit=True)
return qr.make_image()
res = yield threads.deferToThread(gen)
defer.returnValue(res)
def register_servlets(hs, http_server):
LoginQRResource(hs).register(http_server)