Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/_generate_requests_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import importlib.util
import inspect
import os
from typing import Any, Optional, Type
from typing import Any

from pydantic import BaseModel

Expand All @@ -23,7 +23,7 @@ def __init__(self, name: str, lineno: int, col_offset: int, original_text: str):
self.lineno = lineno
self.col_offset = col_offset
self.original_text = original_text
self.docstring: Optional[str] = None
self.docstring: str | None = None

def add_docstring(self, docstring: str) -> None:
"""Add a docstring to the class."""
Expand All @@ -49,23 +49,23 @@ def format_docstring(docstring: str, indent: str) -> str:
return '\n'.join(indented_lines)


def get_docstring_from_parent(cls: type) -> Optional[str]:
def get_docstring_from_parent(cls: type) -> str | None:
"""Get the docstring from the parent class."""
for base in cls.__bases__:
if base.__doc__:
return base.__doc__
return None


def get_field_docstring(cls: Type[BaseModel], field_name: str) -> str:
def get_field_docstring(cls: type[BaseModel], field_name: str) -> str:
"""Get the docstring of a field from the class."""
for name, obj in inspect.getmembers(cls):
if name == field_name:
return obj.__doc__ or "No docstring provided."
return "No docstring found."


def format_type(annotation: Type[Any] | None) -> str:
def format_type(annotation: type[Any] | None) -> str:
"""Format the type to show module and class name."""
if annotation is None:
raise ValueError("Annotation cannot be None")
Expand All @@ -79,7 +79,7 @@ def format_type(annotation: Type[Any] | None) -> str:
return str(annotation) # Fallback for other types


def get_pydantic_fields(cls: Type[BaseModel]) -> str:
def get_pydantic_fields(cls: type[BaseModel]) -> str:
"""Get the fields of a Pydantic model and format them as Google-style Args."""
if not issubclass(cls, BaseModel):
return ""
Expand Down Expand Up @@ -109,7 +109,7 @@ def get_pydantic_fields(cls: Type[BaseModel]) -> str:

def parse_file(file_path: str) -> list[ClassInfo]:
"""Parse the file and extract class definitions."""
with open(file_path, 'r') as file:
with open(file_path) as file:
lines = file.readlines()
tree = ast.parse(''.join(lines))

Expand Down Expand Up @@ -146,7 +146,7 @@ def update_class_docstrings(file_path: str) -> None:
full_docstring = parent_docstring + args_section
class_info.add_docstring(full_docstring)

with open(file_path, 'r', encoding="utf-8") as file:
with open(file_path, encoding="utf-8") as file:
lines = file.readlines()

updated_lines = []
Expand Down
2 changes: 1 addition & 1 deletion examples/usb/download_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def main() -> None:
file_data = await client.download_file(file_location)
end_s = time.time()
duration = end_s - start_s
speed = round(len(file_data) / ((duration)) / 1000, 2)
speed = round(len(file_data) / duration / 1000, 2)

print(f"Speed {speed} KB/s")

Expand Down
5 changes: 3 additions & 2 deletions smpclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import asyncio
import logging
import traceback
from collections.abc import AsyncIterator
from hashlib import sha256
from types import TracebackType
from typing import AsyncIterator, Final, Type
from typing import Final

from pydantic import ValidationError
from smp import header as smpheader
Expand Down Expand Up @@ -414,7 +415,7 @@ async def __aenter__(self) -> "SMPClient":

async def __aexit__(
self,
exc_type: Type[BaseException] | None,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
tb: TracebackType | None,
) -> None:
Expand Down
3 changes: 2 additions & 1 deletion smpclient/extensions/intercreate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Intercreate extensions of the `SMPClient`."""

from typing import AsyncIterator, Final
from collections.abc import AsyncIterator
from typing import Final

from smp import header as smpheader

Expand Down
8 changes: 4 additions & 4 deletions smpclient/generics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generics and Type Narrowing for SMP Requests and Responses."""

from typing import Protocol, Type, TypeVar, Union
from typing import Protocol, TypeVar, Union

from smp import error as smperror
from smp import header as smphdr
Expand Down Expand Up @@ -32,9 +32,9 @@ class ImageStatesRead(smpimg.ImageStatesReadRequest):
```
"""

_Response: Type[TRep]
_ErrorV1: Type[TEr1]
_ErrorV2: Type[TEr2]
_Response: type[TRep]
_ErrorV1: type[TEr1]
_ErrorV2: type[TEr2]

@property
def BYTES(self) -> bytes: # pragma: no cover
Expand Down
4 changes: 2 additions & 2 deletions smpclient/transport/ble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import logging
import re
import sys
from typing import Final, Protocol
from typing import Final, Protocol, TypeGuard
from uuid import UUID

from bleak import BleakClient, BleakGATTCharacteristic, BleakScanner
from bleak.args.winrt import WinRTClientArgs
from bleak.backends.client import BaseBleakClient
from bleak.backends.device import BLEDevice
from smp import header as smphdr
from typing_extensions import TypeGuard, override
from typing_extensions import override

from smpclient.exceptions import SMPClientException
from smpclient.transport import SMPTransport, SMPTransportDisconnected
Expand Down
2 changes: 1 addition & 1 deletion smpclient/transport/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def read_serial(self, delimiter: bytes | None = None) -> bytes:
try:
first_match, remaining_data = self._serial_buffer.split(delimiter, 1)
except ValueError:
return bytes()
return b''
self._serial_buffer = remaining_data
return bytes(first_match)

Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/analyze-mcuboot-img.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def main():
print('Initial image bytes:')
start = img_header.hdr_size
end = start + min(20, img_header.img_size)
print('\t' + ' '.join('{:02x}'.format(b) for b in contents[start:end]))
print('\t' + ' '.join(f'{b:02x}' for b in contents[start:end]))

tlv_info_offset = img_header.hdr_size + img_header.img_size
tlv_info = TLVInfo(*struct.unpack_from(TLV_INFO_FMT, contents, offset=tlv_info_offset))
Expand All @@ -117,11 +117,11 @@ def main():
tlv_num = 0
while tlv_off < tlv_end:
tlv_hdr = TLVHeader(*struct.unpack_from(TLV_HDR_FMT, contents, offset=tlv_off))
print('TLV {}:'.format(tlv_num), tlv_hdr)
print(f'TLV {tlv_num}:', tlv_hdr)
if tlv_hdr.len <= 32:
start = tlv_off + TLV_HDR_SIZE
end = start + tlv_hdr.len
print('\t' + ' '.join('{:02x}'.format(b) for b in contents[start:end]))
print('\t' + ' '.join(f'{b:02x}' for b in contents[start:end]))
tlv_off += TLV_HDR_SIZE + tlv_hdr.len
tlv_num += 1

Expand Down
8 changes: 3 additions & 5 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Test the `SMPRequest` `Protocol` implementations."""

from typing import Type

import pytest
from smp import enumeration_management as smpem
from smp import error as smperr
Expand Down Expand Up @@ -297,9 +295,9 @@ def test_requests(
test_tuple: tuple[
smpmsg.Request,
SMPRequest[TRep, TEr1, TEr2],
Type[smpmsg.Response],
Type[smperr.ErrorV1],
Type[smperr.ErrorV2],
type[smpmsg.Response],
type[smperr.ErrorV1],
type[smperr.ErrorV2],
],
) -> None:
a, b, Response, ErrorV1, ErrorV2 = test_tuple
Expand Down
18 changes: 0 additions & 18 deletions tests/test_smp_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tests for `SMPClient`."""

import sys
from hashlib import sha256
from pathlib import Path
from unittest.mock import AsyncMock, PropertyMock, call, patch
Expand Down Expand Up @@ -39,23 +38,6 @@
from smpclient.requests.os_management import ResetWrite
from smpclient.transport.serial import SMPSerialTransport

if sys.version_info < (3, 10):
from typing import Any

async def anext(iterator: Any, default: Any = None) -> Any:
try:
return await iterator.__anext__()
except StopAsyncIteration:
if default is None:
raise
return default

def aiter(iterable: Any) -> Any:
if hasattr(iterable, '__aiter__'):
return iterable.__aiter__()
else:
raise TypeError(f"{iterable} is not async iterable")


class SMPMockTransport:
"""Satisfies the `SMPTransport` `Protocol`."""
Expand Down
3 changes: 2 additions & 1 deletion tests/test_smp_serial_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from __future__ import annotations

import asyncio
from typing import Any, Generator
from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch

import pytest
Expand Down