Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Initial implementation of MSC3981: recursive relations API #15315

Merged
merged 6 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15315.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support to recursively provide relations per [MSC3981](https://github.com/matrix-org/matrix-spec-proposals/pull/3981).
5 changes: 5 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,8 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:

# MSC2659: Application service ping endpoint
self.msc2659_enabled = experimental.get("msc2659_enabled", False)

# MSC3981: Recurse relations
self.msc3981_recurse_relations = experimental.get(
"msc3981_recurse_relations", False
)
3 changes: 3 additions & 0 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def get_relations(
event_id: str,
room_id: str,
pagin_config: PaginationConfig,
recurse: bool,
include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
Expand All @@ -98,6 +99,7 @@ async def get_relations(
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
pagin_config: The pagination config rules to apply, if any.
recurse: Whether to recursively find relations.
include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
Expand Down Expand Up @@ -132,6 +134,7 @@ async def get_relations(
direction=pagin_config.direction,
from_token=pagin_config.from_token,
to_token=pagin_config.to_token,
recurse=recurse,
)

events = await self._main_store.get_events_as_list(
Expand Down
10 changes: 9 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
self._support_recurse = hs.config.experimental.msc3981_recurse_relations

async def on_GET(
self,
Expand All @@ -63,6 +64,12 @@ async def on_GET(
pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
if self._support_recurse:
recurse = parse_boolean(
request, "org.matrix.msc3981.recurse", default=False
)
else:
recurse = False

# The unstable version of this API returns an extra field for client
# compatibility, see https://github.com/matrix-org/synapse/issues/12930.
Expand All @@ -75,6 +82,7 @@ async def on_GET(
event_id=parent_id,
room_id=room_id,
pagin_config=pagination_config,
recurse=recurse,
include_original_event=include_original_event,
relation_type=relation_type,
event_type=event_type,
Expand Down
52 changes: 39 additions & 13 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def get_relations_for_event(
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
recurse: bool = False,
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.

Expand All @@ -186,6 +187,7 @@ async def get_relations_for_event(
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
recurse: Whether to recursively find relations.

Returns:
A tuple of:
Expand All @@ -200,7 +202,7 @@ async def get_relations_for_event(
# Ensure bad limits aren't being passed in.
assert limit >= 0

where_clause = ["relates_to_id = ?", "room_id = ?"]
where_clause = ["room_id = ?"]
where_args: List[Union[str, int]] = [event.event_id, room_id]
clokep marked this conversation as resolved.
Show resolved Hide resolved
is_redacted = event.internal_metadata.is_redacted()

Expand Down Expand Up @@ -229,18 +231,42 @@ async def get_relations_for_event(
if pagination_clause:
where_clause.append(pagination_clause)

sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)
# TODO This needs to be a recursive query that maybe still matches some of the times above.
clokep marked this conversation as resolved.
Show resolved Hide resolved
if recurse:
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relation_type, relates_to_id, 0 depth
clokep marked this conversation as resolved.
Show resolved Hide resolved
FROM event_relations
WHERE relates_to_id = ?
UNION SELECT e.event_id, e.relation_type, e.relates_to_id, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.event_id = e.relates_to_id
WHERE depth <= 3
)
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM related_events
INNER JOIN events USING (event_id)
WHERE %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?;
Comment on lines +251 to +256
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SELECT portion of these queries are the same -- would be better / clearer to structure this as a preamble & table name to query and only have one copy of the shared bit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably fine for now

""" % (
" AND ".join(where_clause),
order,
order,
)
else:
sql = """
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND %s
ORDER BY topological_ordering %s, stream_ordering %s
LIMIT ?
""" % (
" AND ".join(where_clause),
order,
order,
)

def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
Expand Down
120 changes: 120 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config


class BaseRelationsTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -949,6 +950,125 @@ def test_pagination_from_sync_and_messages(self) -> None:
)


class RecursiveRelationTestCase(BaseRelationsTestCase):
@override_config({"experimental_features": {"msc3981_recurse_relations": True}})
def test_recursive_relations(self) -> None:
"""Generate a complex, multi-level relationship tree and query it."""
# Create a thread with a few messages in it.
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_1 = channel.json_body["event_id"]

channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_2 = channel.json_body["event_id"]

# Add annotations.
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
)
annotation_1 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
)
annotation_2 = channel.json_body["event_id"]

# Add a reference to part of the thread, then edit the reference and annotate it.
channel = self._send_relation(
RelationTypes.REFERENCE, "m.room.test", parent_id=thread_2
)
reference_1 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "c", parent_id=reference_1
)
annotation_3 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.test",
parent_id=reference_1,
)
edit = channel.json_body["event_id"]

# Also more events off the root.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "d")
annotation_4 = channel.json_body["event_id"]

channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}"
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# The above events should be returned in creation order.
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
event_ids,
[
thread_1,
thread_2,
annotation_1,
annotation_2,
reference_1,
annotation_3,
edit,
annotation_4,
],
)

@override_config({"experimental_features": {"msc3981_recurse_relations": True}})
def test_recursive_relations_with_filter(self) -> None:
"""The event_type and rel_type still apply."""
# Create a thread with a few messages in it.
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_1 = channel.json_body["event_id"]

# Add annotations.
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=thread_1
)
annotation_1 = channel.json_body["event_id"]

# Add a reference to part of the thread, then edit the reference and annotate it.
channel = self._send_relation(
RelationTypes.REFERENCE, "m.room.test", parent_id=thread_1
)
reference_1 = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.ANNOTATION, "org.matrix.reaction", "c", parent_id=reference_1
)
annotation_2 = channel.json_body["event_id"]

# Fetch only annotations, but recursively.
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}"
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# The above events should be returned in creation order.
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(event_ids, [annotation_1, annotation_2])

# Fetch only m.reactions, but recursively.
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}/{RelationTypes.ANNOTATION}/m.reaction"
"?dir=f&limit=20&org.matrix.msc3981.recurse=true",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# The above events should be returned in creation order.
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(event_ids, [annotation_1])


class BundledAggregationsTestCase(BaseRelationsTestCase):
"""
See RelationsTestCase.test_edit for a similar test for edits.
Expand Down