Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Pull out less state when handling gaps #12828

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12828.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.
140 changes: 65 additions & 75 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ async def process_remote_join(
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state,
)

Expand Down Expand Up @@ -501,7 +503,7 @@ async def update_state_for_partial_state_event(
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event=state,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
Expand Down Expand Up @@ -765,7 +767,7 @@ async def _process_pulled_event(

async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.

This is used when we have pulled a batch of events from a remote server, and
Expand All @@ -792,8 +794,8 @@ async def _resolve_state_at_missing_prevs(
event: an event to check for missing prevs.

Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the state at `event`.
Comment on lines +797 to +798
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if we already had all the prev events, `None`. Otherwise, returns
the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the event ids of the state at `event`.

"""
room_id = event.room_id
event_id = event.event_id
Expand Down Expand Up @@ -837,13 +839,7 @@ async def _resolve_state_at_missing_prevs(
dest, room_id, p
)

remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)

for x in remote_state:
event_map[x.event_id] = x
state_maps.append(remote_state)

room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
Expand All @@ -854,19 +850,6 @@ async def _resolve_state_at_missing_prevs(
state_res_store=StateResolutionStore(self._store),
)

# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.

# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)

state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
Expand All @@ -878,14 +861,14 @@ async def _resolve_state_at_missing_prevs(
"We can't get valid state history.",
affected=event_id,
)
return state
return state_map

async def _get_state_after_missing_prev_event(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename this to reflect that it returns event ids?

Suggested change
async def _get_state_after_missing_prev_event(
async def _get_state_ids_after_missing_prev_event(

This is probably relevant to other places too.

self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.

Args:
Expand All @@ -894,7 +877,7 @@ async def _get_state_after_missing_prev_event(
event_id: The id of the event we want the state at.

Returns:
A list of events in the state, including the event itself
The state *after* the given event.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The state *after* the given event.
The event ids of the state *after* the given event.

"""
(
state_event_ids,
Expand All @@ -913,15 +896,13 @@ async def _get_state_after_missing_prev_event(
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
have_events = await self._store.have_seen_events(room_id, desired_events)

missing_desired_events = desired_events - fetched_events.keys()
missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
len(have_events),
)

# We probably won't need most of the auth events, so let's just check which
Expand All @@ -932,7 +913,7 @@ async def _get_state_after_missing_prev_event(
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.

missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
Expand All @@ -958,62 +939,67 @@ async def _get_state_after_missing_prev_event(
destination=destination, room_id=room_id, event_ids=missing_events
)

# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
event_metadata = await self._store.get_metadata_for_events(state_event_ids)

# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
Comment on lines 944 to 947
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment looks a bit lost now. Move it down?


bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate line?


for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
state_map = {}

for state_event_id, metadata in event_metadata.items():
if metadata.room_id != room_id:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
state_event_id,
metadata.room_id,
room_id,
)
continue

if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue

del fetched_events[bad_event_id]
state_map[(metadata.event_type, metadata.state_key)] = state_event_id

# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))

# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
event_id,
failed_to_fetch,
)

remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]

if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id

return remote_state
return state_map

async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
Expand All @@ -1040,7 +1026,7 @@ async def _process_received_pdu(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general theme: I'd love it if we could change the generic state parameters to state_ids_before_event or similar, to make it easier to understand whether we are dealing with event ids or complete events, without having to go and look at the type.

Similarly let's rename methods which now return collections of event ids where they previously returned events.

backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
Expand Down Expand Up @@ -1074,7 +1060,7 @@ async def _process_received_pdu(

try:
context = await self._state_handler.compute_event_context(
event, old_state=state
event, state_ids_before_event=state
)
context = await self._check_event_auth(
origin,
Expand Down Expand Up @@ -1565,7 +1551,7 @@ async def _maybe_kick_guest_users(self, event: EventBase) -> None:
async def _check_for_soft_fail(
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
Expand Down Expand Up @@ -1602,17 +1588,21 @@ async def _check_for_soft_fail(
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.

state_sets_d = await self._state_store.get_state_groups(
state_sets_d = await self._state_store.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self._state_handler.resolve_events(
room_version, state_sets, event

current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map={},
state_res_store=StateResolutionStore(self._store),
)
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
Expand Down
21 changes: 19 additions & 2 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,8 +1021,25 @@ async def create_new_client_event(
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)

state_map = {}
for state_id in state_event_ids:
data = metadata.get(state_id)
if data is None:
raise Exception("State event not persisted %s", state_id)

if data.state_key is None:
raise Exception(
"Trying to set non-state event as state: %s", state_id
)

state_map[(data.event_type, data.state_key)] = state_id

context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map,
)
else:
context = await self.state.compute_event_context(event)

Expand Down
14 changes: 5 additions & 9 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def get_hosts_in_room_at_events(
async def compute_event_context(
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
Expand All @@ -273,12 +273,12 @@ async def compute_event_context(

Args:
event:
old_state: The state at the event if it can't be
state_ids_before_event: The state at the event if it can't be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state_ids_before_event: The state at the event if it can't be
state_ids_before_event: The event ids of the state at the event if it can't be

calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
Returns:
The event context.
"""
Expand All @@ -288,11 +288,7 @@ async def compute_event_context(
#
# first of all, figure out the state before the event
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# first of all, figure out the state before the event
# first of all, figure out the state before the event, unless we already have it.

#
if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
if state_ids_before_event:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
Expand Down
Loading