diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..e929f165 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,4 @@ +/infrahub_sdk/ @opsmill/backend +/tests/ @opsmill/backend +uv.lock @opsmill/backend +pyproject.toml @opsmill/backend diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c393efd..537c7969 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.14.10 + rev: v0.11.9 hooks: # Run the linter. - id: ruff diff --git a/infrahub_sdk/ctl/branch.py b/infrahub_sdk/ctl/branch.py index d169cb91..60d67e86 100644 --- a/infrahub_sdk/ctl/branch.py +++ b/infrahub_sdk/ctl/branch.py @@ -293,7 +293,7 @@ async def report( git_files_changed = await check_git_files_changed(client, branch=branch_name) proposed_changes = await client.filters( - kind=CoreProposedChange, + kind=CoreProposedChange, # type: ignore[type-abstract] source_branch__value=branch_name, include=["created_by"], prefetch_relationships=True, diff --git a/infrahub_sdk/graphql/plugin.py b/infrahub_sdk/graphql/plugin.py index d00b32f0..fc5442e4 100644 --- a/infrahub_sdk/graphql/plugin.py +++ b/infrahub_sdk/graphql/plugin.py @@ -14,9 +14,12 @@ class FutureAnnotationPlugin(Plugin): def insert_future_annotation(module: ast.Module) -> ast.Module: # First check if the future annotation is already present for item in module.body: - if isinstance(item, ast.ImportFrom) and item.module == "__future__": - if any(alias.name == "annotations" for alias in item.names): - return module + if ( + isinstance(item, ast.ImportFrom) + and item.module == "__future__" + and any(alias.name == "annotations" for alias in item.names) + ): + return module module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0)) return module diff --git a/infrahub_sdk/graphql/query.py b/infrahub_sdk/graphql/query.py index 54078b76..36a61b2b 100644 --- a/infrahub_sdk/graphql/query.py +++ b/infrahub_sdk/graphql/query.py @@ -71,7 +71,6 @@ def render(self, convert_enum: bool = False) -> str: convert_enum=convert_enum, ) ) - lines.append(" " * self.indentation + "}") - lines.append("}") + lines.extend((" " * self.indentation + "}", "}")) return "\n" + "\n".join(lines) + "\n" diff --git a/infrahub_sdk/node/node.py b/infrahub_sdk/node/node.py index 74f68063..25d9d191 100644 --- a/infrahub_sdk/node/node.py +++ b/infrahub_sdk/node/node.py @@ -313,9 +313,12 @@ def _strip_unmodified_dict(data: dict, original_data: dict, variables: dict, ite if item in original_data and isinstance(original_data[item], dict) and isinstance(data_item, dict): for item_key in original_data[item]: for property_name in PROPERTIES_OBJECT: - if item_key == property_name and isinstance(original_data[item][property_name], dict): - if original_data[item][property_name].get("id"): - original_data[item][property_name] = original_data[item][property_name]["id"] + if ( + item_key == property_name + and isinstance(original_data[item][property_name], dict) + and original_data[item][property_name].get("id") + ): + original_data[item][property_name] = original_data[item][property_name]["id"] if item_key in data[item]: if item_key == "id" and len(data[item].keys()) > 1: # Related nodes typically require an ID. So the ID is only @@ -355,19 +358,23 @@ def _strip_unmodified(self, data: dict, variables: dict) -> tuple[dict, dict]: relationship_property = getattr(self, relationship) if not relationship_property or relationship not in data: continue - if not relationship_property.initialized and ( - not isinstance(relationship_property, RelatedNodeBase) or not relationship_property.schema.optional - ): - data.pop(relationship) - elif isinstance(relationship_property, RelationshipManagerBase) and not relationship_property.has_update: + if ( + not relationship_property.initialized + and ( + not isinstance(relationship_property, RelatedNodeBase) or not relationship_property.schema.optional + ) + ) or (isinstance(relationship_property, RelationshipManagerBase) and not relationship_property.has_update): data.pop(relationship) for item in original_data: if item in data: if data[item] == original_data[item]: - if attr := getattr(self, item, None): # this should never be None, just a safety default value - if not isinstance(attr, Attribute) or not attr.value_has_been_mutated: - data.pop(item) + if ( + attr := getattr(self, item, None) + ) and ( # this should never be None, just a safety default value + not isinstance(attr, Attribute) or not attr.value_has_been_mutated + ): + data.pop(item) continue if isinstance(original_data[item], dict): self._strip_unmodified_dict(data=data, original_data=original_data, variables=variables, item=item) diff --git a/infrahub_sdk/node/related_node.py b/infrahub_sdk/node/related_node.py index 67171f99..5b46a8f7 100644 --- a/infrahub_sdk/node/related_node.py +++ b/infrahub_sdk/node/related_node.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from ..exceptions import Error from ..protocols_base import CoreNodeBase @@ -11,7 +11,7 @@ if TYPE_CHECKING: from ..client import InfrahubClient, InfrahubClientSync from ..schema import RelationshipSchemaAPI - from .node import InfrahubNode, InfrahubNodeBase, InfrahubNodeSync + from .node import InfrahubNode, InfrahubNodeSync class RelatedNodeBase: @@ -34,7 +34,7 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict, self._properties_object = PROPERTIES_OBJECT self._properties = self._properties_flag + self._properties_object - self._peer: InfrahubNodeBase | CoreNodeBase | None = None + self._peer = None self._id: str | None = None self._hfid: list[str] | None = None self._display_label: str | None = None @@ -43,10 +43,8 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict, self._source_typename: str | None = None self._relationship_metadata: RelationshipMetadata | None = None - # Check for InfrahubNodeBase instances using duck-typing (_schema attribute) - # to avoid circular imports, or CoreNodeBase instances - if isinstance(data, CoreNodeBase) or hasattr(data, "_schema"): - self._peer = cast("InfrahubNodeBase | CoreNodeBase", data) + if isinstance(data, (CoreNodeBase)): + self._peer = data for prop in self._properties: setattr(self, prop, None) self._relationship_metadata = None diff --git a/infrahub_sdk/operation.py b/infrahub_sdk/operation.py index ed0bf19a..7983fac5 100644 --- a/infrahub_sdk/operation.py +++ b/infrahub_sdk/operation.py @@ -64,9 +64,9 @@ async def process_nodes(self, data: dict) -> None: await self._init_client.schema.all(branch=self.branch_name) - for kind in data: + for kind, kind_data in data.items(): if kind in self._init_client.schema.cache[self.branch_name].nodes: - for result in data[kind].get("edges", []): + for result in kind_data.get("edges", []): node = await self.infrahub_node.from_graphql( client=self._init_client, branch=self.branch_name, data=result ) diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py index ba920552..8a841b5b 100644 --- a/infrahub_sdk/protocols_base.py +++ b/infrahub_sdk/protocols_base.py @@ -171,7 +171,8 @@ class AnyAttributeOptional(Attribute): value: float | None -class CoreNodeBase: +@runtime_checkable +class CoreNodeBase(Protocol): _schema: MainSchemaTypes _internal_id: str id: str # NOTE this is incorrect, should be str | None @@ -188,28 +189,23 @@ def get_human_friendly_id(self) -> list[str] | None: ... def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | None: ... - def get_kind(self) -> str: - raise NotImplementedError() + def get_kind(self) -> str: ... - def get_all_kinds(self) -> list[str]: - raise NotImplementedError() + def get_all_kinds(self) -> list[str]: ... - def get_branch(self) -> str: - raise NotImplementedError() + def get_branch(self) -> str: ... - def is_ip_prefix(self) -> bool: - raise NotImplementedError() + def is_ip_prefix(self) -> bool: ... - def is_ip_address(self) -> bool: - raise NotImplementedError() + def is_ip_address(self) -> bool: ... - def is_resource_pool(self) -> bool: - raise NotImplementedError() + def is_resource_pool(self) -> bool: ... def get_raw_graphql_data(self) -> dict | None: ... -class CoreNode(CoreNodeBase): +@runtime_checkable +class CoreNode(CoreNodeBase, Protocol): async def save( self, allow_upsert: bool = False, @@ -233,7 +229,8 @@ async def add_relationships(self, relation_to_update: str, related_nodes: list[s async def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... -class CoreNodeSync(CoreNodeBase): +@runtime_checkable +class CoreNodeSync(CoreNodeBase, Protocol): def save( self, allow_upsert: bool = False, diff --git a/infrahub_sdk/pytest_plugin/loader.py b/infrahub_sdk/pytest_plugin/loader.py index 33040293..590a4ed7 100644 --- a/infrahub_sdk/pytest_plugin/loader.py +++ b/infrahub_sdk/pytest_plugin/loader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from collections.abc import Iterable from typing import Any @@ -60,11 +61,8 @@ def get_resource_config(self, group: InfrahubTestGroup) -> Any | None: resource_config = None if resource_config_function is not None: func = getattr(self.session.infrahub_repo_config, resource_config_function) # type:ignore[attr-defined] - try: + with contextlib.suppress(KeyError): resource_config = func(group.resource_name) - except KeyError: - # Ignore error and just return None - pass return resource_config diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index febb204b..3e61ad2a 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -21,7 +21,6 @@ ValidationError, ) from ..graphql import Mutation -from ..protocols_base import CoreNodeBase from ..queries import SCHEMA_HASH_SYNC_STATUS from .main import ( AttributeSchema, @@ -208,14 +207,14 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str: if isinstance(schema, str): return schema - if issubclass(schema, CoreNodeBase): + if hasattr(schema, "_is_runtime_protocol") and getattr(schema, "_is_runtime_protocol", None): if inspect.iscoroutinefunction(schema.save): return schema.__name__ if schema.__name__[-4:] == "Sync": return schema.__name__[:-4] return schema.__name__ - raise ValueError("schema must be a CoreNode subclass or a string") + raise ValueError("schema must be a protocol or a string") @staticmethod def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]: diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index 456c9f1e..e21c8e12 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from enum import Enum from typing import TYPE_CHECKING, Any @@ -128,12 +129,10 @@ def find_matching_relationship( if self.peer_rel and not force: return self.peer_rel - try: + with contextlib.suppress(ValueError): self.peer_rel = peer_schema.get_matching_relationship( id=self.rel_schema.identifier or "", direction=self.rel_schema.direction ) - except ValueError: - pass return self.peer_rel @@ -158,12 +157,10 @@ async def get_relationship_info( peer_schema = await client.schema.get(kind=info.peer_kind, branch=branch) info.peer_human_friendly_id = peer_schema.human_friendly_id - try: + with contextlib.suppress(ValueError): info.peer_rel = peer_schema.get_matching_relationship( id=rel_schema.identifier or "", direction=rel_schema.direction ) - except ValueError: - pass if rel_schema.cardinality == "one" and isinstance(value, list): # validate the list is composed of string @@ -281,14 +278,13 @@ async def validate_object( ) ) - if key in schema.attribute_names: - if not isinstance(value, (str, int, float, bool, list, dict)): - errors.append( - ObjectValidationError( - position=position + [key], - message=f"{key} must be a string, int, float, bool, list, or dict", - ) + if key in schema.attribute_names and not isinstance(value, (str, int, float, bool, list, dict)): + errors.append( + ObjectValidationError( + position=position + [key], + message=f"{key} must be a string, int, float, bool, list, or dict", ) + ) if key in schema.relationship_names: rel_info = await get_relationship_info( diff --git a/infrahub_sdk/spec/range_expansion.py b/infrahub_sdk/spec/range_expansion.py index 99638b06..5a522f0d 100644 --- a/infrahub_sdk/spec/range_expansion.py +++ b/infrahub_sdk/spec/range_expansion.py @@ -60,8 +60,7 @@ def _extract_constants(pattern: str, re_compiled: re.Pattern) -> tuple[list[int] cartesian_list = [] interface_constant = [0] for match in re_compiled.finditer(pattern): - interface_constant.append(match.start()) - interface_constant.append(match.end()) + interface_constant.extend((match.start(), match.end())) cartesian_list.append(_char_range_expand(match.group()[1:-1])) return interface_constant, cartesian_list diff --git a/infrahub_sdk/store.py b/infrahub_sdk/store.py index 6420495b..479badd1 100644 --- a/infrahub_sdk/store.py +++ b/infrahub_sdk/store.py @@ -165,9 +165,7 @@ def _get_by_id(self, id: str, kind: str | None = None) -> InfrahubNode | Infrahu def _get_by_hfid( self, hfid: str | list[str], kind: str | None = None ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync: - if not kind: - node_kind, node_hfid = parse_human_friendly_id(hfid) - elif kind and isinstance(hfid, str) and hfid.startswith(kind): + if not kind or (kind and isinstance(hfid, str) and hfid.startswith(kind)): node_kind, node_hfid = parse_human_friendly_id(hfid) else: node_kind = kind diff --git a/infrahub_sdk/testing/repository.py b/infrahub_sdk/testing/repository.py index d07d2b1a..9e974164 100644 --- a/infrahub_sdk/testing/repository.py +++ b/infrahub_sdk/testing/repository.py @@ -98,7 +98,7 @@ async def wait_for_sync_to_complete( ) -> bool: for _ in range(retries): repo = await client.get( - kind=CoreGenericRepository, + kind=CoreGenericRepository, # type: ignore[type-abstract] name__value=self.name, branch=branch or self.initial_branch, ) diff --git a/infrahub_sdk/transfer/importer/json.py b/infrahub_sdk/transfer/importer/json.py index d9e6ad13..9c0b7ab9 100644 --- a/infrahub_sdk/transfer/importer/json.py +++ b/infrahub_sdk/transfer/importer/json.py @@ -115,10 +115,9 @@ async def remove_and_store_optional_relationships(self) -> None: if relationship_value.peer_ids: self.optional_relationships_by_node[node.id][relationship_name] = relationship_value setattr(node, relationship_name, None) - elif isinstance(relationship_value, RelatedNode): - if relationship_value.id: - self.optional_relationships_by_node[node.id][relationship_name] = relationship_value - setattr(node, relationship_name, None) + elif isinstance(relationship_value, RelatedNode) and relationship_value.id: + self.optional_relationships_by_node[node.id][relationship_name] = relationship_value + setattr(node, relationship_name, None) async def update_optional_relationships(self) -> None: update_batch = await self.client.create_batch(return_exceptions=True) diff --git a/infrahub_sdk/utils.py b/infrahub_sdk/utils.py index 8556b012..f62513d3 100644 --- a/infrahub_sdk/utils.py +++ b/infrahub_sdk/utils.py @@ -141,8 +141,7 @@ def deep_merge_dict(dicta: dict, dictb: dict, path: list | None = None) -> dict: """ if path is None: path = [] - for key in dictb: - b_val = dictb[key] + for key, b_val in dictb.items(): if key in dicta: a_val = dicta[key] if isinstance(a_val, dict) and isinstance(b_val, dict): @@ -158,11 +157,11 @@ def deep_merge_dict(dicta: dict, dictb: dict, path: list | None = None) -> dict: else: raise ValueError("Conflict at %s" % ".".join(path + [str(key)])) else: - dicta[key] = dictb[key] + dicta[key] = b_val return dicta -def str_to_bool(value: str) -> bool: +def str_to_bool(value: str | bool | int) -> bool: """Convert a String to a Boolean""" if isinstance(value, bool): @@ -254,9 +253,9 @@ def calculate_dict_depth(data: dict, level: int = 1) -> int: def calculate_dict_height(data: dict, cnt: int = 0) -> int: """Calculate the number of fields (height) in a nested Dictionary recursively.""" - for key in data: - if isinstance(data[key], dict): - cnt = calculate_dict_height(data=data[key], cnt=cnt + 1) + for value in data.values(): + if isinstance(value, dict): + cnt = calculate_dict_height(data=value, cnt=cnt + 1) else: cnt += 1 return cnt diff --git a/pyproject.toml b/pyproject.toml index 854fc39d..78a7a943 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -240,13 +240,11 @@ ignore = [ "B008", # Do not perform function call `typer.Option` in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` to distinguish them from errors in exception handling "FURB110", # Replace ternary `if` expression with `or` operator - "FURB113", # Use `lines.extend((" " * self.indentation + "}", "}"))` instead of repeatedly calling `lines.append()` "INP001", # File declares a package, but is nested under an implicit namespace package. "N802", # Function name should be lowercase "N806", # Variable in function should be lowercase "PERF203", # `try`-`except` within a loop incurs performance overhead "PERF401", # Use a list comprehension to create a transformed list - "PLC0206", # Extracting value from dictionary without calling `.items()` "PLR0912", # Too many branches "PLR0913", # Too many arguments in function definition "PLR0917", # Too many positional arguments @@ -258,11 +256,8 @@ ignore = [ "RUF029", # Function is declared `async`, but doesn't `await` or use `async` features. "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "S701", # By default, jinja2 sets `autoescape` to `False`. Consider using `autoescape=True` - "SIM102", # Use a single `if` statement instead of nested `if` statements - "SIM105", # Use `contextlib.suppress(KeyError)` instead of `try`-`except`-`pass` "SIM108", # Use ternary operator `key_str = f"{value[ALIAS_KEY]}: {key}" if ALIAS_KEY in value and value[ALIAS_KEY] else key` instead of `if`-`else`-block "SIM110", # Use `return any(getattr(item, resource_field) == resource_id for item in getattr(self, RESOURCE_MAP[resource_type]))` instead of `for` loop - "SIM114", # Combine `if` branches using logical `or` operator "TC003", # Move standard library import `collections.abc.Iterable` into a type-checking block "UP031", # Use format specifiers instead of percent format ] diff --git a/tests/integration/test_repository.py b/tests/integration/test_repository.py index 942e3bc4..72300ae1 100644 --- a/tests/integration/test_repository.py +++ b/tests/integration/test_repository.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING +from dulwich.objects import Commit + from infrahub_sdk.testing.docker import TestInfrahubDockerClient from infrahub_sdk.testing.repository import GitRepo from infrahub_sdk.utils import get_fixtures_dir @@ -17,6 +19,7 @@ async def test_add_repository(self, client: InfrahubClient, remote_repos_dir: Pa src_directory = get_fixtures_dir() / "integration/mock_repo" repo = GitRepo(name="mock_repo", src_directory=src_directory, dst_directory=remote_repos_dir) commit = repo._repo.git[repo._repo.git.head()] + assert isinstance(commit, Commit) assert len(list(repo._repo.git.get_walker())) == 1 assert commit.message.decode("utf-8") == "First commit" diff --git a/tests/unit/ctl/test_graphql_app.py b/tests/unit/ctl/test_graphql_app.py index 285507ae..aca9ee95 100644 --- a/tests/unit/ctl/test_graphql_app.py +++ b/tests/unit/ctl/test_graphql_app.py @@ -6,6 +6,7 @@ import pytest from ariadne_codegen.schema import get_graphql_schema_from_path +from graphql import OperationDefinitionNode from typer.testing import CliRunner from infrahub_sdk.ctl.graphql import app, find_gql_files, get_graphql_query @@ -76,7 +77,10 @@ def test_get_graphql_query_valid(self) -> None: definitions = get_graphql_query(query_file, schema) assert len(definitions) == 1 - assert definitions[0].name.value == "GetTags" + definition = definitions[0] + assert isinstance(definition, OperationDefinitionNode) + assert definition.name is not None + assert definition.name.value == "GetTags" def test_get_graphql_query_invalid(self) -> None: """Test that invalid query raises ValueError.""" diff --git a/tests/unit/sdk/test_repository.py b/tests/unit/sdk/test_repository.py index a3c7f6eb..756b2e36 100644 --- a/tests/unit/sdk/test_repository.py +++ b/tests/unit/sdk/test_repository.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from dulwich.objects import Commit from dulwich.repo import Repo from infrahub_sdk.repository import GitRepoManager @@ -65,4 +66,5 @@ def test_gitrepo_init(temp_dir: str) -> None: repo = GitRepo(name="mock_repo", src_directory=src_directory, dst_directory=Path(temp_dir)) assert len(list(repo._repo.git.get_walker())) == 1 commit = repo._repo.git[repo._repo.git.head()] + assert isinstance(commit, Commit) assert commit.message.decode("utf-8") == "First commit" diff --git a/tests/unit/sdk/test_utils.py b/tests/unit/sdk/test_utils.py index eae23150..97ea6d27 100644 --- a/tests/unit/sdk/test_utils.py +++ b/tests/unit/sdk/test_utils.py @@ -7,7 +7,7 @@ from unittest.mock import Mock import pytest -from graphql import parse +from graphql import OperationDefinitionNode, parse from whenever import Instant from infrahub_sdk.exceptions import JsonDecodeError @@ -16,6 +16,7 @@ base16encode, base36decode, base36encode, + calculate_dict_height, calculate_time_diff, compare_lists, decode_json, @@ -119,6 +120,17 @@ def test_deep_merge_dict() -> None: assert deep_merge_dict(f, g) == {"keyA": "foo", "keyB": "bar"} +def test_calculate_dict_height() -> None: + assert calculate_dict_height({}) == 0 + assert calculate_dict_height({"a": 1}) == 1 + assert calculate_dict_height({"a": 1, "b": 2}) == 2 + assert calculate_dict_height({"a": 1, "b": 2, "c": 3}) == 3 + assert calculate_dict_height({"a": {"b": 1}}) == 2 + assert calculate_dict_height({"a": {"b": 1, "c": 2}}) == 3 + assert calculate_dict_height({"a": {"b": {"c": 1}}}) == 3 + assert calculate_dict_height({"a": 1, "b": {"c": 2, "d": {"e": 3}}}) == 5 + + def test_str_to_bool() -> None: assert str_to_bool(True) is True assert str_to_bool(False) is False @@ -179,7 +191,9 @@ async def test_extract_fields(query_01: str) -> None: }, }, } - assert await extract_fields(document.definitions[0].selection_set) == expected_response + definition = document.definitions[0] + assert isinstance(definition, OperationDefinitionNode) + assert await extract_fields(definition.selection_set) == expected_response async def test_extract_fields_fragment(query_02: str) -> None: @@ -207,7 +221,9 @@ async def test_extract_fields_fragment(query_02: str) -> None: }, } - assert await extract_fields(document.definitions[0].selection_set) == expected_response + definition = document.definitions[0] + assert isinstance(definition, OperationDefinitionNode) + assert await extract_fields(definition.selection_set) == expected_response def test_write_to_file() -> None: