Skip to content

Commit

Permalink
chore: move more things around in version 1 api (#198)
Browse files Browse the repository at this point in the history
* chore: move more things around in version 1 api

* fix: tests
  • Loading branch information
Lash-L authored Apr 8, 2024
1 parent 4f44d03 commit 30d2577
Show file tree
Hide file tree
Showing 11 changed files with 443 additions and 299 deletions.
236 changes: 3 additions & 233 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,125 +4,35 @@

import asyncio
import base64
import dataclasses
import hashlib
import json
import logging
import secrets
import struct
import time
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar, final
from typing import Any

from .command_cache import CacheableAttribute, CommandType, RoborockAttribute, find_cacheable_attribute, get_cache_map
from .containers import (
Consumable,
DeviceData,
ModelStatus,
RoborockBase,
S7MaxVStatus,
Status,
)
from .exceptions import (
RoborockException,
RoborockTimeout,
UnknownMethodError,
VacuumError,
)
from .protocol import Utils
from .roborock_future import RoborockFuture
from .roborock_message import (
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
ROBOROCK_DATA_STATUS_PROTOCOL,
RoborockDataProtocol,
RoborockMessage,
RoborockMessageProtocol,
)
from .roborock_typing import RoborockCommand
from .util import RepeatableTask, RoborockLoggerAdapter, get_running_loop_or_create_one
from .util import RoborockLoggerAdapter, get_running_loop_or_create_one

_LOGGER = logging.getLogger(__name__)
KEEPALIVE = 60
RT = TypeVar("RT", bound=RoborockBase)


def md5hex(message: str) -> str:
md5 = hashlib.md5()
md5.update(message.encode())
return md5.hexdigest()


EVICT_TIME = 60


class AttributeCache:
def __init__(self, attribute: RoborockAttribute, api: RoborockClient):
self.attribute = attribute
self.api = api
self.attribute = attribute
self.task = RepeatableTask(self.api.event_loop, self._async_value, EVICT_TIME)
self._value: Any = None
self._mutex = asyncio.Lock()
self.unsupported: bool = False

@property
def value(self):
return self._value

async def _async_value(self):
if self.unsupported:
return None
try:
self._value = await self.api._send_command(self.attribute.get_command)
except UnknownMethodError as err:
# Limit the amount of times we call unsupported methods
self.unsupported = True
raise err
return self._value

async def async_value(self):
async with self._mutex:
if self._value is None:
return await self.task.reset()
return self._value

def stop(self):
self.task.cancel()

async def update_value(self, params):
if self.attribute.set_command is None:
raise RoborockException(f"{self.attribute.attribute} have no set command")
response = await self.api._send_command(self.attribute.set_command, params)
await self._async_value()
return response

async def add_value(self, params):
if self.attribute.add_command is None:
raise RoborockException(f"{self.attribute.attribute} have no add command")
response = await self.api._send_command(self.attribute.add_command, params)
await self._async_value()
return response

async def close_value(self, params=None):
if self.attribute.close_command is None:
raise RoborockException(f"{self.attribute.attribute} have no close command")
response = await self.api._send_command(self.attribute.close_command, params)
await self._async_value()
return response

async def refresh_value(self):
await self._async_value()


@dataclasses.dataclass
class ListenerModel:
protocol_handlers: dict[RoborockDataProtocol, list[Callable[[Status | Consumable], None]]]
cache: dict[CacheableAttribute, AttributeCache]


class RoborockClient:
_listeners: dict[str, ListenerModel] = {}

def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int = 4) -> None:
self.event_loop = get_running_loop_or_create_one()
self.device_info = device_info
Expand All @@ -136,15 +46,9 @@ def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int =
"misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")}
}
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
self.cache: dict[CacheableAttribute, AttributeCache] = {
cacheable_attribute: AttributeCache(attr, self) for cacheable_attribute, attr in get_cache_map().items()
}
self.is_available: bool = True
self.queue_timeout = queue_timeout
self._status_type: type[Status] = ModelStatus.get(self.device_info.model, S7MaxVStatus)
if device_info.device.duid not in self._listeners:
self._listeners[device_info.device.duid] = ListenerModel({}, self.cache)
self.listener_model = self._listeners[device_info.device.duid]

def __del__(self) -> None:
self.release()
Expand All @@ -156,11 +60,9 @@ def status_type(self) -> type[Status]:

def release(self):
self.sync_disconnect()
[item.stop() for item in self.cache.values()]

async def async_release(self):
await self.async_disconnect()
[item.stop() for item in self.cache.values()]

@property
def diagnostic_data(self) -> dict:
Expand All @@ -185,95 +87,7 @@ async def async_disconnect(self) -> Any:
raise NotImplementedError

def on_message_received(self, messages: list[RoborockMessage]) -> None:
try:
self._last_device_msg_in = self.time_func()
for data in messages:
protocol = data.protocol
if data.payload and protocol in [
RoborockMessageProtocol.RPC_RESPONSE,
RoborockMessageProtocol.GENERAL_REQUEST,
]:
payload = json.loads(data.payload.decode())
for data_point_number, data_point in payload.get("dps").items():
if data_point_number == "102":
data_point_response = json.loads(data_point)
request_id = data_point_response.get("id")
queue = self._waiting_queue.get(request_id)
if queue and queue.protocol == protocol:
error = data_point_response.get("error")
if error:
queue.resolve(
(
None,
VacuumError(
error.get("code"),
error.get("message"),
),
)
)
else:
result = data_point_response.get("result")
if isinstance(result, list) and len(result) == 1:
result = result[0]
queue.resolve((result, None))
else:
try:
data_protocol = RoborockDataProtocol(int(data_point_number))
self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}")
if data_protocol in ROBOROCK_DATA_STATUS_PROTOCOL:
if data_protocol not in self.listener_model.protocol_handlers:
self._logger.debug(
f"Got status update({data_protocol.name}) before get_status was called."
)
return
value = self.listener_model.cache[CacheableAttribute.status].value
value[data_protocol.name] = data_point
status = self._status_type.from_dict(value)
for listener in self.listener_model.protocol_handlers.get(data_protocol, []):
listener(status)
elif data_protocol in ROBOROCK_DATA_CONSUMABLE_PROTOCOL:
if data_protocol not in self.listener_model.protocol_handlers:
self._logger.debug(
f"Got consumable update({data_protocol.name})"
+ "before get_consumable was called."
)
return
value = self.listener_model.cache[CacheableAttribute.consumable].value
value[data_protocol.name] = data_point
consumable = Consumable.from_dict(value)
for listener in self.listener_model.protocol_handlers.get(data_protocol, []):
listener(consumable)
return
except ValueError:
self._logger.warning(
f"Got listener data for {data_point_number}, data: {data_point}. "
f"This lets us update data quicker, please open an issue "
f"at https://github.com/humbertogontijo/python-roborock/issues"
)

pass
dps = {data_point_number: data_point}
self._logger.debug(f"Got unknown data point {dps}")
elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE:
payload = data.payload[0:24]
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload)
if endpoint.decode().startswith(self._endpoint):
try:
decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce)
except ValueError as err:
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
decompressed = Utils.decompress(decrypted)
queue = self._waiting_queue.get(request_id)
if queue:
if isinstance(decompressed, list):
decompressed = decompressed[0]
queue.resolve((decompressed, None))
else:
queue = self._waiting_queue.get(data.seq)
if queue:
queue.resolve((data.payload, None))
except Exception as ex:
self._logger.exception(ex)
raise NotImplementedError

def on_connection_lost(self, exc: Exception | None) -> None:
self._last_disconnection = self.time_func()
Expand Down Expand Up @@ -320,47 +134,3 @@ async def _send_command(
params: list | dict | int | None = None,
):
raise NotImplementedError

@final
async def send_command(
self,
method: RoborockCommand | str,
params: list | dict | int | None = None,
return_type: type[RT] | None = None,
) -> RT:
cacheable_attribute_result = find_cacheable_attribute(method)

cache = None
command_type = None
if cacheable_attribute_result is not None:
cache = self.cache[cacheable_attribute_result.attribute]
command_type = cacheable_attribute_result.type

response: Any = None
if cache is not None and command_type == CommandType.GET:
response = await cache.async_value()
else:
response = await self._send_command(method, params)
if cache is not None and command_type == CommandType.CHANGE:
await cache.refresh_value()

if return_type:
return return_type.from_dict(response)
return response

def add_listener(
self, protocol: RoborockDataProtocol, listener: Callable, cache: dict[CacheableAttribute, AttributeCache]
) -> None:
self.listener_model.cache = cache
if protocol not in self.listener_model.protocol_handlers:
self.listener_model.protocol_handlers[protocol] = []
self.listener_model.protocol_handlers[protocol].append(listener)

def remove_listener(self, protocol: RoborockDataProtocol, listener: Callable) -> None:
self.listener_model.protocol_handlers[protocol].remove(listener)

async def get_from_cache(self, key: CacheableAttribute) -> AttributeCache | None:
val = self.cache.get(key)
if val is not None:
return await val.async_value()
return None
9 changes: 6 additions & 3 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
import base64
import logging
import threading
import typing
import uuid
from asyncio import Lock, Task
from typing import Any
from urllib.parse import urlparse

import paho.mqtt.client as mqtt

from .api import KEEPALIVE, RoborockClient, md5hex
from .api import KEEPALIVE, RoborockClient
from .containers import DeviceData, UserData
from .exceptions import RoborockException, VacuumError
from .protocol import MessageParser, Utils
from .protocol import MessageParser, Utils, md5hex
from .roborock_future import RoborockFuture
from .roborock_message import RoborockMessage
from .roborock_typing import RoborockCommand
from .util import RoborockLoggerAdapter

if typing.TYPE_CHECKING:
pass
_LOGGER = logging.getLogger(__name__)
CONNECT_REQUEST_ID = 0
DISCONNECT_REQUEST_ID = 1
Expand Down Expand Up @@ -78,7 +81,7 @@ def on_connect(self, *args, **kwargs):
connection_queue.resolve((True, None))

def on_message(self, *args, **kwargs):
_, __, msg = args
client, __, msg = args
try:
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
super().on_message_received(messages)
Expand Down
Loading

0 comments on commit 30d2577

Please sign in to comment.