Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/infrahub_sdk/ @opsmill/backend
/tests/ @opsmill/backend
uv.lock @opsmill/backend
pyproject.toml @opsmill/backend
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.10
rev: v0.11.9
hooks:
# Run the linter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async def report(
git_files_changed = await check_git_files_changed(client, branch=branch_name)

proposed_changes = await client.filters(
kind=CoreProposedChange,
kind=CoreProposedChange, # type: ignore[type-abstract]
source_branch__value=branch_name,
include=["created_by"],
prefetch_relationships=True,
Expand Down
9 changes: 6 additions & 3 deletions infrahub_sdk/graphql/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ class FutureAnnotationPlugin(Plugin):
def insert_future_annotation(module: ast.Module) -> ast.Module:
# First check if the future annotation is already present
for item in module.body:
if isinstance(item, ast.ImportFrom) and item.module == "__future__":
if any(alias.name == "annotations" for alias in item.names):
return module
if (
isinstance(item, ast.ImportFrom)
and item.module == "__future__"
and any(alias.name == "annotations" for alias in item.names)
):
return module

module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0))
return module
Expand Down
3 changes: 1 addition & 2 deletions infrahub_sdk/graphql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def render(self, convert_enum: bool = False) -> str:
convert_enum=convert_enum,
)
)
lines.append(" " * self.indentation + "}")
lines.append("}")
lines.extend((" " * self.indentation + "}", "}"))

return "\n" + "\n".join(lines) + "\n"
29 changes: 18 additions & 11 deletions infrahub_sdk/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,12 @@ def _strip_unmodified_dict(data: dict, original_data: dict, variables: dict, ite
if item in original_data and isinstance(original_data[item], dict) and isinstance(data_item, dict):
for item_key in original_data[item]:
for property_name in PROPERTIES_OBJECT:
if item_key == property_name and isinstance(original_data[item][property_name], dict):
if original_data[item][property_name].get("id"):
original_data[item][property_name] = original_data[item][property_name]["id"]
if (
item_key == property_name
and isinstance(original_data[item][property_name], dict)
and original_data[item][property_name].get("id")
):
original_data[item][property_name] = original_data[item][property_name]["id"]
if item_key in data[item]:
if item_key == "id" and len(data[item].keys()) > 1:
# Related nodes typically require an ID. So the ID is only
Expand Down Expand Up @@ -355,19 +358,23 @@ def _strip_unmodified(self, data: dict, variables: dict) -> tuple[dict, dict]:
relationship_property = getattr(self, relationship)
if not relationship_property or relationship not in data:
continue
if not relationship_property.initialized and (
not isinstance(relationship_property, RelatedNodeBase) or not relationship_property.schema.optional
):
data.pop(relationship)
elif isinstance(relationship_property, RelationshipManagerBase) and not relationship_property.has_update:
if (
not relationship_property.initialized
and (
not isinstance(relationship_property, RelatedNodeBase) or not relationship_property.schema.optional
)
) or (isinstance(relationship_property, RelationshipManagerBase) and not relationship_property.has_update):
data.pop(relationship)

for item in original_data:
if item in data:
if data[item] == original_data[item]:
if attr := getattr(self, item, None): # this should never be None, just a safety default value
if not isinstance(attr, Attribute) or not attr.value_has_been_mutated:
data.pop(item)
if (
attr := getattr(self, item, None)
) and ( # this should never be None, just a safety default value
not isinstance(attr, Attribute) or not attr.value_has_been_mutated
):
data.pop(item)
continue
if isinstance(original_data[item], dict):
self._strip_unmodified_dict(data=data, original_data=original_data, variables=variables, item=item)
Expand Down
12 changes: 5 additions & 7 deletions infrahub_sdk/node/related_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any

from ..exceptions import Error
from ..protocols_base import CoreNodeBase
Expand All @@ -11,7 +11,7 @@
if TYPE_CHECKING:
from ..client import InfrahubClient, InfrahubClientSync
from ..schema import RelationshipSchemaAPI
from .node import InfrahubNode, InfrahubNodeBase, InfrahubNodeSync
from .node import InfrahubNode, InfrahubNodeSync


class RelatedNodeBase:
Expand All @@ -34,7 +34,7 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._properties_object = PROPERTIES_OBJECT
self._properties = self._properties_flag + self._properties_object

self._peer: InfrahubNodeBase | CoreNodeBase | None = None
self._peer = None
self._id: str | None = None
self._hfid: list[str] | None = None
self._display_label: str | None = None
Expand All @@ -43,10 +43,8 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._source_typename: str | None = None
self._relationship_metadata: RelationshipMetadata | None = None

# Check for InfrahubNodeBase instances using duck-typing (_schema attribute)
# to avoid circular imports, or CoreNodeBase instances
if isinstance(data, CoreNodeBase) or hasattr(data, "_schema"):
self._peer = cast("InfrahubNodeBase | CoreNodeBase", data)
if isinstance(data, (CoreNodeBase)):
self._peer = data
for prop in self._properties:
setattr(self, prop, None)
self._relationship_metadata = None
Expand Down
4 changes: 2 additions & 2 deletions infrahub_sdk/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ async def process_nodes(self, data: dict) -> None:

await self._init_client.schema.all(branch=self.branch_name)

for kind in data:
for kind, kind_data in data.items():
if kind in self._init_client.schema.cache[self.branch_name].nodes:
for result in data[kind].get("edges", []):
for result in kind_data.get("edges", []):
node = await self.infrahub_node.from_graphql(
client=self._init_client, branch=self.branch_name, data=result
)
Expand Down
27 changes: 12 additions & 15 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ class AnyAttributeOptional(Attribute):
value: float | None


class CoreNodeBase:
@runtime_checkable
class CoreNodeBase(Protocol):
_schema: MainSchemaTypes
_internal_id: str
id: str # NOTE this is incorrect, should be str | None
Expand All @@ -188,28 +189,23 @@ def get_human_friendly_id(self) -> list[str] | None: ...

def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | None: ...

def get_kind(self) -> str:
raise NotImplementedError()
def get_kind(self) -> str: ...

def get_all_kinds(self) -> list[str]:
raise NotImplementedError()
def get_all_kinds(self) -> list[str]: ...

def get_branch(self) -> str:
raise NotImplementedError()
def get_branch(self) -> str: ...

def is_ip_prefix(self) -> bool:
raise NotImplementedError()
def is_ip_prefix(self) -> bool: ...

def is_ip_address(self) -> bool:
raise NotImplementedError()
def is_ip_address(self) -> bool: ...

def is_resource_pool(self) -> bool:
raise NotImplementedError()
def is_resource_pool(self) -> bool: ...

def get_raw_graphql_data(self) -> dict | None: ...


class CoreNode(CoreNodeBase):
@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
async def save(
self,
allow_upsert: bool = False,
Expand All @@ -233,7 +229,8 @@ async def add_relationships(self, relation_to_update: str, related_nodes: list[s
async def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...


class CoreNodeSync(CoreNodeBase):
@runtime_checkable
class CoreNodeSync(CoreNodeBase, Protocol):
def save(
self,
allow_upsert: bool = False,
Expand Down
6 changes: 2 additions & 4 deletions infrahub_sdk/pytest_plugin/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from collections.abc import Iterable
from typing import Any

Expand Down Expand Up @@ -60,11 +61,8 @@ def get_resource_config(self, group: InfrahubTestGroup) -> Any | None:
resource_config = None
if resource_config_function is not None:
func = getattr(self.session.infrahub_repo_config, resource_config_function) # type:ignore[attr-defined]
try:
with contextlib.suppress(KeyError):
resource_config = func(group.resource_name)
except KeyError:
# Ignore error and just return None
pass

return resource_config

Expand Down
5 changes: 2 additions & 3 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
ValidationError,
)
from ..graphql import Mutation
from ..protocols_base import CoreNodeBase
from ..queries import SCHEMA_HASH_SYNC_STATUS
from .main import (
AttributeSchema,
Expand Down Expand Up @@ -208,14 +207,14 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
if isinstance(schema, str):
return schema

if issubclass(schema, CoreNodeBase):
if hasattr(schema, "_is_runtime_protocol") and getattr(schema, "_is_runtime_protocol", None):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

raise ValueError("schema must be a CoreNode subclass or a string")
raise ValueError("schema must be a protocol or a string")

@staticmethod
def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]:
Expand Down
22 changes: 9 additions & 13 deletions infrahub_sdk/spec/object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from enum import Enum
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -128,12 +129,10 @@ def find_matching_relationship(
if self.peer_rel and not force:
return self.peer_rel

try:
with contextlib.suppress(ValueError):
self.peer_rel = peer_schema.get_matching_relationship(
id=self.rel_schema.identifier or "", direction=self.rel_schema.direction
)
except ValueError:
pass

return self.peer_rel

Expand All @@ -158,12 +157,10 @@ async def get_relationship_info(
peer_schema = await client.schema.get(kind=info.peer_kind, branch=branch)
info.peer_human_friendly_id = peer_schema.human_friendly_id

try:
with contextlib.suppress(ValueError):
info.peer_rel = peer_schema.get_matching_relationship(
id=rel_schema.identifier or "", direction=rel_schema.direction
)
except ValueError:
pass

if rel_schema.cardinality == "one" and isinstance(value, list):
# validate the list is composed of string
Expand Down Expand Up @@ -281,14 +278,13 @@ async def validate_object(
)
)

if key in schema.attribute_names:
if not isinstance(value, (str, int, float, bool, list, dict)):
errors.append(
ObjectValidationError(
position=position + [key],
message=f"{key} must be a string, int, float, bool, list, or dict",
)
if key in schema.attribute_names and not isinstance(value, (str, int, float, bool, list, dict)):
errors.append(
ObjectValidationError(
position=position + [key],
message=f"{key} must be a string, int, float, bool, list, or dict",
)
)

if key in schema.relationship_names:
rel_info = await get_relationship_info(
Expand Down
3 changes: 1 addition & 2 deletions infrahub_sdk/spec/range_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def _extract_constants(pattern: str, re_compiled: re.Pattern) -> tuple[list[int]
cartesian_list = []
interface_constant = [0]
for match in re_compiled.finditer(pattern):
interface_constant.append(match.start())
interface_constant.append(match.end())
interface_constant.extend((match.start(), match.end()))
cartesian_list.append(_char_range_expand(match.group()[1:-1]))
return interface_constant, cartesian_list

Expand Down
4 changes: 1 addition & 3 deletions infrahub_sdk/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ def _get_by_id(self, id: str, kind: str | None = None) -> InfrahubNode | Infrahu
def _get_by_hfid(
self, hfid: str | list[str], kind: str | None = None
) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync:
if not kind:
node_kind, node_hfid = parse_human_friendly_id(hfid)
elif kind and isinstance(hfid, str) and hfid.startswith(kind):
if not kind or (kind and isinstance(hfid, str) and hfid.startswith(kind)):
node_kind, node_hfid = parse_human_friendly_id(hfid)
else:
node_kind = kind
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/testing/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def wait_for_sync_to_complete(
) -> bool:
for _ in range(retries):
repo = await client.get(
kind=CoreGenericRepository,
kind=CoreGenericRepository, # type: ignore[type-abstract]
name__value=self.name,
branch=branch or self.initial_branch,
)
Expand Down
7 changes: 3 additions & 4 deletions infrahub_sdk/transfer/importer/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,9 @@ async def remove_and_store_optional_relationships(self) -> None:
if relationship_value.peer_ids:
self.optional_relationships_by_node[node.id][relationship_name] = relationship_value
setattr(node, relationship_name, None)
elif isinstance(relationship_value, RelatedNode):
if relationship_value.id:
self.optional_relationships_by_node[node.id][relationship_name] = relationship_value
setattr(node, relationship_name, None)
elif isinstance(relationship_value, RelatedNode) and relationship_value.id:
self.optional_relationships_by_node[node.id][relationship_name] = relationship_value
setattr(node, relationship_name, None)

async def update_optional_relationships(self) -> None:
update_batch = await self.client.create_batch(return_exceptions=True)
Expand Down
13 changes: 6 additions & 7 deletions infrahub_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def deep_merge_dict(dicta: dict, dictb: dict, path: list | None = None) -> dict:
"""
if path is None:
path = []
for key in dictb:
b_val = dictb[key]
for key, b_val in dictb.items():
if key in dicta:
a_val = dicta[key]
if isinstance(a_val, dict) and isinstance(b_val, dict):
Expand All @@ -158,11 +157,11 @@ def deep_merge_dict(dicta: dict, dictb: dict, path: list | None = None) -> dict:
else:
raise ValueError("Conflict at %s" % ".".join(path + [str(key)]))
else:
dicta[key] = dictb[key]
dicta[key] = b_val
return dicta


def str_to_bool(value: str) -> bool:
def str_to_bool(value: str | bool | int) -> bool:
"""Convert a String to a Boolean"""

if isinstance(value, bool):
Expand Down Expand Up @@ -254,9 +253,9 @@ def calculate_dict_depth(data: dict, level: int = 1) -> int:

def calculate_dict_height(data: dict, cnt: int = 0) -> int:
"""Calculate the number of fields (height) in a nested Dictionary recursively."""
for key in data:
if isinstance(data[key], dict):
cnt = calculate_dict_height(data=data[key], cnt=cnt + 1)
for value in data.values():
if isinstance(value, dict):
cnt = calculate_dict_height(data=value, cnt=cnt + 1)
else:
cnt += 1
return cnt
Expand Down
Loading