Compare commits

...

5 Commits

Author SHA1 Message Date
Erik Johnston
e08187237e Merge remote-tracking branch 'origin/develop' into erikj/ratelimit_media_upload 2025-07-10 12:03:19 +01:00
Erik Johnston
7f5452fc29 Add documentation 2025-06-09 11:12:56 +01:00
Erik Johnston
c0645486a3 Newsfile 2025-06-09 10:52:20 +01:00
Erik Johnston
eb13e9ead6 Add ability to limit amount uploaded by a user
You can now configure how much media can be uploaded by a user in a
given time period.
2025-06-09 10:52:19 +01:00
Erik Johnston
447910df19 Merge create_content and update_content
This deduplicates a bunch of logic.
2025-06-09 10:52:19 +01:00
11 changed files with 292 additions and 73 deletions

View File

@@ -0,0 +1 @@
Add ability to limit amount uploaded by a user in a given time period.

View File

@@ -2086,6 +2086,23 @@ Example configuration:
max_upload_size: 60M max_upload_size: 60M
``` ```
--- ---
### `media_upload_limits`
*(array)* A list of media upload limits defining how much data a given user can upload in a given time period.
An empty list means no limits are applied.
Defaults to `[]`.
Example configuration:
```yaml
media_upload_limits:
- time_period: 1h
max_size: 100M
- time_period: 1w
max_size: 500M
```
---
### `max_image_pixels` ### `max_image_pixels`
*(byte size)* Maximum number of pixels that will be thumbnailed. Defaults to `"32M"`. *(byte size)* Maximum number of pixels that will be thumbnailed. Defaults to `"32M"`.

View File

@@ -2335,6 +2335,30 @@ properties:
default: 50M default: 50M
examples: examples:
- 60M - 60M
media_upload_limits:
type: array
description: >-
A list of media upload limits defining how much data a given user can
upload in a given time period.
An empty list means no limits are applied.
default: []
items:
time_period:
type: "#/$defs/duration"
description: >-
The time period over which the limit applies. Required.
max_size:
type: "#/$defs/bytes"
description: >-
Amount of data that can be uploaded in the time period by the user.
Required.
examples:
- - time_period: 1h
max_size: 100M
- time_period: 1w
max_size: 500M
max_image_pixels: max_image_pixels:
$ref: "#/$defs/bytes" $ref: "#/$defs/bytes"
description: Maximum number of pixels that will be thumbnailed. description: Maximum number of pixels that will be thumbnailed.

View File

@@ -119,6 +119,15 @@ def parse_thumbnail_requirements(
} }
@attr.s(auto_attribs=True, slots=True, frozen=True)
class MediaUploadLimit:
"""A limit on the amount of data a user can upload in a given time
period."""
max_bytes: int
time_period_ms: int
class ContentRepositoryConfig(Config): class ContentRepositoryConfig(Config):
section = "media" section = "media"
@@ -274,6 +283,13 @@ class ContentRepositoryConfig(Config):
self.enable_authenticated_media = config.get("enable_authenticated_media", True) self.enable_authenticated_media = config.get("enable_authenticated_media", True)
self.media_upload_limits: List[MediaUploadLimit] = []
for limit_config in config.get("media_upload_limits", []):
time_period_ms = self.parse_duration(limit_config["time_period"])
max_bytes = self.parse_size(limit_config["max_size"])
self.media_upload_limits.append(MediaUploadLimit(max_bytes, time_period_ms))
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str: def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
assert data_dir_path is not None assert data_dir_path is not None
media_store = os.path.join(data_dir_path, "media_store") media_store = os.path.join(data_dir_path, "media_store")

View File

@@ -824,7 +824,7 @@ class SsoHandler:
return True return True
# store it in media repository # store it in media repository
avatar_mxc_url = await self._media_repo.create_content( avatar_mxc_url = await self._media_repo.create_or_update_content(
media_type=headers[b"Content-Type"][0].decode("utf-8"), media_type=headers[b"Content-Type"][0].decode("utf-8"),
upload_name=upload_name, upload_name=upload_name,
content=picture, content=picture,

View File

@@ -177,6 +177,13 @@ class MediaRepository:
else: else:
self.url_previewer = None self.url_previewer = None
# We get the media upload limits and sort them in descending order of
# time period, so that we can apply some optimizations.
self.media_upload_limits = hs.config.media.media_upload_limits
self.media_upload_limits.sort(
key=lambda limit: limit.time_period_ms, reverse=True
)
def _start_update_recently_accessed(self) -> Deferred: def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process( return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed "update_recently_accessed_media", self._update_recently_accessed
@@ -285,63 +292,16 @@ class MediaRepository:
raise NotFoundError("Media ID has expired") raise NotFoundError("Media ID has expired")
@trace @trace
async def update_content( async def create_or_update_content(
self,
media_id: str,
media_type: str,
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: UserID,
) -> None:
"""Update the content of the given media ID.
Args:
media_id: The media ID to replace.
media_type: The content type of the file.
upload_name: The name of the file, if provided.
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
sha256reader = SHA256TransparentIOReader(content)
# This implements all of IO as it has a passthrough
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)
if should_quarantine:
logger.warning(
"Media has been automatically quarantined as it matched existing quarantined media"
)
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
try:
await self._generate_thumbnails(None, media_id, media_id, media_type)
except Exception as e:
logger.info("Failed to generate thumbnails: %s", e)
@trace
async def create_content(
self, self,
media_type: str, media_type: str,
upload_name: Optional[str], upload_name: Optional[str],
content: IO, content: IO,
content_length: int, content_length: int,
auth_user: UserID, auth_user: UserID,
media_id: Optional[str] = None,
) -> MXCUri: ) -> MXCUri:
"""Store uploaded content for a local user and return the mxc URL """Create or update the content of the given media ID.
Args: Args:
media_type: The content type of the file. media_type: The content type of the file.
@@ -349,16 +309,20 @@ class MediaRepository:
content: A file like object that is the content to store content: A file like object that is the content to store
content_length: The length of the content content_length: The length of the content
auth_user: The user_id of the uploader auth_user: The user_id of the uploader
media_id: The media ID to update if provided, otherwise creates
new media ID.
Returns: Returns:
The mxc url of the stored content The mxc url of the stored content
""" """
media_id = random_string(24) is_new_media = media_id is None
if media_id is None:
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id) file_info = FileInfo(server_name=None, file_id=media_id)
# This implements all of IO as it has a passthrough
sha256reader = SHA256TransparentIOReader(content) sha256reader = SHA256TransparentIOReader(content)
# This implements all of IO as it has a passthrough
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
sha256 = sha256reader.hexdigest() sha256 = sha256reader.hexdigest()
should_quarantine = await self.store.get_is_hash_quarantined(sha256) should_quarantine = await self.store.get_is_hash_quarantined(sha256)
@@ -370,16 +334,56 @@ class MediaRepository:
"Media has been automatically quarantined as it matched existing quarantined media" "Media has been automatically quarantined as it matched existing quarantined media"
) )
await self.store.store_local_media( # Check that the user has not exceeded any of the media upload limits.
media_id=media_id,
media_type=media_type, # This is the total size of media uploaded by the user in the last
time_now_ms=self.clock.time_msec(), # `time_period_ms` milliseconds, or None if we haven't checked yet.
upload_name=upload_name, uploaded_media_size: Optional[int] = None
media_length=content_length,
user_id=auth_user, # Note: the media upload limits are sorted so larger time periods are
sha256=sha256, # first.
quarantined_by="system" if should_quarantine else None, for limit in self.media_upload_limits:
) # We only need to check the amount of media uploaded by the user in
# this latest (smaller) time period if the amount of media uploaded
# in a previous (larger) time period is above the limit.
#
# This optimization means that in the common case where the user
# hasn't uploaded much media, we only need to query the database
# once.
if (
uploaded_media_size is None
or uploaded_media_size + content_length > limit.max_bytes
):
uploaded_media_size = await self.store.get_media_uploaded_size_for_user(
user_id=auth_user.to_string(), time_period_ms=limit.time_period_ms
)
if uploaded_media_size + content_length > limit.max_bytes:
raise SynapseError(
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
)
if is_new_media:
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
else:
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
sha256=sha256,
quarantined_by="system" if should_quarantine else None,
)
try: try:
await self._generate_thumbnails(None, media_id, media_id, media_type) await self._generate_thumbnails(None, media_id, media_id, media_type)

View File

@@ -120,7 +120,7 @@ class UploadServlet(BaseUploadServlet):
try: try:
content: IO = request.content # type: ignore content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content( content_uri = await self.media_repo.create_or_update_content(
media_type, upload_name, content, content_length, requester.user media_type, upload_name, content, content_length, requester.user
) )
except SpamMediaException: except SpamMediaException:
@@ -170,13 +170,13 @@ class AsyncUploadServlet(BaseUploadServlet):
try: try:
content: IO = request.content # type: ignore content: IO = request.content # type: ignore
await self.media_repo.update_content( await self.media_repo.create_or_update_content(
media_id,
media_type, media_type,
upload_name, upload_name,
content, content,
content_length, content_length,
requester.user, requester.user,
media_id=media_id,
) )
except SpamMediaException: except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of # For uploading of media we want to respond with a 400, instead of

View File

@@ -1034,3 +1034,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"local_media_repository", "local_media_repository",
sha256, sha256,
) )
async def get_media_uploaded_size_for_user(
self, user_id: str, time_period_ms: int
) -> int:
"""Get the total size of media uploaded by a user in the last
time_period_ms milliseconds.
Args:
user_id: The user ID to check.
time_period_ms: The time period in milliseconds to consider.
Returns:
The total size of media uploaded by the user in bytes.
"""
sql = """
SELECT COALESCE(SUM(media_length), 0)
FROM local_media_repository
WHERE user_id = ? AND created_ts > ?
"""
def _get_media_uploaded_size_for_user_txn(
txn: LoggingTransaction,
) -> int:
# Calculate the timestamp for the start of the time period
start_ts = self._clock.time_msec() - time_period_ms
txn.execute(sql, (user_id, start_ts))
row = txn.fetchone()
if row is None:
return 0
return row[0]
return await self.db_pool.runInteraction(
"get_media_uploaded_size_for_user",
_get_media_uploaded_size_for_user_txn,
)

View File

@@ -67,7 +67,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
def test_file_download(self) -> None: def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream") content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success( content_uri = self.get_success(
self.media_repo.create_content( self.media_repo.create_or_update_content(
"text/plain", "text/plain",
"test_upload", "test_upload",
content, content,
@@ -110,7 +110,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
content = io.BytesIO(SMALL_PNG) content = io.BytesIO(SMALL_PNG)
content_uri = self.get_success( content_uri = self.get_success(
self.media_repo.create_content( self.media_repo.create_or_update_content(
"image/png", "image/png",
"test_png_upload", "test_png_upload",
content, content,
@@ -152,7 +152,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
content = io.BytesIO(b"file_to_stream") content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success( content_uri = self.get_success(
self.media_repo.create_content( self.media_repo.create_or_update_content(
"text/plain", "text/plain",
"test_upload", "test_upload",
content, content,
@@ -215,7 +215,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
def test_thumbnail_download_scaled(self) -> None: def test_thumbnail_download_scaled(self) -> None:
content = io.BytesIO(small_png.data) content = io.BytesIO(small_png.data)
content_uri = self.get_success( content_uri = self.get_success(
self.media_repo.create_content( self.media_repo.create_or_update_content(
"image/png", "image/png",
"test_png_thumbnail", "test_png_thumbnail",
content, content,
@@ -255,7 +255,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
def test_thumbnail_download_cropped(self) -> None: def test_thumbnail_download_cropped(self) -> None:
content = io.BytesIO(small_png.data) content = io.BytesIO(small_png.data)
content_uri = self.get_success( content_uri = self.get_success(
self.media_repo.create_content( self.media_repo.create_or_update_content(
"image/png", "image/png",
"test_png_thumbnail", "test_png_thumbnail",
content, content,

View File

@@ -78,7 +78,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# If the meda # If the meda
random_content = bytes(random_string(24), "utf-8") random_content = bytes(random_string(24), "utf-8")
mxc_uri: MXCUri = self.get_success( mxc_uri: MXCUri = self.get_success(
media_repository.create_content( media_repository.create_or_update_content(
media_type="text/plain", media_type="text/plain",
upload_name=None, upload_name=None,
content=io.BytesIO(random_content), content=io.BytesIO(random_content),

View File

@@ -1952,7 +1952,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
def test_file_download(self) -> None: def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream") content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success( content_uri = self.get_success(
self.repo.create_content( self.repo.create_or_update_content(
"text/plain", "text/plain",
"test_upload", "test_upload",
content, content,
@@ -2846,3 +2846,124 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
custom_headers=[("If-None-Match", etag)], custom_headers=[("If-None-Match", etag)],
) )
self.assertEqual(channel3.code, 404) self.assertEqual(channel3.code, 404)
class MediaUploadLimits(unittest.HomeserverTestCase):
"""
This test case simulates a homeserver with media upload limits configured.
"""
servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
# These are the limits that we are testing
config["media_upload_limits"] = [
{"time_period": "1d", "max_size": "1K"},
{"time_period": "1w", "max_size": "3K"},
]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def upload_media(self, size: int) -> FakeChannel:
"""Helper to upload media of a given size."""
return self.make_request(
"POST",
"/_matrix/media/v3/upload",
content=b"0" * size,
access_token=self.tok,
shorthand=False,
content_type=b"text/plain",
custom_headers=[("Content-Length", str(size))],
)
def test_upload_under_limit(self) -> None:
"""Test that uploading media under the limit works."""
channel = self.upload_media(67)
self.assertEqual(channel.code, 200)
def test_over_day_limit(self) -> None:
"""Test that uploading media over the daily limit fails."""
channel = self.upload_media(500)
self.assertEqual(channel.code, 200)
channel = self.upload_media(800)
self.assertEqual(channel.code, 400)
def test_under_daily_limit(self) -> None:
"""Test that uploading media under the daily limit fails."""
channel = self.upload_media(500)
self.assertEqual(channel.code, 200)
self.reactor.advance(60 * 60 * 24) # Advance by one day
# This will succeed as the daily limit has reset
channel = self.upload_media(800)
self.assertEqual(channel.code, 200)
self.reactor.advance(60 * 60 * 24) # Advance by one day
# ... and again
channel = self.upload_media(800)
self.assertEqual(channel.code, 200)
def test_over_weekly_limit(self) -> None:
"""Test that uploading media over the weekly limit fails."""
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
self.reactor.advance(60 * 60 * 24) # Advance by one day
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
# This will fail as the weekly limit has been exceeded
channel = self.upload_media(900)
self.assertEqual(channel.code, 400)
# Reset the weekly limit by advancing a week
self.reactor.advance(7 * 60 * 60 * 24) # Advance by 7 days
# This will succeed as the weekly limit has reset
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)