Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sambacc: complete type annotations #47

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ quiet = true

[tool.mypy]
disallow_incomplete_defs = true

[[tool.mypy.overrides]]
module = "sambacc.*"
disallow_untyped_defs = true

[[tool.mypy.overrides]]
module = "sambacc.commands.*"
disallow_untyped_defs = false
4 changes: 2 additions & 2 deletions sambacc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def shares(self) -> typing.Iterable[ShareConfig]:


class GlobalConfig:
def __init__(self, source=None):
self.data = {}
def __init__(self, source: typing.Optional[typing.IO] = None) -> None:
self.data: dict[str, typing.Any] = {}
if source is not None:
self.load(source)

Expand Down
41 changes: 29 additions & 12 deletions sambacc/container_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>
#

from __future__ import annotations

import json
import subprocess
import typing
Expand All @@ -27,18 +29,22 @@


class HostState:
def __init__(self, ref="", items=[]):
self.ref = ref
self.items = items
T = typing.TypeVar("T", bound="HostState")

def __init__(
self, ref: str = "", items: typing.Optional[list[HostInfo]] = None
) -> None:
self.ref: str = ref
self.items: list[HostInfo] = items or []

@classmethod
def from_dict(cls, d):
def from_dict(cls: typing.Type[T], d: dict[str, typing.Any]) -> T:
return cls(
ref=d["ref"],
items=[HostInfo.from_dict(i) for i in d.get("items", [])],
)

def __eq__(self, other):
def __eq__(self, other: typing.Any) -> bool:
return (
self.ref == other.ref
and len(self.items) == len(other.items)
Expand All @@ -47,32 +53,36 @@ def __eq__(self, other):


class HostInfo:
def __init__(self, name="", ipv4_addr="", target=""):
T = typing.TypeVar("T", bound="HostInfo")

def __init__(
self, name: str = "", ipv4_addr: str = "", target: str = ""
) -> None:
self.name = name
self.ipv4_addr = ipv4_addr
self.target = target

@classmethod
def from_dict(cls, d):
def from_dict(cls: typing.Type[T], d: dict[str, typing.Any]) -> T:
return cls(
name=d["name"],
ipv4_addr=d["ipv4"],
target=d.get("target", ""),
)

def __eq__(self, other):
def __eq__(self, other: typing.Any) -> bool:
return (
self.name == other.name
and self.ipv4_addr == other.ipv4_addr
and self.target == other.target
)


def parse(fh):
def parse(fh: typing.IO) -> HostState:
return HostState.from_dict(json.load(fh))


def parse_file(path):
def parse_file(path: str) -> HostState:
with open(path) as fh:
return parse(fh)

Expand Down Expand Up @@ -115,13 +125,20 @@ def parse_and_update(


# TODO: replace this with the common version added to simple_waiter
def watch(domain, source, update_func, pause_func, print_func=None):
def watch(
domain: str,
source: str,
update_func: typing.Callable,
pause_func: typing.Callable,
print_func: typing.Optional[typing.Callable],
) -> None:
previous = None
while True:
try:
previous, updated = update_func(domain, source, previous)
except FileNotFoundError:
print_func(f"Source file [{source}] not found")
if print_func:
print_func(f"Source file [{source}] not found")
updated = False
previous = None
if updated and print_func:
Expand Down
6 changes: 5 additions & 1 deletion sambacc/ctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def next_state(state: NodeState) -> NodeState:


class NodeNotPresent(KeyError):
def __init__(self, identity, pnn=None):
def __init__(
self,
identity: typing.Any,
pnn: typing.Optional[typing.Union[str, int]] = None,
) -> None:
super().__init__(identity)
self.identity = identity
self.pnn = pnn
Expand Down
4 changes: 2 additions & 2 deletions sambacc/inotify_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def acted(self) -> None:
def wait(self) -> None:
next(self._wait())

def _get_events(self):
def _get_events(self) -> list[typing.Any]:
timeout = 1000 * self.timeout
self._print("waiting {}ms for activity...".format(timeout))
events = self._inotify.read(timeout=timeout)
Expand All @@ -85,7 +85,7 @@ def _get_events(self):
and ((event.mask & _inotify.flags.CLOSE_WRITE) != 0)
]

def _wait(self):
def _wait(self) -> typing.Iterator[None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no just have a single wait function? What is the advantage of using next(typing.Iterator[None]) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the time I must have felt implementing the function in terms of an iterator was easier or clearer. This PR only adds the proper type annotation to the function, it doesn't change the runtime behavior. Let's leave that for another time.

while True:
for event in self._get_events():
if event is None:
Expand Down
11 changes: 7 additions & 4 deletions sambacc/jfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@
import fcntl
import json
import os
import typing

OPEN_RO = os.O_RDONLY
OPEN_RW = os.O_CREAT | os.O_RDWR


def open(path, flags, mode=0o644):
def open(path: str, flags: int, mode: int = 0o644) -> typing.IO:
"""A wrapper around open to open JSON files for read or read/write.
`flags` must be os.open type flags. Use `OPEN_RO` and `OPEN_RW` for
convenience.
"""
return os.fdopen(os.open(path, flags, mode), "r+")


def load(fh, default=None):
def load(
fh: typing.IO, default: typing.Optional[dict[str, typing.Any]] = None
) -> typing.Any:
"""Similar to json.load, but returns the `default` value if fh refers to an
empty file. fh must be seekable."""
if fh.read(4) == "":
Expand All @@ -46,7 +49,7 @@ def load(fh, default=None):
return data


def dump(data, fh):
def dump(data: typing.Any, fh: typing.IO) -> None:
"""Similar to json.dump, but truncates the file before writing in order
to avoid appending data to the file. fh must be seekable.
"""
Expand All @@ -55,6 +58,6 @@ def dump(data, fh):
json.dump(data, fh)


def flock(fh):
def flock(fh: typing.IO) -> None:
"""A simple wrapper around flock."""
fcntl.flock(fh.fileno(), fcntl.LOCK_EX)
18 changes: 12 additions & 6 deletions sambacc/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@


class JoinError(Exception):
def __init__(self, v):
def __init__(self, v: typing.Any) -> None:
super().__init__(v)
self.errors = []
self.errors: list[typing.Any] = []


_PROMPT = object()
_PT = typing.TypeVar("_PT")
_PW = typing.Union[str, _PT]


class JoinBy(enum.Enum):
Expand All @@ -47,9 +49,13 @@ class UserPass:
"""Encapsulate a username/password pair."""

username: str = "Administrator"
password: typing.Optional[str] = None
password: typing.Optional[_PW] = None

def __init__(self, username=None, password=None):
def __init__(
self,
username: typing.Optional[str] = None,
password: typing.Optional[_PW] = None,
) -> None:
if username is not None:
self.username = username
if password is not None:
Expand All @@ -65,8 +71,8 @@ class Joiner:

_net_ads_join = samba_cmds.net["ads", "join"]

def __init__(self, marker=None):
self._sources = []
def __init__(self, marker: typing.Optional[str] = None) -> None:
self._sources: list[tuple[JoinBy, typing.Any]] = []
self.marker = marker

def add_source(
Expand Down
4 changes: 3 additions & 1 deletion sambacc/netcmd_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def template_config(
class NetCmdLoader:
_net_conf = samba_cmds.net["conf"]

def _cmd(self, *args, **kwargs):
def _cmd(
self, *args: str, **kwargs: typing.Any
) -> tuple[list[str], typing.Any]:
cmd = list(self._net_conf[args])
return cmd, subprocess.Popen(cmd, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions sambacc/nsswitch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@


class NameServiceSwitchLoader(TextFileLoader):
def __init__(self, path):
def __init__(self, path: str) -> None:
super().__init__(path)
self.lines = []
self.idx = {}
self.lines: list[str] = []
self.idx: dict[str, int] = {}

def loadlines(self, lines: typing.Iterable[str]) -> None:
"""Load in the lines from the text source."""
Expand Down
6 changes: 4 additions & 2 deletions sambacc/passdb_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>
#

import typing

from sambacc import config

# Do the samba python bindings not export any useful constants?
Expand All @@ -24,15 +26,15 @@
ACB_PWNOEXP = 0x00000200


def _samba_modules():
def _samba_modules() -> tuple[typing.Any, typing.Any]:
from samba.samba3 import param # type: ignore
from samba.samba3 import passdb # type: ignore

return param, passdb


class PassDBLoader:
def __init__(self, smbconf=None):
def __init__(self, smbconf: typing.Any = None) -> None:
param, passdb = _samba_modules()
lp = param.get_context()
if smbconf is None:
Expand Down
12 changes: 6 additions & 6 deletions sambacc/passwd_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@


class LineFileLoader(TextFileLoader):
def __init__(self, path):
def __init__(self, path: str) -> None:
super().__init__(path)
self.lines = []
self.lines: list[str] = []

def loadlines(self, lines: typing.Iterable[str]) -> None:
"""Load in the lines from the text source."""
Expand All @@ -43,9 +43,9 @@ def dumplines(self) -> typing.Iterable[str]:


class PasswdFileLoader(LineFileLoader):
def __init__(self, path="/etc/passwd"):
def __init__(self, path: str = "/etc/passwd") -> None:
super().__init__(path)
self._usernames = set()
self._usernames: set[str] = set()

def readfp(self, fp: typing.IO) -> None:
super().readfp(fp)
Expand All @@ -66,9 +66,9 @@ def add_user(self, user_entry: config.UserEntry) -> None:


class GroupFileLoader(LineFileLoader):
def __init__(self, path="/etc/group"):
def __init__(self, path: str = "/etc/group") -> None:
super().__init__(path)
self._groupnames = set()
self._groupnames: set[str] = set()

def readfp(self, fp: typing.IO) -> None:
super().readfp(fp)
Expand Down
8 changes: 4 additions & 4 deletions sambacc/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,22 @@ def path(self) -> str:
def _full_path(self) -> str:
return os.path.join(self._root, self._path.lstrip("/"))

def has_status(self):
def has_status(self) -> bool:
try:
self._get_status()
return True
except KeyError:
return False

def status_ok(self):
def status_ok(self) -> bool:
try:
sval = self._get_status()
except KeyError:
return False
curr_prefix = sval.split("/")[0]
return curr_prefix == self._prefix

def update(self):
def update(self) -> None:
if self.status_ok():
return
self._set_perms()
Expand Down Expand Up @@ -182,6 +182,6 @@ class AlwaysPosixPermsHandler(InitPosixPermsHandler):
May be useful for testing and debugging.
"""

def update(self):
def update(self) -> None:
self._set_perms()
self._set_status()
6 changes: 3 additions & 3 deletions sambacc/samba_cmds.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,19 @@ def __repr__(self) -> str:
samba_dc = SambaCommand("/usr/sbin/samba")


def smbd_foreground():
def smbd_foreground() -> SambaCommand:
return smbd[
"--foreground", _daemon_stdout_opt("smbd"), "--no-process-group"
]
Comment on lines +162 to 165
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, now that we have explicit return type, those lines look a bit enigmatic. I would consider (in a separate PR) to have more explicit code (also in other places below):

def smbd_foreground() -> SambaCommand:
    return SambaCommand("/usr/sbin/smbd",
        [ "--foreground", _daemon_stdout_opt("smbd"), "--no-process-group" ])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll think about it. On the other hand, the square-braces after a SambaCommand object gives you a new SambaCommand object with additional arguments was there from the beginning as a convenient, don't need to repeat yourself, construct.



def winbindd_foreground():
def winbindd_foreground() -> SambaCommand:
return winbindd[
"--foreground", _daemon_stdout_opt("winbindd"), "--no-process-group"
]


def samba_dc_foreground():
def samba_dc_foreground() -> SambaCommand:
return samba_dc["--foreground"]


Expand Down
Loading