From b1ef381f68f11e5148fb31d96a8028ce9d617f18 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 10 Oct 2025 17:19:25 -0700 Subject: [PATCH 01/16] wip --- python/lib/sift_client/client.py | 11 +++- python/lib/sift_client/transport/__init__.py | 3 +- .../sift_client/transport/base_connection.py | 6 +- .../sift_client/transport/grpc_transport.py | 60 ++++++++++++++++++- python/lib/sift_py/grpc/transport.py | 58 ++++++++++++++++-- python/pyproject.toml | 1 + 6 files changed, 127 insertions(+), 12 deletions(-) diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index 2a2252ef8..b6c94a6fc 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -23,6 +23,7 @@ TestResultsAPIAsync, ) from sift_client.transport import ( + CacheConfig, GrpcClient, GrpcConfig, RestClient, @@ -109,6 +110,7 @@ def __init__( grpc_url: str | None = None, rest_url: str | None = None, connection_config: SiftConnectionConfig | None = None, + cache_config: CacheConfig | None = None, ): """Initialize the SiftClient with specific connection parameters or a connection_config. @@ -117,6 +119,7 @@ def __init__( grpc_url: The Sift gRPC API URL. rest_url: The Sift REST API URL. connection_config: A SiftConnectionConfig object to configure the connection behavior of the SiftClient. + cache_config: Optional cache configuration for gRPC responses. If provided, enables response caching. """ if not (api_key and grpc_url and rest_url) and not connection_config: raise ValueError( @@ -124,10 +127,14 @@ def __init__( ) if connection_config: - grpc_client = GrpcClient(connection_config.get_grpc_config()) + grpc_config = connection_config.get_grpc_config() + # Override cache_config if provided directly to SiftClient + if cache_config is not None: + grpc_config.cache_config = cache_config + grpc_client = GrpcClient(grpc_config) rest_client = RestClient(connection_config.get_rest_config()) elif api_key and grpc_url and rest_url: - grpc_client = GrpcClient(GrpcConfig(grpc_url, api_key)) + grpc_client = GrpcClient(GrpcConfig(grpc_url, api_key, cache_config=cache_config)) rest_client = RestClient(RestConfig(rest_url, api_key)) else: raise ValueError( diff --git a/python/lib/sift_client/transport/__init__.py b/python/lib/sift_client/transport/__init__.py index 249d9bc7e..280f19e2c 100644 --- a/python/lib/sift_client/transport/__init__.py +++ b/python/lib/sift_client/transport/__init__.py @@ -3,10 +3,11 @@ WithGrpcClient, WithRestClient, ) -from sift_client.transport.grpc_transport import GrpcClient, GrpcConfig +from sift_client.transport.grpc_transport import CacheConfig, GrpcClient, GrpcConfig from sift_client.transport.rest_transport import RestClient, RestConfig __all__ = [ + "CacheConfig", "GrpcClient", "GrpcConfig", "RestClient", diff --git a/python/lib/sift_client/transport/base_connection.py b/python/lib/sift_client/transport/base_connection.py index 02f0e096e..2b94fa52b 100644 --- a/python/lib/sift_client/transport/base_connection.py +++ b/python/lib/sift_client/transport/base_connection.py @@ -3,7 +3,7 @@ from abc import ABC from typing import TYPE_CHECKING -from sift_client.transport.grpc_transport import GrpcClient, GrpcConfig +from sift_client.transport.grpc_transport import CacheConfig, GrpcClient, GrpcConfig from sift_client.transport.rest_transport import RestClient, RestConfig if TYPE_CHECKING: @@ -24,6 +24,7 @@ def __init__( api_key: str, use_ssl: bool = True, cert_via_openssl: bool = False, + cache_config: CacheConfig | None = None, ): """Initialize the connection configuration. @@ -33,12 +34,14 @@ def __init__( api_key: The API key for authentication. use_ssl: Whether to use SSL/TLS for secure connections. cert_via_openssl: Whether to use OpenSSL for certificate validation. + cache_config: Optional cache configuration for gRPC responses. """ self.api_key = api_key self.grpc_url = grpc_url self.rest_url = rest_url self.use_ssl = use_ssl self.cert_via_openssl = cert_via_openssl + self.cache_config = cache_config def get_grpc_config(self): """Create and return a GrpcConfig with the current settings. @@ -51,6 +54,7 @@ def get_grpc_config(self): api_key=self.api_key, use_ssl=self.use_ssl, cert_via_openssl=self.cert_via_openssl, + cache_config=self.cache_config, ) def get_rest_config(self): diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index b27ce8fc1..6da09d62a 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -9,10 +9,13 @@ import asyncio import atexit import logging +import tempfile import threading +from pathlib import Path from typing import Any from sift_py.grpc.transport import ( + CacheConfig as SiftCacheConfig, SiftChannelConfig, use_sift_async_channel, ) @@ -34,6 +37,51 @@ def _suppress_blocking_io(loop, context): loop.default_exception_handler(context) +class CacheConfig: + """Configuration for gRPC response caching. + + Attributes: + enabled: Whether to enable caching. Default is False. + ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). + cache_path: Path to the cache directory. Default is system temp directory + 'sift_grpc_cache'. + size_limit: Maximum size of the cache in bytes. Default is 1GB. + """ + + def __init__( + self, + enabled: bool = False, + ttl: int = 3600, + cache_path: str | None = None, + size_limit: int = 1024 * 1024 * 1024, # 1GB + ): + """Initialize the cache configuration. + + Args: + enabled: Whether to enable caching. + ttl: Time-to-live for cached entries in seconds. + cache_path: Path to the cache directory. If None, uses system temp directory. + size_limit: Maximum size of the cache in bytes. + """ + self.enabled = enabled + self.ttl = ttl + self.cache_path = cache_path or str( + Path(tempfile.gettempdir()) / "sift_grpc_cache" + ) + self.size_limit = size_limit + + def to_sift_cache_config(self) -> SiftCacheConfig: + """Convert to a SiftCacheConfig for use with sift_py.grpc.transport. + + Returns: + A SiftCacheConfig dictionary. + """ + return { + "ttl": self.ttl, + "cache_path": self.cache_path, + "size_limit": self.size_limit, + } + + class GrpcConfig: """Configuration for gRPC API clients.""" @@ -44,6 +92,7 @@ def __init__( use_ssl: bool = True, cert_via_openssl: bool = False, metadata: dict[str, str] | None = None, + cache_config: CacheConfig | None = None, ): """Initialize the gRPC configuration. @@ -52,14 +101,15 @@ def __init__( api_key: The API key for authentication. use_ssl: Whether to use SSL/TLS. cert_via_openssl: Whether to use OpenSSL for SSL/TLS. - use_async: Whether to use async gRPC client. metadata: Additional metadata to include in all requests. + cache_config: Optional cache configuration. If None, caching is disabled. """ self.uri = url self.api_key = api_key self.use_ssl = use_ssl self.cert_via_openssl = cert_via_openssl self.metadata = metadata or {} + self.cache_config = cache_config def _to_sift_channel_config(self) -> SiftChannelConfig: """Convert to a SiftChannelConfig. @@ -67,12 +117,18 @@ def _to_sift_channel_config(self) -> SiftChannelConfig: Returns: A SiftChannelConfig. """ - return { + config: SiftChannelConfig = { "uri": self.uri, "apikey": self.api_key, "use_ssl": self.use_ssl, "cert_via_openssl": self.cert_via_openssl, } + + # Add cache config if enabled + if self.cache_config and self.cache_config.enabled: + config["cache_config"] = self.cache_config.to_sift_cache_config() + + return config class GrpcClient: diff --git a/python/lib/sift_py/grpc/transport.py b/python/lib/sift_py/grpc/transport.py index 07d13f667..51b02c0c4 100644 --- a/python/lib/sift_py/grpc/transport.py +++ b/python/lib/sift_py/grpc/transport.py @@ -15,8 +15,10 @@ from typing_extensions import NotRequired, TypeAlias from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor +from sift_py.grpc._async_interceptors.caching import CachingAsyncInterceptor from sift_py.grpc._async_interceptors.metadata import MetadataAsyncInterceptor from sift_py.grpc._interceptors.base import ClientInterceptor +from sift_py.grpc._interceptors.caching import CachingInterceptor from sift_py.grpc._interceptors.metadata import Metadata, MetadataInterceptor from sift_py.grpc._retry import RetryPolicy from sift_py.grpc.keepalive import DEFAULT_KEEPALIVE_CONFIG, KeepaliveConfig @@ -130,17 +132,45 @@ def _compute_sift_interceptors( """ Initialized all interceptors here. """ - return [ - _metadata_interceptor(config, metadata), - ] + interceptors: List[ClientInterceptor] = [] + + # Add caching interceptor if enabled + cache_config = config.get("cache_config") + if cache_config: + interceptors.append( + CachingInterceptor( + ttl=cache_config.get("ttl", 3600), + cache_path=cache_config.get("cache_path", ".grpc_cache"), + size_limit=cache_config.get("size_limit", 1024 * 1024 * 1024), + ) + ) + + # Metadata interceptor should be last to ensure metadata is always added + interceptors.append(_metadata_interceptor(config, metadata)) + + return interceptors def _compute_sift_async_interceptors( config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None ) -> List[grpc_aio.ClientInterceptor]: - return [ - _metadata_async_interceptor(config, metadata), - ] + interceptors: List[grpc_aio.ClientInterceptor] = [] + + # Add caching interceptor if enabled + cache_config = config.get("cache_config") + if cache_config: + interceptors.append( + CachingAsyncInterceptor( + ttl=cache_config.get("ttl", 3600), + cache_path=cache_config.get("cache_path", ".grpc_cache"), + size_limit=cache_config.get("size_limit", 1024 * 1024 * 1024), + ) + ) + + # Metadata interceptor should be last to ensure metadata is always added + interceptors.append(_metadata_async_interceptor(config, metadata)) + + return interceptors def _compute_channel_options(opts: SiftChannelConfig) -> List[Tuple[str, Any]]: @@ -229,6 +259,19 @@ def _compute_keep_alive_channel_opts(config: KeepaliveConfig) -> List[Tuple[str, ] +class CacheConfig(TypedDict): + """ + Configuration for gRPC response caching. + - `ttl`: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). + - `cache_path`: Path to the cache directory. Default is ".grpc_cache". + - `size_limit`: Maximum size of the cache in bytes. Default is 1GB. + """ + + ttl: NotRequired[int] + cache_path: NotRequired[str] + size_limit: NotRequired[int] + + class SiftChannelConfig(TypedDict): """ Config class used to instantiate a `SiftChannel` via `use_sift_channel`. @@ -241,6 +284,8 @@ class SiftChannelConfig(TypedDict): Run `pip install sift-stack-py[openssl]` to install the dependencies required to use this option. This works around this issue with grpc loading SSL certificates: https://github.com/grpc/grpc/issues/29682. Default is False. + - `cache_config`: Optional configuration for response caching. If provided, caching will be enabled. + Use metadata flags to control caching on a per-request basis. """ uri: str @@ -248,3 +293,4 @@ class SiftChannelConfig(TypedDict): enable_keepalive: NotRequired[Union[bool, KeepaliveConfig]] use_ssl: NotRequired[bool] cert_via_openssl: NotRequired[bool] + cache_config: NotRequired[CacheConfig] diff --git a/python/pyproject.toml b/python/pyproject.toml index 993b4236b..1a30e27bf 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "types-protobuf>=4.0", "typing-extensions~=4.6", "types-requests~=2.25", + "diskcache~=5.6" ] [project.urls] From 9c79f576336876d89496cd7068267f8c6ee2b58b Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 13 Oct 2025 11:40:17 -0700 Subject: [PATCH 02/16] updating defaults --- .../sift_client/transport/grpc_transport.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 6da09d62a..876771694 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -16,6 +16,8 @@ from sift_py.grpc.transport import ( CacheConfig as SiftCacheConfig, +) +from sift_py.grpc.transport import ( SiftChannelConfig, use_sift_async_channel, ) @@ -37,11 +39,16 @@ def _suppress_blocking_io(loop, context): loop.default_exception_handler(context) +DEFAULT_CACHE_TTL_SECONDS = 7 * 24 * 60 * 60 # 1 week +DEFAULT_CACHE_FOLDER = Path(tempfile.gettempdir()) / "sift_client" +DEFAULT_CACHE_SIZE_LIMIT_BYTES = 5 * 1024 * 1024 * 1024 # 5GB + + class CacheConfig: """Configuration for gRPC response caching. - + Attributes: - enabled: Whether to enable caching. Default is False. + enabled: Whether to enable caching. Default is True. ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). cache_path: Path to the cache directory. Default is system temp directory + 'sift_grpc_cache'. size_limit: Maximum size of the cache in bytes. Default is 1GB. @@ -49,22 +56,22 @@ class CacheConfig: def __init__( self, - enabled: bool = False, - ttl: int = 3600, - cache_path: str | None = None, - size_limit: int = 1024 * 1024 * 1024, # 1GB + enabled: bool = True, + ttl: int = DEFAULT_CACHE_TTL_SECONDS, + cache_folder: Path | str | None = DEFAULT_CACHE_FOLDER, + size_limit: int = DEFAULT_CACHE_SIZE_LIMIT_BYTES, # 1GB ): """Initialize the cache configuration. Args: enabled: Whether to enable caching. ttl: Time-to-live for cached entries in seconds. - cache_path: Path to the cache directory. If None, uses system temp directory. + cache_folder: Path to the cache directory. If None, uses system temp directory. size_limit: Maximum size of the cache in bytes. """ self.enabled = enabled self.ttl = ttl - self.cache_path = cache_path or str( + self.cache_path = cache_folder or str( Path(tempfile.gettempdir()) / "sift_grpc_cache" ) self.size_limit = size_limit @@ -123,11 +130,11 @@ def _to_sift_channel_config(self) -> SiftChannelConfig: "use_ssl": self.use_ssl, "cert_via_openssl": self.cert_via_openssl, } - + # Add cache config if enabled if self.cache_config and self.cache_config.enabled: config["cache_config"] = self.cache_config.to_sift_cache_config() - + return config @@ -146,7 +153,9 @@ def __init__(self, config: GrpcConfig): self._config = config # map each asyncio loop to its async channel and stub dict self._channels_async: dict[asyncio.AbstractEventLoop, Any] = {} - self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = {} + self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = ( + {} + ) # default loop for sync API self._default_loop = asyncio.new_event_loop() atexit.register(self.close_sync) @@ -208,7 +217,9 @@ def close_sync(self): """Close the sync channel and all async channels.""" try: for ch in self._channels_async.values(): - asyncio.run_coroutine_threadsafe(ch.close(), self._default_loop).result() + asyncio.run_coroutine_threadsafe( + ch.close(), self._default_loop + ).result() self._default_loop.call_soon_threadsafe(self._default_loop.stop) self._default_loop_thread.join(timeout=1.0) except ValueError: From d4504e3f80f84cd16d71f4acc3ca168c04b38b97 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 13 Oct 2025 11:46:06 -0700 Subject: [PATCH 03/16] add some missing files --- python/examples/caching_example.py | 91 +++++ .../grpc/_async_interceptors/caching.py | 222 ++++++++++++ .../lib/sift_py/grpc/_interceptors/caching.py | 216 ++++++++++++ .../lib/sift_py/grpc/_tests/test_caching.py | 329 ++++++++++++++++++ python/lib/sift_py/grpc/cache.py | 118 +++++++ 5 files changed, 976 insertions(+) create mode 100644 python/examples/caching_example.py create mode 100644 python/lib/sift_py/grpc/_async_interceptors/caching.py create mode 100644 python/lib/sift_py/grpc/_interceptors/caching.py create mode 100644 python/lib/sift_py/grpc/_tests/test_caching.py create mode 100644 python/lib/sift_py/grpc/cache.py diff --git a/python/examples/caching_example.py b/python/examples/caching_example.py new file mode 100644 index 000000000..34b3565c4 --- /dev/null +++ b/python/examples/caching_example.py @@ -0,0 +1,91 @@ +"""Example demonstrating gRPC response caching with the Sift client. + +This example shows how to: +1. Enable caching via SiftClient configuration +2. Use cache control metadata to control caching behavior +3. Measure the performance improvement from caching + +Requirements: + pip install sift-stack-py[cache] +""" + +import time +from sift_client import CacheConfig, SiftClient +from sift_py.grpc.cache import with_cache, with_force_refresh, clear_cache_for + +# Configure caching +cache_config = CacheConfig( + enabled=True, # Enable caching + ttl=3600, # Cache for 1 hour + cache_path=None, # Uses system temp directory by default + size_limit=1024 * 1024 * 1024, # 1GB max +) + +# Initialize client with caching enabled +client = SiftClient( + api_key="your-api-key-here", + grpc_url="api.siftstack.com", + rest_url="https://api.siftstack.com", + cache_config=cache_config, # Pass cache config directly +) + +# Example 1: Basic caching +print("Example 1: Basic Caching") +print("-" * 50) + +# First call - cache miss (fetches from server) +start = time.time() +response = client.ping.ping() # Note: Need to add metadata support to high-level APIs +elapsed_first = time.time() - start +print(f"First call (cache miss): {elapsed_first:.3f}s") + +# Second call - cache hit (returns cached response) +start = time.time() +response = client.ping.ping() +elapsed_second = time.time() - start +print(f"Second call (cache hit): {elapsed_second:.3f}s") +print(f"Speedup: {elapsed_first / elapsed_second:.1f}x faster") + +# Example 2: Force refresh +print("\nExample 2: Force Refresh") +print("-" * 50) + +# Force refresh - bypasses cache and fetches fresh data +start = time.time() +response = client.ping.ping() # with force_refresh metadata +elapsed = time.time() - start +print(f"Force refresh: {elapsed:.3f}s") + +# Example 3: Clear cache +print("\nExample 3: Clear Cache") +print("-" * 50) + +# Clear the cache for this specific request +response = client.ping.ping() # with clear_cache_for metadata +print("Cache cleared for this request") + +# Example 4: Conditional caching +print("\nExample 4: Conditional Caching") +print("-" * 50) + + +def get_data(use_cache: bool = False): + """Helper function that conditionally uses caching.""" + if use_cache: + # Use cache + return client.ping.ping() # with with_cache metadata + else: + # Skip cache + return client.ping.ping() # without cache metadata + + +# Use cache in production +response = get_data(use_cache=True) +print("Called with caching enabled") + +# Skip cache in development +response = get_data(use_cache=False) +print("Called without caching") + +print("\nNote: This example requires integration with the high-level API") +print("to pass cache control metadata. See the documentation for details.") diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py new file mode 100644 index 000000000..9b57e21b8 --- /dev/null +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -0,0 +1,222 @@ +"""Async gRPC caching interceptor for transparent local response caching. + +This module provides an async caching interceptor that can be used to cache gRPC +unary-unary responses locally using diskcache. The cache is persistent across runs +and supports TTL expiration and per-request control via metadata. + +Usage: + from sift_py.grpc._async_interceptors.caching import CachingAsyncInterceptor + + # Create interceptor with 1 hour TTL + cache_interceptor = CachingAsyncInterceptor(ttl=3600, cache_path=".grpc_cache") + + # Use with metadata to control caching: + metadata = [ + ("use-cache", "true"), # Enable caching for this call + # ("force-refresh", "true"), # Bypass cache and store fresh result + # ("clear-cache", "true"), # Delete cached entry before request + ] +""" + +from __future__ import annotations + +import hashlib +import logging +from pathlib import Path +from typing import Any, Optional + +from grpc import aio as grpc_aio + +from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor + +logger = logging.getLogger(__name__) + +# Metadata keys for cache control +METADATA_USE_CACHE = "use-cache" +METADATA_FORCE_REFRESH = "force-refresh" +METADATA_CLEAR_CACHE = "clear-cache" + + +class CachingAsyncInterceptor(ClientAsyncInterceptor): + """Async interceptor that caches unary-unary gRPC responses locally. + + This interceptor uses diskcache for persistent storage with TTL support. + Cache keys are generated deterministically based on the gRPC method name + and serialized request payload. + + Note: diskcache operations are synchronous, but the overhead is minimal + for most use cases. For high-throughput scenarios, consider using an + async-native cache backend. + + Attributes: + ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). + cache_path: Path to the cache directory. Default is ".grpc_cache". + size_limit: Maximum size of the cache in bytes. Default is 1GB. + """ + + def __init__( + self, + ttl: int = 3600, + cache_path: str = ".grpc_cache", + size_limit: int = 1024 * 1024 * 1024, # 1GB + ): + """Initialize the async caching interceptor. + + Args: + ttl: Time-to-live for cached entries in seconds. + cache_path: Path to the cache directory. + size_limit: Maximum size of the cache in bytes. + """ + try: + import diskcache + except ImportError: + raise ImportError( + "diskcache is required for caching. Install it with: pip install diskcache" + ) + + self.ttl = ttl + self.cache_path = Path(cache_path) + self.size_limit = size_limit + + # Create cache directory if it doesn't exist + self.cache_path.mkdir(parents=True, exist_ok=True) + + # Initialize diskcache + self._cache = diskcache.Cache(str(self.cache_path), size_limit=size_limit) + + logger.debug( + f"Initialized CachingAsyncInterceptor with ttl={ttl}s, " + f"cache_path={cache_path}, size_limit={size_limit} bytes" + ) + + async def intercept( + self, + method: Any, + request_or_iterator: Any, + client_call_details: grpc_aio.ClientCallDetails, + ) -> Any: + """Intercept the async gRPC call and apply caching logic. + + Args: + method: The continuation to call for the actual RPC. + request_or_iterator: The request object or iterator. + client_call_details: The call details including method name and metadata. + + Returns: + The response from the cache or the actual RPC call. + """ + # Extract metadata flags + metadata_dict = self._extract_metadata(client_call_details.metadata) + use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" + force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" + clear_cache = metadata_dict.get(METADATA_CLEAR_CACHE, "false").lower() == "true" + + # If caching is not enabled, just pass through + if not use_cache and not clear_cache and not force_refresh: + return await method(request_or_iterator, client_call_details) + + # Generate cache key + cache_key = self._generate_cache_key(client_call_details.method, request_or_iterator) + + # Handle clear-cache flag + if clear_cache: + logger.debug(f"Clearing cache for key: {cache_key}") + self._cache.delete(cache_key) + # Continue with the request after clearing + + # Handle force-refresh flag + if force_refresh: + logger.debug(f"Force refresh for key: {cache_key}") + call = await method(request_or_iterator, client_call_details) + # For async, we need to await the response + response = await call + # Cache the fresh result + self._cache_response(cache_key, response) + return response + + # Try to get from cache if use-cache is enabled + if use_cache: + cached_response = self._cache.get(cache_key) + if cached_response is not None: + logger.debug(f"Cache hit for key: {cache_key}") + return cached_response + + logger.debug(f"Cache miss for key: {cache_key}") + + # Make the actual RPC call + call = await method(request_or_iterator, client_call_details) + response = await call + + # Cache the response if use-cache is enabled + if use_cache: + self._cache_response(cache_key, response) + + return response + + def _generate_cache_key(self, method_name: str, request: Any) -> str: + """Generate a deterministic cache key from method name and request. + + Args: + method_name: The gRPC method name. + request: The request object. + + Returns: + A SHA256 hash of the method name and serialized request. + """ + try: + # Serialize the request using protobuf's SerializeToString + request_bytes = request.SerializeToString() + except AttributeError: + # If the request doesn't have SerializeToString, fall back to str + logger.warning( + f"Request for {method_name} doesn't have SerializeToString, using str() instead" + ) + request_bytes = str(request).encode() + + # Create a deterministic hash + key_material = method_name.encode() + request_bytes + cache_key = hashlib.sha256(key_material).hexdigest() + + return cache_key + + def _cache_response(self, cache_key: str, response: Any) -> None: + """Store a response in the cache with TTL. + + Args: + cache_key: The cache key. + response: The response to cache. + """ + try: + self._cache.set(cache_key, response, expire=self.ttl) + logger.debug(f"Cached response for key: {cache_key} with TTL: {self.ttl}s") + except Exception as e: + logger.error(f"Failed to cache response for key {cache_key}: {e}") + + def _extract_metadata(self, metadata: Optional[tuple[tuple[str, str], ...]]) -> dict[str, str]: + """Extract metadata into a dictionary. + + Args: + metadata: The metadata tuple. + + Returns: + A dictionary of metadata key-value pairs. + """ + if metadata is None: + return {} + return dict(metadata) + + def clear_all(self) -> None: + """Clear all cached entries.""" + logger.info("Clearing all cached entries") + self._cache.clear() + + def close(self) -> None: + """Close the cache and release resources.""" + logger.debug("Closing cache") + self._cache.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/python/lib/sift_py/grpc/_interceptors/caching.py b/python/lib/sift_py/grpc/_interceptors/caching.py new file mode 100644 index 000000000..d8dee6986 --- /dev/null +++ b/python/lib/sift_py/grpc/_interceptors/caching.py @@ -0,0 +1,216 @@ +"""gRPC caching interceptor for transparent local response caching. + +This module provides a caching interceptor that can be used to cache gRPC unary-unary +responses locally using diskcache. The cache is persistent across runs and supports TTL +expiration and per-request control via metadata. + +Usage: + from sift_py.grpc._interceptors.caching import CachingInterceptor + + # Create interceptor with 1 hour TTL + cache_interceptor = CachingInterceptor(ttl=3600, cache_path=".grpc_cache") + + # Use with metadata to control caching: + metadata = [ + ("use-cache", "true"), # Enable caching for this call + # ("force-refresh", "true"), # Bypass cache and store fresh result + # ("clear-cache", "true"), # Delete cached entry before request + ] +""" + +from __future__ import annotations + +import hashlib +import logging +from pathlib import Path +from typing import Any, Optional + +import grpc + +from sift_py.grpc._interceptors.base import ClientInterceptor, Continuation +from sift_py.grpc._interceptors.context import ClientCallDetails + +logger = logging.getLogger(__name__) + +# Metadata keys for cache control +METADATA_USE_CACHE = "use-cache" +METADATA_FORCE_REFRESH = "force-refresh" +METADATA_CLEAR_CACHE = "clear-cache" + + +class CachingInterceptor(ClientInterceptor): + """Interceptor that caches unary-unary gRPC responses locally. + + This interceptor uses diskcache for persistent storage with TTL support. + Cache keys are generated deterministically based on the gRPC method name + and serialized request payload. + + Attributes: + ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). + cache_path: Path to the cache directory. Default is ".grpc_cache". + size_limit: Maximum size of the cache in bytes. Default is 1GB. + """ + + def __init__( + self, + ttl: int = 3600, + cache_path: str = ".grpc_cache", + size_limit: int = 1024 * 1024 * 1024, # 1GB + ): + """Initialize the caching interceptor. + + Args: + ttl: Time-to-live for cached entries in seconds. + cache_path: Path to the cache directory. + size_limit: Maximum size of the cache in bytes. + """ + try: + import diskcache + except ImportError: + raise ImportError( + "diskcache is required for caching. Install it with: pip install diskcache" + ) + + self.ttl = ttl + self.cache_path = Path(cache_path) + self.size_limit = size_limit + + # Create cache directory if it doesn't exist + self.cache_path.mkdir(parents=True, exist_ok=True) + + # Initialize diskcache + self._cache = diskcache.Cache(str(self.cache_path), size_limit=size_limit) + + logger.debug( + f"Initialized CachingInterceptor with ttl={ttl}s, " + f"cache_path={cache_path}, size_limit={size_limit} bytes" + ) + + def intercept( + self, + method: Continuation, + request_or_iterator: Any, + client_call_details: grpc.ClientCallDetails, + ): + """Intercept the gRPC call and apply caching logic. + + Args: + method: The continuation to call for the actual RPC. + request_or_iterator: The request object or iterator. + client_call_details: The call details including method name and metadata. + + Returns: + The response from the cache or the actual RPC call. + """ + # Extract metadata flags + metadata_dict = self._extract_metadata(client_call_details.metadata) + use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" + force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" + clear_cache = metadata_dict.get(METADATA_CLEAR_CACHE, "false").lower() == "true" + + # If caching is not enabled, just pass through + if not use_cache and not clear_cache and not force_refresh: + return method(request_or_iterator, client_call_details) + + # Generate cache key + cache_key = self._generate_cache_key(client_call_details.method, request_or_iterator) + + # Handle clear-cache flag + if clear_cache: + logger.debug(f"Clearing cache for key: {cache_key}") + self._cache.delete(cache_key) + # Continue with the request after clearing + + # Handle force-refresh flag + if force_refresh: + logger.debug(f"Force refresh for key: {cache_key}") + response = method(request_or_iterator, client_call_details) + # Cache the fresh result + self._cache_response(cache_key, response) + return response + + # Try to get from cache if use-cache is enabled + if use_cache: + cached_response = self._cache.get(cache_key) + if cached_response is not None: + logger.debug(f"Cache hit for key: {cache_key}") + return cached_response + + logger.debug(f"Cache miss for key: {cache_key}") + + # Make the actual RPC call + response = method(request_or_iterator, client_call_details) + + # Cache the response if use-cache is enabled + if use_cache: + self._cache_response(cache_key, response) + + return response + + def _generate_cache_key(self, method_name: str, request: Any) -> str: + """Generate a deterministic cache key from method name and request. + + Args: + method_name: The gRPC method name. + request: The request object. + + Returns: + A SHA256 hash of the method name and serialized request. + """ + try: + # Serialize the request using protobuf's SerializeToString + request_bytes = request.SerializeToString() + except AttributeError: + # If the request doesn't have SerializeToString, fall back to str + logger.warning( + f"Request for {method_name} doesn't have SerializeToString, using str() instead" + ) + request_bytes = str(request).encode() + + # Create a deterministic hash + key_material = method_name.encode() + request_bytes + cache_key = hashlib.sha256(key_material).hexdigest() + + return cache_key + + def _cache_response(self, cache_key: str, response: Any) -> None: + """Store a response in the cache with TTL. + + Args: + cache_key: The cache key. + response: The response to cache. + """ + try: + self._cache.set(cache_key, response, expire=self.ttl) + logger.debug(f"Cached response for key: {cache_key} with TTL: {self.ttl}s") + except Exception as e: + logger.error(f"Failed to cache response for key {cache_key}: {e}") + + def _extract_metadata(self, metadata: Optional[tuple[tuple[str, str], ...]]) -> dict[str, str]: + """Extract metadata into a dictionary. + + Args: + metadata: The metadata tuple. + + Returns: + A dictionary of metadata key-value pairs. + """ + if metadata is None: + return {} + return dict(metadata) + + def clear_all(self) -> None: + """Clear all cached entries.""" + logger.info("Clearing all cached entries") + self._cache.clear() + + def close(self) -> None: + """Close the cache and release resources.""" + logger.debug("Closing cache") + self._cache.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/python/lib/sift_py/grpc/_tests/test_caching.py b/python/lib/sift_py/grpc/_tests/test_caching.py new file mode 100644 index 000000000..c0641b85a --- /dev/null +++ b/python/lib/sift_py/grpc/_tests/test_caching.py @@ -0,0 +1,329 @@ +"""Tests for gRPC caching interceptor.""" + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock + +import grpc +import pytest + +from sift_py.grpc._interceptors.caching import CachingInterceptor +from sift_py.grpc._interceptors.context import ClientCallDetails +from sift_py.grpc.cache import ( + clear_cache_for, + with_cache, + with_force_refresh, + without_cache, +) + + +class MockRequest: + """Mock protobuf request for testing.""" + + def __init__(self, data: str): + self.data = data + + def SerializeToString(self) -> bytes: + return self.data.encode() + + +class MockResponse: + """Mock protobuf response for testing.""" + + def __init__(self, value: str): + self.value = value + + +@pytest.fixture +def temp_cache_dir(): + """Create a temporary cache directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def interceptor(temp_cache_dir): + """Create a caching interceptor with a temporary cache directory.""" + return CachingInterceptor(ttl=60, cache_path=temp_cache_dir) + + +def test_cache_miss_and_hit(interceptor): + """Test that cache miss fetches from server and cache hit returns cached response.""" + # Setup + request = MockRequest("test-data") + response = MockResponse("test-response") + method_name = "/test.Service/TestMethod" + + # Create mock continuation + continuation = Mock(return_value=response) + + # Create call details with cache enabled + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_cache(), + credentials=None, + wait_for_ready=None, + ) + + # First call - cache miss + result1 = interceptor.intercept(continuation, request, call_details) + assert result1 == response + assert continuation.call_count == 1 + + # Second call - cache hit + result2 = interceptor.intercept(continuation, request, call_details) + assert result2 == response + assert continuation.call_count == 1 # Should not call continuation again + + +def test_cache_disabled_by_default(interceptor): + """Test that caching is disabled by default without metadata.""" + # Setup + request = MockRequest("test-data") + response = MockResponse("test-response") + method_name = "/test.Service/TestMethod" + + # Create mock continuation + continuation = Mock(return_value=response) + + # Create call details without cache metadata + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + ) + + # First call + result1 = interceptor.intercept(continuation, request, call_details) + assert result1 == response + assert continuation.call_count == 1 + + # Second call - should call continuation again (no caching) + result2 = interceptor.intercept(continuation, request, call_details) + assert result2 == response + assert continuation.call_count == 2 + + +def test_force_refresh(interceptor): + """Test that force refresh bypasses cache and stores fresh result.""" + # Setup + request = MockRequest("test-data") + response1 = MockResponse("response-1") + response2 = MockResponse("response-2") + method_name = "/test.Service/TestMethod" + + # Create mock continuation that returns different responses + continuation = Mock(side_effect=[response1, response2]) + + # First call with cache + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_cache(), + credentials=None, + wait_for_ready=None, + ) + result1 = interceptor.intercept(continuation, request, call_details) + assert result1 == response1 + assert continuation.call_count == 1 + + # Second call with force refresh + call_details_refresh = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_force_refresh(), + credentials=None, + wait_for_ready=None, + ) + result2 = interceptor.intercept(continuation, request, call_details_refresh) + assert result2 == response2 + assert continuation.call_count == 2 # Should call continuation again + + +def test_clear_cache(interceptor): + """Test that clear cache deletes the cached entry.""" + # Setup + request = MockRequest("test-data") + response1 = MockResponse("response-1") + response2 = MockResponse("response-2") + method_name = "/test.Service/TestMethod" + + # Create mock continuation + continuation = Mock(side_effect=[response1, response2]) + + # First call with cache + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_cache(), + credentials=None, + wait_for_ready=None, + ) + result1 = interceptor.intercept(continuation, request, call_details) + assert result1 == response1 + assert continuation.call_count == 1 + + # Second call with clear cache + call_details_clear = ClientCallDetails( + method=method_name, + timeout=None, + metadata=clear_cache_for(), + credentials=None, + wait_for_ready=None, + ) + result2 = interceptor.intercept(continuation, request, call_details_clear) + assert result2 == response2 + assert continuation.call_count == 2 # Should call continuation again + + +def test_different_requests_different_cache_keys(interceptor): + """Test that different requests generate different cache keys.""" + # Setup + request1 = MockRequest("data-1") + request2 = MockRequest("data-2") + response1 = MockResponse("response-1") + response2 = MockResponse("response-2") + method_name = "/test.Service/TestMethod" + + # Create mock continuation + continuation = Mock(side_effect=[response1, response2]) + + # First call with request1 + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_cache(), + credentials=None, + wait_for_ready=None, + ) + result1 = interceptor.intercept(continuation, request1, call_details) + assert result1 == response1 + assert continuation.call_count == 1 + + # Second call with request2 (different request) + result2 = interceptor.intercept(continuation, request2, call_details) + assert result2 == response2 + assert continuation.call_count == 2 # Should call continuation for different request + + +def test_cache_key_generation(interceptor): + """Test that cache key generation is deterministic.""" + request = MockRequest("test-data") + method_name = "/test.Service/TestMethod" + + key1 = interceptor._generate_cache_key(method_name, request) + key2 = interceptor._generate_cache_key(method_name, request) + + assert key1 == key2 + assert len(key1) == 64 # SHA256 hex digest length + + +def test_without_cache_helper(): + """Test the without_cache helper function.""" + metadata = without_cache() + assert len(metadata) == 0 + + # Test with existing metadata + existing = [("key", "value")] + metadata = without_cache(existing) + assert ("key", "value") in metadata + assert ("use-cache", "true") not in metadata + + +def test_with_cache_helper(): + """Test the with_cache helper function.""" + metadata = with_cache() + assert ("use-cache", "true") in metadata + + +def test_with_force_refresh_helper(): + """Test the with_force_refresh helper function.""" + metadata = with_force_refresh() + assert ("force-refresh", "true") in metadata + assert ("use-cache", "true") in metadata + + +def test_clear_cache_for_helper(): + """Test the clear_cache_for helper function.""" + metadata = clear_cache_for() + assert ("clear-cache", "true") in metadata + + +def test_clear_all(interceptor): + """Test clearing all cached entries.""" + # Setup + request = MockRequest("test-data") + response = MockResponse("test-response") + method_name = "/test.Service/TestMethod" + + # Create mock continuation + continuation = Mock(return_value=response) + + # Create call details with cache enabled + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_cache(), + credentials=None, + wait_for_ready=None, + ) + + # First call - cache miss + result1 = interceptor.intercept(continuation, request, call_details) + assert result1 == response + assert continuation.call_count == 1 + + # Clear all cache + interceptor.clear_all() + + # Second call - should be cache miss again + result2 = interceptor.intercept(continuation, request, call_details) + assert result2 == response + assert continuation.call_count == 2 + + +def test_context_manager(temp_cache_dir): + """Test that the interceptor works as a context manager.""" + with CachingInterceptor(ttl=60, cache_path=temp_cache_dir) as interceptor: + assert interceptor is not None + # Cache should be usable within the context + request = MockRequest("test") + key = interceptor._generate_cache_key("/test", request) + assert key is not None + + +def test_cache_persistence(temp_cache_dir): + """Test that cache persists across interceptor instances.""" + request = MockRequest("test-data") + response = MockResponse("test-response") + method_name = "/test.Service/TestMethod" + + # Create first interceptor and cache a response + interceptor1 = CachingInterceptor(ttl=60, cache_path=temp_cache_dir) + continuation = Mock(return_value=response) + call_details = ClientCallDetails( + method=method_name, + timeout=None, + metadata=with_cache(), + credentials=None, + wait_for_ready=None, + ) + result1 = interceptor1.intercept(continuation, request, call_details) + assert result1 == response + assert continuation.call_count == 1 + interceptor1.close() + + # Create second interceptor with same cache path + interceptor2 = CachingInterceptor(ttl=60, cache_path=temp_cache_dir) + continuation2 = Mock(return_value=response) + result2 = interceptor2.intercept(continuation2, request, call_details) + assert result2 == response + assert continuation2.call_count == 0 # Should use cached response + interceptor2.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py new file mode 100644 index 000000000..919938751 --- /dev/null +++ b/python/lib/sift_py/grpc/cache.py @@ -0,0 +1,118 @@ +"""Utilities for controlling gRPC response caching. + +This module provides helper functions and constants for working with the gRPC +caching interceptor. Use these utilities to control caching behavior on a +per-request basis via metadata. + +Example: + from sift_py.grpc.cache import with_cache, with_force_refresh, clear_cache_for + + # Enable caching for a request + metadata = with_cache() + response = stub.GetData(request, metadata=metadata) + + # Force refresh (bypass cache and store fresh result) + metadata = with_force_refresh() + response = stub.GetData(request, metadata=metadata) + + # Clear cache for a specific request + metadata = clear_cache_for() + response = stub.GetData(request, metadata=metadata) +""" + +from typing import List, Tuple + +# Metadata keys for cache control +METADATA_USE_CACHE = "use-cache" +METADATA_FORCE_REFRESH = "force-refresh" +METADATA_CLEAR_CACHE = "clear-cache" + + +def with_cache( + existing_metadata: List[Tuple[str, str]] | None = None, +) -> List[Tuple[str, str]]: + """Add cache control metadata to enable caching for a request. + + Args: + existing_metadata: Optional existing metadata to extend. + + Returns: + Metadata list with cache enabled. + + Example: + metadata = with_cache() + response = stub.GetData(request, metadata=metadata) + """ + metadata = list(existing_metadata) if existing_metadata else [] + metadata.append((METADATA_USE_CACHE, "true")) + return metadata + + +def with_force_refresh( + existing_metadata: List[Tuple[str, str]] | None = None, +) -> List[Tuple[str, str]]: + """Add cache control metadata to force refresh (bypass cache and store fresh result). + + Args: + existing_metadata: Optional existing metadata to extend. + + Returns: + Metadata list with force refresh enabled. + + Example: + metadata = with_force_refresh() + response = stub.GetData(request, metadata=metadata) + """ + metadata = list(existing_metadata) if existing_metadata else [] + metadata.append((METADATA_FORCE_REFRESH, "true")) + metadata.append((METADATA_USE_CACHE, "true")) # Also enable caching + return metadata + + +def clear_cache_for( + existing_metadata: List[Tuple[str, str]] | None = None, +) -> List[Tuple[str, str]]: + """Add cache control metadata to clear the cache for a specific request. + + This will delete the cached entry before making the request. + + Args: + existing_metadata: Optional existing metadata to extend. + + Returns: + Metadata list with clear cache enabled. + + Example: + metadata = clear_cache_for() + response = stub.GetData(request, metadata=metadata) + """ + metadata = list(existing_metadata) if existing_metadata else [] + metadata.append((METADATA_CLEAR_CACHE, "true")) + return metadata + + +def without_cache( + existing_metadata: List[Tuple[str, str]] | None = None, +) -> List[Tuple[str, str]]: + """Explicitly disable caching for a request. + + This is the default behavior, so this function is mainly for clarity. + + Args: + existing_metadata: Optional existing metadata to extend. + + Returns: + Metadata list without cache flags. + + Example: + metadata = without_cache() + response = stub.GetData(request, metadata=metadata) + """ + metadata = list(existing_metadata) if existing_metadata else [] + # Remove any cache-related metadata + metadata = [ + (k, v) + for k, v in metadata + if k not in (METADATA_USE_CACHE, METADATA_FORCE_REFRESH, METADATA_CLEAR_CACHE) + ] + return metadata From 2c94c66b3910e203b9b8eb97019f19ec95a1ce4c Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Tue, 14 Oct 2025 09:54:15 -0700 Subject: [PATCH 04/16] add async grpc caching --- python/examples/caching_example.py | 12 +- python/lib/sift_client/client.py | 13 +- python/lib/sift_client/transport/__init__.py | 3 +- .../sift_client/transport/grpc_transport.py | 62 ++-- .../grpc/_async_interceptors/caching.py | 69 ++-- .../lib/sift_py/grpc/_interceptors/caching.py | 216 ------------ .../lib/sift_py/grpc/_tests/test_caching.py | 329 ------------------ python/lib/sift_py/grpc/cache.py | 42 ++- python/lib/sift_py/grpc/transport.py | 36 +- 9 files changed, 151 insertions(+), 631 deletions(-) delete mode 100644 python/lib/sift_py/grpc/_interceptors/caching.py delete mode 100644 python/lib/sift_py/grpc/_tests/test_caching.py diff --git a/python/examples/caching_example.py b/python/examples/caching_example.py index 34b3565c4..c864817c3 100644 --- a/python/examples/caching_example.py +++ b/python/examples/caching_example.py @@ -11,7 +11,7 @@ import time from sift_client import CacheConfig, SiftClient -from sift_py.grpc.cache import with_cache, with_force_refresh, clear_cache_for +from sift_py.grpc.cache import with_cache, with_force_refresh, ignore_cache # Configure caching cache_config = CacheConfig( @@ -56,13 +56,13 @@ elapsed = time.time() - start print(f"Force refresh: {elapsed:.3f}s") -# Example 3: Clear cache -print("\nExample 3: Clear Cache") +# Example 3: Ignore cache +print("\nExample 3: Ignore Cache") print("-" * 50) -# Clear the cache for this specific request -response = client.ping.ping() # with clear_cache_for metadata -print("Cache cleared for this request") +# Bypass cache without clearing it +response = client.ping.ping() # with ignore_cache metadata +print("Cache bypassed for this request (entry still exists)") # Example 4: Conditional caching print("\nExample 4: Conditional Caching") diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index b6c94a6fc..f93ae5477 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -32,10 +32,21 @@ WithGrpcClient, WithRestClient, ) +from sift_client.transport.grpc_transport import ( + DEFAULT_CACHE_FOLDER, + DEFAULT_CACHE_SIZE_LIMIT_BYTES, + DEFAULT_CACHE_TTL_SECONDS, +) from sift_client.util.util import AsyncAPIs _sift_client_experimental_warning() +DEFAULT_CACHE_CONFIG = CacheConfig( + ttl=DEFAULT_CACHE_TTL_SECONDS, + cache_folder=DEFAULT_CACHE_FOLDER, + size_limit=DEFAULT_CACHE_SIZE_LIMIT_BYTES, +) + class SiftClient( WithGrpcClient, @@ -119,7 +130,7 @@ def __init__( grpc_url: The Sift gRPC API URL. rest_url: The Sift REST API URL. connection_config: A SiftConnectionConfig object to configure the connection behavior of the SiftClient. - cache_config: Optional cache configuration for gRPC responses. If provided, enables response caching. + cache_config: Optional cache configuration override for gRPC responses. """ if not (api_key and grpc_url and rest_url) and not connection_config: raise ValueError( diff --git a/python/lib/sift_client/transport/__init__.py b/python/lib/sift_client/transport/__init__.py index 280f19e2c..a7bb2a9ca 100644 --- a/python/lib/sift_client/transport/__init__.py +++ b/python/lib/sift_client/transport/__init__.py @@ -3,11 +3,12 @@ WithGrpcClient, WithRestClient, ) -from sift_client.transport.grpc_transport import CacheConfig, GrpcClient, GrpcConfig +from sift_client.transport.grpc_transport import CacheConfig, CacheMode, GrpcClient, GrpcConfig from sift_client.transport.rest_transport import RestClient, RestConfig __all__ = [ "CacheConfig", + "CacheMode", "GrpcClient", "GrpcConfig", "RestClient", diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 876771694..1e4400d6e 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -8,6 +8,7 @@ import asyncio import atexit +import enum import logging import tempfile import threading @@ -41,40 +42,62 @@ def _suppress_blocking_io(loop, context): DEFAULT_CACHE_TTL_SECONDS = 7 * 24 * 60 * 60 # 1 week DEFAULT_CACHE_FOLDER = Path(tempfile.gettempdir()) / "sift_client" -DEFAULT_CACHE_SIZE_LIMIT_BYTES = 5 * 1024 * 1024 * 1024 # 5GB +DEFAULT_CACHE_SIZE_LIMIT_BYTES = 5 * 1024**3 # 5GB + + +class CacheMode(str, enum.Enum): + """Cache behavior modes. + + - ENABLED: Cache is enabled and persists across sessions + - DISABLED: Cache is completely disabled + - CLEAR_ON_INIT: Cache is cleared when client is initialized (useful for notebooks) + """ + + ENABLED = "enabled" + DISABLED = "disabled" + CLEAR_ON_INIT = "clear_on_init" class CacheConfig: """Configuration for gRPC response caching. Attributes: - enabled: Whether to enable caching. Default is True. - ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). - cache_path: Path to the cache directory. Default is system temp directory + 'sift_grpc_cache'. - size_limit: Maximum size of the cache in bytes. Default is 1GB. + mode: Cache behavior mode (enabled, disabled, clear_on_init). + ttl: Time-to-live for cached entries in seconds. Default is 1 week. + cache_folder: Path to the cache directory. Default is system temp directory. + size_limit: Maximum size of the cache in bytes. Default is 5GB. """ def __init__( self, - enabled: bool = True, + mode: str = CacheMode.ENABLED, ttl: int = DEFAULT_CACHE_TTL_SECONDS, - cache_folder: Path | str | None = DEFAULT_CACHE_FOLDER, - size_limit: int = DEFAULT_CACHE_SIZE_LIMIT_BYTES, # 1GB + cache_folder: Path | str = DEFAULT_CACHE_FOLDER, + size_limit: int = DEFAULT_CACHE_SIZE_LIMIT_BYTES, ): """Initialize the cache configuration. Args: - enabled: Whether to enable caching. + mode: Cache behavior mode (use CacheMode constants). ttl: Time-to-live for cached entries in seconds. - cache_folder: Path to the cache directory. If None, uses system temp directory. + cache_folder: Path to the cache directory. size_limit: Maximum size of the cache in bytes. """ - self.enabled = enabled + self.mode = mode self.ttl = ttl - self.cache_path = cache_folder or str( - Path(tempfile.gettempdir()) / "sift_grpc_cache" - ) + self.cache_path = str(Path(cache_folder) / "grpc_cache") self.size_limit = size_limit + self._should_clear_on_init = mode == CacheMode.CLEAR_ON_INIT + + @property + def is_enabled(self) -> bool: + """Check if caching is enabled.""" + return self.mode != CacheMode.DISABLED + + @property + def should_clear_on_init(self) -> bool: + """Check if cache should be cleared on initialization.""" + return self._should_clear_on_init def to_sift_cache_config(self) -> SiftCacheConfig: """Convert to a SiftCacheConfig for use with sift_py.grpc.transport. @@ -86,6 +109,7 @@ def to_sift_cache_config(self) -> SiftCacheConfig: "ttl": self.ttl, "cache_path": self.cache_path, "size_limit": self.size_limit, + "clear_on_init": self.should_clear_on_init, } @@ -132,7 +156,7 @@ def _to_sift_channel_config(self) -> SiftChannelConfig: } # Add cache config if enabled - if self.cache_config and self.cache_config.enabled: + if self.cache_config and self.cache_config.is_enabled: config["cache_config"] = self.cache_config.to_sift_cache_config() return config @@ -153,9 +177,7 @@ def __init__(self, config: GrpcConfig): self._config = config # map each asyncio loop to its async channel and stub dict self._channels_async: dict[asyncio.AbstractEventLoop, Any] = {} - self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = ( - {} - ) + self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = {} # default loop for sync API self._default_loop = asyncio.new_event_loop() atexit.register(self.close_sync) @@ -217,9 +239,7 @@ def close_sync(self): """Close the sync channel and all async channels.""" try: for ch in self._channels_async.values(): - asyncio.run_coroutine_threadsafe( - ch.close(), self._default_loop - ).result() + asyncio.run_coroutine_threadsafe(ch.close(), self._default_loop).result() self._default_loop.call_soon_threadsafe(self._default_loop.stop) self._default_loop_thread.join(timeout=1.0) except ValueError: diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py index 9b57e21b8..8094c76c2 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/caching.py +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -14,7 +14,7 @@ metadata = [ ("use-cache", "true"), # Enable caching for this call # ("force-refresh", "true"), # Bypass cache and store fresh result - # ("clear-cache", "true"), # Delete cached entry before request + # ("ignore-cache", "true"), # Bypass cache without clearing ] """ @@ -25,6 +25,7 @@ from pathlib import Path from typing import Any, Optional +import diskcache from grpc import aio as grpc_aio from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor @@ -34,7 +35,8 @@ # Metadata keys for cache control METADATA_USE_CACHE = "use-cache" METADATA_FORCE_REFRESH = "force-refresh" -METADATA_CLEAR_CACHE = "clear-cache" +METADATA_IGNORE_CACHE = "ignore-cache" +METADATA_CACHE_TTL = "cache-ttl" class CachingAsyncInterceptor(ClientAsyncInterceptor): @@ -58,7 +60,8 @@ def __init__( self, ttl: int = 3600, cache_path: str = ".grpc_cache", - size_limit: int = 1024 * 1024 * 1024, # 1GB + size_limit: int = 1024**3, # 1GB + clear_on_init: bool = False, ): """Initialize the async caching interceptor. @@ -66,14 +69,8 @@ def __init__( ttl: Time-to-live for cached entries in seconds. cache_path: Path to the cache directory. size_limit: Maximum size of the cache in bytes. + clear_on_init: Whether to clear the cache on initialization. """ - try: - import diskcache - except ImportError: - raise ImportError( - "diskcache is required for caching. Install it with: pip install diskcache" - ) - self.ttl = ttl self.cache_path = Path(cache_path) self.size_limit = size_limit @@ -84,9 +81,19 @@ def __init__( # Initialize diskcache self._cache = diskcache.Cache(str(self.cache_path), size_limit=size_limit) + # Clear cache if requested + if clear_on_init: + logger.info(f"Clearing cache on initialization: {cache_path}") + self._cache.clear() + + logger.info( + f"gRPC cache initialized at {self.cache_path.absolute()!r} " + f"with size {self._cache.volume() / (1024**3):.2f} MB" + ) + logger.debug( f"Initialized CachingAsyncInterceptor with ttl={ttl}s, " - f"cache_path={cache_path}, size_limit={size_limit} bytes" + f"cache_path={cache_path}, size_limit={size_limit} bytes, clear_on_init={clear_on_init}" ) async def intercept( @@ -109,29 +116,38 @@ async def intercept( metadata_dict = self._extract_metadata(client_call_details.metadata) use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" - clear_cache = metadata_dict.get(METADATA_CLEAR_CACHE, "false").lower() == "true" + ignore_cache = metadata_dict.get(METADATA_IGNORE_CACHE, "false").lower() == "true" + custom_ttl_str = metadata_dict.get(METADATA_CACHE_TTL) + + # Parse custom TTL if provided + custom_ttl = None + if custom_ttl_str: + try: + custom_ttl = int(custom_ttl_str) + except ValueError: + logger.warning(f"Invalid cache TTL value: {custom_ttl_str}, using default") + + # If ignore_cache is set, bypass cache without clearing + if ignore_cache: + logger.debug("Ignoring cache for this request") + return await method(request_or_iterator, client_call_details) # If caching is not enabled, just pass through - if not use_cache and not clear_cache and not force_refresh: + if not use_cache and not force_refresh: return await method(request_or_iterator, client_call_details) # Generate cache key cache_key = self._generate_cache_key(client_call_details.method, request_or_iterator) - # Handle clear-cache flag - if clear_cache: - logger.debug(f"Clearing cache for key: {cache_key}") - self._cache.delete(cache_key) - # Continue with the request after clearing - # Handle force-refresh flag if force_refresh: logger.debug(f"Force refresh for key: {cache_key}") call = await method(request_or_iterator, client_call_details) # For async, we need to await the response response = await call - # Cache the fresh result - self._cache_response(cache_key, response) + # Cache the fresh result with custom TTL if provided + ttl = custom_ttl if custom_ttl is not None else self.ttl + self._cache_response(cache_key, response, ttl) return response # Try to get from cache if use-cache is enabled @@ -149,7 +165,8 @@ async def intercept( # Cache the response if use-cache is enabled if use_cache: - self._cache_response(cache_key, response) + ttl = custom_ttl if custom_ttl is not None else self.ttl + self._cache_response(cache_key, response, ttl) return response @@ -179,16 +196,18 @@ def _generate_cache_key(self, method_name: str, request: Any) -> str: return cache_key - def _cache_response(self, cache_key: str, response: Any) -> None: + def _cache_response(self, cache_key: str, response: Any, ttl: int | None = None) -> None: """Store a response in the cache with TTL. Args: cache_key: The cache key. response: The response to cache. + ttl: Optional custom TTL. If None, uses the default TTL. """ try: - self._cache.set(cache_key, response, expire=self.ttl) - logger.debug(f"Cached response for key: {cache_key} with TTL: {self.ttl}s") + effective_ttl = ttl if ttl is not None else self.ttl + self._cache.set(cache_key, response, expire=effective_ttl) + logger.debug(f"Cached response for key: {cache_key} with TTL: {effective_ttl}s") except Exception as e: logger.error(f"Failed to cache response for key {cache_key}: {e}") diff --git a/python/lib/sift_py/grpc/_interceptors/caching.py b/python/lib/sift_py/grpc/_interceptors/caching.py deleted file mode 100644 index d8dee6986..000000000 --- a/python/lib/sift_py/grpc/_interceptors/caching.py +++ /dev/null @@ -1,216 +0,0 @@ -"""gRPC caching interceptor for transparent local response caching. - -This module provides a caching interceptor that can be used to cache gRPC unary-unary -responses locally using diskcache. The cache is persistent across runs and supports TTL -expiration and per-request control via metadata. - -Usage: - from sift_py.grpc._interceptors.caching import CachingInterceptor - - # Create interceptor with 1 hour TTL - cache_interceptor = CachingInterceptor(ttl=3600, cache_path=".grpc_cache") - - # Use with metadata to control caching: - metadata = [ - ("use-cache", "true"), # Enable caching for this call - # ("force-refresh", "true"), # Bypass cache and store fresh result - # ("clear-cache", "true"), # Delete cached entry before request - ] -""" - -from __future__ import annotations - -import hashlib -import logging -from pathlib import Path -from typing import Any, Optional - -import grpc - -from sift_py.grpc._interceptors.base import ClientInterceptor, Continuation -from sift_py.grpc._interceptors.context import ClientCallDetails - -logger = logging.getLogger(__name__) - -# Metadata keys for cache control -METADATA_USE_CACHE = "use-cache" -METADATA_FORCE_REFRESH = "force-refresh" -METADATA_CLEAR_CACHE = "clear-cache" - - -class CachingInterceptor(ClientInterceptor): - """Interceptor that caches unary-unary gRPC responses locally. - - This interceptor uses diskcache for persistent storage with TTL support. - Cache keys are generated deterministically based on the gRPC method name - and serialized request payload. - - Attributes: - ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). - cache_path: Path to the cache directory. Default is ".grpc_cache". - size_limit: Maximum size of the cache in bytes. Default is 1GB. - """ - - def __init__( - self, - ttl: int = 3600, - cache_path: str = ".grpc_cache", - size_limit: int = 1024 * 1024 * 1024, # 1GB - ): - """Initialize the caching interceptor. - - Args: - ttl: Time-to-live for cached entries in seconds. - cache_path: Path to the cache directory. - size_limit: Maximum size of the cache in bytes. - """ - try: - import diskcache - except ImportError: - raise ImportError( - "diskcache is required for caching. Install it with: pip install diskcache" - ) - - self.ttl = ttl - self.cache_path = Path(cache_path) - self.size_limit = size_limit - - # Create cache directory if it doesn't exist - self.cache_path.mkdir(parents=True, exist_ok=True) - - # Initialize diskcache - self._cache = diskcache.Cache(str(self.cache_path), size_limit=size_limit) - - logger.debug( - f"Initialized CachingInterceptor with ttl={ttl}s, " - f"cache_path={cache_path}, size_limit={size_limit} bytes" - ) - - def intercept( - self, - method: Continuation, - request_or_iterator: Any, - client_call_details: grpc.ClientCallDetails, - ): - """Intercept the gRPC call and apply caching logic. - - Args: - method: The continuation to call for the actual RPC. - request_or_iterator: The request object or iterator. - client_call_details: The call details including method name and metadata. - - Returns: - The response from the cache or the actual RPC call. - """ - # Extract metadata flags - metadata_dict = self._extract_metadata(client_call_details.metadata) - use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" - force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" - clear_cache = metadata_dict.get(METADATA_CLEAR_CACHE, "false").lower() == "true" - - # If caching is not enabled, just pass through - if not use_cache and not clear_cache and not force_refresh: - return method(request_or_iterator, client_call_details) - - # Generate cache key - cache_key = self._generate_cache_key(client_call_details.method, request_or_iterator) - - # Handle clear-cache flag - if clear_cache: - logger.debug(f"Clearing cache for key: {cache_key}") - self._cache.delete(cache_key) - # Continue with the request after clearing - - # Handle force-refresh flag - if force_refresh: - logger.debug(f"Force refresh for key: {cache_key}") - response = method(request_or_iterator, client_call_details) - # Cache the fresh result - self._cache_response(cache_key, response) - return response - - # Try to get from cache if use-cache is enabled - if use_cache: - cached_response = self._cache.get(cache_key) - if cached_response is not None: - logger.debug(f"Cache hit for key: {cache_key}") - return cached_response - - logger.debug(f"Cache miss for key: {cache_key}") - - # Make the actual RPC call - response = method(request_or_iterator, client_call_details) - - # Cache the response if use-cache is enabled - if use_cache: - self._cache_response(cache_key, response) - - return response - - def _generate_cache_key(self, method_name: str, request: Any) -> str: - """Generate a deterministic cache key from method name and request. - - Args: - method_name: The gRPC method name. - request: The request object. - - Returns: - A SHA256 hash of the method name and serialized request. - """ - try: - # Serialize the request using protobuf's SerializeToString - request_bytes = request.SerializeToString() - except AttributeError: - # If the request doesn't have SerializeToString, fall back to str - logger.warning( - f"Request for {method_name} doesn't have SerializeToString, using str() instead" - ) - request_bytes = str(request).encode() - - # Create a deterministic hash - key_material = method_name.encode() + request_bytes - cache_key = hashlib.sha256(key_material).hexdigest() - - return cache_key - - def _cache_response(self, cache_key: str, response: Any) -> None: - """Store a response in the cache with TTL. - - Args: - cache_key: The cache key. - response: The response to cache. - """ - try: - self._cache.set(cache_key, response, expire=self.ttl) - logger.debug(f"Cached response for key: {cache_key} with TTL: {self.ttl}s") - except Exception as e: - logger.error(f"Failed to cache response for key {cache_key}: {e}") - - def _extract_metadata(self, metadata: Optional[tuple[tuple[str, str], ...]]) -> dict[str, str]: - """Extract metadata into a dictionary. - - Args: - metadata: The metadata tuple. - - Returns: - A dictionary of metadata key-value pairs. - """ - if metadata is None: - return {} - return dict(metadata) - - def clear_all(self) -> None: - """Clear all cached entries.""" - logger.info("Clearing all cached entries") - self._cache.clear() - - def close(self) -> None: - """Close the cache and release resources.""" - logger.debug("Closing cache") - self._cache.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() diff --git a/python/lib/sift_py/grpc/_tests/test_caching.py b/python/lib/sift_py/grpc/_tests/test_caching.py deleted file mode 100644 index c0641b85a..000000000 --- a/python/lib/sift_py/grpc/_tests/test_caching.py +++ /dev/null @@ -1,329 +0,0 @@ -"""Tests for gRPC caching interceptor.""" - -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, Mock - -import grpc -import pytest - -from sift_py.grpc._interceptors.caching import CachingInterceptor -from sift_py.grpc._interceptors.context import ClientCallDetails -from sift_py.grpc.cache import ( - clear_cache_for, - with_cache, - with_force_refresh, - without_cache, -) - - -class MockRequest: - """Mock protobuf request for testing.""" - - def __init__(self, data: str): - self.data = data - - def SerializeToString(self) -> bytes: - return self.data.encode() - - -class MockResponse: - """Mock protobuf response for testing.""" - - def __init__(self, value: str): - self.value = value - - -@pytest.fixture -def temp_cache_dir(): - """Create a temporary cache directory.""" - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - -@pytest.fixture -def interceptor(temp_cache_dir): - """Create a caching interceptor with a temporary cache directory.""" - return CachingInterceptor(ttl=60, cache_path=temp_cache_dir) - - -def test_cache_miss_and_hit(interceptor): - """Test that cache miss fetches from server and cache hit returns cached response.""" - # Setup - request = MockRequest("test-data") - response = MockResponse("test-response") - method_name = "/test.Service/TestMethod" - - # Create mock continuation - continuation = Mock(return_value=response) - - # Create call details with cache enabled - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_cache(), - credentials=None, - wait_for_ready=None, - ) - - # First call - cache miss - result1 = interceptor.intercept(continuation, request, call_details) - assert result1 == response - assert continuation.call_count == 1 - - # Second call - cache hit - result2 = interceptor.intercept(continuation, request, call_details) - assert result2 == response - assert continuation.call_count == 1 # Should not call continuation again - - -def test_cache_disabled_by_default(interceptor): - """Test that caching is disabled by default without metadata.""" - # Setup - request = MockRequest("test-data") - response = MockResponse("test-response") - method_name = "/test.Service/TestMethod" - - # Create mock continuation - continuation = Mock(return_value=response) - - # Create call details without cache metadata - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - ) - - # First call - result1 = interceptor.intercept(continuation, request, call_details) - assert result1 == response - assert continuation.call_count == 1 - - # Second call - should call continuation again (no caching) - result2 = interceptor.intercept(continuation, request, call_details) - assert result2 == response - assert continuation.call_count == 2 - - -def test_force_refresh(interceptor): - """Test that force refresh bypasses cache and stores fresh result.""" - # Setup - request = MockRequest("test-data") - response1 = MockResponse("response-1") - response2 = MockResponse("response-2") - method_name = "/test.Service/TestMethod" - - # Create mock continuation that returns different responses - continuation = Mock(side_effect=[response1, response2]) - - # First call with cache - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_cache(), - credentials=None, - wait_for_ready=None, - ) - result1 = interceptor.intercept(continuation, request, call_details) - assert result1 == response1 - assert continuation.call_count == 1 - - # Second call with force refresh - call_details_refresh = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_force_refresh(), - credentials=None, - wait_for_ready=None, - ) - result2 = interceptor.intercept(continuation, request, call_details_refresh) - assert result2 == response2 - assert continuation.call_count == 2 # Should call continuation again - - -def test_clear_cache(interceptor): - """Test that clear cache deletes the cached entry.""" - # Setup - request = MockRequest("test-data") - response1 = MockResponse("response-1") - response2 = MockResponse("response-2") - method_name = "/test.Service/TestMethod" - - # Create mock continuation - continuation = Mock(side_effect=[response1, response2]) - - # First call with cache - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_cache(), - credentials=None, - wait_for_ready=None, - ) - result1 = interceptor.intercept(continuation, request, call_details) - assert result1 == response1 - assert continuation.call_count == 1 - - # Second call with clear cache - call_details_clear = ClientCallDetails( - method=method_name, - timeout=None, - metadata=clear_cache_for(), - credentials=None, - wait_for_ready=None, - ) - result2 = interceptor.intercept(continuation, request, call_details_clear) - assert result2 == response2 - assert continuation.call_count == 2 # Should call continuation again - - -def test_different_requests_different_cache_keys(interceptor): - """Test that different requests generate different cache keys.""" - # Setup - request1 = MockRequest("data-1") - request2 = MockRequest("data-2") - response1 = MockResponse("response-1") - response2 = MockResponse("response-2") - method_name = "/test.Service/TestMethod" - - # Create mock continuation - continuation = Mock(side_effect=[response1, response2]) - - # First call with request1 - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_cache(), - credentials=None, - wait_for_ready=None, - ) - result1 = interceptor.intercept(continuation, request1, call_details) - assert result1 == response1 - assert continuation.call_count == 1 - - # Second call with request2 (different request) - result2 = interceptor.intercept(continuation, request2, call_details) - assert result2 == response2 - assert continuation.call_count == 2 # Should call continuation for different request - - -def test_cache_key_generation(interceptor): - """Test that cache key generation is deterministic.""" - request = MockRequest("test-data") - method_name = "/test.Service/TestMethod" - - key1 = interceptor._generate_cache_key(method_name, request) - key2 = interceptor._generate_cache_key(method_name, request) - - assert key1 == key2 - assert len(key1) == 64 # SHA256 hex digest length - - -def test_without_cache_helper(): - """Test the without_cache helper function.""" - metadata = without_cache() - assert len(metadata) == 0 - - # Test with existing metadata - existing = [("key", "value")] - metadata = without_cache(existing) - assert ("key", "value") in metadata - assert ("use-cache", "true") not in metadata - - -def test_with_cache_helper(): - """Test the with_cache helper function.""" - metadata = with_cache() - assert ("use-cache", "true") in metadata - - -def test_with_force_refresh_helper(): - """Test the with_force_refresh helper function.""" - metadata = with_force_refresh() - assert ("force-refresh", "true") in metadata - assert ("use-cache", "true") in metadata - - -def test_clear_cache_for_helper(): - """Test the clear_cache_for helper function.""" - metadata = clear_cache_for() - assert ("clear-cache", "true") in metadata - - -def test_clear_all(interceptor): - """Test clearing all cached entries.""" - # Setup - request = MockRequest("test-data") - response = MockResponse("test-response") - method_name = "/test.Service/TestMethod" - - # Create mock continuation - continuation = Mock(return_value=response) - - # Create call details with cache enabled - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_cache(), - credentials=None, - wait_for_ready=None, - ) - - # First call - cache miss - result1 = interceptor.intercept(continuation, request, call_details) - assert result1 == response - assert continuation.call_count == 1 - - # Clear all cache - interceptor.clear_all() - - # Second call - should be cache miss again - result2 = interceptor.intercept(continuation, request, call_details) - assert result2 == response - assert continuation.call_count == 2 - - -def test_context_manager(temp_cache_dir): - """Test that the interceptor works as a context manager.""" - with CachingInterceptor(ttl=60, cache_path=temp_cache_dir) as interceptor: - assert interceptor is not None - # Cache should be usable within the context - request = MockRequest("test") - key = interceptor._generate_cache_key("/test", request) - assert key is not None - - -def test_cache_persistence(temp_cache_dir): - """Test that cache persists across interceptor instances.""" - request = MockRequest("test-data") - response = MockResponse("test-response") - method_name = "/test.Service/TestMethod" - - # Create first interceptor and cache a response - interceptor1 = CachingInterceptor(ttl=60, cache_path=temp_cache_dir) - continuation = Mock(return_value=response) - call_details = ClientCallDetails( - method=method_name, - timeout=None, - metadata=with_cache(), - credentials=None, - wait_for_ready=None, - ) - result1 = interceptor1.intercept(continuation, request, call_details) - assert result1 == response - assert continuation.call_count == 1 - interceptor1.close() - - # Create second interceptor with same cache path - interceptor2 = CachingInterceptor(ttl=60, cache_path=temp_cache_dir) - continuation2 = Mock(return_value=response) - result2 = interceptor2.intercept(continuation2, request, call_details) - assert result2 == response - assert continuation2.call_count == 0 # Should use cached response - interceptor2.close() - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py index 919938751..5af87f83a 100644 --- a/python/lib/sift_py/grpc/cache.py +++ b/python/lib/sift_py/grpc/cache.py @@ -5,7 +5,7 @@ per-request basis via metadata. Example: - from sift_py.grpc.cache import with_cache, with_force_refresh, clear_cache_for + from sift_py.grpc.cache import with_cache, with_force_refresh, ignore_cache # Enable caching for a request metadata = with_cache() @@ -15,8 +15,8 @@ metadata = with_force_refresh() response = stub.GetData(request, metadata=metadata) - # Clear cache for a specific request - metadata = clear_cache_for() + # Ignore cache without clearing + metadata = ignore_cache() response = stub.GetData(request, metadata=metadata) """ @@ -25,69 +25,89 @@ # Metadata keys for cache control METADATA_USE_CACHE = "use-cache" METADATA_FORCE_REFRESH = "force-refresh" -METADATA_CLEAR_CACHE = "clear-cache" +METADATA_IGNORE_CACHE = "ignore-cache" +METADATA_CACHE_TTL = "cache-ttl" def with_cache( existing_metadata: List[Tuple[str, str]] | None = None, + ttl: int | None = None, ) -> List[Tuple[str, str]]: """Add cache control metadata to enable caching for a request. Args: existing_metadata: Optional existing metadata to extend. + ttl: Optional custom TTL in seconds for this specific request. Returns: Metadata list with cache enabled. Example: + # Use default TTL metadata = with_cache() response = stub.GetData(request, metadata=metadata) + + # Use custom TTL (5 minutes) + metadata = with_cache(ttl=300) + response = stub.GetData(request, metadata=metadata) """ metadata = list(existing_metadata) if existing_metadata else [] metadata.append((METADATA_USE_CACHE, "true")) + if ttl is not None: + metadata.append((METADATA_CACHE_TTL, str(ttl))) return metadata def with_force_refresh( existing_metadata: List[Tuple[str, str]] | None = None, + ttl: int | None = None, ) -> List[Tuple[str, str]]: """Add cache control metadata to force refresh (bypass cache and store fresh result). Args: existing_metadata: Optional existing metadata to extend. + ttl: Optional custom TTL in seconds for the refreshed entry. Returns: Metadata list with force refresh enabled. Example: + # Force refresh with default TTL metadata = with_force_refresh() response = stub.GetData(request, metadata=metadata) + + # Force refresh with custom TTL + metadata = with_force_refresh(ttl=600) + response = stub.GetData(request, metadata=metadata) """ metadata = list(existing_metadata) if existing_metadata else [] metadata.append((METADATA_FORCE_REFRESH, "true")) metadata.append((METADATA_USE_CACHE, "true")) # Also enable caching + if ttl is not None: + metadata.append((METADATA_CACHE_TTL, str(ttl))) return metadata -def clear_cache_for( +def ignore_cache( existing_metadata: List[Tuple[str, str]] | None = None, ) -> List[Tuple[str, str]]: - """Add cache control metadata to clear the cache for a specific request. + """Add metadata to ignore cache for this request without clearing it. - This will delete the cached entry before making the request. + This is useful when you want to bypass the cache for a specific call + but don't want to clear the cached entry. Args: existing_metadata: Optional existing metadata to extend. Returns: - Metadata list with clear cache enabled. + Metadata list with ignore cache flag. Example: - metadata = clear_cache_for() + metadata = ignore_cache() response = stub.GetData(request, metadata=metadata) """ metadata = list(existing_metadata) if existing_metadata else [] - metadata.append((METADATA_CLEAR_CACHE, "true")) + metadata.append((METADATA_IGNORE_CACHE, "true")) return metadata @@ -113,6 +133,6 @@ def without_cache( metadata = [ (k, v) for k, v in metadata - if k not in (METADATA_USE_CACHE, METADATA_FORCE_REFRESH, METADATA_CLEAR_CACHE) + if k not in (METADATA_USE_CACHE, METADATA_FORCE_REFRESH, METADATA_IGNORE_CACHE, METADATA_CACHE_TTL) ] return metadata diff --git a/python/lib/sift_py/grpc/transport.py b/python/lib/sift_py/grpc/transport.py index 51b02c0c4..1b3b8bc90 100644 --- a/python/lib/sift_py/grpc/transport.py +++ b/python/lib/sift_py/grpc/transport.py @@ -18,7 +18,6 @@ from sift_py.grpc._async_interceptors.caching import CachingAsyncInterceptor from sift_py.grpc._async_interceptors.metadata import MetadataAsyncInterceptor from sift_py.grpc._interceptors.base import ClientInterceptor -from sift_py.grpc._interceptors.caching import CachingInterceptor from sift_py.grpc._interceptors.metadata import Metadata, MetadataInterceptor from sift_py.grpc._retry import RetryPolicy from sift_py.grpc.keepalive import DEFAULT_KEEPALIVE_CONFIG, KeepaliveConfig @@ -133,21 +132,10 @@ def _compute_sift_interceptors( Initialized all interceptors here. """ interceptors: List[ClientInterceptor] = [] - - # Add caching interceptor if enabled - cache_config = config.get("cache_config") - if cache_config: - interceptors.append( - CachingInterceptor( - ttl=cache_config.get("ttl", 3600), - cache_path=cache_config.get("cache_path", ".grpc_cache"), - size_limit=cache_config.get("size_limit", 1024 * 1024 * 1024), - ) - ) - + # Metadata interceptor should be last to ensure metadata is always added interceptors.append(_metadata_interceptor(config, metadata)) - + return interceptors @@ -155,21 +143,25 @@ def _compute_sift_async_interceptors( config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None ) -> List[grpc_aio.ClientInterceptor]: interceptors: List[grpc_aio.ClientInterceptor] = [] - + # Add caching interceptor if enabled cache_config = config.get("cache_config") - if cache_config: + if cache_config and all( + field in ["ttl", "cache_path", "size_limit", "clear_on_init"] + for field in cache_config.keys() + ): interceptors.append( CachingAsyncInterceptor( - ttl=cache_config.get("ttl", 3600), - cache_path=cache_config.get("cache_path", ".grpc_cache"), - size_limit=cache_config.get("size_limit", 1024 * 1024 * 1024), + ttl=cache_config["ttl"], + cache_path=cache_config["cache_path"], + size_limit=cache_config["size_limit"], + clear_on_init=cache_config["clear_on_init"], ) ) - + # Metadata interceptor should be last to ensure metadata is always added interceptors.append(_metadata_async_interceptor(config, metadata)) - + return interceptors @@ -265,11 +257,13 @@ class CacheConfig(TypedDict): - `ttl`: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). - `cache_path`: Path to the cache directory. Default is ".grpc_cache". - `size_limit`: Maximum size of the cache in bytes. Default is 1GB. + - `clear_on_init`: Whether to clear the cache on initialization. Default is False. """ ttl: NotRequired[int] cache_path: NotRequired[str] size_limit: NotRequired[int] + clear_on_init: NotRequired[bool] class SiftChannelConfig(TypedDict): From 9eca9a62fc0d8983fdd1569954d5d9b4dcd73f65 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Wed, 15 Oct 2025 09:39:57 -0700 Subject: [PATCH 05/16] further iteration --- .../_internal/low_level_wrappers/base.py | 70 ++- .../_internal/low_level_wrappers/data.py | 37 +- python/lib/sift_client/transport/cache.py | 0 .../sift_client/transport/grpc_transport.py | 31 +- .../grpc/_async_interceptors/caching.py | 266 ++++------ .../grpc/_async_interceptors/metadata.py | 9 +- python/lib/sift_py/grpc/cache.py | 247 ++++++--- python/lib/sift_py/grpc/cache_test.py | 477 ++++++++++++++++++ python/lib/sift_py/grpc/transport.py | 50 +- 9 files changed, 910 insertions(+), 277 deletions(-) create mode 100644 python/lib/sift_client/transport/cache.py create mode 100644 python/lib/sift_py/grpc/cache_test.py diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index d349e51a2..d3dfde827 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -1,7 +1,11 @@ from __future__ import annotations from abc import ABC -from typing import Any, Callable +from typing import Any, Callable, TypeVar + +from sift_py.grpc.cache import ignore_cache, with_cache, with_force_refresh + +T = TypeVar("T") class LowLevelClientBase(ABC): @@ -50,3 +54,67 @@ async def _handle_pagination( if max_results and len(results) > max_results: results = results[:max_results] return results + + @staticmethod + async def _call_with_cache( + stub_method: Callable[[Any, tuple[tuple[str, str], ...]], T], + request: Any, + *, + use_cache: bool = True, + force_refresh: bool = False, + ttl: int | None = None, + ) -> T: + """Call a gRPC stub method with cache control. + + This is a convenience method for low-level wrappers to easily enable caching + on their gRPC calls without manually constructing metadata. + + Args: + stub_method: The gRPC stub method to call (e.g., stub.GetData). + request: The protobuf request object. + use_cache: Whether to enable caching for this request. Default: True. + force_refresh: Whether to force refresh the cache. Default: False. + ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. + + Returns: + The response from the gRPC call. + + Example: + # Enable caching + response = await self._call_with_cache( + stub.GetData, + request, + use_cache=True, + ) + + # Force refresh + response = await self._call_with_cache( + stub.GetData, + request, + force_refresh=True, + ) + + # With custom TTL + response = await self._call_with_cache( + stub.GetData, + request, + use_cache=True, + ttl=7200, # 2 hours + ) + + # Ignore cache + response = await self._call_with_cache( + stub.GetData, + request, + use_cache=False, + ) + """ + + if force_refresh: + metadata = with_force_refresh(ttl=ttl) + elif use_cache: + metadata = with_cache(ttl=ttl) + else: + metadata = ignore_cache() + + return await stub_method(request, metadata=metadata) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/data.py b/python/lib/sift_client/_internal/low_level_wrappers/data.py index e5370bbe7..752c05ec6 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/data.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/data.py @@ -74,8 +74,6 @@ def _update_name_id_map(self, channels: list[Channel]): ) self.channel_cache.name_id_map[channel.name] = str(channel.id_) - # TODO: Cache calls. Only read cache if end_time is more than 30 min in the past. - # Also, consider manually caching full channel data and evaluating start/end times while ignoring pagination. Do this ful caching at a higher level though to handle case where pagination fails. async def _get_data_impl( self, *, @@ -86,8 +84,27 @@ async def _get_data_impl( page_size: int | None = None, page_token: str | None = None, order_by: str | None = None, + use_cache: bool = False, + force_refresh: bool = False, + cache_ttl: int | None = None, ) -> tuple[list[Any], str | None]: - """Get the data for a channel during a run.""" + """Get the data for a channel during a run. + + Args: + channel_ids: List of channel IDs to fetch data for. + run_id: Optional run ID to filter data. + start_time: Optional start time for the data range. + end_time: End time for the data range. + page_size: Number of results per page. + page_token: Token for pagination. + order_by: Field to order results by. + use_cache: Whether to enable caching for this request. Default: False. + force_refresh: Whether to force refresh the cache. Default: False. + cache_ttl: Optional custom TTL in seconds for cached responses. + + Returns: + Tuple of (data list, next page token). + """ queries = [ Query(channel=ChannelQuery(channel_id=channel_id, run_id=run_id)) for channel_id in channel_ids @@ -102,7 +119,19 @@ async def _get_data_impl( } request = GetDataRequest(**request_kwargs) - response = await self._grpc_client.get_stub(DataServiceStub).GetData(request) + + # Use cache helper if caching is enabled + if use_cache or force_refresh: + response = await self._call_with_cache( + self._grpc_client.get_stub(DataServiceStub).GetData, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=cache_ttl, + ) + else: + response = await self._grpc_client.get_stub(DataServiceStub).GetData(request) + response = cast("GetDataResponse", response) return response.data, response.next_page_token # type: ignore # mypy doesn't know RepeatedCompositeFieldContainer can be treated like a list diff --git a/python/lib/sift_client/transport/cache.py b/python/lib/sift_client/transport/cache.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 1e4400d6e..50cc7a51c 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -15,10 +15,9 @@ from pathlib import Path from typing import Any +from sift_py.grpc.cache import GrpcCache from sift_py.grpc.transport import ( - CacheConfig as SiftCacheConfig, -) -from sift_py.grpc.transport import ( + SiftCacheConfig, SiftChannelConfig, use_sift_async_channel, ) @@ -178,6 +177,10 @@ def __init__(self, config: GrpcConfig): # map each asyncio loop to its async channel and stub dict self._channels_async: dict[asyncio.AbstractEventLoop, Any] = {} self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = {} + + # Initialize cache if caching is enabled + self._cache = self._init_cache() + # default loop for sync API self._default_loop = asyncio.new_event_loop() atexit.register(self.close_sync) @@ -203,6 +206,24 @@ def _run_default_loop(): self._channels_async[self._default_loop] = channel self._stubs_async_map[self._default_loop] = {} + def _init_cache(self) -> GrpcCache | None: + """Initialize the GrpcCache instance if caching is enabled.""" + if not self._config.cache_config or not self._config.cache_config.is_enabled: + return None + + try: + cache_config = self._config.cache_config + sift_cache_config: SiftCacheConfig = { + "ttl": cache_config.ttl, + "cache_path": cache_config.cache_path, + "size_limit": cache_config.size_limit, + "clear_on_init": cache_config.mode == CacheMode.CLEAR_ON_INIT, + } + return GrpcCache(sift_cache_config) + except Exception as e: + logger.warning(f"Failed to initialize cache: {e}") + return None + @property def default_loop(self) -> asyncio.AbstractEventLoop: """Return the default event loop used for synchronous API operations. @@ -225,7 +246,7 @@ def get_stub(self, stub_class: type[Any]) -> Any: if loop not in self._channels_async: channel = use_sift_async_channel( - self._config._to_sift_channel_config(), self._config.metadata + self._config._to_sift_channel_config(), self._config.metadata, self._cache ) self._channels_async[loop] = channel self._stubs_async_map[loop] = {} @@ -268,4 +289,4 @@ async def _create_async_channel( self, cfg: SiftChannelConfig, metadata: dict[str, str] | None ) -> Any: """Helper to create async channel on default loop.""" - return use_sift_async_channel(cfg, metadata) + return use_sift_async_channel(cfg, metadata, self._cache) diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py index 8094c76c2..609e4bafa 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/caching.py +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -1,14 +1,17 @@ """Async gRPC caching interceptor for transparent local response caching. This module provides an async caching interceptor that can be used to cache gRPC -unary-unary responses locally using diskcache. The cache is persistent across runs -and supports TTL expiration and per-request control via metadata. +unary-unary responses locally using diskcache. The cache is initialized at the +GrpcClient level and passed to the interceptor. -Usage: - from sift_py.grpc._async_interceptors.caching import CachingAsyncInterceptor +Note: Cache initialization is handled by GrpcClient, not by this interceptor. - # Create interceptor with 1 hour TTL - cache_interceptor = CachingAsyncInterceptor(ttl=3600, cache_path=".grpc_cache") +Usage: + # Cache is initialized at GrpcClient level + cache = diskcache.Cache(".grpc_cache", size_limit=1024**3) + + # Create interceptor with cache instance + cache_interceptor = CachingAsyncInterceptor(ttl=3600, cache_instance=cache) # Use with metadata to control caching: metadata = [ @@ -20,81 +23,48 @@ from __future__ import annotations -import hashlib +import importlib import logging -from pathlib import Path -from typing import Any, Optional +from typing import Any import diskcache +from google.protobuf import message from grpc import aio as grpc_aio from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor +from sift_py.grpc.cache import GrpcCache logger = logging.getLogger(__name__) -# Metadata keys for cache control -METADATA_USE_CACHE = "use-cache" -METADATA_FORCE_REFRESH = "force-refresh" -METADATA_IGNORE_CACHE = "ignore-cache" -METADATA_CACHE_TTL = "cache-ttl" - - class CachingAsyncInterceptor(ClientAsyncInterceptor): """Async interceptor that caches unary-unary gRPC responses locally. - This interceptor uses diskcache for persistent storage with TTL support. + This interceptor uses a diskcache instance for persistent storage with TTL support. + The cache instance must be provided during initialization (typically from GrpcClient). Cache keys are generated deterministically based on the gRPC method name and serialized request payload. + Responses are serialized to bytes before caching to avoid pickling issues with + async objects. + Note: diskcache operations are synchronous, but the overhead is minimal for most use cases. For high-throughput scenarios, consider using an async-native cache backend. Attributes: - ttl: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). - cache_path: Path to the cache directory. Default is ".grpc_cache". - size_limit: Maximum size of the cache in bytes. Default is 1GB. + _cache: The GrpcCache instance provided during initialization. """ def __init__( self, - ttl: int = 3600, - cache_path: str = ".grpc_cache", - size_limit: int = 1024**3, # 1GB - clear_on_init: bool = False, + cache: GrpcCache, ): """Initialize the async caching interceptor. Args: - ttl: Time-to-live for cached entries in seconds. - cache_path: Path to the cache directory. - size_limit: Maximum size of the cache in bytes. - clear_on_init: Whether to clear the cache on initialization. + cache: Pre-initialized GrpcCache instance (required). """ - self.ttl = ttl - self.cache_path = Path(cache_path) - self.size_limit = size_limit - - # Create cache directory if it doesn't exist - self.cache_path.mkdir(parents=True, exist_ok=True) - - # Initialize diskcache - self._cache = diskcache.Cache(str(self.cache_path), size_limit=size_limit) - - # Clear cache if requested - if clear_on_init: - logger.info(f"Clearing cache on initialization: {cache_path}") - self._cache.clear() - - logger.info( - f"gRPC cache initialized at {self.cache_path.absolute()!r} " - f"with size {self._cache.volume() / (1024**3):.2f} MB" - ) - - logger.debug( - f"Initialized CachingAsyncInterceptor with ttl={ttl}s, " - f"cache_path={cache_path}, size_limit={size_limit} bytes, clear_on_init={clear_on_init}" - ) + self.cache = cache async def intercept( self, @@ -104,6 +74,8 @@ async def intercept( ) -> Any: """Intercept the async gRPC call and apply caching logic. + Uses GrpcCache.resolve_cache_metadata() to determine caching behavior. + Args: method: The continuation to call for the actual RPC. request_or_iterator: The request object or iterator. @@ -112,130 +84,104 @@ async def intercept( Returns: The response from the cache or the actual RPC call. """ - # Extract metadata flags - metadata_dict = self._extract_metadata(client_call_details.metadata) - use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" - force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" - ignore_cache = metadata_dict.get(METADATA_IGNORE_CACHE, "false").lower() == "true" - custom_ttl_str = metadata_dict.get(METADATA_CACHE_TTL) - - # Parse custom TTL if provided - custom_ttl = None - if custom_ttl_str: - try: - custom_ttl = int(custom_ttl_str) - except ValueError: - logger.warning(f"Invalid cache TTL value: {custom_ttl_str}, using default") - - # If ignore_cache is set, bypass cache without clearing - if ignore_cache: - logger.debug("Ignoring cache for this request") - return await method(request_or_iterator, client_call_details) - - # If caching is not enabled, just pass through - if not use_cache and not force_refresh: - return await method(request_or_iterator, client_call_details) + # Resolve cache metadata to determine behavior + cache_settings = self.cache.resolve_cache_metadata(client_call_details.metadata) # Generate cache key - cache_key = self._generate_cache_key(client_call_details.method, request_or_iterator) - - # Handle force-refresh flag - if force_refresh: - logger.debug(f"Force refresh for key: {cache_key}") - call = await method(request_or_iterator, client_call_details) - # For async, we need to await the response - response = await call - # Cache the fresh result with custom TTL if provided - ttl = custom_ttl if custom_ttl is not None else self.ttl - self._cache_response(cache_key, response, ttl) - return response - - # Try to get from cache if use-cache is enabled - if use_cache: - cached_response = self._cache.get(cache_key) - if cached_response is not None: - logger.debug(f"Cache hit for key: {cache_key}") - return cached_response + key = self.cache.key_from_proto_message( + method_name=client_call_details.method, request=request_or_iterator + ) - logger.debug(f"Cache miss for key: {cache_key}") + # Try to read from cache if allowed + if cache_settings.use_cache and not cache_settings.force_refresh: + try: + cached_data = self.cache.get(key) + if cached_data is not None: + logger.debug(f"Cache hit for `{key}`") + # Cached data is a tuple of (response_type_name, response_bytes) + response_type_name, response_bytes = cached_data + # Reconstruct the response from bytes + response = self._deserialize_response(response_type_name, response_bytes) + return response + except diskcache.Timeout as e: + logger.debug(f"Cache read timeout for `{key}`: {e}") + except Exception as e: + logger.warning(f"Failed to deserialize cached response for `{key}`: {e}") + + # Force refresh if requested + if cache_settings.force_refresh: + logger.debug(f"Forcing refresh for `{key}`") + self.cache.delete(key) # Make the actual RPC call call = await method(request_or_iterator, client_call_details) + + # The call is a UnaryUnaryCall object, we need to await it to get the actual response response = await call - # Cache the response if use-cache is enabled - if use_cache: - ttl = custom_ttl if custom_ttl is not None else self.ttl - self._cache_response(cache_key, response, ttl) + # Cache the response if allowed + if cache_settings.use_cache: + try: + # Serialize the protobuf response to bytes before caching + if isinstance(response, message.Message): + response_bytes = response.SerializeToString() + response_type_name = type(response).DESCRIPTOR.full_name + # Store both the type name and the serialized bytes + cached_data = (response_type_name, response_bytes) + self.cache.set_with_default_ttl(key, cached_data, expire=cache_settings.custom_ttl) + logger.debug(f"Cached response for `{key}`") + else: + logger.warning(f"Response is not a protobuf message, skipping cache for `{key}`") + logger.warning(f"Response type: {type(response)}") + except diskcache.Timeout as e: + logger.warning(f"Failed to cache response for `{key}`: {e}") + except Exception as e: + logger.warning(f"Failed to serialize response for caching `{key}`: {e}") return response - def _generate_cache_key(self, method_name: str, request: Any) -> str: - """Generate a deterministic cache key from method name and request. + def _deserialize_response(self, response_type_name: str, response_bytes: bytes) -> message.Message: + """Deserialize a cached response from bytes. Args: - method_name: The gRPC method name. - request: The request object. + response_type_name: The full protobuf type name (e.g., 'sift.data.v2.GetDataResponse') + response_bytes: The serialized protobuf bytes Returns: - A SHA256 hash of the method name and serialized request. - """ - try: - # Serialize the request using protobuf's SerializeToString - request_bytes = request.SerializeToString() - except AttributeError: - # If the request doesn't have SerializeToString, fall back to str - logger.warning( - f"Request for {method_name} doesn't have SerializeToString, using str() instead" - ) - request_bytes = str(request).encode() - - # Create a deterministic hash - key_material = method_name.encode() + request_bytes - cache_key = hashlib.sha256(key_material).hexdigest() + The deserialized protobuf message - return cache_key - - def _cache_response(self, cache_key: str, response: Any, ttl: int | None = None) -> None: - """Store a response in the cache with TTL. - - Args: - cache_key: The cache key. - response: The response to cache. - ttl: Optional custom TTL. If None, uses the default TTL. + Raises: + ImportError: If the response type cannot be imported + Exception: If deserialization fails """ - try: - effective_ttl = ttl if ttl is not None else self.ttl - self._cache.set(cache_key, response, expire=effective_ttl) - logger.debug(f"Cached response for key: {cache_key} with TTL: {effective_ttl}s") - except Exception as e: - logger.error(f"Failed to cache response for key {cache_key}: {e}") - - def _extract_metadata(self, metadata: Optional[tuple[tuple[str, str], ...]]) -> dict[str, str]: - """Extract metadata into a dictionary. + # Import the response type dynamically + # Convert 'sift.data.v2.GetDataResponse' to module and class + parts = response_type_name.rsplit('.', 1) + if len(parts) != 2: + raise ValueError(f"Invalid response type name: {response_type_name}") + + package_name, class_name = parts + + # Protobuf generates Python modules with _pb2 suffix + # e.g., 'sift.data.v2' -> 'sift.data.v2.data_pb2' + # Extract the service name from the package (last part before version) + package_parts = package_name.split('.') + if len(package_parts) >= 2: + # Get the service name (e.g., 'data' from 'sift.data.v2') + service_name = package_parts[-2] + python_module = f"{package_name}.{service_name}_pb2" + else: + # Fallback: just append _pb2 + python_module = f"{package_name}_pb2" - Args: - metadata: The metadata tuple. - - Returns: - A dictionary of metadata key-value pairs. - """ - if metadata is None: - return {} - return dict(metadata) - - def clear_all(self) -> None: - """Clear all cached entries.""" - logger.info("Clearing all cached entries") - self._cache.clear() - - def close(self) -> None: - """Close the cache and release resources.""" - logger.debug("Closing cache") - self._cache.close() - - def __enter__(self): - return self + try: + # Import the module + module = importlib.import_module(python_module) + response_class = getattr(module, class_name) - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + # Deserialize + response = response_class() + response.ParseFromString(response_bytes) + return response + except (ImportError, AttributeError) as e: + raise ImportError(f"Failed to import response type {response_type_name} from {python_module}: {e}") diff --git a/python/lib/sift_py/grpc/_async_interceptors/metadata.py b/python/lib/sift_py/grpc/_async_interceptors/metadata.py index 0592c3648..d506266f7 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/metadata.py +++ b/python/lib/sift_py/grpc/_async_interceptors/metadata.py @@ -26,10 +26,17 @@ async def intercept( client_call_details: grpc_aio.ClientCallDetails, ): call_details = cast(grpc_aio.ClientCallDetails, client_call_details) + + # Merge existing metadata with interceptor metadata + # Existing metadata from the call takes precedence + merged_metadata = list(self.metadata) + if call_details.metadata: + merged_metadata.extend(call_details.metadata) + new_details = grpc_aio.ClientCallDetails( call_details.method, call_details.timeout, - self.metadata, + merged_metadata, call_details.credentials, call_details.wait_for_ready, ) diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py index 5af87f83a..03aac393f 100644 --- a/python/lib/sift_py/grpc/cache.py +++ b/python/lib/sift_py/grpc/cache.py @@ -20,119 +20,214 @@ response = stub.GetData(request, metadata=metadata) """ -from typing import List, Tuple +from __future__ import annotations + +import hashlib +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, NamedTuple + +import diskcache +from google.protobuf import json_format, message + +if TYPE_CHECKING: + + + from sift_py.grpc.transport import SiftCacheConfig + +logger = logging.getLogger(__name__) + + +class CacheSettings(NamedTuple): + """Resolved cache metadata from gRPC request.""" + + use_cache: bool + force_refresh: bool + custom_ttl: float | None # Metadata keys for cache control METADATA_USE_CACHE = "use-cache" METADATA_FORCE_REFRESH = "force-refresh" -METADATA_IGNORE_CACHE = "ignore-cache" METADATA_CACHE_TTL = "cache-ttl" -def with_cache( - existing_metadata: List[Tuple[str, str]] | None = None, - ttl: int | None = None, -) -> List[Tuple[str, str]]: - """Add cache control metadata to enable caching for a request. - +class GrpcCache(diskcache.Cache): + """Subclass of diskcache.Cache for gRPC response caching.""" + + def __init__(self, config: SiftCacheConfig): + """Initialize the cache from configuration. + + Args: + config: Cache configuration with ttl, cache_path, size_limit, clear_on_init. + """ + self.default_ttl = config["ttl"] + self.cache_path = Path(config["cache_path"]) + self.size_limit = config["size_limit"] + + # Create cache directory if it doesn't exist + self.cache_path.mkdir(parents=True, exist_ok=True) + + # Initialize parent diskcache.Cache + super().__init__(str(self.cache_path), size_limit=self.size_limit) + + # Clear cache if requested + if config.get("clear_on_init", False): + logger.debug(f"Clearing cache on initialization: {self.cache_path}") + self.clear() + + logger.debug( + f"Cache initialized at {self.cache_path.absolute()!r} " + f"with size {self.volume() / (1024**2):.2f} MB" + ) + + def set_with_default_ttl(self, key: str, value: Any, expire: float | None = None, **kwargs) -> bool: + expire_time = expire if expire is not None else self.default_ttl + return super().set(key, value, expire=expire_time, **kwargs) + + @staticmethod + def key_from_proto_message(method_name: str | bytes, request: message.Message) -> str: + # Serialize the request to bytes + request_json = json_format.MessageToJson(request).encode("utf-8") + + if isinstance(method_name, str): + method_name = method_name.encode("utf-8") + + # Create a hash of method name + request + hasher = hashlib.sha256() + hasher.update(method_name) + hasher.update(request_json) + + return hasher.hexdigest() + + @staticmethod + def resolve_cache_metadata( + metadata: tuple[tuple[str, str], ...] | None + ) -> CacheSettings: + """Extract and resolve cache-related metadata fields. + + Args: + metadata: The gRPC request metadata tuple. + + Returns: + CacheMetadata named tuple with resolved cache control fields: + - use_cache: bool - Whether to use caching + - force_refresh: bool - Whether to force refresh + - ignore_cache: bool - Whether to ignore cache + - custom_ttl: int | None - Custom TTL if specified + - should_read: bool - Whether to read from cache + - should_cache: bool - Whether to cache the response + + Example: + cache_info = cache.resolve_cache_metadata(metadata) + if cache_info.should_read: + cached = cache.get(key) + if cache_info.should_cache: + cache.set_with_default_ttl(key, response, expire=cache_info.custom_ttl) + """ + if not metadata: + metadata_dict = {} + else: + # Handle both tuple and grpc.aio.Metadata types + metadata_dict = {} + for key, value in metadata: + metadata_dict[key] = value + + use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" + + if not use_cache: + return CacheSettings(use_cache=False, force_refresh=False, custom_ttl=None) + + force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" + custom_ttl_str = metadata_dict.get(METADATA_CACHE_TTL) + + # Parse custom TTL if provided + custom_ttl = None + if custom_ttl_str: + try: + custom_ttl = int(custom_ttl_str) + except ValueError: + logger.warning(f"Invalid cache TTL value: {custom_ttl_str}, using default") + + return CacheSettings( + use_cache=use_cache, + force_refresh=force_refresh, + custom_ttl=custom_ttl, + ) + + +def with_cache(ttl: int | None = None) -> tuple[tuple[str, str], ...]: + """Enable caching for a gRPC request. + Args: - existing_metadata: Optional existing metadata to extend. - ttl: Optional custom TTL in seconds for this specific request. - + ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. + Returns: - Metadata list with cache enabled. - + Metadata tuple to pass to the gRPC stub method. + Example: - # Use default TTL metadata = with_cache() response = stub.GetData(request, metadata=metadata) - - # Use custom TTL (5 minutes) - metadata = with_cache(ttl=300) + + # With custom TTL + metadata = with_cache(ttl=7200) # 2 hours response = stub.GetData(request, metadata=metadata) """ - metadata = list(existing_metadata) if existing_metadata else [] - metadata.append((METADATA_USE_CACHE, "true")) + metadata = [(METADATA_USE_CACHE, "true")] if ttl is not None: metadata.append((METADATA_CACHE_TTL, str(ttl))) - return metadata - + return tuple(metadata) -def with_force_refresh( - existing_metadata: List[Tuple[str, str]] | None = None, - ttl: int | None = None, -) -> List[Tuple[str, str]]: - """Add cache control metadata to force refresh (bypass cache and store fresh result). +def with_force_refresh(ttl: int | None = None) -> tuple[tuple[str, str], ...]: + """Force refresh the cache for a gRPC request. + + Bypasses the cache, fetches fresh data from the server, and stores the result. + Args: - existing_metadata: Optional existing metadata to extend. - ttl: Optional custom TTL in seconds for the refreshed entry. - + ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. + Returns: - Metadata list with force refresh enabled. - + Metadata tuple to pass to the gRPC stub method. + Example: - # Force refresh with default TTL metadata = with_force_refresh() response = stub.GetData(request, metadata=metadata) - - # Force refresh with custom TTL - metadata = with_force_refresh(ttl=600) - response = stub.GetData(request, metadata=metadata) """ - metadata = list(existing_metadata) if existing_metadata else [] - metadata.append((METADATA_FORCE_REFRESH, "true")) - metadata.append((METADATA_USE_CACHE, "true")) # Also enable caching + metadata = [ + (METADATA_USE_CACHE, "true"), + (METADATA_FORCE_REFRESH, "true"), + ] if ttl is not None: metadata.append((METADATA_CACHE_TTL, str(ttl))) - return metadata + return tuple(metadata) -def ignore_cache( - existing_metadata: List[Tuple[str, str]] | None = None, -) -> List[Tuple[str, str]]: - """Add metadata to ignore cache for this request without clearing it. - - This is useful when you want to bypass the cache for a specific call - but don't want to clear the cached entry. - - Args: - existing_metadata: Optional existing metadata to extend. - +def ignore_cache() -> tuple[tuple[str, str], ...]: + """Ignore the cache for a gRPC request without clearing it. + + Bypasses the cache for this request but doesn't invalidate the cached entry. + The response from this request will not be cached. + Returns: - Metadata list with ignore cache flag. - + Metadata tuple to pass to the gRPC stub method. + Example: metadata = ignore_cache() response = stub.GetData(request, metadata=metadata) """ - metadata = list(existing_metadata) if existing_metadata else [] - metadata.append((METADATA_IGNORE_CACHE, "true")) - return metadata - - -def without_cache( - existing_metadata: List[Tuple[str, str]] | None = None, -) -> List[Tuple[str, str]]: - """Explicitly disable caching for a request. - - This is the default behavior, so this function is mainly for clarity. + return tuple() - Args: - existing_metadata: Optional existing metadata to extend. +def without_cache() -> tuple[tuple[str, str], ...]: + """Explicitly disable caching for a gRPC request. + + This is the default behavior when no cache metadata is provided. + Returns: - Metadata list without cache flags. - + Empty metadata tuple. + Example: metadata = without_cache() response = stub.GetData(request, metadata=metadata) """ - metadata = list(existing_metadata) if existing_metadata else [] - # Remove any cache-related metadata - metadata = [ - (k, v) - for k, v in metadata - if k not in (METADATA_USE_CACHE, METADATA_FORCE_REFRESH, METADATA_IGNORE_CACHE, METADATA_CACHE_TTL) - ] - return metadata + return tuple() diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py new file mode 100644 index 000000000..59f1141b0 --- /dev/null +++ b/python/lib/sift_py/grpc/cache_test.py @@ -0,0 +1,477 @@ +# ruff: noqa: N802 + +import logging +import tempfile +from concurrent import futures +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, Iterator, cast + +import grpc +import pytest +from pytest_mock import MockFixture, MockType +from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse +from sift.data.v2.data_pb2_grpc import ( + DataServiceServicer, + DataServiceStub, + add_DataServiceServicer_to_server, +) + +from sift_py._internal.test_util.server_interceptor import ServerInterceptor +from sift_py.grpc.cache import ( + GrpcCache, + ignore_cache, + with_cache, + with_force_refresh, + without_cache, +) +from sift_py.grpc.transport import SiftChannelConfig, use_sift_async_channel + +# Enable debug logging for cache-related modules +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logging.getLogger('sift_py').setLevel(logging.DEBUG) + +class DataService(DataServiceServicer): + """Mock data service that returns a unique response each time.""" + + call_count: int + + def __init__(self): + self.call_count = 0 + + def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): + self.call_count += 1 + # Return a unique token each time to verify caching + return GetDataResponse(next_page_token=f"token-{self.call_count}") + + +class AuthInterceptor(ServerInterceptor): + """Simple auth interceptor that checks for Bearer token.""" + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + authenticated = False + for metadata in context.invocation_metadata(): + if metadata.key == "authorization": + if metadata.value.startswith("Bearer "): + authenticated = True + break + + if authenticated: + return method(request_or_iterator, context) + else: + context.set_code(grpc.StatusCode.UNAUTHENTICATED) + context.set_details("Invalid or missing API key") + raise + + +@contextmanager +def server_with_service(mocker: MockFixture) -> Iterator[tuple[MockType, DataService, int]]: + """Create a test server with a spy on the DataService. + + Returns: + Tuple of (spy, data_service, port) + """ + server = grpc.server( + thread_pool=futures.ThreadPoolExecutor(max_workers=1), + interceptors=[AuthInterceptor()], + ) + + data_service = DataService() + spy = mocker.spy(data_service, "GetData") + + add_DataServiceServicer_to_server(data_service, server) + # Use port 0 to let the OS assign an available port + port = server.add_insecure_port("[::]:0") + server.start() + try: + yield spy, data_service, port + finally: + server.stop(None) + server.wait_for_termination() + + +def test_cache_helper_functions(): + """Test the cache metadata helper functions.""" + # Test with_cache + metadata = with_cache() + assert metadata == (("use-cache", "true"),) + + # Test with_cache with custom TTL + metadata = with_cache(ttl=7200) + assert metadata == (("use-cache", "true"), ("cache-ttl", "7200")) + + # Test with_force_refresh + metadata = with_force_refresh() + assert metadata == (("use-cache", "true"), ("force-refresh", "true")) + + # Test with_force_refresh with custom TTL + metadata = with_force_refresh(ttl=3600) + assert metadata == (("use-cache", "true"), ("force-refresh", "true"), ("cache-ttl", "3600")) + + # Test ignore_cache + metadata = ignore_cache() + assert metadata == () + + # Test without_cache + metadata = without_cache() + assert metadata == () + + +def test_grpc_cache_initialization(): + """Test GrpcCache initialization and configuration.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_config = { + "ttl": 1800, + "cache_path": str(Path(tmpdir) / "test_cache"), + "size_limit": 1024 * 1024, # 1MB + "clear_on_init": False, + } + + cache = GrpcCache(cache_config) + assert cache.default_ttl == 1800 + assert cache.cache_path == Path(tmpdir) / "test_cache" + assert cache.size_limit == 1024 * 1024 + assert cache.cache_path.exists() + + # Test clear_on_init + cache.set("test-key", "test-value") + assert cache.get("test-key") == "test-value" + + cache_config["clear_on_init"] = True + cache2 = GrpcCache(cache_config) + assert cache2.get("test-key") is None + + +def test_cache_key_generation(): + """Test deterministic cache key generation.""" + request1 = GetDataRequest(page_size=100) + request2 = GetDataRequest(page_size=100) + request3 = GetDataRequest(page_size=200) + + key1 = GrpcCache.key_from_proto_message("/sift.data.v2.DataService/GetData", request1) + key2 = GrpcCache.key_from_proto_message("/sift.data.v2.DataService/GetData", request2) + key3 = GrpcCache.key_from_proto_message("/sift.data.v2.DataService/GetData", request3) + + # Same request should generate same key + assert key1 == key2 + + # Different request should generate different key + assert key1 != key3 + + # Keys should be SHA256 hashes (64 hex characters) + assert len(key1) == 64 + assert all(c in "0123456789abcdef" for c in key1) + + +def test_cache_metadata_resolution(): + """Test cache metadata resolution logic.""" + # No metadata + settings = GrpcCache.resolve_cache_metadata(None) + assert settings.use_cache is False + assert settings.force_refresh is False + assert settings.custom_ttl is None + + # use-cache enabled + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"),)) + assert settings.use_cache is True + assert settings.force_refresh is False + assert settings.custom_ttl is None + + # force-refresh enabled + settings = GrpcCache.resolve_cache_metadata( + (("use-cache", "true"), ("force-refresh", "true")) + ) + assert settings.use_cache is True + assert settings.force_refresh is True + assert settings.custom_ttl is None + + # Custom TTL + settings = GrpcCache.resolve_cache_metadata( + (("use-cache", "true"), ("cache-ttl", "7200")) + ) + assert settings.use_cache is True + assert settings.force_refresh is False + assert settings.custom_ttl == 7200 + + # Invalid TTL (should be ignored) + settings = GrpcCache.resolve_cache_metadata( + (("use-cache", "true"), ("cache-ttl", "invalid")) + ) + assert settings.use_cache is True + assert settings.custom_ttl is None + + +@pytest.mark.asyncio +async def test_basic_caching(mocker: MockFixture): + """Test basic cache hit and miss scenarios.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + }, + } + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # First call without cache - should hit server + res1 = cast(GetDataResponse, await stub.GetData(request)) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Second call without cache - should hit server again + res2 = cast(GetDataResponse, await stub.GetData(request)) + assert res2.next_page_token == "token-2" + assert data_service.call_count == 2 + + # Third call WITH cache - should hit server + res3 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res3.next_page_token == "token-3" + assert data_service.call_count == 3 + + # Fourth call WITH cache - should use cached response + res4 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res4.next_page_token == "token-3" # Same as res3! + assert data_service.call_count == 3 # No new call + + # Fifth call WITH cache - should still use cached response + res5 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res5.next_page_token == "token-3" + assert data_service.call_count == 3 + + +@pytest.mark.asyncio +async def test_force_refresh(mocker: MockFixture): + """Test force refresh bypasses cache and updates it.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + }, + } + + async with use_sift_async_channel(config) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # First call with cache + res1 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Second call with cache - should use cached + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Force refresh - should hit server and update cache + res3 = cast( + GetDataResponse, await stub.GetData(request, metadata=with_force_refresh()) + ) + assert res3.next_page_token == "token-2" + assert data_service.call_count == 2 + + # Next call with cache should use the refreshed value + res4 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res4.next_page_token == "token-2" + assert data_service.call_count == 2 + + +@pytest.mark.asyncio +async def test_ignore_cache(mocker: MockFixture): + """Test ignore_cache bypasses cache without updating it.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + }, + } + + async with use_sift_async_channel(config) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # First call with cache + res1 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Call with ignore_cache - should hit server + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=ignore_cache())) + assert res2.next_page_token == "token-2" + assert data_service.call_count == 2 + + # Call with cache again - should still have original cached value + res3 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res3.next_page_token == "token-1" # Original cached value + assert data_service.call_count == 2 + + +@pytest.mark.asyncio +async def test_different_requests_different_cache_keys(mocker: MockFixture): + """Test that different requests use different cache entries.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + }, + } + + async with use_sift_async_channel(config) as channel: + stub = DataServiceStub(channel) + request1 = GetDataRequest(page_size=100) + request2 = GetDataRequest(page_size=200) + + # First request with cache + res1 = cast(GetDataResponse, await stub.GetData(request1, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Different request with cache - should hit server + res2 = cast(GetDataResponse, await stub.GetData(request2, metadata=with_cache())) + assert res2.next_page_token == "token-2" + assert data_service.call_count == 2 + + # First request again - should use cache + res3 = cast(GetDataResponse, await stub.GetData(request1, metadata=with_cache())) + assert res3.next_page_token == "token-1" + assert data_service.call_count == 2 + + # Second request again - should use cache + res4 = cast(GetDataResponse, await stub.GetData(request2, metadata=with_cache())) + assert res4.next_page_token == "token-2" + assert data_service.call_count == 2 + + +@pytest.mark.asyncio +async def test_cache_persists_across_channels(mocker: MockFixture): + """Test that cache persists across different channel instances.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_path = str(Path(tmpdir) / "cache") + + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": cache_path, + "size_limit": 1024 * 1024, + }, + } + + # First channel - populate cache + async with use_sift_async_channel(config) as channel1: + stub1 = DataServiceStub(channel1) + request = GetDataRequest(page_size=100) + res1 = cast(GetDataResponse, await stub1.GetData(request, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Second channel - should use cached value + async with use_sift_async_channel(config) as channel2: + stub2 = DataServiceStub(channel2) + request = GetDataRequest(page_size=100) + res2 = cast(GetDataResponse, await stub2.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" # Same as first call + assert data_service.call_count == 1 # No new server call + + +@pytest.mark.asyncio +async def test_custom_ttl(mocker: MockFixture): + """Test custom TTL parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, # Default TTL + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + }, + } + + async with use_sift_async_channel(config) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # Call with custom TTL + res1 = cast( + GetDataResponse, await stub.GetData(request, metadata=with_cache(ttl=7200)) + ) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Verify it's cached + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" + assert data_service.call_count == 1 + + +@pytest.mark.asyncio +async def test_metadata_merging(mocker: MockFixture): + """Test that cache metadata is properly merged with API key metadata.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", # This adds authorization metadata + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + }, + } + + async with use_sift_async_channel(config) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # This should work - cache metadata should be merged with auth metadata + res = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Verify cache works + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" + assert data_service.call_count == 1 diff --git a/python/lib/sift_py/grpc/transport.py b/python/lib/sift_py/grpc/transport.py index 1b3b8bc90..b868b146c 100644 --- a/python/lib/sift_py/grpc/transport.py +++ b/python/lib/sift_py/grpc/transport.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast from urllib.parse import ParseResult, urlparse +import diskcache import grpc import grpc.aio as grpc_aio from typing_extensions import NotRequired, TypeAlias @@ -79,7 +80,7 @@ def use_sift_channel( def use_sift_async_channel( - config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None + config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None, cache: Any = None ) -> SiftAsyncChannel: """ Like `use_sift_channel` but returns a channel meant to be used within the context @@ -89,13 +90,13 @@ def use_sift_async_channel( cert_via_openssl = config.get("cert_via_openssl", False) if not use_ssl: - return _use_insecure_sift_async_channel(config, metadata) + return _use_insecure_sift_async_channel(config, metadata, cache) return grpc_aio.secure_channel( target=_clean_uri(config["uri"], use_ssl), credentials=get_ssl_credentials(cert_via_openssl), options=_compute_channel_options(config), - interceptors=_compute_sift_async_interceptors(config, metadata), + interceptors=_compute_sift_async_interceptors(config, metadata, cache), ) @@ -113,7 +114,7 @@ def _use_insecure_sift_channel( def _use_insecure_sift_async_channel( - config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None + config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None, cache: Any = None ) -> SiftAsyncChannel: """ FOR DEVELOPMENT PURPOSES ONLY @@ -121,7 +122,7 @@ def _use_insecure_sift_async_channel( return grpc_aio.insecure_channel( target=_clean_uri(config["uri"], False), options=_compute_channel_options(config), - interceptors=_compute_sift_async_interceptors(config, metadata), + interceptors=_compute_sift_async_interceptors(config, metadata, cache), ) @@ -140,24 +141,13 @@ def _compute_sift_interceptors( def _compute_sift_async_interceptors( - config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None + config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None, cache: Any = None ) -> List[grpc_aio.ClientInterceptor]: interceptors: List[grpc_aio.ClientInterceptor] = [] - # Add caching interceptor if enabled - cache_config = config.get("cache_config") - if cache_config and all( - field in ["ttl", "cache_path", "size_limit", "clear_on_init"] - for field in cache_config.keys() - ): - interceptors.append( - CachingAsyncInterceptor( - ttl=cache_config["ttl"], - cache_path=cache_config["cache_path"], - size_limit=cache_config["size_limit"], - clear_on_init=cache_config["clear_on_init"], - ) - ) + # Add caching interceptor if cache instance is provided + if cache is not None: + interceptors.append(CachingAsyncInterceptor(cache=cache)) # Metadata interceptor should be last to ensure metadata is always added interceptors.append(_metadata_async_interceptor(config, metadata)) @@ -251,19 +241,19 @@ def _compute_keep_alive_channel_opts(config: KeepaliveConfig) -> List[Tuple[str, ] -class CacheConfig(TypedDict): +class SiftCacheConfig(TypedDict): """ Configuration for gRPC response caching. - - `ttl`: Time-to-live for cached entries in seconds. Default is 3600 (1 hour). - - `cache_path`: Path to the cache directory. Default is ".grpc_cache". - - `size_limit`: Maximum size of the cache in bytes. Default is 1GB. - - `clear_on_init`: Whether to clear the cache on initialization. Default is False. + - `ttl`: Time-to-live for cached entries in seconds. + - `cache_path`: Path to the cache directory. + - `size_limit`: Maximum size of the cache in bytes. + - `clear_on_init`: Whether to clear the cache on initialization. """ - ttl: NotRequired[int] - cache_path: NotRequired[str] - size_limit: NotRequired[int] - clear_on_init: NotRequired[bool] + ttl: int + cache_path: str + size_limit: int + clear_on_init: bool class SiftChannelConfig(TypedDict): @@ -287,4 +277,4 @@ class SiftChannelConfig(TypedDict): enable_keepalive: NotRequired[Union[bool, KeepaliveConfig]] use_ssl: NotRequired[bool] cert_via_openssl: NotRequired[bool] - cache_config: NotRequired[CacheConfig] + cache_config: NotRequired[SiftCacheConfig] From 61a83713e01f211cc65caec579d790d67e7ec98f Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Wed, 15 Oct 2025 09:57:09 -0700 Subject: [PATCH 06/16] wip --- .../grpc/_async_interceptors/caching.py | 64 ++++--------------- 1 file changed, 12 insertions(+), 52 deletions(-) diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py index 609e4bafa..e5ab291fe 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/caching.py +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -28,7 +28,7 @@ from typing import Any import diskcache -from google.protobuf import message +from google.protobuf import message, symbol_database from grpc import aio as grpc_aio from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor @@ -123,65 +123,25 @@ async def intercept( if cache_settings.use_cache: try: # Serialize the protobuf response to bytes before caching - if isinstance(response, message.Message): - response_bytes = response.SerializeToString() - response_type_name = type(response).DESCRIPTOR.full_name - # Store both the type name and the serialized bytes - cached_data = (response_type_name, response_bytes) + cached_data = self._serialize_response(response) + if cached_data is not None: self.cache.set_with_default_ttl(key, cached_data, expire=cache_settings.custom_ttl) logger.debug(f"Cached response for `{key}`") - else: - logger.warning(f"Response is not a protobuf message, skipping cache for `{key}`") - logger.warning(f"Response type: {type(response)}") except diskcache.Timeout as e: logger.warning(f"Failed to cache response for `{key}`: {e}") - except Exception as e: - logger.warning(f"Failed to serialize response for caching `{key}`: {e}") return response - def _deserialize_response(self, response_type_name: str, response_bytes: bytes) -> message.Message: - """Deserialize a cached response from bytes. - - Args: - response_type_name: The full protobuf type name (e.g., 'sift.data.v2.GetDataResponse') - response_bytes: The serialized protobuf bytes - - Returns: - The deserialized protobuf message - - Raises: - ImportError: If the response type cannot be imported - Exception: If deserialization fails - """ - # Import the response type dynamically - # Convert 'sift.data.v2.GetDataResponse' to module and class - parts = response_type_name.rsplit('.', 1) - if len(parts) != 2: - raise ValueError(f"Invalid response type name: {response_type_name}") + @staticmethod + def _serialize_response(response: message.Message) -> tuple[Any, bytes] | None: + if isinstance(response, message.Message): + return (response.DESCRIPTOR.full_name, response.SerializeToString()) + else: + logger.warning(f"Response is not a protobuf message: {type(response)}") + return None - package_name, class_name = parts + @staticmethod + def _deserialize_response(response: tuple[Any, bytes]) -> message.Message: - # Protobuf generates Python modules with _pb2 suffix - # e.g., 'sift.data.v2' -> 'sift.data.v2.data_pb2' - # Extract the service name from the package (last part before version) - package_parts = package_name.split('.') - if len(package_parts) >= 2: - # Get the service name (e.g., 'data' from 'sift.data.v2') - service_name = package_parts[-2] - python_module = f"{package_name}.{service_name}_pb2" - else: - # Fallback: just append _pb2 - python_module = f"{package_name}_pb2" - - try: - # Import the module - module = importlib.import_module(python_module) - response_class = getattr(module, class_name) - - # Deserialize - response = response_class() - response.ParseFromString(response_bytes) - return response except (ImportError, AttributeError) as e: raise ImportError(f"Failed to import response type {response_type_name} from {python_module}: {e}") From 7f73b63f2ba1e8e15088fbdf2fe38c7b28c5801e Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Wed, 15 Oct 2025 11:21:38 -0700 Subject: [PATCH 07/16] fix tests --- .../grpc/_async_interceptors/caching.py | 31 ++++++++++------- python/lib/sift_py/grpc/cache_test.py | 34 +++++++++++++------ 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py index e5ab291fe..efe140eb3 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/caching.py +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -23,7 +23,6 @@ from __future__ import annotations -import importlib import logging from typing import Any @@ -65,6 +64,7 @@ def __init__( cache: Pre-initialized GrpcCache instance (required). """ self.cache = cache + self.symbol_db = symbol_database.Default() async def intercept( self, @@ -98,11 +98,12 @@ async def intercept( cached_data = self.cache.get(key) if cached_data is not None: logger.debug(f"Cache hit for `{key}`") - # Cached data is a tuple of (response_type_name, response_bytes) - response_type_name, response_bytes = cached_data - # Reconstruct the response from bytes - response = self._deserialize_response(response_type_name, response_bytes) - return response + # Reconstruct the response + response = self._deserialize_response(cached_data) + if response is not None: + return response + else: + logger.warning(f"Failed to deserialize cached response for `{key}`") except diskcache.Timeout as e: logger.debug(f"Cache read timeout for `{key}`: {e}") except Exception as e: @@ -135,13 +136,19 @@ async def intercept( @staticmethod def _serialize_response(response: message.Message) -> tuple[Any, bytes] | None: if isinstance(response, message.Message): - return (response.DESCRIPTOR.full_name, response.SerializeToString()) + return response.DESCRIPTOR.full_name, response.SerializeToString() else: logger.warning(f"Response is not a protobuf message: {type(response)}") return None - @staticmethod - def _deserialize_response(response: tuple[Any, bytes]) -> message.Message: - - except (ImportError, AttributeError) as e: - raise ImportError(f"Failed to import response type {response_type_name} from {python_module}: {e}") + def _deserialize_response(self, response: tuple[Any, bytes]) -> message.Message | None: + response_type, data = response + try: + response_type_cls = self.symbol_db.GetSymbol(response_type) + message = response_type_cls() + message.ParseFromString(data) + return message + except Exception as e: + logger.warning(f"Failed to deserialize response: {e}") + return None + diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py index 59f1141b0..4bdaf0898 100644 --- a/python/lib/sift_py/grpc/cache_test.py +++ b/python/lib/sift_py/grpc/cache_test.py @@ -223,6 +223,7 @@ async def test_basic_caching(mocker: MockFixture): "ttl": 3600, "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": True, }, } cache = GrpcCache(config["cache_config"]) @@ -270,10 +271,12 @@ async def test_force_refresh(mocker: MockFixture): "ttl": 3600, "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": True, }, } + cache = GrpcCache(config["cache_config"]) - async with use_sift_async_channel(config) as channel: + async with use_sift_async_channel(config, cache=cache) as channel: stub = DataServiceStub(channel) request = GetDataRequest(page_size=100) @@ -313,10 +316,12 @@ async def test_ignore_cache(mocker: MockFixture): "ttl": 3600, "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": True, }, } + cache = GrpcCache(config["cache_config"]) - async with use_sift_async_channel(config) as channel: + async with use_sift_async_channel(config, cache=cache) as channel: stub = DataServiceStub(channel) request = GetDataRequest(page_size=100) @@ -349,10 +354,12 @@ async def test_different_requests_different_cache_keys(mocker: MockFixture): "ttl": 3600, "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": True, }, } + cache = GrpcCache(config["cache_config"]) - async with use_sift_async_channel(config) as channel: + async with use_sift_async_channel(config, cache=cache) as channel: stub = DataServiceStub(channel) request1 = GetDataRequest(page_size=100) request2 = GetDataRequest(page_size=200) @@ -391,13 +398,15 @@ async def test_cache_persists_across_channels(mocker: MockFixture): "use_ssl": False, "cache_config": { "ttl": 3600, - "cache_path": cache_path, + "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": False, }, } + cache = GrpcCache(config["cache_config"]) # First channel - populate cache - async with use_sift_async_channel(config) as channel1: + async with use_sift_async_channel(config, cache=cache) as channel1: stub1 = DataServiceStub(channel1) request = GetDataRequest(page_size=100) res1 = cast(GetDataResponse, await stub1.GetData(request, metadata=with_cache())) @@ -405,7 +414,7 @@ async def test_cache_persists_across_channels(mocker: MockFixture): assert data_service.call_count == 1 # Second channel - should use cached value - async with use_sift_async_channel(config) as channel2: + async with use_sift_async_channel(config, cache=cache) as channel2: stub2 = DataServiceStub(channel2) request = GetDataRequest(page_size=100) res2 = cast(GetDataResponse, await stub2.GetData(request, metadata=with_cache())) @@ -423,13 +432,16 @@ async def test_custom_ttl(mocker: MockFixture): "apikey": "test-token", "use_ssl": False, "cache_config": { - "ttl": 3600, # Default TTL + "ttl": 3600, "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": True, }, } - async with use_sift_async_channel(config) as channel: + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: stub = DataServiceStub(channel) request = GetDataRequest(page_size=100) @@ -453,16 +465,18 @@ async def test_metadata_merging(mocker: MockFixture): with server_with_service(mocker) as (get_data_spy, data_service, port): config: SiftChannelConfig = { "uri": f"localhost:{port}", - "apikey": "test-token", # This adds authorization metadata + "apikey": "test-token", "use_ssl": False, "cache_config": { "ttl": 3600, "cache_path": str(Path(tmpdir) / "cache"), "size_limit": 1024 * 1024, + "clear_on_init": True, }, } + cache = GrpcCache(config["cache_config"]) - async with use_sift_async_channel(config) as channel: + async with use_sift_async_channel(config, cache=cache) as channel: stub = DataServiceStub(channel) request = GetDataRequest(page_size=100) From 5b9dfdeb11c4197413eb0da353b8fe04597d25ec Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Wed, 15 Oct 2025 16:05:43 -0700 Subject: [PATCH 08/16] wip --- .../_internal/low_level_wrappers/runs.py | 119 +++++++++++++++--- python/lib/sift_py/grpc/cache.py | 31 ----- 2 files changed, 104 insertions(+), 46 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/runs.py b/python/lib/sift_client/_internal/low_level_wrappers/runs.py index 38c020454..947925bc1 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/runs.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/runs.py @@ -41,11 +41,19 @@ def __init__(self, grpc_client: GrpcClient): """ super().__init__(grpc_client) - async def get_run(self, run_id: str) -> Run: + async def get_run( + self, + run_id: str, + *, + use_cache: bool = True, + force_refresh: bool = False, + ttl: int | None = None, + ) -> Run: """Get a run by run_id. Args: run_id: The run ID to get. + metadata: Optional gRPC metadata including cache control. Returns: The Run. @@ -54,7 +62,14 @@ async def get_run(self, run_id: str) -> Run: ValueError: If run_id is not provided. """ request = GetRunRequest(run_id=run_id) - response = await self._grpc_client.get_stub(RunServiceStub).GetRun(request) + stub = self._grpc_client.get_stub(RunServiceStub) + response = await self._call_with_cache( + stub.GetRun, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, + ) grpc_run = cast("GetRunResponse", response).run return Run._from_proto(grpc_run) @@ -65,6 +80,9 @@ async def list_runs( page_token: str | None = None, query_filter: str | None = None, order_by: str | None = None, + use_cache: bool = True, + force_refresh: bool = False, + ttl: int | None = None, ) -> tuple[list[Run], str]: """List runs with optional filtering and pagination. @@ -88,7 +106,14 @@ async def list_runs( request_kwargs["order_by"] = order_by request = ListRunsRequest(**request_kwargs) - response = await self._grpc_client.get_stub(RunServiceStub).ListRuns(request) + stub = self._grpc_client.get_stub(RunServiceStub) + response = await self._call_with_cache( + stub.ListRuns, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, + ) response = cast("ListRunsResponse", response) runs = [Run._from_proto(run) for run in response.runs] @@ -100,6 +125,9 @@ async def list_all_runs( query_filter: str | None = None, order_by: str | None = None, max_results: int | None = None, + use_cache: bool = True, + force_refresh: bool = False, + ttl: int | None = None, ) -> list[Run]: """List all runs with optional filtering. @@ -112,26 +140,68 @@ async def list_all_runs( A list of all matching runs. """ return await self._handle_pagination( - self.list_runs, - kwargs={"query_filter": query_filter}, + lambda **k: self.list_runs( + **k, + query_filter=query_filter, + order_by=order_by, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, + ), + kwargs={}, order_by=order_by, max_results=max_results, ) - async def create_run(self, *, create: RunCreate) -> Run: + async def create_run( + self, + *, + create: RunCreate, + use_cache: bool = False, # Default to False for write operations + force_refresh: bool = False, + ttl: int | None = None, + ) -> Run: request_proto = create.to_proto() - response = await self._grpc_client.get_stub(RunServiceStub).CreateRun(request_proto) + stub = self._grpc_client.get_stub(RunServiceStub) + response = await self._call_with_cache( + stub.CreateRun, + request_proto, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, + ) grpc_run = cast("CreateRunResponse", response).run return Run._from_proto(grpc_run) - async def update_run(self, update: RunUpdate) -> Run: + async def update_run( + self, + update: RunUpdate, + *, + use_cache: bool = False, # Default to False for write operations + force_refresh: bool = False, + ttl: int | None = None, + ) -> Run: grpc_run, update_mask = update.to_proto_with_mask() request = UpdateRunRequest(run=grpc_run, update_mask=update_mask) - response = await self._grpc_client.get_stub(RunServiceStub).UpdateRun(request) + stub = self._grpc_client.get_stub(RunServiceStub) + response = await self._call_with_cache( + stub.UpdateRun, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, + ) updated_grpc_run = cast("UpdateRunResponse", response).run return Run._from_proto(updated_grpc_run) - async def stop_run(self, run_id: str) -> None: + async def stop_run( + self, + run_id: str, + *, + use_cache: bool = False, # Default to False for write operations + force_refresh: bool = False, + ttl: int | None = None, + ) -> None: """Stop a run by setting its stop time to the current time. Args: @@ -144,10 +214,23 @@ async def stop_run(self, run_id: str) -> None: raise ValueError("run_id must be provided") request = StopRunRequest(run_id=run_id) - await self._grpc_client.get_stub(RunServiceStub).StopRun(request) + stub = self._grpc_client.get_stub(RunServiceStub) + await self._call_with_cache( + stub.StopRun, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, + ) async def create_automatic_run_association_for_assets( - self, run_id: str, asset_names: list[str] + self, + run_id: str, + asset_names: list[str], + *, + use_cache: bool = False, # Default to False for write operations + force_refresh: bool = False, + ttl: int | None = None, ) -> None: """Associate assets with a run for automatic data ingestion. @@ -164,8 +247,14 @@ async def create_automatic_run_association_for_assets( raise ValueError("asset_names must be provided") request = CreateAutomaticRunAssociationForAssetsRequest( - run_id=run_id, asset_names=asset_names + run_id=run_id, + asset_names=asset_names ) - await self._grpc_client.get_stub(RunServiceStub).CreateAutomaticRunAssociationForAssets( - request + stub = self._grpc_client.get_stub(RunServiceStub) + await self._call_with_cache( + stub.CreateAutomaticRunAssociationForAssets, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=ttl, ) diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py index 03aac393f..c7a537bf8 100644 --- a/python/lib/sift_py/grpc/cache.py +++ b/python/lib/sift_py/grpc/cache.py @@ -200,34 +200,3 @@ def with_force_refresh(ttl: int | None = None) -> tuple[tuple[str, str], ...]: if ttl is not None: metadata.append((METADATA_CACHE_TTL, str(ttl))) return tuple(metadata) - - -def ignore_cache() -> tuple[tuple[str, str], ...]: - """Ignore the cache for a gRPC request without clearing it. - - Bypasses the cache for this request but doesn't invalidate the cached entry. - The response from this request will not be cached. - - Returns: - Metadata tuple to pass to the gRPC stub method. - - Example: - metadata = ignore_cache() - response = stub.GetData(request, metadata=metadata) - """ - return tuple() - - -def without_cache() -> tuple[tuple[str, str], ...]: - """Explicitly disable caching for a gRPC request. - - This is the default behavior when no cache metadata is provided. - - Returns: - Empty metadata tuple. - - Example: - metadata = without_cache() - response = stub.GetData(request, metadata=metadata) - """ - return tuple() From c674c82238a6aa34a176a1f6ab572f337d92e887 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 17 Oct 2025 14:52:52 -0700 Subject: [PATCH 09/16] clean up call_with_cache --- .../sift_client/_internal/low_level_wrappers/base.py | 10 +++++----- python/lib/sift_client/transport/grpc_transport.py | 10 +++++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index d3dfde827..c5a074c55 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -3,7 +3,7 @@ from abc import ABC from typing import Any, Callable, TypeVar -from sift_py.grpc.cache import ignore_cache, with_cache, with_force_refresh +from sift_py.grpc.cache import with_cache, with_force_refresh T = TypeVar("T") @@ -56,7 +56,7 @@ async def _handle_pagination( return results @staticmethod - async def _call_with_cache( + async def call_with_cache( stub_method: Callable[[Any, tuple[tuple[str, str], ...]], T], request: Any, *, @@ -109,12 +109,12 @@ async def _call_with_cache( use_cache=False, ) """ + if not use_cache: + return await stub_method(request) if force_refresh: metadata = with_force_refresh(ttl=ttl) - elif use_cache: - metadata = with_cache(ttl=ttl) else: - metadata = ignore_cache() + metadata = with_cache(ttl=ttl) return await stub_method(request, metadata=metadata) diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 50cc7a51c..50665c2e4 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -179,7 +179,7 @@ def __init__(self, config: GrpcConfig): self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = {} # Initialize cache if caching is enabled - self._cache = self._init_cache() + self.cache = self._init_cache() # default loop for sync API self._default_loop = asyncio.new_event_loop() @@ -224,6 +224,10 @@ def _init_cache(self) -> GrpcCache | None: logger.warning(f"Failed to initialize cache: {e}") return None + @property + def has_cache(self): + return self.cache is not None + @property def default_loop(self) -> asyncio.AbstractEventLoop: """Return the default event loop used for synchronous API operations. @@ -246,7 +250,7 @@ def get_stub(self, stub_class: type[Any]) -> Any: if loop not in self._channels_async: channel = use_sift_async_channel( - self._config._to_sift_channel_config(), self._config.metadata, self._cache + self._config._to_sift_channel_config(), self._config.metadata, self.cache ) self._channels_async[loop] = channel self._stubs_async_map[loop] = {} @@ -289,4 +293,4 @@ async def _create_async_channel( self, cfg: SiftChannelConfig, metadata: dict[str, str] | None ) -> Any: """Helper to create async channel on default loop.""" - return use_sift_async_channel(cfg, metadata, self._cache) + return use_sift_async_channel(cfg, metadata, self.cache) From 15c302b82b0b988798ce84be02b014bf8b2aef9f Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Fri, 17 Oct 2025 14:53:12 -0700 Subject: [PATCH 10/16] add run for demo --- .../_internal/low_level_wrappers/runs.py | 94 +++---------------- 1 file changed, 15 insertions(+), 79 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/runs.py b/python/lib/sift_client/_internal/low_level_wrappers/runs.py index 947925bc1..920b6ff40 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/runs.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/runs.py @@ -42,12 +42,10 @@ def __init__(self, grpc_client: GrpcClient): super().__init__(grpc_client) async def get_run( - self, + self, run_id: str, *, - use_cache: bool = True, force_refresh: bool = False, - ttl: int | None = None, ) -> Run: """Get a run by run_id. @@ -63,12 +61,11 @@ async def get_run( """ request = GetRunRequest(run_id=run_id) stub = self._grpc_client.get_stub(RunServiceStub) - response = await self._call_with_cache( + response = await self.call_with_cache( stub.GetRun, request, - use_cache=use_cache, + use_cache=self._grpc_client.has_cache, force_refresh=force_refresh, - ttl=ttl, ) grpc_run = cast("GetRunResponse", response).run return Run._from_proto(grpc_run) @@ -80,9 +77,7 @@ async def list_runs( page_token: str | None = None, query_filter: str | None = None, order_by: str | None = None, - use_cache: bool = True, force_refresh: bool = False, - ttl: int | None = None, ) -> tuple[list[Run], str]: """List runs with optional filtering and pagination. @@ -107,12 +102,11 @@ async def list_runs( request = ListRunsRequest(**request_kwargs) stub = self._grpc_client.get_stub(RunServiceStub) - response = await self._call_with_cache( + response = await self.call_with_cache( stub.ListRuns, request, - use_cache=use_cache, + use_cache=self._grpc_client.has_cache, force_refresh=force_refresh, - ttl=ttl, ) response = cast("ListRunsResponse", response) @@ -125,9 +119,7 @@ async def list_all_runs( query_filter: str | None = None, order_by: str | None = None, max_results: int | None = None, - use_cache: bool = True, force_refresh: bool = False, - ttl: int | None = None, ) -> list[Run]: """List all runs with optional filtering. @@ -144,64 +136,27 @@ async def list_all_runs( **k, query_filter=query_filter, order_by=order_by, - use_cache=use_cache, force_refresh=force_refresh, - ttl=ttl, ), kwargs={}, order_by=order_by, max_results=max_results, ) - async def create_run( - self, - *, - create: RunCreate, - use_cache: bool = False, # Default to False for write operations - force_refresh: bool = False, - ttl: int | None = None, - ) -> Run: + async def create_run(self, *, create: RunCreate) -> Run: request_proto = create.to_proto() - stub = self._grpc_client.get_stub(RunServiceStub) - response = await self._call_with_cache( - stub.CreateRun, - request_proto, - use_cache=use_cache, - force_refresh=force_refresh, - ttl=ttl, - ) + response = await self._grpc_client.get_stub(RunServiceStub).CreateRun(request_proto) grpc_run = cast("CreateRunResponse", response).run return Run._from_proto(grpc_run) - async def update_run( - self, - update: RunUpdate, - *, - use_cache: bool = False, # Default to False for write operations - force_refresh: bool = False, - ttl: int | None = None, - ) -> Run: + async def update_run(self, update: RunUpdate) -> Run: grpc_run, update_mask = update.to_proto_with_mask() request = UpdateRunRequest(run=grpc_run, update_mask=update_mask) - stub = self._grpc_client.get_stub(RunServiceStub) - response = await self._call_with_cache( - stub.UpdateRun, - request, - use_cache=use_cache, - force_refresh=force_refresh, - ttl=ttl, - ) + response = await self._grpc_client.get_stub(RunServiceStub).UpdateRun(request) updated_grpc_run = cast("UpdateRunResponse", response).run return Run._from_proto(updated_grpc_run) - async def stop_run( - self, - run_id: str, - *, - use_cache: bool = False, # Default to False for write operations - force_refresh: bool = False, - ttl: int | None = None, - ) -> None: + async def stop_run(self, run_id: str) -> None: """Stop a run by setting its stop time to the current time. Args: @@ -214,23 +169,10 @@ async def stop_run( raise ValueError("run_id must be provided") request = StopRunRequest(run_id=run_id) - stub = self._grpc_client.get_stub(RunServiceStub) - await self._call_with_cache( - stub.StopRun, - request, - use_cache=use_cache, - force_refresh=force_refresh, - ttl=ttl, - ) + await self._grpc_client.get_stub(RunServiceStub).StopRun(request) async def create_automatic_run_association_for_assets( - self, - run_id: str, - asset_names: list[str], - *, - use_cache: bool = False, # Default to False for write operations - force_refresh: bool = False, - ttl: int | None = None, + self, run_id: str, asset_names: list[str] ) -> None: """Associate assets with a run for automatic data ingestion. @@ -247,14 +189,8 @@ async def create_automatic_run_association_for_assets( raise ValueError("asset_names must be provided") request = CreateAutomaticRunAssociationForAssetsRequest( - run_id=run_id, - asset_names=asset_names + run_id=run_id, asset_names=asset_names ) - stub = self._grpc_client.get_stub(RunServiceStub) - await self._call_with_cache( - stub.CreateAutomaticRunAssociationForAssets, - request, - use_cache=use_cache, - force_refresh=force_refresh, - ttl=ttl, + await self._grpc_client.get_stub(RunServiceStub).CreateAutomaticRunAssociationForAssets( + request ) From 99fc42f1a1d25d284622750770adc92016cc0bc9 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 20 Oct 2025 17:15:18 -0700 Subject: [PATCH 11/16] move cache testing to ping --- python/examples/caching_example.py | 91 ---------- .../_internal/low_level_wrappers/base.py | 16 +- .../_internal/low_level_wrappers/data.py | 8 +- .../_internal/low_level_wrappers/ping.py | 15 +- .../_internal/low_level_wrappers/runs.py | 15 +- python/lib/sift_client/_tests/conftest.py | 2 + .../sift_client/_tests/resources/test_ping.py | 165 ++++++++++++++++++ python/lib/sift_client/client.py | 15 +- .../sift_client/transport/base_connection.py | 9 +- .../sift_client/transport/grpc_transport.py | 5 +- .../grpc/_async_interceptors/caching.py | 8 +- python/lib/sift_py/grpc/cache.py | 27 ++- python/lib/sift_py/grpc/cache_test.py | 20 +-- python/lib/sift_py/grpc/transport.py | 1 - 14 files changed, 235 insertions(+), 162 deletions(-) delete mode 100644 python/examples/caching_example.py diff --git a/python/examples/caching_example.py b/python/examples/caching_example.py deleted file mode 100644 index c864817c3..000000000 --- a/python/examples/caching_example.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Example demonstrating gRPC response caching with the Sift client. - -This example shows how to: -1. Enable caching via SiftClient configuration -2. Use cache control metadata to control caching behavior -3. Measure the performance improvement from caching - -Requirements: - pip install sift-stack-py[cache] -""" - -import time -from sift_client import CacheConfig, SiftClient -from sift_py.grpc.cache import with_cache, with_force_refresh, ignore_cache - -# Configure caching -cache_config = CacheConfig( - enabled=True, # Enable caching - ttl=3600, # Cache for 1 hour - cache_path=None, # Uses system temp directory by default - size_limit=1024 * 1024 * 1024, # 1GB max -) - -# Initialize client with caching enabled -client = SiftClient( - api_key="your-api-key-here", - grpc_url="api.siftstack.com", - rest_url="https://api.siftstack.com", - cache_config=cache_config, # Pass cache config directly -) - -# Example 1: Basic caching -print("Example 1: Basic Caching") -print("-" * 50) - -# First call - cache miss (fetches from server) -start = time.time() -response = client.ping.ping() # Note: Need to add metadata support to high-level APIs -elapsed_first = time.time() - start -print(f"First call (cache miss): {elapsed_first:.3f}s") - -# Second call - cache hit (returns cached response) -start = time.time() -response = client.ping.ping() -elapsed_second = time.time() - start -print(f"Second call (cache hit): {elapsed_second:.3f}s") -print(f"Speedup: {elapsed_first / elapsed_second:.1f}x faster") - -# Example 2: Force refresh -print("\nExample 2: Force Refresh") -print("-" * 50) - -# Force refresh - bypasses cache and fetches fresh data -start = time.time() -response = client.ping.ping() # with force_refresh metadata -elapsed = time.time() - start -print(f"Force refresh: {elapsed:.3f}s") - -# Example 3: Ignore cache -print("\nExample 3: Ignore Cache") -print("-" * 50) - -# Bypass cache without clearing it -response = client.ping.ping() # with ignore_cache metadata -print("Cache bypassed for this request (entry still exists)") - -# Example 4: Conditional caching -print("\nExample 4: Conditional Caching") -print("-" * 50) - - -def get_data(use_cache: bool = False): - """Helper function that conditionally uses caching.""" - if use_cache: - # Use cache - return client.ping.ping() # with with_cache metadata - else: - # Skip cache - return client.ping.ping() # without cache metadata - - -# Use cache in production -response = get_data(use_cache=True) -print("Called with caching enabled") - -# Skip cache in development -response = get_data(use_cache=False) -print("Called without caching") - -print("\nNote: This example requires integration with the high-level API") -print("to pass cache control metadata. See the documentation for details.") diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index c5a074c55..e23ce65b4 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -56,7 +56,7 @@ async def _handle_pagination( return results @staticmethod - async def call_with_cache( + async def _call_with_cache( stub_method: Callable[[Any, tuple[tuple[str, str], ...]], T], request: Any, *, @@ -65,20 +65,20 @@ async def call_with_cache( ttl: int | None = None, ) -> T: """Call a gRPC stub method with cache control. - + This is a convenience method for low-level wrappers to easily enable caching on their gRPC calls without manually constructing metadata. - + Args: stub_method: The gRPC stub method to call (e.g., stub.GetData). request: The protobuf request object. use_cache: Whether to enable caching for this request. Default: True. force_refresh: Whether to force refresh the cache. Default: False. ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. - + Returns: The response from the gRPC call. - + Example: # Enable caching response = await self._call_with_cache( @@ -86,14 +86,14 @@ async def call_with_cache( request, use_cache=True, ) - + # Force refresh response = await self._call_with_cache( stub.GetData, request, force_refresh=True, ) - + # With custom TTL response = await self._call_with_cache( stub.GetData, @@ -101,7 +101,7 @@ async def call_with_cache( use_cache=True, ttl=7200, # 2 hours ) - + # Ignore cache response = await self._call_with_cache( stub.GetData, diff --git a/python/lib/sift_client/_internal/low_level_wrappers/data.py b/python/lib/sift_client/_internal/low_level_wrappers/data.py index 752c05ec6..2f3e4699d 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/data.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/data.py @@ -89,7 +89,7 @@ async def _get_data_impl( cache_ttl: int | None = None, ) -> tuple[list[Any], str | None]: """Get the data for a channel during a run. - + Args: channel_ids: List of channel IDs to fetch data for. run_id: Optional run ID to filter data. @@ -101,7 +101,7 @@ async def _get_data_impl( use_cache: Whether to enable caching for this request. Default: False. force_refresh: Whether to force refresh the cache. Default: False. cache_ttl: Optional custom TTL in seconds for cached responses. - + Returns: Tuple of (data list, next page token). """ @@ -119,7 +119,7 @@ async def _get_data_impl( } request = GetDataRequest(**request_kwargs) - + # Use cache helper if caching is enabled if use_cache or force_refresh: response = await self._call_with_cache( @@ -131,7 +131,7 @@ async def _get_data_impl( ) else: response = await self._grpc_client.get_stub(DataServiceStub).GetData(request) - + response = cast("GetDataResponse", response) return response.data, response.next_page_token # type: ignore # mypy doesn't know RepeatedCompositeFieldContainer can be treated like a list diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ping.py b/python/lib/sift_client/_internal/low_level_wrappers/ping.py index 650f2d44a..562cbee24 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ping.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ping.py @@ -33,6 +33,10 @@ class PingLowLevelClient(LowLevelClientBase, WithGrpcClient): It handles common concerns like error handling and retries. """ + + _cache_results: bool + """Whether to cache the results of the ping request. Used for testing.""" + def __init__(self, grpc_client: GrpcClient): """Initialize the PingLowLevelClient. @@ -40,11 +44,18 @@ def __init__(self, grpc_client: GrpcClient): grpc_client: The gRPC client to use for making API calls. """ super().__init__(grpc_client=grpc_client) + self._cache_results = False - async def ping(self) -> str: + async def ping(self, _force_refresh: bool = False) -> str: """Send a ping request to the server in the current event loop.""" # get stub bound to this loop stub = self._grpc_client.get_stub(PingServiceStub) request = PingRequest() - response = await stub.Ping(request) + response = await self._call_with_cache( + stub.Ping, + request, + use_cache=self._cache_results, + force_refresh=_force_refresh, + ttl=1 + ) return cast("PingResponse", response).response diff --git a/python/lib/sift_client/_internal/low_level_wrappers/runs.py b/python/lib/sift_client/_internal/low_level_wrappers/runs.py index 920b6ff40..65a542e2d 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/runs.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/runs.py @@ -61,7 +61,7 @@ async def get_run( """ request = GetRunRequest(run_id=run_id) stub = self._grpc_client.get_stub(RunServiceStub) - response = await self.call_with_cache( + response = await self._call_with_cache( stub.GetRun, request, use_cache=self._grpc_client.has_cache, @@ -102,7 +102,7 @@ async def list_runs( request = ListRunsRequest(**request_kwargs) stub = self._grpc_client.get_stub(RunServiceStub) - response = await self.call_with_cache( + response = await self._call_with_cache( stub.ListRuns, request, use_cache=self._grpc_client.has_cache, @@ -132,13 +132,8 @@ async def list_all_runs( A list of all matching runs. """ return await self._handle_pagination( - lambda **k: self.list_runs( - **k, - query_filter=query_filter, - order_by=order_by, - force_refresh=force_refresh, - ), - kwargs={}, + self.list_runs, + kwargs={"query_filter": query_filter, "force_refresh": force_refresh}, order_by=order_by, max_results=max_results, ) @@ -172,7 +167,7 @@ async def stop_run(self, run_id: str) -> None: await self._grpc_client.get_stub(RunServiceStub).StopRun(request) async def create_automatic_run_association_for_assets( - self, run_id: str, asset_names: list[str] + self, run_id: str, asset_names: list[str] ) -> None: """Associate assets with a run for automatic data ingestion. diff --git a/python/lib/sift_client/_tests/conftest.py b/python/lib/sift_client/_tests/conftest.py index 397848d7e..7be7c2ade 100644 --- a/python/lib/sift_client/_tests/conftest.py +++ b/python/lib/sift_client/_tests/conftest.py @@ -6,6 +6,7 @@ import pytest from sift_client import SiftClient, SiftConnectionConfig +from sift_client.transport import CacheConfig, CacheMode from sift_client.util.util import AsyncAPIs @@ -26,6 +27,7 @@ def sift_client() -> SiftClient: grpc_url=grpc_url, rest_url=rest_url, use_ssl=True, + cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT) ) ) diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py index 587d8a7f0..8c51391cc 100644 --- a/python/lib/sift_client/_tests/resources/test_ping.py +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -3,9 +3,13 @@ These tests demonstrate and validate the usage of the Ping API including: - Basic ping functionality - Connection health checks +- Cache behavior and performance - Error handling and edge cases """ +import asyncio +import time + import pytest from sift_client import SiftClient @@ -60,3 +64,164 @@ def test_basic_ping(self, ping_api_sync): # Verify response is not empty assert len(response) > 0 + + +class TestPingCacheBehavior: + """Test suite for ping cache behavior.""" + + @pytest.mark.asyncio + async def test_cache_enabled(self, ping_api_async): + """Test that caching can be enabled for ping requests.""" + # Enable caching on the low-level client + ping_api_async._low_level_client._cache_results = True + + # Measure time for first ping - should hit the server (slower) + start1 = time.perf_counter() + response1 = await ping_api_async.ping() + duration1 = time.perf_counter() - start1 + assert isinstance(response1, str) + assert len(response1) > 0 + + # Measure time for second ping - should use cache (much faster) + start2 = time.perf_counter() + response2 = await ping_api_async.ping() + duration2 = time.perf_counter() - start2 + assert response2 == response1 + + # Print timing info + print(f"\nFirst ping (server): {duration1*1000:.2f}ms") + print(f"Second ping (cache): {duration2*1000:.2f}ms") + print(f"Speedup: {duration1/duration2:.2f}x") + + # Cached call should be significantly faster (at least 5x) + assert duration2 < duration1 / 5, ( + f"Cached ping should be much faster. " + f"First: {duration1*1000:.2f}ms, Second: {duration2*1000:.2f}ms" + ) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_force_refresh_bypasses_cache(self, ping_api_async): + """Test that force_refresh bypasses the cache.""" + # Enable caching + ping_api_async._low_level_client._cache_results = True + + # First ping - populate cache + start1 = time.perf_counter() + response1 = await ping_api_async._low_level_client.ping() + duration1 = time.perf_counter() - start1 + assert isinstance(response1, str) + + # Second ping without force_refresh - should use cache (fast) + start2 = time.perf_counter() + response2 = await ping_api_async._low_level_client.ping(_force_refresh=False) + duration2 = time.perf_counter() - start2 + assert isinstance(response2, str) + + # Third ping with force_refresh - should bypass cache (slow, like first call) + start3 = time.perf_counter() + response3 = await ping_api_async._low_level_client.ping(_force_refresh=True) + duration3 = time.perf_counter() - start3 + assert isinstance(response3, str) + + # Print timing info + print(f"\nFirst ping (server): {duration1*1000:.2f}ms") + print(f"Second ping (cache): {duration2*1000:.2f}ms") + print(f"Third ping (force_refresh, server): {duration3*1000:.2f}ms") + + # Cached call should be much faster than both server calls + assert duration2 < duration1 / 5, ( + f"Cached ping should be much faster than first ping. " + f"First: {duration1*1000:.2f}ms, Cached: {duration2*1000:.2f}ms" + ) + assert duration2 < duration3 / 5, ( + f"Cached ping should be much faster than force_refresh ping. " + f"Force refresh: {duration3*1000:.2f}ms, Cached: {duration2*1000:.2f}ms" + ) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_cache_ttl_expiration(self, ping_api_async): + """Test that cache entries expire after TTL.""" + # Enable caching with very short TTL (1 second) + ping_api_async._low_level_client._cache_results = True + + # First ping - populate cache with 1 second TTL + response1 = await ping_api_async._low_level_client.ping() + assert isinstance(response1, str) + + # Immediate second ping - should use cache + response2 = await ping_api_async._low_level_client.ping() + assert isinstance(response2, str) + + # Wait for TTL to expire (1 second + buffer) + await asyncio.sleep(1.5) + + # Third ping - cache should have expired, will fetch fresh + response3 = await ping_api_async._low_level_client.ping() + assert isinstance(response3, str) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_cache_performance(self, ping_api_async): + """Test that cached ping requests are faster than uncached ones.""" + num_iterations = 10 + + # Enable caching + ping_api_async._low_level_client._cache_results = True + + # Measure uncached performance (force_refresh=True) + start_time = time.perf_counter() + for _ in range(num_iterations): + await ping_api_async._low_level_client.ping(_force_refresh=True) + uncached_duration = time.perf_counter() - start_time + + # Warm up cache + await ping_api_async._low_level_client.ping() + + # Measure cached performance + start_time = time.perf_counter() + for _ in range(num_iterations): + await ping_api_async._low_level_client.ping(_force_refresh=False) + cached_duration = time.perf_counter() - start_time + + # Print performance metrics + print(f"\n{'='*60}") + print(f"Ping Cache Performance ({num_iterations} iterations)") + print(f"{'='*60}") + print(f"Cached duration: {cached_duration:.4f}s ({cached_duration/num_iterations*1000:.2f}ms per call)") + print(f"Uncached duration: {uncached_duration:.4f}s ({uncached_duration/num_iterations*1000:.2f}ms per call)") + print(f"Speedup: {uncached_duration / cached_duration:.2f}x") + print(f"Time saved: {uncached_duration - cached_duration:.4f}s") + print(f"{'='*60}\n") + + # Assert that cached is faster + assert cached_duration < uncached_duration, ( + f"Cached pings should be faster. " + f"Cached: {cached_duration:.4f}s, Uncached: {uncached_duration:.4f}s" + ) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_cache_disabled_by_default(self, ping_api_async): + """Test that caching is disabled by default for ping.""" + # Verify cache is disabled by default + assert ping_api_async._low_level_client._cache_results is False + + # Multiple pings should all hit the server (no caching) + response1 = await ping_api_async.ping() + response2 = await ping_api_async.ping() + response3 = await ping_api_async.ping() + + # All should succeed + assert isinstance(response1, str) + assert isinstance(response2, str) + assert isinstance(response3, str) diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index f93ae5477..94d61a17c 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -32,22 +32,11 @@ WithGrpcClient, WithRestClient, ) -from sift_client.transport.grpc_transport import ( - DEFAULT_CACHE_FOLDER, - DEFAULT_CACHE_SIZE_LIMIT_BYTES, - DEFAULT_CACHE_TTL_SECONDS, -) +from sift_client.transport.grpc_transport import DEFAULT_CACHE_CONFIG from sift_client.util.util import AsyncAPIs _sift_client_experimental_warning() -DEFAULT_CACHE_CONFIG = CacheConfig( - ttl=DEFAULT_CACHE_TTL_SECONDS, - cache_folder=DEFAULT_CACHE_FOLDER, - size_limit=DEFAULT_CACHE_SIZE_LIMIT_BYTES, -) - - class SiftClient( WithGrpcClient, WithRestClient, @@ -121,7 +110,7 @@ def __init__( grpc_url: str | None = None, rest_url: str | None = None, connection_config: SiftConnectionConfig | None = None, - cache_config: CacheConfig | None = None, + cache_config: CacheConfig | None = DEFAULT_CACHE_CONFIG, ): """Initialize the SiftClient with specific connection parameters or a connection_config. diff --git a/python/lib/sift_client/transport/base_connection.py b/python/lib/sift_client/transport/base_connection.py index 2b94fa52b..c6b764d20 100644 --- a/python/lib/sift_client/transport/base_connection.py +++ b/python/lib/sift_client/transport/base_connection.py @@ -3,7 +3,12 @@ from abc import ABC from typing import TYPE_CHECKING -from sift_client.transport.grpc_transport import CacheConfig, GrpcClient, GrpcConfig +from sift_client.transport.grpc_transport import ( + DEFAULT_CACHE_CONFIG, + CacheConfig, + GrpcClient, + GrpcConfig, +) from sift_client.transport.rest_transport import RestClient, RestConfig if TYPE_CHECKING: @@ -24,7 +29,7 @@ def __init__( api_key: str, use_ssl: bool = True, cert_via_openssl: bool = False, - cache_config: CacheConfig | None = None, + cache_config: CacheConfig | None = DEFAULT_CACHE_CONFIG, ): """Initialize the connection configuration. diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 50665c2e4..03ea8328e 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -69,7 +69,7 @@ class CacheConfig: def __init__( self, - mode: str = CacheMode.ENABLED, + mode: CacheMode = CacheMode.ENABLED, ttl: int = DEFAULT_CACHE_TTL_SECONDS, cache_folder: Path | str = DEFAULT_CACHE_FOLDER, size_limit: int = DEFAULT_CACHE_SIZE_LIMIT_BYTES, @@ -112,6 +112,9 @@ def to_sift_cache_config(self) -> SiftCacheConfig: } +DEFAULT_CACHE_CONFIG = CacheConfig() + + class GrpcConfig: """Configuration for gRPC API clients.""" diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py index efe140eb3..3b45ed095 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/caching.py +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -9,7 +9,7 @@ Usage: # Cache is initialized at GrpcClient level cache = diskcache.Cache(".grpc_cache", size_limit=1024**3) - + # Create interceptor with cache instance cache_interceptor = CachingAsyncInterceptor(ttl=3600, cache_instance=cache) @@ -35,6 +35,7 @@ logger = logging.getLogger(__name__) + class CachingAsyncInterceptor(ClientAsyncInterceptor): """Async interceptor that caches unary-unary gRPC responses locally. @@ -126,7 +127,9 @@ async def intercept( # Serialize the protobuf response to bytes before caching cached_data = self._serialize_response(response) if cached_data is not None: - self.cache.set_with_default_ttl(key, cached_data, expire=cache_settings.custom_ttl) + self.cache.set_with_default_ttl( + key, cached_data, expire=cache_settings.custom_ttl + ) logger.debug(f"Cached response for `{key}`") except diskcache.Timeout as e: logger.warning(f"Failed to cache response for `{key}`: {e}") @@ -151,4 +154,3 @@ def _deserialize_response(self, response: tuple[Any, bytes]) -> message.Message except Exception as e: logger.warning(f"Failed to deserialize response: {e}") return None - diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py index c7a537bf8..6b3f0c1fa 100644 --- a/python/lib/sift_py/grpc/cache.py +++ b/python/lib/sift_py/grpc/cache.py @@ -31,8 +31,6 @@ from google.protobuf import json_format, message if TYPE_CHECKING: - - from sift_py.grpc.transport import SiftCacheConfig logger = logging.getLogger(__name__) @@ -45,6 +43,7 @@ class CacheSettings(NamedTuple): force_refresh: bool custom_ttl: float | None + # Metadata keys for cache control METADATA_USE_CACHE = "use-cache" METADATA_FORCE_REFRESH = "force-refresh" @@ -80,7 +79,9 @@ def __init__(self, config: SiftCacheConfig): f"with size {self.volume() / (1024**2):.2f} MB" ) - def set_with_default_ttl(self, key: str, value: Any, expire: float | None = None, **kwargs) -> bool: + def set_with_default_ttl( + self, key: str, value: Any, expire: float | None = None, **kwargs + ) -> bool: expire_time = expire if expire is not None else self.default_ttl return super().set(key, value, expire=expire_time, **kwargs) @@ -100,9 +101,7 @@ def key_from_proto_message(method_name: str | bytes, request: message.Message) - return hasher.hexdigest() @staticmethod - def resolve_cache_metadata( - metadata: tuple[tuple[str, str], ...] | None - ) -> CacheSettings: + def resolve_cache_metadata(metadata: tuple[tuple[str, str], ...] | None) -> CacheSettings: """Extract and resolve cache-related metadata fields. Args: @@ -157,17 +156,17 @@ def resolve_cache_metadata( def with_cache(ttl: int | None = None) -> tuple[tuple[str, str], ...]: """Enable caching for a gRPC request. - + Args: ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. - + Returns: Metadata tuple to pass to the gRPC stub method. - + Example: metadata = with_cache() response = stub.GetData(request, metadata=metadata) - + # With custom TTL metadata = with_cache(ttl=7200) # 2 hours response = stub.GetData(request, metadata=metadata) @@ -180,15 +179,15 @@ def with_cache(ttl: int | None = None) -> tuple[tuple[str, str], ...]: def with_force_refresh(ttl: int | None = None) -> tuple[tuple[str, str], ...]: """Force refresh the cache for a gRPC request. - + Bypasses the cache, fetches fresh data from the server, and stores the result. - + Args: ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. - + Returns: Metadata tuple to pass to the gRPC stub method. - + Example: metadata = with_force_refresh() response = stub.GetData(request, metadata=metadata) diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py index 4bdaf0898..8714ca123 100644 --- a/python/lib/sift_py/grpc/cache_test.py +++ b/python/lib/sift_py/grpc/cache_test.py @@ -29,10 +29,10 @@ # Enable debug logging for cache-related modules logging.basicConfig( - level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) -logging.getLogger('sift_py').setLevel(logging.DEBUG) +logging.getLogger("sift_py").setLevel(logging.DEBUG) + class DataService(DataServiceServicer): """Mock data service that returns a unique response each time.""" @@ -76,7 +76,7 @@ def intercept( @contextmanager def server_with_service(mocker: MockFixture) -> Iterator[tuple[MockType, DataService, int]]: """Create a test server with a spy on the DataService. - + Returns: Tuple of (spy, data_service, port) """ @@ -187,25 +187,19 @@ def test_cache_metadata_resolution(): assert settings.custom_ttl is None # force-refresh enabled - settings = GrpcCache.resolve_cache_metadata( - (("use-cache", "true"), ("force-refresh", "true")) - ) + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"), ("force-refresh", "true"))) assert settings.use_cache is True assert settings.force_refresh is True assert settings.custom_ttl is None # Custom TTL - settings = GrpcCache.resolve_cache_metadata( - (("use-cache", "true"), ("cache-ttl", "7200")) - ) + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"), ("cache-ttl", "7200"))) assert settings.use_cache is True assert settings.force_refresh is False assert settings.custom_ttl == 7200 # Invalid TTL (should be ignored) - settings = GrpcCache.resolve_cache_metadata( - (("use-cache", "true"), ("cache-ttl", "invalid")) - ) + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"), ("cache-ttl", "invalid"))) assert settings.use_cache is True assert settings.custom_ttl is None diff --git a/python/lib/sift_py/grpc/transport.py b/python/lib/sift_py/grpc/transport.py index b868b146c..b8b065130 100644 --- a/python/lib/sift_py/grpc/transport.py +++ b/python/lib/sift_py/grpc/transport.py @@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast from urllib.parse import ParseResult, urlparse -import diskcache import grpc import grpc.aio as grpc_aio from typing_extensions import NotRequired, TypeAlias From a6247972edf754b4ce2432e65baaa400d27e6082 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 20 Oct 2025 17:19:39 -0700 Subject: [PATCH 12/16] remove caching from LLW runs --- .../_internal/low_level_wrappers/runs.py | 28 +++---------------- .../sift_client/transport/grpc_transport.py | 4 --- python/lib/sift_py/grpc/cache_test.py | 2 -- 3 files changed, 4 insertions(+), 30 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/runs.py b/python/lib/sift_client/_internal/low_level_wrappers/runs.py index 65a542e2d..38c020454 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/runs.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/runs.py @@ -41,17 +41,11 @@ def __init__(self, grpc_client: GrpcClient): """ super().__init__(grpc_client) - async def get_run( - self, - run_id: str, - *, - force_refresh: bool = False, - ) -> Run: + async def get_run(self, run_id: str) -> Run: """Get a run by run_id. Args: run_id: The run ID to get. - metadata: Optional gRPC metadata including cache control. Returns: The Run. @@ -60,13 +54,7 @@ async def get_run( ValueError: If run_id is not provided. """ request = GetRunRequest(run_id=run_id) - stub = self._grpc_client.get_stub(RunServiceStub) - response = await self._call_with_cache( - stub.GetRun, - request, - use_cache=self._grpc_client.has_cache, - force_refresh=force_refresh, - ) + response = await self._grpc_client.get_stub(RunServiceStub).GetRun(request) grpc_run = cast("GetRunResponse", response).run return Run._from_proto(grpc_run) @@ -77,7 +65,6 @@ async def list_runs( page_token: str | None = None, query_filter: str | None = None, order_by: str | None = None, - force_refresh: bool = False, ) -> tuple[list[Run], str]: """List runs with optional filtering and pagination. @@ -101,13 +88,7 @@ async def list_runs( request_kwargs["order_by"] = order_by request = ListRunsRequest(**request_kwargs) - stub = self._grpc_client.get_stub(RunServiceStub) - response = await self._call_with_cache( - stub.ListRuns, - request, - use_cache=self._grpc_client.has_cache, - force_refresh=force_refresh, - ) + response = await self._grpc_client.get_stub(RunServiceStub).ListRuns(request) response = cast("ListRunsResponse", response) runs = [Run._from_proto(run) for run in response.runs] @@ -119,7 +100,6 @@ async def list_all_runs( query_filter: str | None = None, order_by: str | None = None, max_results: int | None = None, - force_refresh: bool = False, ) -> list[Run]: """List all runs with optional filtering. @@ -133,7 +113,7 @@ async def list_all_runs( """ return await self._handle_pagination( self.list_runs, - kwargs={"query_filter": query_filter, "force_refresh": force_refresh}, + kwargs={"query_filter": query_filter}, order_by=order_by, max_results=max_results, ) diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 03ea8328e..9a602a1b0 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -227,10 +227,6 @@ def _init_cache(self) -> GrpcCache | None: logger.warning(f"Failed to initialize cache: {e}") return None - @property - def has_cache(self): - return self.cache is not None - @property def default_loop(self) -> asyncio.AbstractEventLoop: """Return the default event loop used for synchronous API operations. diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py index 8714ca123..e0f8227a3 100644 --- a/python/lib/sift_py/grpc/cache_test.py +++ b/python/lib/sift_py/grpc/cache_test.py @@ -383,8 +383,6 @@ async def test_different_requests_different_cache_keys(mocker: MockFixture): async def test_cache_persists_across_channels(mocker: MockFixture): """Test that cache persists across different channel instances.""" with tempfile.TemporaryDirectory() as tmpdir: - cache_path = str(Path(tmpdir) / "cache") - with server_with_service(mocker) as (get_data_spy, data_service, port): config: SiftChannelConfig = { "uri": f"localhost:{port}", From 143707dd9beb21240e3fbbad2aeacc6c4e58942c Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 20 Oct 2025 17:41:00 -0700 Subject: [PATCH 13/16] move around cache defaults and cache folder --- .../sift_client/_tests/resources/test_ping.py | 65 ++++++++++++++++++- python/lib/sift_client/client.py | 9 +-- .../sift_client/transport/grpc_transport.py | 4 +- python/pyproject.toml | 3 +- 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py index 8c51391cc..8a1887249 100644 --- a/python/lib/sift_client/_tests/resources/test_ping.py +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -8,15 +8,38 @@ """ import asyncio +import os import time import pytest -from sift_client import SiftClient +from sift_client import SiftClient, SiftConnectionConfig from sift_client.resources import PingAPI, PingAPIAsync +from sift_client.transport import CacheConfig, CacheMode pytestmark = pytest.mark.integration +# We reimplement this here so that the cache is cleared each time we instantiate +@pytest.fixture +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing. + + This fixture is shared across all test files and is session-scoped + to avoid creating multiple client instances. + """ + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + connection_config=SiftConnectionConfig( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT) + ) + ) + def test_client_binding(sift_client): assert sift_client.ping @@ -225,3 +248,43 @@ async def test_cache_disabled_by_default(self, ping_api_async): assert isinstance(response1, str) assert isinstance(response2, str) assert isinstance(response3, str) + + @pytest.mark.asyncio + async def test_ping_without_grpc_cache(self): + """Test that ping works when GrpcCache is not enabled on the SiftClient.""" + import os + + from sift_client import SiftClient, SiftConnectionConfig + + # Create a client with caching explicitly disabled + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + client = SiftClient( + connection_config=SiftConnectionConfig( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + use_ssl=True, + cache_config=None + ) + ) + + # Verify cache is not initialized + assert client.grpc_client.cache is None + + # Ping should still work without cache + response1 = await client.async_.ping.ping() + assert isinstance(response1, str) + assert len(response1) > 0 + + # Multiple pings should work + response2 = await client.async_.ping.ping() + assert isinstance(response2, str) + + response3 = await client.async_.ping.ping() + assert isinstance(response3, str) + + print(f"\nPing without cache successful: {response1}") + diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index 94d61a17c..82b02e745 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -23,7 +23,6 @@ TestResultsAPIAsync, ) from sift_client.transport import ( - CacheConfig, GrpcClient, GrpcConfig, RestClient, @@ -110,7 +109,6 @@ def __init__( grpc_url: str | None = None, rest_url: str | None = None, connection_config: SiftConnectionConfig | None = None, - cache_config: CacheConfig | None = DEFAULT_CACHE_CONFIG, ): """Initialize the SiftClient with specific connection parameters or a connection_config. @@ -118,8 +116,7 @@ def __init__( api_key: The Sift API key for authentication. grpc_url: The Sift gRPC API URL. rest_url: The Sift REST API URL. - connection_config: A SiftConnectionConfig object to configure the connection behavior of the SiftClient. - cache_config: Optional cache configuration override for gRPC responses. + connection_config: A SiftConnectionConfig object to configure the connection and cache behavior of the SiftClient. """ if not (api_key and grpc_url and rest_url) and not connection_config: raise ValueError( @@ -129,12 +126,10 @@ def __init__( if connection_config: grpc_config = connection_config.get_grpc_config() # Override cache_config if provided directly to SiftClient - if cache_config is not None: - grpc_config.cache_config = cache_config grpc_client = GrpcClient(grpc_config) rest_client = RestClient(connection_config.get_rest_config()) elif api_key and grpc_url and rest_url: - grpc_client = GrpcClient(GrpcConfig(grpc_url, api_key, cache_config=cache_config)) + grpc_client = GrpcClient(GrpcConfig(grpc_url, api_key, cache_config=DEFAULT_CACHE_CONFIG)) rest_client = RestClient(RestConfig(rest_url, api_key)) else: raise ValueError( diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index 9a602a1b0..cd99d2880 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -10,11 +10,11 @@ import atexit import enum import logging -import tempfile import threading from pathlib import Path from typing import Any +from platformdirs import user_cache_dir from sift_py.grpc.cache import GrpcCache from sift_py.grpc.transport import ( SiftCacheConfig, @@ -40,7 +40,7 @@ def _suppress_blocking_io(loop, context): DEFAULT_CACHE_TTL_SECONDS = 7 * 24 * 60 * 60 # 1 week -DEFAULT_CACHE_FOLDER = Path(tempfile.gettempdir()) / "sift_client" +DEFAULT_CACHE_FOLDER = Path(user_cache_dir("sift_client")) DEFAULT_CACHE_SIZE_LIMIT_BYTES = 5 * 1024**3 # 5GB diff --git a/python/pyproject.toml b/python/pyproject.toml index 1a30e27bf..78da390fb 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "types-protobuf>=4.0", "typing-extensions~=4.6", "types-requests~=2.25", - "diskcache~=5.6" + "diskcache~=5.6", + "platformdirs>=4.5.0" ] [project.urls] From e572bc25bdeaed2734a4ae2abfb8f516dbb70e62 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 20 Oct 2025 17:54:46 -0700 Subject: [PATCH 14/16] linting --- .../_internal/low_level_wrappers/base.py | 8 ++-- .../_internal/low_level_wrappers/ping.py | 7 +--- python/lib/sift_client/_tests/conftest.py | 2 +- .../sift_client/_tests/resources/test_ping.py | 38 ++++++++++--------- python/lib/sift_client/client.py | 5 ++- .../grpc/_async_interceptors/caching.py | 10 ++--- python/lib/sift_py/grpc/cache.py | 1 + python/lib/sift_py/grpc/cache_test.py | 14 +------ python/pyproject.toml | 6 +++ 9 files changed, 43 insertions(+), 48 deletions(-) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index e23ce65b4..169734d25 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -1,12 +1,10 @@ from __future__ import annotations from abc import ABC -from typing import Any, Callable, TypeVar +from typing import Any, Awaitable, Callable from sift_py.grpc.cache import with_cache, with_force_refresh -T = TypeVar("T") - class LowLevelClientBase(ABC): @staticmethod @@ -57,13 +55,13 @@ async def _handle_pagination( @staticmethod async def _call_with_cache( - stub_method: Callable[[Any, tuple[tuple[str, str], ...]], T], + stub_method: Callable[..., Awaitable[Any]], request: Any, *, use_cache: bool = True, force_refresh: bool = False, ttl: int | None = None, - ) -> T: + ) -> Any: """Call a gRPC stub method with cache control. This is a convenience method for low-level wrappers to easily enable caching diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ping.py b/python/lib/sift_client/_internal/low_level_wrappers/ping.py index 562cbee24..b28b733f3 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ping.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ping.py @@ -33,7 +33,6 @@ class PingLowLevelClient(LowLevelClientBase, WithGrpcClient): It handles common concerns like error handling and retries. """ - _cache_results: bool """Whether to cache the results of the ping request. Used for testing.""" @@ -52,10 +51,6 @@ async def ping(self, _force_refresh: bool = False) -> str: stub = self._grpc_client.get_stub(PingServiceStub) request = PingRequest() response = await self._call_with_cache( - stub.Ping, - request, - use_cache=self._cache_results, - force_refresh=_force_refresh, - ttl=1 + stub.Ping, request, use_cache=self._cache_results, force_refresh=_force_refresh, ttl=1 ) return cast("PingResponse", response).response diff --git a/python/lib/sift_client/_tests/conftest.py b/python/lib/sift_client/_tests/conftest.py index 7be7c2ade..c2c9e5c4e 100644 --- a/python/lib/sift_client/_tests/conftest.py +++ b/python/lib/sift_client/_tests/conftest.py @@ -27,7 +27,7 @@ def sift_client() -> SiftClient: grpc_url=grpc_url, rest_url=rest_url, use_ssl=True, - cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT) + cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT), ) ) diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py index 8a1887249..cc519f0e4 100644 --- a/python/lib/sift_client/_tests/resources/test_ping.py +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -19,6 +19,7 @@ pytestmark = pytest.mark.integration + # We reimplement this here so that the cache is cleared each time we instantiate @pytest.fixture def sift_client() -> SiftClient: @@ -36,7 +37,7 @@ def sift_client() -> SiftClient: api_key=api_key, grpc_url=grpc_url, rest_url=rest_url, - cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT) + cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT), ) ) @@ -112,14 +113,14 @@ async def test_cache_enabled(self, ping_api_async): assert response2 == response1 # Print timing info - print(f"\nFirst ping (server): {duration1*1000:.2f}ms") - print(f"Second ping (cache): {duration2*1000:.2f}ms") - print(f"Speedup: {duration1/duration2:.2f}x") + print(f"\nFirst ping (server): {duration1 * 1000:.2f}ms") + print(f"Second ping (cache): {duration2 * 1000:.2f}ms") + print(f"Speedup: {duration1 / duration2:.2f}x") # Cached call should be significantly faster (at least 5x) assert duration2 < duration1 / 5, ( f"Cached ping should be much faster. " - f"First: {duration1*1000:.2f}ms, Second: {duration2*1000:.2f}ms" + f"First: {duration1 * 1000:.2f}ms, Second: {duration2 * 1000:.2f}ms" ) # Disable caching for cleanup @@ -150,18 +151,18 @@ async def test_force_refresh_bypasses_cache(self, ping_api_async): assert isinstance(response3, str) # Print timing info - print(f"\nFirst ping (server): {duration1*1000:.2f}ms") - print(f"Second ping (cache): {duration2*1000:.2f}ms") - print(f"Third ping (force_refresh, server): {duration3*1000:.2f}ms") + print(f"\nFirst ping (server): {duration1 * 1000:.2f}ms") + print(f"Second ping (cache): {duration2 * 1000:.2f}ms") + print(f"Third ping (force_refresh, server): {duration3 * 1000:.2f}ms") # Cached call should be much faster than both server calls assert duration2 < duration1 / 5, ( f"Cached ping should be much faster than first ping. " - f"First: {duration1*1000:.2f}ms, Cached: {duration2*1000:.2f}ms" + f"First: {duration1 * 1000:.2f}ms, Cached: {duration2 * 1000:.2f}ms" ) assert duration2 < duration3 / 5, ( f"Cached ping should be much faster than force_refresh ping. " - f"Force refresh: {duration3*1000:.2f}ms, Cached: {duration2*1000:.2f}ms" + f"Force refresh: {duration3 * 1000:.2f}ms, Cached: {duration2 * 1000:.2f}ms" ) # Disable caching for cleanup @@ -215,14 +216,18 @@ async def test_cache_performance(self, ping_api_async): cached_duration = time.perf_counter() - start_time # Print performance metrics - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"Ping Cache Performance ({num_iterations} iterations)") - print(f"{'='*60}") - print(f"Cached duration: {cached_duration:.4f}s ({cached_duration/num_iterations*1000:.2f}ms per call)") - print(f"Uncached duration: {uncached_duration:.4f}s ({uncached_duration/num_iterations*1000:.2f}ms per call)") + print(f"{'=' * 60}") + print( + f"Cached duration: {cached_duration:.4f}s ({cached_duration / num_iterations * 1000:.2f}ms per call)" + ) + print( + f"Uncached duration: {uncached_duration:.4f}s ({uncached_duration / num_iterations * 1000:.2f}ms per call)" + ) print(f"Speedup: {uncached_duration / cached_duration:.2f}x") print(f"Time saved: {uncached_duration - cached_duration:.4f}s") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Assert that cached is faster assert cached_duration < uncached_duration, ( @@ -267,7 +272,7 @@ async def test_ping_without_grpc_cache(self): grpc_url=grpc_url, rest_url=rest_url, use_ssl=True, - cache_config=None + cache_config=None, ) ) @@ -287,4 +292,3 @@ async def test_ping_without_grpc_cache(self): assert isinstance(response3, str) print(f"\nPing without cache successful: {response1}") - diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index 82b02e745..c3a75045a 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -36,6 +36,7 @@ _sift_client_experimental_warning() + class SiftClient( WithGrpcClient, WithRestClient, @@ -129,7 +130,9 @@ def __init__( grpc_client = GrpcClient(grpc_config) rest_client = RestClient(connection_config.get_rest_config()) elif api_key and grpc_url and rest_url: - grpc_client = GrpcClient(GrpcConfig(grpc_url, api_key, cache_config=DEFAULT_CACHE_CONFIG)) + grpc_client = GrpcClient( + GrpcConfig(grpc_url, api_key, cache_config=DEFAULT_CACHE_CONFIG) + ) rest_client = RestClient(RestConfig(rest_url, api_key)) else: raise ValueError( diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py index 3b45ed095..2c3a866c7 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/caching.py +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -122,14 +122,12 @@ async def intercept( response = await call # Cache the response if allowed - if cache_settings.use_cache: + if cache_settings.use_cache and response is not None: try: # Serialize the protobuf response to bytes before caching - cached_data = self._serialize_response(response) - if cached_data is not None: - self.cache.set_with_default_ttl( - key, cached_data, expire=cache_settings.custom_ttl - ) + new_data = self._serialize_response(response) + if new_data is not None: + self.cache.set_with_default_ttl(key, new_data, expire=cache_settings.custom_ttl) logger.debug(f"Cached response for `{key}`") except diskcache.Timeout as e: logger.warning(f"Failed to cache response for `{key}`: {e}") diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py index 6b3f0c1fa..f327a6cbc 100644 --- a/python/lib/sift_py/grpc/cache.py +++ b/python/lib/sift_py/grpc/cache.py @@ -123,6 +123,7 @@ def resolve_cache_metadata(metadata: tuple[tuple[str, str], ...] | None) -> Cach if cache_info.should_cache: cache.set_with_default_ttl(key, response, expire=cache_info.custom_ttl) """ + metadata_dict: dict[str, str] if not metadata: metadata_dict = {} else: diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py index e0f8227a3..35a06b161 100644 --- a/python/lib/sift_py/grpc/cache_test.py +++ b/python/lib/sift_py/grpc/cache_test.py @@ -20,10 +20,8 @@ from sift_py._internal.test_util.server_interceptor import ServerInterceptor from sift_py.grpc.cache import ( GrpcCache, - ignore_cache, with_cache, with_force_refresh, - without_cache, ) from sift_py.grpc.transport import SiftChannelConfig, use_sift_async_channel @@ -117,14 +115,6 @@ def test_cache_helper_functions(): metadata = with_force_refresh(ttl=3600) assert metadata == (("use-cache", "true"), ("force-refresh", "true"), ("cache-ttl", "3600")) - # Test ignore_cache - metadata = ignore_cache() - assert metadata == () - - # Test without_cache - metadata = without_cache() - assert metadata == () - def test_grpc_cache_initialization(): """Test GrpcCache initialization and configuration.""" @@ -324,8 +314,8 @@ async def test_ignore_cache(mocker: MockFixture): assert res1.next_page_token == "token-1" assert data_service.call_count == 1 - # Call with ignore_cache - should hit server - res2 = cast(GetDataResponse, await stub.GetData(request, metadata=ignore_cache())) + # Call with no metadata - should hit server + res2 = cast(GetDataResponse, await stub.GetData(request)) assert res2.next_page_token == "token-2" assert data_service.call_count == 2 diff --git a/python/pyproject.toml b/python/pyproject.toml index 78da390fb..9fe3be0bf 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -151,6 +151,12 @@ module = "requests_toolbelt" ignore_missing_imports = true ignore_errors = true +[[tool.mypy.overrides]] +module = "diskcache" +ignore_missing_imports = true +ignore_errors = true + + [tool.setuptools.packages.find] where = ["lib"] From 02f3525d0774dfee8e506c4462709cbdfdf838a1 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 20 Oct 2025 17:57:37 -0700 Subject: [PATCH 15/16] fix dependency --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 9fe3be0bf..bbbaccc71 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "typing-extensions~=4.6", "types-requests~=2.25", "diskcache~=5.6", - "platformdirs>=4.5.0" + "platformdirs~=4.0" ] [project.urls] From 8a48a918df4d9b385caadf1c51d7c94f9df30e87 Mon Sep 17 00:00:00 2001 From: Alex Luck Date: Mon, 20 Oct 2025 18:01:02 -0700 Subject: [PATCH 16/16] fix test --- python/lib/sift_py/grpc/cache_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py index 35a06b161..e060a6ff3 100644 --- a/python/lib/sift_py/grpc/cache_test.py +++ b/python/lib/sift_py/grpc/cache_test.py @@ -5,11 +5,11 @@ from concurrent import futures from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Iterator, cast +from typing import Any, Callable, cast import grpc import pytest -from pytest_mock import MockFixture, MockType +from pytest_mock import MockFixture from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse from sift.data.v2.data_pb2_grpc import ( DataServiceServicer, @@ -72,7 +72,7 @@ def intercept( @contextmanager -def server_with_service(mocker: MockFixture) -> Iterator[tuple[MockType, DataService, int]]: +def server_with_service(mocker: MockFixture): """Create a test server with a spy on the DataService. Returns: