Skip to content

Commit

Permalink
Merge pull request #447 from alisaifee/fix-concurrent-access-error-fo…
Browse files Browse the repository at this point in the history
…r-default-dicts

Replace use of default dict for exemptions
  • Loading branch information
alisaifee authored May 19, 2024
2 parents 7a125a6 + 5ff2943 commit 5c12e59
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
15 changes: 13 additions & 2 deletions flask_limiter/commands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import itertools
import time
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from urllib.parse import urlparse

import click
Expand Down Expand Up @@ -180,7 +191,7 @@ def get_filtered_endpoint(
filter_endpoint, _ = adapter.match(
parsed.path, method=method, query_args=parsed.query
)
return filter_endpoint
return cast(str, filter_endpoint)
except NotFound:
console.print(
f"[error]Error: {path} could not be matched to an endpoint[/error]"
Expand Down
8 changes: 2 additions & 6 deletions flask_limiter/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,8 @@ def __init__(
self._in_memory_fallback_enabled = in_memory_fallback_enabled or (
in_memory_fallback and len(in_memory_fallback) > 0
)
self._route_exemptions: Dict[str, ExemptionScope] = defaultdict(
lambda: ExemptionScope.NONE
)
self._blueprint_exemptions: Dict[str, ExemptionScope] = defaultdict(
lambda: ExemptionScope.NONE
)
self._route_exemptions: Dict[str, ExemptionScope] = {}
self._blueprint_exemptions: Dict[str, ExemptionScope] = {}
self._request_filters: List[Callable[[], bool]] = []

self._headers_enabled = headers_enabled
Expand Down
17 changes: 12 additions & 5 deletions flask_limiter/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def exemption_scope(
) -> ExemptionScope:
view_func = app.view_functions.get(endpoint or "", None)
name = get_qualified_name(view_func) if view_func else ""
route_exemption_scope = self._route_exemptions[name]
route_exemption_scope = self._route_exemptions.get(name, ExemptionScope.NONE)
blueprint_instance = app.blueprints.get(blueprint) if blueprint else None

if not blueprint_instance:
Expand All @@ -161,7 +161,7 @@ def exemption_scope(

def decorated_limits(self, callable_name: str) -> List[Limit]:
limits = []
if not self._route_exemptions[callable_name]:
if not self._route_exemptions.get(callable_name, ExemptionScope.NONE):
if callable_name in self._decorated_limits:
for group in self._decorated_limits[callable_name]:
try:
Expand Down Expand Up @@ -206,7 +206,9 @@ def blueprint_limits(self, app: flask.Flask, blueprint: str) -> List[Limit]:
limit.override_defaults for limit in blueprint_self_limits
)
)
and not self._blueprint_exemptions[blueprint_name]
and not self._blueprint_exemptions.get(
blueprint_name, ExemptionScope.NONE
)
& ExemptionScope.ANCESTORS
else blueprint_self_limits
)
Expand Down Expand Up @@ -242,7 +244,9 @@ def _blueprint_exemption_scope(
self, app: flask.Flask, blueprint_name: str
) -> Tuple[ExemptionScope, Dict[str, ExemptionScope]]:
name = app.blueprints[blueprint_name].name
exemption = self._blueprint_exemptions[name] & ~(ExemptionScope.ANCESTORS)
exemption = self._blueprint_exemptions.get(name, ExemptionScope.NONE) & ~(
ExemptionScope.ANCESTORS
)

ancestory = set(blueprint_name.split("."))
ancestor_exemption = {
Expand All @@ -251,4 +255,7 @@ def _blueprint_exemption_scope(
if f & ExemptionScope.DESCENDENTS
}.intersection(ancestory)

return exemption, {k: self._blueprint_exemptions[k] for k in ancestor_exemption}
return exemption, {
k: self._blueprint_exemptions.get(k, ExemptionScope.NONE)
for k in ancestor_exemption
}

0 comments on commit 5c12e59

Please sign in to comment.