Compare commits

...

1 Commits

Author SHA1 Message Date
Erik Johnston
13da1dca0a Join policy server PoC
Adds the ability to enforce users to go through a web flow before being
able to join a room (unless invited).

Configured by specifying a "policy server" in the join rule event
content:

```
{
  "join_rule": "invite",
  "re.jki.join_policy_server": "localhost:8865"
}
```

The server will then return a 403 when a client tries to join, including
a URL that the client can redirect the user to, which eventually returns
a token (very much like an OAuth2 flow). This token then can be included
when calling `/join` again and the join will be successful.
2025-03-22 18:49:50 +00:00
6 changed files with 257 additions and 1 deletions

114
policy_server.py Normal file
View File

@@ -0,0 +1,114 @@
import secrets
import ssl
from dataclasses import dataclass, field
from aiohttp import web
from signedjson.key import decode_signing_key_base64
from signedjson.types import SigningKey
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import compute_event_signature
routes = web.RouteTableDef()
JOIN_FLOW_PAGE = """
<html>
<body>
<a href="/accept?redirect_url={redirect_url}" target="_self">Accept policy and join room</a>
</body>
</html>
"""
SIGNING_KEY = decode_signing_key_base64(
"ed25519", "p_afG2", "E+EmxfcqLYjlS20I5ZzjoYeN7oR9Qt/zitPGomU0hmA"
)
@dataclass
class PolicyServer:
server_name: str
signing_key: SigningKey
base_url: str
token_store: dict[str, str] = field(default_factory=dict)
@routes.get("/")
async def hello(request):
return web.Response(text="Hello, world")
@routes.post("/_matrix/federation/unstable/re.jki.join_policy/request_join")
async def request_join(request: web.Request) -> web.Response:
policy_server: PolicyServer = request.app["policy_server"]
return web.json_response({"url": policy_server.base_url + "/join_flow"})
@routes.post("/_matrix/federation/unstable/re.jki.join_policy/sign_join")
async def sign_join(request: web.Request) -> web.Response:
policy_server: PolicyServer = request.app["policy_server"]
json_body = await request.json()
if json_body["token"] not in policy_server.token_store:
return web.json_response({}, status=403)
room_version_id = json_body["room_version"]
event_json = json_body["event"]
room_version = KNOWN_ROOM_VERSIONS[room_version_id]
signatures = compute_event_signature(
room_version=room_version,
event_dict=event_json,
signature_name=policy_server.server_name,
signing_key=policy_server.signing_key,
)
return web.json_response({"signatures": signatures[policy_server.server_name]})
@routes.get("/join_flow")
async def join_flow(request: web.Request) -> web.Response:
redirect_url = request.query["redirect_url"]
return web.Response(
text=JOIN_FLOW_PAGE.format(redirect_url=redirect_url), content_type="text/html"
)
@routes.get("/accept")
async def accept(request: web.Request) -> web.Response:
policy_server: PolicyServer = request.app["policy_server"]
redirect_url = request.query["redirect_url"]
token = secrets.token_hex(16)
policy_server.token_store[token] = "user_id"
# TODO: Use less dodgy URL creation
if "?" in redirect_url:
redirect_url += f"&token={token}"
else:
redirect_url += f"?token={token}"
return web.Response(
text="Done!",
status=307,
headers={"location": redirect_url},
)
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(
certfile="/home/erikj/git/synapse/demo/8080/localhost:8080.tls.crt",
keyfile="/home/erikj/git/synapse/demo/8080/localhost:8080.tls.key",
)
app = web.Application()
app["policy_server"] = PolicyServer(
server_name="localhost:8865",
signing_key=SIGNING_KEY,
base_url="https://localhost:8865",
)
app.add_routes(routes)
web.run_app(app, port=8865, ssl_context=context)

View File

@@ -552,11 +552,23 @@ def _is_membership_change_allowed(
key = (EventTypes.JoinRules, "")
join_rule_event = auth_events.get(key)
join_policy_server: Optional[str] = None
if join_rule_event:
join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE)
join_policy_server = join_rule_event.content.get("re.jki.join_policy_server")
else:
join_rule = JoinRules.INVITE
if (
join_policy_server
and membership == Membership.JOIN
and not (caller_in_room or caller_invited)
):
logger.info("Checking sigs")
if not event.signatures.get(join_policy_server):
raise AuthError(403, "Not signed by join policy server")
caller_invited = True
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(target_user_id, auth_events)

View File

@@ -1960,6 +1960,43 @@ class FederationClient(FederationBase):
ip_address=ip_address,
)
async def join_policy_server_get_url(
self, policy_server: str, room_id: str, room_version: RoomVersion, user_id: str
) -> Optional[str]:
result = await self.transport_layer.join_policy_server_get_url(
policy_server=policy_server,
room_id=room_id,
room_version=room_version,
user_id=user_id,
)
url = result.get("url")
if isinstance(url, str):
return url
return None
async def join_policy_server_sign_join(
self,
policy_server: str,
room_id: str,
user_id: str,
token: str,
room_version: RoomVersion,
event: EventBase,
) -> None:
result = await self.transport_layer.join_policy_server_sign_join(
policy_server=policy_server,
room_id=room_id,
user_id=user_id,
token=token,
room_version=room_version,
event=event,
)
signatures = result.get("signatures")
if signatures:
event.signatures[policy_server] = signatures
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse:

View File

@@ -894,6 +894,46 @@ class TransportLayerClient:
ip_address=ip_address,
)
async def join_policy_server_get_url(
self, policy_server: str, room_id: str, room_version: RoomVersion, user_id: str
) -> JsonDict:
path = _create_path(
FEDERATION_UNSTABLE_PREFIX, "/re.jki.join_policy/request_join"
)
return await self.client.post_json(
policy_server,
path,
data={
"room_id": room_id,
"room_version": room_version.identifier,
"user_id": user_id,
},
ignore_backoff=True,
)
async def join_policy_server_sign_join(
self,
policy_server: str,
room_id: str,
user_id: str,
token: str,
room_version: RoomVersion,
event: EventBase,
) -> JsonDict:
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/re.jki.join_policy/sign_join")
return await self.client.post_json(
policy_server,
path,
data={
"room_id": room_id,
"user_id": user_id,
"token": token,
"room_version": room_version.identifier,
"event": event.get_pdu_json(),
},
ignore_backoff=True,
)
def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""

View File

@@ -398,6 +398,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
require_consent: bool = True,
outlier: bool = False,
origin_server_ts: Optional[int] = None,
join_policy_token: Optional[str] = None,
) -> Tuple[str, int]:
"""
Internal membership update function to get an existing event or create
@@ -491,9 +492,49 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, user_id)])
StateFilter.from_types(
[(EventTypes.Member, user_id), (EventTypes.JoinRules, "")]
)
)
if membership == Membership.JOIN:
join_rule_id = prev_state_ids.get((EventTypes.JoinRules, ""))
if join_rule_id is not None:
join_rule_event = await self.store.get_event(
join_rule_id, allow_none=True
)
if join_rule_event:
join_policy_server = join_rule_event.content.get(
"re.jki.join_policy_server"
)
if isinstance(join_policy_server, str):
if join_policy_token is None:
policy_url = await self.federation_handler.federation_client.join_policy_server_get_url(
policy_server=join_policy_server,
room_id=room_id,
room_version=event.room_version,
user_id=target.to_string(),
)
if policy_url is not None:
raise SynapseError(
403,
"Cannot join room",
errcode="RE_JKI_JOIN_POLICY_URL",
additional_fields={
"re.jki.join_policy_url": policy_url
},
)
else:
await self.federation_handler.federation_client.join_policy_server_sign_join(
policy_server=join_policy_server,
room_id=room_id,
room_version=event.room_version,
user_id=target.to_string(),
token=join_policy_token,
event=event,
)
prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, user_id), None
)
@@ -584,6 +625,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
origin_server_ts: Optional[int] = None,
join_policy_token: Optional[str] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -681,6 +723,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
state_event_ids=state_event_ids,
depth=depth,
origin_server_ts=origin_server_ts,
join_policy_token=join_policy_token,
)
return result
@@ -704,6 +747,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
origin_server_ts: Optional[int] = None,
join_policy_token: Optional[str] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -929,6 +973,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
require_consent=require_consent,
outlier=outlier,
origin_server_ts=origin_server_ts,
join_policy_token=join_policy_token,
)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
@@ -1188,6 +1233,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
require_consent=require_consent,
outlier=outlier,
origin_server_ts=origin_server_ts,
join_policy_token=join_policy_token,
)
async def check_for_any_membership_in_room(

View File

@@ -528,6 +528,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
remote_room_hosts,
)
join_policy_token = parse_string(
request, "re.jki.join_policy_token", required=False
)
logger.info("re.jki.join_policy_token: %s", join_policy_token)
await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
@@ -537,6 +543,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
remote_room_hosts=remote_room_hosts,
content=content,
third_party_signed=content.get("third_party_signed", None),
join_policy_token=join_policy_token,
)
return 200, {"room_id": room_id}