Compare commits

...

3 Commits

Author SHA1 Message Date
Andrew Morgan
50b74bac93 Update mocks in tests to work with new utility function 2025-10-01 17:28:05 +01:00
Andrew Morgan
d912558c4f Replace instances of request.getClientAddress with new method 2025-10-01 16:29:14 +01:00
Andrew Morgan
4e333c310a Add a new get_ip_address_from_request method
This method raises a `SynapseException` if Synapse is unable to extract the IP address of a client from an incoming request. This
typically indicates that there is an invalid configuration in one's reverse proxy.

Raise an exception rather than returning a dummy IP address, as it's typically better to fail loudly in this case.
2025-10-01 16:28:34 +01:00
15 changed files with 107 additions and 40 deletions

View File

@@ -188,6 +188,22 @@ class Auth(Protocol):
request
"""
@staticmethod
def get_ip_address_from_request(request: Request) -> str:
"""
Extract the IPv4 or IPv6 address from a client request.
Args:
request: The request to process.
Returns:
The IPv4 or IPv6 address of the client.
Raises:
SynapseError: If an IP address could not be extracted from the
request.
"""
async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]:

View File

@@ -19,10 +19,12 @@
#
#
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from netaddr import IPAddress
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.web.server import Request
from synapse import event_auth
@@ -31,6 +33,7 @@ from synapse.api.errors import (
AuthError,
Codes,
MissingClientTokenError,
SynapseError,
UnstableSpecAuthError,
)
from synapse.appservice import ApplicationService
@@ -291,6 +294,37 @@ class BaseAuth:
return query_params[0].decode("ascii")
@staticmethod
def get_ip_address_from_request(request: Request) -> str:
"""
Extract the IPv4 or IPv6 address from a client request.
Args:
request: The request to process.
Returns:
The IPv4 or IPv6 address of the client.
Raises:
SynapseError: If an IP address could not be extracted from the
request.
"""
client_address = request.getClientAddress()
if not isinstance(client_address, IPv4Address) and not isinstance(
client_address, IPv6Address
):
logger.error(
"Unable to view IP address of the requester. " \
"Check that you are setting the X-Forwarded-For header correctly in your reverse proxy."
)
raise SynapseError(
HTTPStatus.INTERNAL_SERVER_ERROR,
"Unable to read client IP address",
Codes.UNKNOWN,
)
return client_address.host
@cancellable
async def get_appservice_user(
self, request: Request, access_token: str
@@ -326,7 +360,8 @@ class BaseAuth:
return None
if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientAddress().host)
ip_address_str = self.get_ip_address_from_request(request)
ip_address = IPAddress(ip_address_str)
if ip_address not in app_service.ip_range_whitelist:
return None

View File

@@ -567,7 +567,7 @@ class AuthHandler:
await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = get_request_user_agent(request)
clientip = request.getClientAddress().host
clientip = self.auth.get_ip_address_from_request(request)
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip

View File

@@ -57,6 +57,7 @@ id_server_scheme = "https://"
class IdentityHandler:
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self.store = hs.get_datastores().main
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
@@ -97,9 +98,8 @@ class IdentityHandler:
address: The actual threepid ID, e.g. the phone number or email address
"""
await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientAddress().host)
)
ip_address = self._auth.get_ip_address_from_request(request)
await self._3pid_validation_ratelimiter_ip.ratelimit(None, (medium, ip_address))
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
)

View File

@@ -205,6 +205,7 @@ class SsoHandler:
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._registration_handler = hs.get_registration_handler()
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._error_template = hs.config.sso.sso_error_template
@@ -505,12 +506,13 @@ class SsoHandler:
auth_provider_session_id,
)
ip_address = self._auth.get_ip_address_from_request(request)
user_id = await self._register_mapped_user(
attributes,
auth_provider_id,
remote_user_id,
get_request_user_agent(request),
request.getClientAddress().host,
ip_address,
)
new_user = True
elif self._sso_update_profile_information:
@@ -1080,6 +1082,8 @@ class SsoHandler:
if session.use_avatar:
attributes.picture = session.avatar_url
ip_address = self._auth.get_ip_address_from_request(request)
# the following will raise a 400 error if the username has been taken in the
# meantime.
user_id = await self._register_mapped_user(
@@ -1087,7 +1091,7 @@ class SsoHandler:
session.auth_provider_id,
session.remote_user_id,
get_request_user_agent(request),
request.getClientAddress().host,
ip_address,
)
logger.info(

View File

@@ -134,6 +134,8 @@ class AuthRestServlet(RestServlet):
if not session:
raise SynapseError(400, "No session supplied")
ip_address = self.auth.get_ip_address_from_request(request)
if stagetype == LoginType.RECAPTCHA:
response = parse_string(request, "g-recaptcha-response")
@@ -144,7 +146,9 @@ class AuthRestServlet(RestServlet):
try:
await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
LoginType.RECAPTCHA,
authdict,
ip_address,
)
except LoginError as e:
# Authentication failed, let user try again
@@ -164,7 +168,9 @@ class AuthRestServlet(RestServlet):
try:
await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientAddress().host
LoginType.TERMS,
authdict,
ip_address,
)
except LoginError as e:
# Authentication failed, let user try again
@@ -195,7 +201,7 @@ class AuthRestServlet(RestServlet):
await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN,
authdict,
request.getClientAddress().host,
ip_address,
)
except LoginError as e:
html = self.registration_token_template.render(

View File

@@ -205,6 +205,7 @@ class LoginRestServlet(RestServlet):
)
request_info = request.request_info()
ip_address = self.auth.get_ip_address_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
@@ -224,9 +225,7 @@ class LoginRestServlet(RestServlet):
)
if appservice.is_rate_limited():
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_appservice_login(
login_submission,
@@ -238,27 +237,21 @@ class LoginRestServlet(RestServlet):
self.jwt_enabled
and login_submission["type"] == LoginRestServlet.JWT_TYPE
):
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,

View File

@@ -192,7 +192,8 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
ip_address = self.auth.get_ip_address_from_request(request)
remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
@@ -263,7 +264,8 @@ class DownloadResource(RestServlet):
request, media_id, file_name, max_timeout_ms
)
else:
ip_address = request.getClientAddress().host
ip_address = self.auth.get_ip_address_from_request(request)
await self.media_repo.get_remote_media(
request,
server_name,

View File

@@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self._auth = hs.get_auth()
self.server_name = hs.hostname
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
@@ -361,7 +362,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
if self.inhibit_user_in_use_error:
return 200, {"available": True}
ip = request.getClientAddress().host
ip = self._auth.get_ip_address_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self._auth = hs.get_auth()
self.store = hs.get_datastores().main
self.ratelimiter = Ratelimiter(
store=self.store,
@@ -403,7 +405,8 @@ class RegistrationTokenValidityRestServlet(RestServlet):
)
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
ip_address = self._auth.get_ip_address_from_request(request)
await self.ratelimiter.ratelimit(None, (ip_address,))
if not self.hs.config.registration.enable_registration:
raise SynapseError(
@@ -456,7 +459,7 @@ class RegisterRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientAddress().host
client_addr = self.auth.get_ip_address_from_request(request)
await self.ratelimiter.ratelimit(None, client_addr, update=False)
@@ -916,7 +919,7 @@ class RegisterAppServiceOnlyRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientAddress().host
client_addr = self.auth.get_ip_address_from_request(request)
await self.ratelimiter.ratelimit(None, client_addr, update=False)

View File

@@ -49,6 +49,7 @@ class DownloadResource(RestServlet):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self._auth = hs.get_auth()
self.media_repo = media_repo
self._is_mine_server_name = hs.is_mine_server_name
@@ -97,7 +98,7 @@ class DownloadResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
ip_address = self._auth.get_ip_address_from_request(request)
await self.media_repo.get_remote_media(
request,
server_name,

View File

@@ -58,6 +58,7 @@ class ThumbnailResource(RestServlet):
):
super().__init__()
self._auth = hs.get_auth()
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
@@ -120,7 +121,7 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
ip_address = self._auth.get_ip_address_from_request(request)
remote_resp_function = (
self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails

View File

@@ -23,6 +23,7 @@ from unittest.mock import AsyncMock, Mock
import pymacaroons
from twisted.internet.address import IPv4Address
from twisted.internet.testing import MemoryReactor
from synapse.api.auth.internal import InternalAuth
@@ -118,7 +119,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
@@ -137,7 +138,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "192.168.10.10"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="192.168.10.10", port=12345)
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
@@ -156,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "131.111.8.42"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="131.111.8.42", port=12345)
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
f = self.get_failure(
@@ -209,7 +210,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -231,7 +232,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -261,7 +262,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_device = AsyncMock(return_value={"hidden": False})
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@@ -296,7 +297,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_device = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
@@ -320,7 +321,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = AsyncMock(return_value=False)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))
@@ -341,7 +342,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1"
request.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))

View File

@@ -21,6 +21,7 @@
from typing import Any, Dict
from unittest.mock import AsyncMock, Mock
from twisted.internet.address import IPv4Address
from twisted.internet.testing import MemoryReactor
from synapse.handlers.cas import CasResponse
@@ -234,6 +235,7 @@ def _mock_request() -> Mock:
"write",
]
)
mock.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
# `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
mock._disconnected = False
return mock

View File

@@ -25,6 +25,7 @@ from urllib.parse import parse_qs, urlparse
import pymacaroons
from twisted.internet.address import IPv4Address
from twisted.internet.testing import MemoryReactor
from synapse.handlers.sso import MappingException
@@ -1684,5 +1685,5 @@ def _build_callback_request(
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientAddress.return_value.host = ip_address
request.getClientAddress.return_value = IPv4Address(type="TCP", host=ip_address, port=12345)
return request

View File

@@ -24,6 +24,7 @@ from unittest.mock import AsyncMock, Mock
import attr
from twisted.internet.address import IPv4Address
from twisted.internet.testing import MemoryReactor
from synapse.api.errors import RedirectException
@@ -424,4 +425,5 @@ def _mock_request() -> Mock:
)
# `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
mock._disconnected = False
mock.getClientAddress.return_value = IPv4Address(type="TCP", host="127.0.0.1", port=12345)
return mock