Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Additional type hints for config module. (#11465)
Browse files Browse the repository at this point in the history
This adds some misc. type hints to helper methods used
in the `synapse.config` module.
  • Loading branch information
clokep committed Dec 1, 2021
1 parent a265fbd commit f44d729
Show file tree
Hide file tree
Showing 15 changed files with 129 additions and 99 deletions.
1 change: 1 addition & 0 deletions changelog.d/11465.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `synapse.config` module.
3 changes: 2 additions & 1 deletion synapse/app/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Iterable,
List,
NoReturn,
Optional,
Tuple,
cast,
)
Expand Down Expand Up @@ -129,7 +130,7 @@ def start_worker_reactor(
def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Tuple[int, int, int],
gc_thresholds: Optional[Tuple[int, int, int]],
pid_file: str,
daemonize: bool,
print_pidfile: bool,
Expand Down
3 changes: 2 additions & 1 deletion synapse/config/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import List

from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig


def main(args):
def main(args: List[str]) -> None:
action = args[1] if len(args) > 1 and args[1] == "read" else None
# If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]`
# will be the key to read.
Expand Down
23 changes: 14 additions & 9 deletions synapse/config/appservice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,14 +14,14 @@
# limitations under the License.

import logging
from typing import Dict
from typing import Dict, List
from urllib import parse as urlparse

import yaml
from netaddr import IPSet

from synapse.appservice import ApplicationService
from synapse.types import UserID
from synapse.types import JsonDict, UserID

from ._base import Config, ConfigError

Expand All @@ -30,12 +31,12 @@
class AppServiceConfig(Config):
section = "appservice"

def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)

def generate_config_section(cls, **kwargs):
def generate_config_section(cls, **kwargs) -> str:
return """\
# A list of application service config files to use
#
Expand All @@ -50,7 +51,9 @@ def generate_config_section(cls, **kwargs):
"""


def load_appservices(hostname, config_files):
def load_appservices(
hostname: str, config_files: List[str]
) -> List[ApplicationService]:
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning("Expected %s to be a list of AS config files.", config_files)
Expand Down Expand Up @@ -93,7 +96,9 @@ def load_appservices(hostname, config_files):
return appservices


def _load_appservice(hostname, as_info, config_filename):
def _load_appservice(
hostname: str, as_info: JsonDict, config_filename: str
) -> ApplicationService:
required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields:
if not isinstance(as_info.get(field), str):
Expand All @@ -115,9 +120,9 @@ def _load_appservice(hostname, as_info, config_filename):
user_id = user.to_string()

# Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = True
if isinstance(as_info.get("rate_limited"), bool):
rate_limited = as_info.get("rate_limited")
rate_limited = as_info.get("rate_limited")
if not isinstance(rate_limited, bool):
rate_limited = True

# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
Expand Down
26 changes: 14 additions & 12 deletions synapse/config/cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Matrix.org Foundation C.I.C.
# Copyright 2019-2021 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
import threading
from typing import Callable, Dict, Optional

import attr

from synapse.python_dependencies import DependencyException, check_requirements

from ._base import Config, ConfigError
Expand All @@ -34,13 +36,13 @@
_DEFAULT_EVENT_CACHE_SIZE = "10K"


@attr.s(slots=True, auto_attribs=True)
class CacheProperties:
def __init__(self):
# The default factor size for all caches
self.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
self.resize_all_caches_func = None
# The default factor size for all caches
default_factor_size: float = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
)
resize_all_caches_func: Optional[Callable[[], None]] = None


properties = CacheProperties()
Expand All @@ -62,7 +64,7 @@ def _canonicalise_cache_name(cache_name: str) -> str:

def add_resizable_cache(
cache_name: str, cache_resize_callback: Callable[[float], None]
):
) -> None:
"""Register a cache that's size can dynamically change
Args:
Expand Down Expand Up @@ -91,7 +93,7 @@ class CacheConfig(Config):
_environ = os.environ

@staticmethod
def reset():
def reset() -> None:
"""Resets the caches to their defaults. Used for tests."""
properties.default_factor_size = float(
os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE)
Expand All @@ -100,7 +102,7 @@ def reset():
with _CACHES_LOCK:
_CACHES.clear()

def generate_config_section(self, **kwargs):
def generate_config_section(self, **kwargs) -> str:
return """\
## Caching ##
Expand Down Expand Up @@ -162,7 +164,7 @@ def generate_config_section(self, **kwargs):
#sync_response_cache_duration: 2m
"""

def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
)
Expand Down Expand Up @@ -232,7 +234,7 @@ def read_config(self, config, **kwargs):
# needing an instance of Config
properties.resize_all_caches_func = self.resize_all_caches

def resize_all_caches(self):
def resize_all_caches(self) -> None:
"""Ensure all cache sizes are up to date
For each cache, run the mapped callback function with either
Expand Down
5 changes: 3 additions & 2 deletions synapse/config/cas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,7 +29,7 @@ class CasConfig(Config):

section = "cas"

def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
cas_config = config.get("cas_config", None)
self.cas_enabled = cas_config and cas_config.get("enabled", True)

Expand All @@ -51,7 +52,7 @@ def read_config(self, config, **kwargs):
self.cas_displayname_attribute = None
self.cas_required_attributes = []

def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
return """\
# Enable Central Authentication Service (CAS) for registration and login.
#
Expand Down
13 changes: 7 additions & 6 deletions synapse/config/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os

Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(self, *args, **kwargs):

self.databases = []

def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
# We *experimentally* support specifying multiple databases via the
# `databases` key. This is a map from a label to database config in the
# same format as the `database` config option, plus an extra
Expand Down Expand Up @@ -163,12 +164,12 @@ def read_config(self, config, **kwargs):
self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(database_path)

def generate_config_section(self, data_dir_path, **kwargs):
def generate_config_section(self, data_dir_path, **kwargs) -> str:
return DEFAULT_CONFIG % {
"database_path": os.path.join(data_dir_path, "homeserver.db")
}

def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
"""
Cases for the cli input:
- If no databases are configured and no database_path is set, raise.
Expand All @@ -194,15 +195,15 @@ def read_arguments(self, args):
else:
logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)

def set_databasepath(self, database_path):
def set_databasepath(self, database_path: str) -> None:

if database_path != ":memory:":
database_path = self.abspath(database_path)

self.databases[0].config["args"]["database"] = database_path

@staticmethod
def add_arguments(parser):
def add_arguments(parser: argparse.ArgumentParser) -> None:
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d",
Expand Down
24 changes: 14 additions & 10 deletions synapse/config/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +19,7 @@
import sys
import threading
from string import Template
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional

import yaml
from zope.interface import implementer
Expand All @@ -40,6 +41,7 @@
from ._base import Config, ConfigError

if TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig
from synapse.server import HomeServer

DEFAULT_LOG_CONFIG = Template(
Expand Down Expand Up @@ -141,13 +143,13 @@
class LoggingConfig(Config):
section = "logging"

def read_config(self, config, **kwargs):
def read_config(self, config, **kwargs) -> None:
if config.get("log_file"):
raise ConfigError(LOG_FILE_ERROR)
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)

def generate_config_section(self, config_dir_path, server_name, **kwargs):
def generate_config_section(self, config_dir_path, server_name, **kwargs) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return (
"""\
Expand All @@ -161,14 +163,14 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs):
% locals()
)

def read_arguments(self, args):
def read_arguments(self, args: argparse.Namespace) -> None:
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
if args.log_file is not None:
raise ConfigError(LOG_FILE_ERROR)

@staticmethod
def add_arguments(parser):
def add_arguments(parser: argparse.ArgumentParser) -> None:
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"-n",
Expand Down Expand Up @@ -197,7 +199,9 @@ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))


def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None:
def _setup_stdlib_logging(
config: "HomeServerConfig", log_config_path: Optional[str], logBeginner: LogBeginner
) -> None:
"""
Set up Python standard library logging.
"""
Expand Down Expand Up @@ -230,7 +234,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
log_metadata_filter = MetadataFilter({"server_name": config.server.server_name})
old_factory = logging.getLogRecordFactory()

def factory(*args, **kwargs):
def factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
record = old_factory(*args, **kwargs)
log_context_filter.filter(record)
log_metadata_filter.filter(record)
Expand Down Expand Up @@ -297,7 +301,7 @@ def _load_logging_config(log_config_path: str) -> None:
logging.config.dictConfig(log_config)


def _reload_logging_config(log_config_path):
def _reload_logging_config(log_config_path: Optional[str]) -> None:
"""
Reload the log configuration from the file and apply it.
"""
Expand All @@ -311,8 +315,8 @@ def _reload_logging_config(log_config_path):

def setup_logging(
hs: "HomeServer",
config,
use_worker_options=False,
config: "HomeServerConfig",
use_worker_options: bool = False,
logBeginner: LogBeginner = globalLogBeginner,
) -> None:
"""
Expand Down
Loading

0 comments on commit f44d729

Please sign in to comment.