Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.9
rev: v0.14.10
hooks:
# Run the linter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async def report(
git_files_changed = await check_git_files_changed(client, branch=branch_name)

proposed_changes = await client.filters(
kind=CoreProposedChange, # type: ignore[type-abstract]
kind=CoreProposedChange,
source_branch__value=branch_name,
include=["created_by"],
prefetch_relationships=True,
Expand Down
12 changes: 7 additions & 5 deletions infrahub_sdk/node/related_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from ..exceptions import Error
from ..protocols_base import CoreNodeBase
Expand All @@ -11,7 +11,7 @@
if TYPE_CHECKING:
from ..client import InfrahubClient, InfrahubClientSync
from ..schema import RelationshipSchemaAPI
from .node import InfrahubNode, InfrahubNodeSync
from .node import InfrahubNode, InfrahubNodeBase, InfrahubNodeSync


class RelatedNodeBase:
Expand All @@ -34,7 +34,7 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._properties_object = PROPERTIES_OBJECT
self._properties = self._properties_flag + self._properties_object

self._peer = None
self._peer: InfrahubNodeBase | CoreNodeBase | None = None
self._id: str | None = None
self._hfid: list[str] | None = None
self._display_label: str | None = None
Expand All @@ -43,8 +43,10 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._source_typename: str | None = None
self._relationship_metadata: RelationshipMetadata | None = None

if isinstance(data, (CoreNodeBase)):
self._peer = data
# Check for InfrahubNodeBase instances using duck-typing (_schema attribute)
# to avoid circular imports, or CoreNodeBase instances
if isinstance(data, CoreNodeBase) or hasattr(data, "_schema"):
self._peer = cast("InfrahubNodeBase | CoreNodeBase", data)
for prop in self._properties:
setattr(self, prop, None)
self._relationship_metadata = None
Expand Down
27 changes: 15 additions & 12 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ class AnyAttributeOptional(Attribute):
value: float | None


@runtime_checkable
class CoreNodeBase(Protocol):
class CoreNodeBase:
_schema: MainSchemaTypes
_internal_id: str
id: str # NOTE this is incorrect, should be str | None
Expand All @@ -189,23 +188,28 @@ def get_human_friendly_id(self) -> list[str] | None: ...

def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | None: ...

def get_kind(self) -> str: ...
def get_kind(self) -> str:
raise NotImplementedError()

def get_all_kinds(self) -> list[str]: ...
def get_all_kinds(self) -> list[str]:
raise NotImplementedError()

def get_branch(self) -> str: ...
def get_branch(self) -> str:
raise NotImplementedError()

def is_ip_prefix(self) -> bool: ...
def is_ip_prefix(self) -> bool:
raise NotImplementedError()

def is_ip_address(self) -> bool: ...
def is_ip_address(self) -> bool:
raise NotImplementedError()

def is_resource_pool(self) -> bool: ...
def is_resource_pool(self) -> bool:
raise NotImplementedError()

def get_raw_graphql_data(self) -> dict | None: ...


@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
class CoreNode(CoreNodeBase):
async def save(
self,
allow_upsert: bool = False,
Expand All @@ -229,8 +233,7 @@ async def add_relationships(self, relation_to_update: str, related_nodes: list[s
async def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...


@runtime_checkable
class CoreNodeSync(CoreNodeBase, Protocol):
class CoreNodeSync(CoreNodeBase):
def save(
self,
allow_upsert: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ValidationError,
)
from ..graphql import Mutation
from ..protocols_base import CoreNodeBase
from ..queries import SCHEMA_HASH_SYNC_STATUS
from .main import (
AttributeSchema,
Expand Down Expand Up @@ -207,14 +208,14 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
if isinstance(schema, str):
return schema

if hasattr(schema, "_is_runtime_protocol") and getattr(schema, "_is_runtime_protocol", None):
if issubclass(schema, CoreNodeBase):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

raise ValueError("schema must be a protocol or a string")
raise ValueError("schema must be a CoreNode subclass or a string")

@staticmethod
def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]:
Expand Down
14 changes: 12 additions & 2 deletions infrahub_sdk/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import inspect
import warnings
from typing import TYPE_CHECKING, Literal, overload

from infrahub_sdk.protocols_base import CoreNodeBase

from .exceptions import NodeInvalidError, NodeNotFoundError
from .node.parsers import parse_human_friendly_id

Expand All @@ -16,8 +19,15 @@ def get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str | None = Non
if isinstance(schema, str):
return schema

if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr]
return schema.__name__ # type: ignore[union-attr]
if schema is None:
return None

if issubclass(schema, CoreNodeBase):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

return None

Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/testing/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def wait_for_sync_to_complete(
) -> bool:
for _ in range(retries):
repo = await client.get(
kind=CoreGenericRepository, # type: ignore[type-abstract]
kind=CoreGenericRepository,
name__value=self.name,
branch=branch or self.initial_branch,
)
Expand Down
8 changes: 7 additions & 1 deletion tests/unit/sdk/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from infrahub_sdk.exceptions import NodeInvalidError, NodeNotFoundError
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
from infrahub_sdk.store import NodeStore, NodeStoreSync
from infrahub_sdk.protocols import BuiltinIPAddressSync, BuiltinIPPrefix
from infrahub_sdk.store import NodeStore, NodeStoreSync, get_schema_name

if TYPE_CHECKING:
from infrahub_sdk.schema import NodeSchemaAPI
Expand Down Expand Up @@ -157,3 +158,8 @@ def test_node_store_get_with_hfid(
store.get(kind="BuiltinLocation", key="anotherkey")
with pytest.raises(NodeNotFoundError):
store.get(key="anotherkey")


def test_store_get_schema_name() -> None:
assert get_schema_name(schema=BuiltinIPPrefix) == BuiltinIPPrefix.__name__
assert get_schema_name(schema=BuiltinIPAddressSync) == BuiltinIPAddressSync.__name__[:-4]