Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added typehints to midimessage and init #52

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions adafruit_midi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
https://github.com/adafruit/circuitpython/releases

"""
try:
from typing import Union, Tuple, Any, List, Optional, Dict, BinaryIO
except ImportError:
pass

from .midi_message import MIDIMessage

Expand Down Expand Up @@ -54,13 +58,13 @@ class MIDI:

def __init__(
self,
midi_in=None,
midi_out=None,
midi_in: Optional[BinaryIO] = None,
midi_out: Optional[BinaryIO] = None,
*,
in_channel=None,
out_channel=0,
in_buf_size=30,
debug=False
in_channel: Optional[Union[int, Tuple[int, ...]]] = None,
out_channel: int = 0,
in_buf_size: int = 30,
debug: bool = False
):
if midi_in is None and midi_out is None:
raise ValueError("No midi_in or midi_out provided")
Expand All @@ -78,7 +82,7 @@ def __init__(
self._skipped_bytes = 0

@property
def in_channel(self):
def in_channel(self) -> Optional[Union[int, Tuple[int, ...]]]:
"""The incoming MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
``in_channel = 3`` will listen on MIDI channel 4.
Can also listen on multiple channels, e.g. ``in_channel = (0,1,2)``
Expand All @@ -87,7 +91,7 @@ def in_channel(self):
return self._in_channel

@in_channel.setter
def in_channel(self, channel):
def in_channel(self, channel: Optional[Union[str, int, Tuple[int, ...]]]) -> None:
if channel is None or channel == "ALL":
self._in_channel = tuple(range(16))
elif isinstance(channel, int) and 0 <= channel <= 15:
Expand All @@ -98,19 +102,19 @@ def in_channel(self, channel):
raise RuntimeError("Invalid input channel")

@property
def out_channel(self):
def out_channel(self) -> int:
"""The outgoing MIDI channel. Must be 0-15. Correlates to MIDI channels 1-16, e.g.
``out_channel = 3`` will send to MIDI channel 4. Default is 0 (MIDI channel 1).
"""
return self._out_channel

@out_channel.setter
def out_channel(self, channel):
def out_channel(self, channel: int) -> None:
if not 0 <= channel <= 15:
raise RuntimeError("Invalid output channel")
self._out_channel = channel

def receive(self):
def receive(self) -> Optional[MIDIMessage]:
"""Read messages from MIDI port, store them in internal read buffer, then parse that data
and return the first MIDI message (event).
This maintains the blocking characteristics of the midi_in port.
Expand Down Expand Up @@ -141,7 +145,7 @@ def receive(self):
# msg could still be None at this point, e.g. in middle of monster SysEx
return msg

def send(self, msg, channel=None):
def send(self, msg: MIDIMessage, channel: Optional[int] = None) -> None:
"""Sends a MIDI message.

:param msg: Either a MIDIMessage object or a sequence (list) of MIDIMessage objects.
Expand All @@ -165,7 +169,7 @@ def send(self, msg, channel=None):

self._send(data, len(data))

def _send(self, packet, num):
def _send(self, packet: bytes, num: int) -> None:
if self._debug:
print("Sending: ", [hex(i) for i in packet[:num]])
self._midi_out.write(packet, num)
61 changes: 41 additions & 20 deletions adafruit_midi/midi_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,19 @@
__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_MIDI.git"

try:
from typing import Union, Tuple, Any, List, Optional
except ImportError:
pass

# From C3 - A and B are above G
# Semitones A B C D E F G
NOTE_OFFSET = [21, 23, 12, 14, 16, 17, 19]


def channel_filter(channel, channel_spec):
def channel_filter(
channel: int, channel_spec: Optional[Union[int, Tuple[int, ...]]]
) -> bool:
"""
Utility function to return True iff the given channel matches channel_spec.
"""
Expand All @@ -41,13 +48,12 @@ def channel_filter(channel, channel_spec):
raise ValueError("Incorrect type for channel_spec" + str(type(channel_spec)))


def note_parser(note):
def note_parser(note: Union[int, str]) -> int:
"""If note is a string then it will be parsed and converted to a MIDI note (key) number, e.g.
"C4" will return 60, "C#4" will return 61. If note is not a string it will simply be returned.

:param note: Either 0-127 int or a str representing the note, e.g. "C#4"
"""
midi_note = note
if isinstance(note, str):
if len(note) < 2:
raise ValueError("Bad note format")
Expand All @@ -61,7 +67,8 @@ def note_parser(note):
sharpen = -1
# int may throw exception here
midi_note = int(note[1 + abs(sharpen) :]) * 12 + NOTE_OFFSET[noteidx] + sharpen

elif isinstance(note, int):
midi_note = note
return midi_note


Expand All @@ -82,40 +89,43 @@ class MIDIMessage:
This is an *abstract* class.
"""

_STATUS = None
_STATUS: Optional[int] = None
_STATUSMASK = None
LENGTH = None
LENGTH: Optional[int] = None
CHANNELMASK = 0x0F
ENDSTATUS = None

# Commonly used exceptions to save memory
@staticmethod
def _raise_valueerror_oor():
def _raise_valueerror_oor() -> None:
raise ValueError("Out of range")

# Each element is ((status, mask), class)
# order is more specific masks first
_statusandmask_to_class = []
# Add better type hints for status, mask, class referenced above
_statusandmask_to_class: List[
Tuple[Tuple[Optional[bytes], Optional[int]], "MIDIMessage"]
] = []

def __init__(self, *, channel=None):
def __init__(self, *, channel: Optional[int] = None) -> None:
self._channel = channel # dealing with pylint inadequacy
self.channel = channel

@property
def channel(self):
def channel(self) -> Optional[int]:
"""The channel number of the MIDI message where appropriate.
This is *updated* by MIDI.send() method.
"""
return self._channel

@channel.setter
def channel(self, channel):
def channel(self, channel: int) -> None:
if channel is not None and not 0 <= channel <= 15:
raise ValueError("Channel must be 0-15 or None")
self._channel = channel

@classmethod
def register_message_type(cls):
def register_message_type(cls) -> None:
"""Register a new message by its status value and mask.
This is called automagically at ``import`` time for each message.
"""
Expand All @@ -132,7 +142,14 @@ def register_message_type(cls):

# pylint: disable=too-many-arguments
@classmethod
def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endidx):
def _search_eom_status(
cls,
buf: bytearray,
eom_status: Optional[int],
msgstartidx: int,
msgendidxplusone: int,
endidx: int,
) -> Tuple[int, bool, bool]:
good_termination = False
bad_termination = False

Expand All @@ -155,7 +172,9 @@ def _search_eom_status(cls, buf, eom_status, msgstartidx, msgendidxplusone, endi
return (msgendidxplusone, good_termination, bad_termination)

@classmethod
def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):
def _match_message_status(
cls, buf: bytearray, msgstartidx: int, msgendidxplusone: int, endidx: int
) -> Tuple[Optional[Any], int, bool, bool, bool, int]:
msgclass = None
status = buf[msgstartidx]
known_msg = False
Expand Down Expand Up @@ -198,7 +217,9 @@ def _match_message_status(cls, buf, msgstartidx, msgendidxplusone, endidx):

# pylint: disable=too-many-locals,too-many-branches
@classmethod
def from_message_bytes(cls, midibytes, channel_in):
def from_message_bytes(
cls, midibytes: bytearray, channel_in: Optional[Union[int, Tuple[int, ...]]]
) -> Tuple[Optional["MIDIMessage"], int, int]:
"""Create an appropriate object of the correct class for the
first message found in some MIDI bytes filtered by channel_in.

Expand Down Expand Up @@ -270,7 +291,7 @@ def from_message_bytes(cls, midibytes, channel_in):

# A default method for constructing wire messages with no data.
# Returns an (immutable) bytes with just the status code in.
def __bytes__(self):
def __bytes__(self) -> bytes:
"""Return the ``bytes`` wire protocol representation of the object
with channel number applied where appropriate."""
return bytes([self._STATUS])
Expand All @@ -280,12 +301,12 @@ def __bytes__(self):
# Returns the new object.
# pylint: disable=unused-argument
@classmethod
def from_bytes(cls, msg_bytes):
def from_bytes(cls, msg_bytes: bytes) -> "MIDIMessage":
"""Creates an object from the byte stream of the wire protocol
representation of the MIDI message."""
return cls()

def __str__(self):
def __str__(self) -> str:
"""Print an instance"""
cls = self.__class__
if slots := getattr(cls, "_message_slots", None):
Expand Down Expand Up @@ -313,7 +334,7 @@ class MIDIUnknownEvent(MIDIMessage):
_message_slots = ["status"]
LENGTH = -1

def __init__(self, status):
def __init__(self, status: int):
self.status = status
super().__init__()

Expand All @@ -333,7 +354,7 @@ class MIDIBadEvent(MIDIMessage):

_message_slots = ["msg_bytes", "exception"]

def __init__(self, msg_bytes, exception):
def __init__(self, msg_bytes: bytearray, exception: Exception):
self.data = bytes(msg_bytes)
self.exception_text = repr(exception)
super().__init__()