mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
5 Commits
quenting/l
...
erikj/rate
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e08187237e | ||
|
|
7f5452fc29 | ||
|
|
c0645486a3 | ||
|
|
eb13e9ead6 | ||
|
|
447910df19 |
1
changelog.d/18527.feature
Normal file
1
changelog.d/18527.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add ability to limit amount uploaded by a user in a given time period.
|
||||
@@ -2086,6 +2086,23 @@ Example configuration:
|
||||
max_upload_size: 60M
|
||||
```
|
||||
---
|
||||
### `media_upload_limits`
|
||||
|
||||
*(array)* A list of media upload limits defining how much data a given user can upload in a given time period.
|
||||
|
||||
An empty list means no limits are applied.
|
||||
|
||||
Defaults to `[]`.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
media_upload_limits:
|
||||
- time_period: 1h
|
||||
max_size: 100M
|
||||
- time_period: 1w
|
||||
max_size: 500M
|
||||
```
|
||||
---
|
||||
### `max_image_pixels`
|
||||
|
||||
*(byte size)* Maximum number of pixels that will be thumbnailed. Defaults to `"32M"`.
|
||||
|
||||
@@ -2335,6 +2335,30 @@ properties:
|
||||
default: 50M
|
||||
examples:
|
||||
- 60M
|
||||
media_upload_limits:
|
||||
type: array
|
||||
description: >-
|
||||
A list of media upload limits defining how much data a given user can
|
||||
upload in a given time period.
|
||||
|
||||
|
||||
An empty list means no limits are applied.
|
||||
default: []
|
||||
items:
|
||||
time_period:
|
||||
type: "#/$defs/duration"
|
||||
description: >-
|
||||
The time period over which the limit applies. Required.
|
||||
max_size:
|
||||
type: "#/$defs/bytes"
|
||||
description: >-
|
||||
Amount of data that can be uploaded in the time period by the user.
|
||||
Required.
|
||||
examples:
|
||||
- - time_period: 1h
|
||||
max_size: 100M
|
||||
- time_period: 1w
|
||||
max_size: 500M
|
||||
max_image_pixels:
|
||||
$ref: "#/$defs/bytes"
|
||||
description: Maximum number of pixels that will be thumbnailed.
|
||||
|
||||
@@ -119,6 +119,15 @@ def parse_thumbnail_requirements(
|
||||
}
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
class MediaUploadLimit:
|
||||
"""A limit on the amount of data a user can upload in a given time
|
||||
period."""
|
||||
|
||||
max_bytes: int
|
||||
time_period_ms: int
|
||||
|
||||
|
||||
class ContentRepositoryConfig(Config):
|
||||
section = "media"
|
||||
|
||||
@@ -274,6 +283,13 @@ class ContentRepositoryConfig(Config):
|
||||
|
||||
self.enable_authenticated_media = config.get("enable_authenticated_media", True)
|
||||
|
||||
self.media_upload_limits: List[MediaUploadLimit] = []
|
||||
for limit_config in config.get("media_upload_limits", []):
|
||||
time_period_ms = self.parse_duration(limit_config["time_period"])
|
||||
max_bytes = self.parse_size(limit_config["max_size"])
|
||||
|
||||
self.media_upload_limits.append(MediaUploadLimit(max_bytes, time_period_ms))
|
||||
|
||||
def generate_config_section(self, data_dir_path: str, **kwargs: Any) -> str:
|
||||
assert data_dir_path is not None
|
||||
media_store = os.path.join(data_dir_path, "media_store")
|
||||
|
||||
@@ -824,7 +824,7 @@ class SsoHandler:
|
||||
return True
|
||||
|
||||
# store it in media repository
|
||||
avatar_mxc_url = await self._media_repo.create_content(
|
||||
avatar_mxc_url = await self._media_repo.create_or_update_content(
|
||||
media_type=headers[b"Content-Type"][0].decode("utf-8"),
|
||||
upload_name=upload_name,
|
||||
content=picture,
|
||||
|
||||
@@ -177,6 +177,13 @@ class MediaRepository:
|
||||
else:
|
||||
self.url_previewer = None
|
||||
|
||||
# We get the media upload limits and sort them in descending order of
|
||||
# time period, so that we can apply some optimizations.
|
||||
self.media_upload_limits = hs.config.media.media_upload_limits
|
||||
self.media_upload_limits.sort(
|
||||
key=lambda limit: limit.time_period_ms, reverse=True
|
||||
)
|
||||
|
||||
def _start_update_recently_accessed(self) -> Deferred:
|
||||
return run_as_background_process(
|
||||
"update_recently_accessed_media", self._update_recently_accessed
|
||||
@@ -285,63 +292,16 @@ class MediaRepository:
|
||||
raise NotFoundError("Media ID has expired")
|
||||
|
||||
@trace
|
||||
async def update_content(
|
||||
self,
|
||||
media_id: str,
|
||||
media_type: str,
|
||||
upload_name: Optional[str],
|
||||
content: IO,
|
||||
content_length: int,
|
||||
auth_user: UserID,
|
||||
) -> None:
|
||||
"""Update the content of the given media ID.
|
||||
|
||||
Args:
|
||||
media_id: The media ID to replace.
|
||||
media_type: The content type of the file.
|
||||
upload_name: The name of the file, if provided.
|
||||
content: A file like object that is the content to store
|
||||
content_length: The length of the content
|
||||
auth_user: The user_id of the uploader
|
||||
"""
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
sha256reader = SHA256TransparentIOReader(content)
|
||||
# This implements all of IO as it has a passthrough
|
||||
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
|
||||
sha256 = sha256reader.hexdigest()
|
||||
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
if should_quarantine:
|
||||
logger.warning(
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
await self.store.update_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
except Exception as e:
|
||||
logger.info("Failed to generate thumbnails: %s", e)
|
||||
|
||||
@trace
|
||||
async def create_content(
|
||||
async def create_or_update_content(
|
||||
self,
|
||||
media_type: str,
|
||||
upload_name: Optional[str],
|
||||
content: IO,
|
||||
content_length: int,
|
||||
auth_user: UserID,
|
||||
media_id: Optional[str] = None,
|
||||
) -> MXCUri:
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
"""Create or update the content of the given media ID.
|
||||
|
||||
Args:
|
||||
media_type: The content type of the file.
|
||||
@@ -349,16 +309,20 @@ class MediaRepository:
|
||||
content: A file like object that is the content to store
|
||||
content_length: The length of the content
|
||||
auth_user: The user_id of the uploader
|
||||
media_id: The media ID to update if provided, otherwise creates
|
||||
new media ID.
|
||||
|
||||
Returns:
|
||||
The mxc url of the stored content
|
||||
"""
|
||||
|
||||
media_id = random_string(24)
|
||||
is_new_media = media_id is None
|
||||
if media_id is None:
|
||||
media_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
# This implements all of IO as it has a passthrough
|
||||
sha256reader = SHA256TransparentIOReader(content)
|
||||
# This implements all of IO as it has a passthrough
|
||||
fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
|
||||
sha256 = sha256reader.hexdigest()
|
||||
should_quarantine = await self.store.get_is_hash_quarantined(sha256)
|
||||
@@ -370,16 +334,56 @@ class MediaRepository:
|
||||
"Media has been automatically quarantined as it matched existing quarantined media"
|
||||
)
|
||||
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
# Check that the user has not exceeded any of the media upload limits.
|
||||
|
||||
# This is the total size of media uploaded by the user in the last
|
||||
# `time_period_ms` milliseconds, or None if we haven't checked yet.
|
||||
uploaded_media_size: Optional[int] = None
|
||||
|
||||
# Note: the media upload limits are sorted so larger time periods are
|
||||
# first.
|
||||
for limit in self.media_upload_limits:
|
||||
# We only need to check the amount of media uploaded by the user in
|
||||
# this latest (smaller) time period if the amount of media uploaded
|
||||
# in a previous (larger) time period is above the limit.
|
||||
#
|
||||
# This optimization means that in the common case where the user
|
||||
# hasn't uploaded much media, we only need to query the database
|
||||
# once.
|
||||
if (
|
||||
uploaded_media_size is None
|
||||
or uploaded_media_size + content_length > limit.max_bytes
|
||||
):
|
||||
uploaded_media_size = await self.store.get_media_uploaded_size_for_user(
|
||||
user_id=auth_user.to_string(), time_period_ms=limit.time_period_ms
|
||||
)
|
||||
|
||||
if uploaded_media_size + content_length > limit.max_bytes:
|
||||
raise SynapseError(
|
||||
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
|
||||
)
|
||||
|
||||
if is_new_media:
|
||||
await self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
else:
|
||||
await self.store.update_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
upload_name=upload_name,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
sha256=sha256,
|
||||
quarantined_by="system" if should_quarantine else None,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
|
||||
@@ -120,7 +120,7 @@ class UploadServlet(BaseUploadServlet):
|
||||
|
||||
try:
|
||||
content: IO = request.content # type: ignore
|
||||
content_uri = await self.media_repo.create_content(
|
||||
content_uri = await self.media_repo.create_or_update_content(
|
||||
media_type, upload_name, content, content_length, requester.user
|
||||
)
|
||||
except SpamMediaException:
|
||||
@@ -170,13 +170,13 @@ class AsyncUploadServlet(BaseUploadServlet):
|
||||
|
||||
try:
|
||||
content: IO = request.content # type: ignore
|
||||
await self.media_repo.update_content(
|
||||
media_id,
|
||||
await self.media_repo.create_or_update_content(
|
||||
media_type,
|
||||
upload_name,
|
||||
content,
|
||||
content_length,
|
||||
requester.user,
|
||||
media_id=media_id,
|
||||
)
|
||||
except SpamMediaException:
|
||||
# For uploading of media we want to respond with a 400, instead of
|
||||
|
||||
@@ -1034,3 +1034,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
"local_media_repository",
|
||||
sha256,
|
||||
)
|
||||
|
||||
async def get_media_uploaded_size_for_user(
|
||||
self, user_id: str, time_period_ms: int
|
||||
) -> int:
|
||||
"""Get the total size of media uploaded by a user in the last
|
||||
time_period_ms milliseconds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check.
|
||||
time_period_ms: The time period in milliseconds to consider.
|
||||
|
||||
Returns:
|
||||
The total size of media uploaded by the user in bytes.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT COALESCE(SUM(media_length), 0)
|
||||
FROM local_media_repository
|
||||
WHERE user_id = ? AND created_ts > ?
|
||||
"""
|
||||
|
||||
def _get_media_uploaded_size_for_user_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> int:
|
||||
# Calculate the timestamp for the start of the time period
|
||||
start_ts = self._clock.time_msec() - time_period_ms
|
||||
txn.execute(sql, (user_id, start_ts))
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
return 0
|
||||
return row[0]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_media_uploaded_size_for_user",
|
||||
_get_media_uploaded_size_for_user_txn,
|
||||
)
|
||||
|
||||
@@ -67,7 +67,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
def test_file_download(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
@@ -110,7 +110,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
content = io.BytesIO(SMALL_PNG)
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"image/png",
|
||||
"test_png_upload",
|
||||
content,
|
||||
@@ -152,7 +152,7 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
@@ -215,7 +215,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
|
||||
def test_thumbnail_download_scaled(self) -> None:
|
||||
content = io.BytesIO(small_png.data)
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"image/png",
|
||||
"test_png_thumbnail",
|
||||
content,
|
||||
@@ -255,7 +255,7 @@ class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
|
||||
def test_thumbnail_download_cropped(self) -> None:
|
||||
content = io.BytesIO(small_png.data)
|
||||
content_uri = self.get_success(
|
||||
self.media_repo.create_content(
|
||||
self.media_repo.create_or_update_content(
|
||||
"image/png",
|
||||
"test_png_thumbnail",
|
||||
content,
|
||||
|
||||
@@ -78,7 +78,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
|
||||
# If the meda
|
||||
random_content = bytes(random_string(24), "utf-8")
|
||||
mxc_uri: MXCUri = self.get_success(
|
||||
media_repository.create_content(
|
||||
media_repository.create_or_update_content(
|
||||
media_type="text/plain",
|
||||
upload_name=None,
|
||||
content=io.BytesIO(random_content),
|
||||
|
||||
@@ -1952,7 +1952,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
||||
def test_file_download(self) -> None:
|
||||
content = io.BytesIO(b"file_to_stream")
|
||||
content_uri = self.get_success(
|
||||
self.repo.create_content(
|
||||
self.repo.create_or_update_content(
|
||||
"text/plain",
|
||||
"test_upload",
|
||||
content,
|
||||
@@ -2846,3 +2846,124 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
|
||||
custom_headers=[("If-None-Match", etag)],
|
||||
)
|
||||
self.assertEqual(channel3.code, 404)
|
||||
|
||||
|
||||
class MediaUploadLimits(unittest.HomeserverTestCase):
|
||||
"""
|
||||
This test case simulates a homeserver with media upload limits configured.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
media.register_servlets,
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
config["media_store_path"] = self.media_store_path
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
|
||||
# These are the limits that we are testing
|
||||
config["media_upload_limits"] = [
|
||||
{"time_period": "1d", "max_size": "1K"},
|
||||
{"time_period": "1w", "max_size": "3K"},
|
||||
]
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.repo = hs.get_media_repository()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.store = hs.get_datastores().main
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
resources = super().create_resource_dict()
|
||||
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
|
||||
return resources
|
||||
|
||||
def upload_media(self, size: int) -> FakeChannel:
|
||||
"""Helper to upload media of a given size."""
|
||||
return self.make_request(
|
||||
"POST",
|
||||
"/_matrix/media/v3/upload",
|
||||
content=b"0" * size,
|
||||
access_token=self.tok,
|
||||
shorthand=False,
|
||||
content_type=b"text/plain",
|
||||
custom_headers=[("Content-Length", str(size))],
|
||||
)
|
||||
|
||||
def test_upload_under_limit(self) -> None:
|
||||
"""Test that uploading media under the limit works."""
|
||||
channel = self.upload_media(67)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
def test_over_day_limit(self) -> None:
|
||||
"""Test that uploading media over the daily limit fails."""
|
||||
channel = self.upload_media(500)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
channel = self.upload_media(800)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
def test_under_daily_limit(self) -> None:
|
||||
"""Test that uploading media under the daily limit fails."""
|
||||
channel = self.upload_media(500)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(60 * 60 * 24) # Advance by one day
|
||||
|
||||
# This will succeed as the daily limit has reset
|
||||
channel = self.upload_media(800)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(60 * 60 * 24) # Advance by one day
|
||||
|
||||
# ... and again
|
||||
channel = self.upload_media(800)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
def test_over_weekly_limit(self) -> None:
|
||||
"""Test that uploading media over the weekly limit fails."""
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(60 * 60 * 24) # Advance by one day
|
||||
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
|
||||
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.reactor.advance(2 * 60 * 60 * 24) # Advance by one day
|
||||
|
||||
# This will fail as the weekly limit has been exceeded
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
# Reset the weekly limit by advancing a week
|
||||
self.reactor.advance(7 * 60 * 60 * 24) # Advance by 7 days
|
||||
|
||||
# This will succeed as the weekly limit has reset
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
Reference in New Issue
Block a user