mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
3 Commits
anoa/updat
...
anoa/raise
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50b74bac93 | ||
|
|
d912558c4f | ||
|
|
4e333c310a |
@@ -188,6 +188,22 @@ class Auth(Protocol):
|
||||
request
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_ip_address_from_request(request: Request) -> str:
|
||||
"""
|
||||
Extract the IPv4 or IPv6 address from a client request.
|
||||
|
||||
Args:
|
||||
request: The request to process.
|
||||
|
||||
Returns:
|
||||
The IPv4 or IPv6 address of the client.
|
||||
|
||||
Raises:
|
||||
SynapseError: If an IP address could not be extracted from the
|
||||
request.
|
||||
"""
|
||||
|
||||
async def check_user_in_room_or_world_readable(
|
||||
self, room_id: str, requester: Requester, allow_departed_users: bool = False
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
|
||||
@@ -19,10 +19,12 @@
|
||||
#
|
||||
#
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from netaddr import IPAddress
|
||||
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse import event_auth
|
||||
@@ -31,6 +33,7 @@ from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
MissingClientTokenError,
|
||||
SynapseError,
|
||||
UnstableSpecAuthError,
|
||||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
@@ -291,6 +294,37 @@ class BaseAuth:
|
||||
|
||||
return query_params[0].decode("ascii")
|
||||
|
||||
@staticmethod
|
||||
def get_ip_address_from_request(request: Request) -> str:
|
||||
"""
|
||||
Extract the IPv4 or IPv6 address from a client request.
|
||||
|
||||
Args:
|
||||
request: The request to process.
|
||||
|
||||
Returns:
|
||||
The IPv4 or IPv6 address of the client.
|
||||
|
||||
Raises:
|
||||
SynapseError: If an IP address could not be extracted from the
|
||||
request.
|
||||
"""
|
||||
client_address = request.getClientAddress()
|
||||
if not isinstance(client_address, IPv4Address) and not isinstance(
|
||||
client_address, IPv6Address
|
||||
):
|
||||
logger.error(
|
||||
"Unable to view IP address of the requester. " \
|
||||
"Check that you are setting the X-Forwarded-For header correctly in your reverse proxy."
|
||||
)
|
||||
raise SynapseError(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
"Unable to read client IP address",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
return client_address.host
|
||||
|
||||
@cancellable
|
||||
async def get_appservice_user(
|
||||
self, request: Request, access_token: str
|
||||
@@ -326,7 +360,8 @@ class BaseAuth:
|
||||
return None
|
||||
|
||||
if app_service.ip_range_whitelist:
|
||||
ip_address = IPAddress(request.getClientAddress().host)
|
||||
ip_address_str = self.get_ip_address_from_request(request)
|
||||
ip_address = IPAddress(ip_address_str)
|
||||
if ip_address not in app_service.ip_range_whitelist:
|
||||
return None
|
||||
|
||||
|
||||
@@ -567,7 +567,7 @@ class AuthHandler:
|
||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||
|
||||
user_agent = get_request_user_agent(request)
|
||||
clientip = request.getClientAddress().host
|
||||
clientip = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||
session.session_id, user_agent, clientip
|
||||
|
||||
@@ -57,6 +57,7 @@ id_server_scheme = "https://"
|
||||
|
||||
class IdentityHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
# An HTTP client for contacting trusted URLs.
|
||||
self.http_client = SimpleHttpClient(hs)
|
||||
@@ -97,9 +98,8 @@ class IdentityHandler:
|
||||
address: The actual threepid ID, e.g. the phone number or email address
|
||||
"""
|
||||
|
||||
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||
None, (medium, request.getClientAddress().host)
|
||||
)
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
await self._3pid_validation_ratelimiter_ip.ratelimit(None, (medium, ip_address))
|
||||
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||
None, (medium, address)
|
||||
)
|
||||
|
||||
@@ -205,6 +205,7 @@ class SsoHandler:
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._error_template = hs.config.sso.sso_error_template
|
||||
@@ -505,12 +506,13 @@ class SsoHandler:
|
||||
auth_provider_session_id,
|
||||
)
|
||||
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
user_id = await self._register_mapped_user(
|
||||
attributes,
|
||||
auth_provider_id,
|
||||
remote_user_id,
|
||||
get_request_user_agent(request),
|
||||
request.getClientAddress().host,
|
||||
ip_address,
|
||||
)
|
||||
new_user = True
|
||||
elif self._sso_update_profile_information:
|
||||
@@ -1080,6 +1082,8 @@ class SsoHandler:
|
||||
if session.use_avatar:
|
||||
attributes.picture = session.avatar_url
|
||||
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
|
||||
# the following will raise a 400 error if the username has been taken in the
|
||||
# meantime.
|
||||
user_id = await self._register_mapped_user(
|
||||
@@ -1087,7 +1091,7 @@ class SsoHandler:
|
||||
session.auth_provider_id,
|
||||
session.remote_user_id,
|
||||
get_request_user_agent(request),
|
||||
request.getClientAddress().host,
|
||||
ip_address,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -134,6 +134,8 @@ class AuthRestServlet(RestServlet):
|
||||
if not session:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
if stagetype == LoginType.RECAPTCHA:
|
||||
response = parse_string(request, "g-recaptcha-response")
|
||||
|
||||
@@ -144,7 +146,9 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
|
||||
LoginType.RECAPTCHA,
|
||||
authdict,
|
||||
ip_address,
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
@@ -164,7 +168,9 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.TERMS, authdict, request.getClientAddress().host
|
||||
LoginType.TERMS,
|
||||
authdict,
|
||||
ip_address,
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
@@ -195,7 +201,7 @@ class AuthRestServlet(RestServlet):
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.REGISTRATION_TOKEN,
|
||||
authdict,
|
||||
request.getClientAddress().host,
|
||||
ip_address,
|
||||
)
|
||||
except LoginError as e:
|
||||
html = self.registration_token_template.render(
|
||||
|
||||
@@ -205,6 +205,7 @@ class LoginRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
request_info = request.request_info()
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
@@ -224,9 +225,7 @@ class LoginRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
if appservice.is_rate_limited():
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
|
||||
result = await self._do_appservice_login(
|
||||
login_submission,
|
||||
@@ -238,27 +237,21 @@ class LoginRestServlet(RestServlet):
|
||||
self.jwt_enabled
|
||||
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
):
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
result = await self._do_jwt_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
result = await self._do_token_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
else:
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
result = await self._do_other_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
|
||||
@@ -192,7 +192,8 @@ class ThumbnailResource(RestServlet):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
remote_resp_function = (
|
||||
self.thumbnailer.select_or_generate_remote_thumbnail
|
||||
if self.dynamic_thumbnails
|
||||
@@ -263,7 +264,8 @@ class DownloadResource(RestServlet):
|
||||
request, media_id, file_name, max_timeout_ms
|
||||
)
|
||||
else:
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.media_repo.get_remote_media(
|
||||
request,
|
||||
server_name,
|
||||
|
||||
@@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self._auth = hs.get_auth()
|
||||
self.server_name = hs.hostname
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
@@ -361,7 +362,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||
if self.inhibit_user_in_use_error:
|
||||
return 200, {"available": True}
|
||||
|
||||
ip = request.getClientAddress().host
|
||||
ip = self._auth.get_ip_address_from_request(request)
|
||||
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
||||
await wait_deferred
|
||||
|
||||
@@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self._auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
@@ -403,7 +405,8 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
await self.ratelimiter.ratelimit(None, (ip_address,))
|
||||
|
||||
if not self.hs.config.registration.enable_registration:
|
||||
raise SynapseError(
|
||||
@@ -456,7 +459,7 @@ class RegisterRestServlet(RestServlet):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientAddress().host
|
||||
client_addr = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||
|
||||
@@ -916,7 +919,7 @@ class RegisterAppServiceOnlyRestServlet(RestServlet):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientAddress().host
|
||||
client_addr = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class DownloadResource(RestServlet):
|
||||
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
self._auth = hs.get_auth()
|
||||
self.media_repo = media_repo
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
@@ -97,7 +98,7 @@ class DownloadResource(RestServlet):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
await self.media_repo.get_remote_media(
|
||||
request,
|
||||
server_name,
|
||||
|
||||
@@ -58,6 +58,7 @@ class ThumbnailResource(RestServlet):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
@@ -120,7 +121,7 @@ class ThumbnailResource(RestServlet):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
remote_resp_function = (
|
||||
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
||||
if self.dynamic_thumbnails
|
||||
|
||||
@@ -23,6 +23,7 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.auth.internal import InternalAuth
|
||||
@@ -118,7 +119,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
@@ -137,7 +138,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "192.168.10.10"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="192.168.10.10", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
@@ -156,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "131.111.8.42"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="131.111.8.42", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
f = self.get_failure(
|
||||
@@ -209,7 +210,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
@@ -231,7 +232,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
@@ -261,7 +262,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_device = AsyncMock(return_value={"hidden": False})
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||
@@ -296,7 +297,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.get_device = AsyncMock(return_value=None)
|
||||
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.args[b"user_id"] = [masquerading_user_id]
|
||||
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
|
||||
@@ -320,7 +321,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
|
||||
self.store.get_user_locked_status = AsyncMock(return_value=False)
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_success(self.auth.get_user_by_req(request))
|
||||
@@ -341,7 +342,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.store.insert_client_ip = AsyncMock(return_value=None)
|
||||
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
|
||||
request = Mock(args={})
|
||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
request.args[b"access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_success(self.auth.get_user_by_req(request))
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.handlers.cas import CasResponse
|
||||
@@ -234,6 +235,7 @@ def _mock_request() -> Mock:
|
||||
"write",
|
||||
]
|
||||
)
|
||||
mock.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
# `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
|
||||
mock._disconnected = False
|
||||
return mock
|
||||
|
||||
@@ -25,6 +25,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.handlers.sso import MappingException
|
||||
@@ -1684,5 +1685,5 @@ def _build_callback_request(
|
||||
request.args = {}
|
||||
request.args[b"code"] = [code.encode("utf-8")]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
request.getClientAddress.return_value.host = ip_address
|
||||
request.getClientAddress.return_value = IPv4Address(type="TCP", host=ip_address, port=12345)
|
||||
return request
|
||||
|
||||
@@ -24,6 +24,7 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
@@ -424,4 +425,5 @@ def _mock_request() -> Mock:
|
||||
)
|
||||
# `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
|
||||
mock._disconnected = False
|
||||
mock.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
|
||||
return mock
|
||||
|
||||
Reference in New Issue
Block a user