Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Do not allow cross-room relations, per MSC2674. #11516

Merged
merged 7 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11516.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event.
11 changes: 7 additions & 4 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,23 +454,26 @@ async def _injected_bundled_aggregations(
return

event_id = event.event_id
room_id = event.room_id

# The bundled aggregations to include.
aggregations = {}

annotations = await self.store.get_aggregation_groups_for_event(event_id)
annotations = await self.store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()

references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()

edit = None
if event.type == EventTypes.Message:
edit = await self.store.get_applicable_edit(event_id)
edit = await self.store.get_applicable_edit(event_id, room_id)

if edit:
# If there is an edit replace the content, preserving existing
Expand Down Expand Up @@ -503,7 +506,7 @@ async def _injected_bundled_aggregations(
(
thread_count,
latest_thread_event,
) = await self.store.get_thread_summary(event_id)
) = await self.store.get_thread_summary(event_id, room_id)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
Expand Down
7 changes: 6 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ async def on_GET(

pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
Expand Down Expand Up @@ -317,6 +318,7 @@ async def on_GET(

pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
room_id=room_id,
event_type=event_type,
limit=limit,
from_token=from_token,
Expand Down Expand Up @@ -383,7 +385,9 @@ async def on_GET(

# This checks that a) the event exists and b) the user is allowed to
# view it.
await self.event_handler.get_event(requester.user, room_id, parent_id)
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")

if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
Expand All @@ -402,6 +406,7 @@ async def on_GET(

result = await self.store.get_relations_for_event(
event_id=parent_id,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,10 +1780,14 @@ def _handle_event_relations(
)

if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
txn.call_after(
self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
)

if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)

def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
Expand Down
36 changes: 26 additions & 10 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_relations_for_event(
self,
event_id: str,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
Expand All @@ -49,6 +50,7 @@ async def get_relations_for_event(

Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
Expand All @@ -63,8 +65,8 @@ async def get_relations_for_event(
the form `{"event_id": "..."}`.
"""

where_clause = ["relates_to_id = ?"]
where_args: List[Union[str, int]] = [event_id]
where_clause = ["relates_to_id = ?", "room_id = ?"]
where_args: List[Union[str, int]] = [event_id, room_id]

if relation_type is not None:
where_clause.append("relation_type = ?")
Expand Down Expand Up @@ -199,6 +201,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool:
async def get_aggregation_groups_for_event(
self,
event_id: str,
room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
Expand All @@ -213,6 +216,7 @@ async def get_aggregation_groups_for_event(

Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
Expand All @@ -225,8 +229,12 @@ async def get_aggregation_groups_for_event(
`type`, `key` and `count` fields.
"""

where_clause = ["relates_to_id = ?", "relation_type = ?"]
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
where_args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]

if event_type:
where_clause.append("type = ?")
Expand Down Expand Up @@ -288,14 +296,17 @@ def _get_aggregation_groups_for_event_txn(
)

@cached()
async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
async def get_applicable_edit(
self, event_id: str, room_id: str
) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.

Correctly handles checking whether edits were allowed to happen.

Args:
event_id: The original event ID
room_id: The original event's room ID

Returns:
The most recent edit, if any.
Expand All @@ -317,13 +328,14 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
WHERE
relates_to_id = ?
AND relation_type = ?
AND edit.room_id = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""

def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE))
txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone()
if row:
return row[0]
Expand All @@ -340,13 +352,14 @@ def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:

@cached()
async def get_thread_summary(
self, event_id: str
self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.

Args:
event_id: The original event ID
event_id: Summarize the thread related to this event ID.
room_id: The room the event belongs to.

Returns:
The number of items in the thread and the most recent response, if any.
Expand All @@ -363,12 +376,13 @@ def _get_thread_summary_txn(
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""

txn.execute(sql, (event_id, RelationTypes.THREAD))
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
Expand All @@ -378,11 +392,13 @@ def _get_thread_summary_txn(
sql = """
SELECT COALESCE(COUNT(event_id), 0)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
"""
txn.execute(sql, (event_id, RelationTypes.THREAD))
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = txn.fetchone()[0] # type: ignore[index]

return count, latest_event_id
Expand Down
115 changes: 115 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
import itertools
import urllib.parse
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch

from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync

from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event


class RelationsTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -651,6 +654,118 @@ def test_aggregation_get_event_for_thread(self):
},
)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self):
"""Test that we ignore invalid relations over federation."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why single out federation here, sorry? Because the CS API should reject this kind of event if a local client tries to send it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why single out federation here, sorry? Because the CS API should reject this kind of event if a local client tries to send it?

Yes, exactly. We're somewhat assuming our server isn't generating "bad" events, but they could have before C-S had validation, but it is more likely to be from a buggy / malicious server over federation.

# Create another room and send a message in it.
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
res = self.helper.send(room2, body="Hi!", tok=self.user_token)
parent_id = res["event_id"]

# Disable the validation to pretend this came over federation.
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None),
):
# Generate a various relations from a different room.
self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.reaction",
sender=self.user_id,
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": parent_id,
"key": "A",
}
},
)
)

self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": RelationTypes.REFERENCE,
"event_id": parent_id,
},
},
)
)

self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": parent_id,
},
},
)
)

self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"new_content": {
"body": "new content",
"msgtype": "m.text",
},
"m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": parent_id,
},
},
)
)

# They should be ignored when fetching relations.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])

# And when fetching aggregations.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])

# And for bundled aggregations.
channel = self.make_request(
"GET",
f"/rooms/{room2}/event/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])

def test_edit(self):
"""Test that a simple edit works."""

Expand Down