diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 2aca35a..4208b5c 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.5.0" + ".": "0.6.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 78b59ef..737a29d 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 7 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/meta%2Fllama-api-edf0a308dd29bea2feb29f2e7f04eec4dbfb130ffe52511641783958168f60a4.yml openapi_spec_hash: 23af966c58151516aaef00e0af602c01 -config_hash: 431a8aed31c3576451a36d2db8f48c25 +config_hash: 416c3d950e58dbdb47588eaf29fa9fa5 diff --git a/CHANGELOG.md b/CHANGELOG.md index c2f7667..4724323 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,41 @@ # Changelog +## 0.6.0 (2025-12-18) + +Full Changelog: [v0.5.0...v0.6.0](https://github.com/meta-llama/llama-api-python/compare/v0.5.0...v0.6.0) + +### Features + +* **api:** manual updates ([345a9f0](https://github.com/meta-llama/llama-api-python/commit/345a9f0554fe30991f936078a3242e44c0cca302)) +* **api:** manual updates ([a87c139](https://github.com/meta-llama/llama-api-python/commit/a87c139abe8dc412688855d2ea8226e02d3d1376)) +* **api:** manual updates ([2f2aed5](https://github.com/meta-llama/llama-api-python/commit/2f2aed50d52ca6d1e23523a6d4a2469445feb088)) + + +### Bug Fixes + +* **client:** close streams without requiring full consumption ([cb29768](https://github.com/meta-llama/llama-api-python/commit/cb29768aa1a4c3d9e8e94ef28cfa40a856618fb4)) +* compat with Python 3.14 ([009ca0d](https://github.com/meta-llama/llama-api-python/commit/009ca0d914ec813285e1f195d645871f9cd3d6df)) +* **compat:** update signatures of `model_dump` and `model_dump_json` for Pydantic v1 ([eb03850](https://github.com/meta-llama/llama-api-python/commit/eb03850c5905d443da89b71fe8306af8cf5d7062)) +* ensure streams are always closed ([4b8f9b7](https://github.com/meta-llama/llama-api-python/commit/4b8f9b7b7f63e0d72daf9bd24c3f12c424040c6d)) +* **types:** allow pyright to infer TypedDict types within SequenceNotStr ([aab06ad](https://github.com/meta-llama/llama-api-python/commit/aab06adc22ed41bd16af636c3bc94e08b9bf2c82)) +* use async_to_httpx_files in patch method ([7a5d301](https://github.com/meta-llama/llama-api-python/commit/7a5d3019d53edd2a3b92c8aa91971aa3421ae758)) + + +### Chores + +* add missing docstrings ([edaf4a2](https://github.com/meta-llama/llama-api-python/commit/edaf4a2677b2c2a5d4b89a96ac1de289430c6957)) +* bump `httpx-aiohttp` version to 0.1.9 ([29874b0](https://github.com/meta-llama/llama-api-python/commit/29874b0abe332ac6c10d44fa93088bd13a4b793f)) +* **deps:** mypy 1.18.1 has a regression, pin to 1.17 ([8c1ce31](https://github.com/meta-llama/llama-api-python/commit/8c1ce316a22980fd33b11bacb6f23d3166322f13)) +* **docs:** use environment variables for authentication in code snippets ([e04fade](https://github.com/meta-llama/llama-api-python/commit/e04fade6f4b6c8ce315e337a7f58b7b72a981a28)) +* **internal/tests:** avoid race condition with implicit client cleanup ([1b7f280](https://github.com/meta-llama/llama-api-python/commit/1b7f2809275a7c39377d5841dd77e52bad1476ed)) +* **internal:** add missing files argument to base client ([9223e75](https://github.com/meta-llama/llama-api-python/commit/9223e753f32938fc39e0866ce8f86d4fbcef37ec)) +* **internal:** codegen related update ([3b9a132](https://github.com/meta-llama/llama-api-python/commit/3b9a132f8271d76851669aedab9ee880806a83e8)) +* **internal:** detect missing future annotations with ruff ([b4cccf1](https://github.com/meta-llama/llama-api-python/commit/b4cccf14ca10e965177a0904692646da0c2892f0)) +* **internal:** grammar fix (it's -> its) ([46f738d](https://github.com/meta-llama/llama-api-python/commit/46f738d2b1bf8c210607abe2d1acc46f8361895d)) +* **package:** drop Python 3.8 support ([785c446](https://github.com/meta-llama/llama-api-python/commit/785c4468206841d6c9d172b2733b0bcf053dcce6)) +* speedup initial import ([cc25ae3](https://github.com/meta-llama/llama-api-python/commit/cc25ae390cb3e10aa517638ca25ba60b9e6b8b07)) +* update lockfile ([2b01e25](https://github.com/meta-llama/llama-api-python/commit/2b01e25f540a589b7bedef48a3f18e0d4bba8d7d)) + ## 0.5.0 (2025-10-01) Full Changelog: [v0.4.0...v0.5.0](https://github.com/meta-llama/llama-api-python/compare/v0.4.0...v0.5.0) diff --git a/README.md b/README.md index 94c2456..6b4ff79 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![PyPI version](https://img.shields.io/pypi/v/llama_api_client.svg?label=pypi%20(stable))](https://pypi.org/project/llama_api_client/) -The Llama API Client Python library provides convenient access to the Llama API Client REST API from any Python 3.8+ +The Llama API Client Python library provides convenient access to the Llama API Client REST API from any Python 3.9+ application. The library includes type definitions for all request params and response fields, and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). @@ -94,6 +94,7 @@ pip install 'llama_api_client[aiohttp] @ git+ssh://git@github.com/meta-llama/lla Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: ```python +import os import asyncio from llama_api_client import DefaultAioHttpClient from llama_api_client import AsyncLlamaAPIClient @@ -101,7 +102,7 @@ from llama_api_client import AsyncLlamaAPIClient async def main() -> None: async with AsyncLlamaAPIClient( - api_key="My API Key", + api_key=os.environ.get("LLAMA_API_KEY"), # This is the default and can be omitted http_client=DefaultAioHttpClient(), ) as client: create_chat_completion_response = await client.chat.completions.create( @@ -281,7 +282,7 @@ client.with_options(max_retries=5).chat.completions.create( ### Timeouts -By default requests time out after 1 minute. You can configure this with a `timeout` option, +By default requests time out after 10 minutes. You can configure this with a `timeout` option, which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/timeouts/#fine-tuning-the-configuration) object: ```python @@ -289,7 +290,7 @@ from llama_api_client import LlamaAPIClient # Configure the default for all requests: client = LlamaAPIClient( - # 20 seconds (default is 1 minute) + # 20 seconds (default is 10 minutes) timeout=20.0, ) @@ -490,7 +491,7 @@ print(llama_api_client.__version__) ## Requirements -Python 3.8 or higher. +Python 3.9 or higher. ## Contributing diff --git a/pyproject.toml b/pyproject.toml index 2c175b5..936914d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,30 +1,32 @@ [project] name = "llama_api_client" -version = "0.5.0" +version = "0.6.0" description = "The official Python library for the llama-api-client API" dynamic = ["readme"] license = "MIT" authors = [ { name = "Llama API Client", email = "support@llama.developer.meta.com" }, ] + dependencies = [ - "httpx>=0.23.0, <1", - "pydantic>=1.9.0, <3", - "typing-extensions>=4.10, <5", - "anyio>=3.5.0, <5", - "distro>=1.7.0, <2", - "sniffio", + "httpx>=0.23.0, <1", + "pydantic>=1.9.0, <3", + "typing-extensions>=4.10, <5", + "anyio>=3.5.0, <5", + "distro>=1.7.0, <2", + "sniffio", ] -requires-python = ">= 3.8" + +requires-python = ">= 3.9" classifiers = [ "Typing :: Typed", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Operating System :: POSIX", "Operating System :: MacOS", @@ -39,14 +41,14 @@ Homepage = "https://github.com/meta-llama/llama-api-python" Repository = "https://github.com/meta-llama/llama-api-python" [project.optional-dependencies] -aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.8"] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] [tool.rye] managed = true # version pins are in requirements-dev.lock dev-dependencies = [ "pyright==1.1.399", - "mypy", + "mypy==1.17", "respx", "pytest", "pytest-asyncio", @@ -141,7 +143,7 @@ filterwarnings = [ # there are a couple of flags that are still disabled by # default in strict mode as they are experimental and niche. typeCheckingMode = "strict" -pythonVersion = "3.8" +pythonVersion = "3.9" exclude = [ "_dev", @@ -224,6 +226,8 @@ select = [ "B", # remove unused imports "F401", + # check for missing future annotations + "FA102", # bare except statements "E722", # unused arguments @@ -246,6 +250,8 @@ unfixable = [ "T203", ] +extend-safe-fixes = ["FA102"] + [tool.ruff.lint.flake8-tidy-imports.banned-api] "functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" diff --git a/requirements-dev.lock b/requirements-dev.lock index b071722..da47431 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -12,40 +12,45 @@ -e file:. aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.12.8 +aiohttp==3.13.2 # via httpx-aiohttp # via llama-api-client -aiosignal==1.3.2 +aiosignal==1.4.0 # via aiohttp -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.4.0 +anyio==4.12.0 # via httpx # via llama-api-client -argcomplete==3.1.2 +argcomplete==3.6.3 # via nox async-timeout==5.0.1 # via aiohttp -attrs==25.3.0 +attrs==25.4.0 # via aiohttp -certifi==2023.7.22 + # via nox +backports-asyncio-runner==1.2.0 + # via pytest-asyncio +certifi==2025.11.12 # via httpcore # via httpx -colorlog==6.7.0 +colorlog==6.10.1 + # via nox +dependency-groups==1.3.1 # via nox -dirty-equals==0.6.0 -distlib==0.3.7 +dirty-equals==0.11 +distlib==0.4.0 # via virtualenv -distro==1.8.0 +distro==1.9.0 # via llama-api-client -exceptiongroup==1.2.2 +exceptiongroup==1.3.1 # via anyio # via pytest -execnet==2.1.1 +execnet==2.1.2 # via pytest-xdist -filelock==3.12.4 +filelock==3.19.1 # via virtualenv -frozenlist==1.6.2 +frozenlist==1.8.0 # via aiohttp # via aiosignal h11==0.16.0 @@ -56,82 +61,89 @@ httpx==0.28.1 # via httpx-aiohttp # via llama-api-client # via respx -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via llama-api-client -idna==3.4 +humanize==4.13.0 + # via nox +idna==3.11 # via anyio # via httpx # via yarl -importlib-metadata==7.0.0 -iniconfig==2.0.0 +importlib-metadata==8.7.0 +iniconfig==2.1.0 # via pytest markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py -multidict==6.4.4 +multidict==6.7.0 # via aiohttp # via yarl -mypy==1.14.1 -mypy-extensions==1.0.0 +mypy==1.17.0 +mypy-extensions==1.1.0 # via mypy -nodeenv==1.8.0 +nodeenv==1.9.1 # via pyright -nox==2023.4.22 -packaging==23.2 +nox==2025.11.12 +packaging==25.0 + # via dependency-groups # via nox # via pytest -platformdirs==3.11.0 +pathspec==0.12.1 + # via mypy +platformdirs==4.4.0 # via virtualenv -pluggy==1.5.0 +pluggy==1.6.0 # via pytest -propcache==0.3.1 +propcache==0.4.1 # via aiohttp # via yarl -pydantic==2.11.9 +pydantic==2.12.5 # via llama-api-client -pydantic-core==2.33.2 +pydantic-core==2.41.5 # via pydantic -pygments==2.18.0 +pygments==2.19.2 + # via pytest # via rich pyright==1.1.399 -pytest==8.3.3 +pytest==8.4.2 # via pytest-asyncio # via pytest-xdist -pytest-asyncio==0.24.0 -pytest-xdist==3.7.0 -python-dateutil==2.8.2 +pytest-asyncio==1.2.0 +pytest-xdist==3.8.0 +python-dateutil==2.9.0.post0 # via time-machine -pytz==2023.3.post1 - # via dirty-equals respx==0.22.0 -rich==13.7.1 -ruff==0.9.4 -setuptools==68.2.2 - # via nodeenv -six==1.16.0 +rich==14.2.0 +ruff==0.14.7 +six==1.17.0 # via python-dateutil -sniffio==1.3.0 - # via anyio +sniffio==1.3.1 # via llama-api-client -time-machine==2.9.0 -tomli==2.0.2 +time-machine==2.19.0 +tomli==2.3.0 + # via dependency-groups # via mypy + # via nox # via pytest -typing-extensions==4.12.2 +typing-extensions==4.15.0 + # via aiosignal # via anyio + # via exceptiongroup # via llama-api-client # via multidict # via mypy # via pydantic # via pydantic-core # via pyright + # via pytest-asyncio # via typing-inspection -typing-inspection==0.4.1 + # via virtualenv +typing-inspection==0.4.2 # via pydantic -virtualenv==20.24.5 +virtualenv==20.35.4 # via nox -yarl==1.20.0 +yarl==1.22.0 # via aiohttp -zipp==3.17.0 +zipp==3.23.0 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index d31d6a3..3c3614f 100644 --- a/requirements.lock +++ b/requirements.lock @@ -12,28 +12,28 @@ -e file:. aiohappyeyeballs==2.6.1 # via aiohttp -aiohttp==3.12.8 +aiohttp==3.13.2 # via httpx-aiohttp # via llama-api-client -aiosignal==1.3.2 +aiosignal==1.4.0 # via aiohttp -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.4.0 +anyio==4.12.0 # via httpx # via llama-api-client async-timeout==5.0.1 # via aiohttp -attrs==25.3.0 +attrs==25.4.0 # via aiohttp -certifi==2023.7.22 +certifi==2025.11.12 # via httpcore # via httpx -distro==1.8.0 +distro==1.9.0 # via llama-api-client -exceptiongroup==1.2.2 +exceptiongroup==1.3.1 # via anyio -frozenlist==1.6.2 +frozenlist==1.8.0 # via aiohttp # via aiosignal h11==0.16.0 @@ -43,33 +43,34 @@ httpcore==1.0.9 httpx==0.28.1 # via httpx-aiohttp # via llama-api-client -httpx-aiohttp==0.1.8 +httpx-aiohttp==0.1.9 # via llama-api-client -idna==3.4 +idna==3.11 # via anyio # via httpx # via yarl -multidict==6.4.4 +multidict==6.7.0 # via aiohttp # via yarl -propcache==0.3.1 +propcache==0.4.1 # via aiohttp # via yarl -pydantic==2.11.9 +pydantic==2.12.5 # via llama-api-client -pydantic-core==2.33.2 +pydantic-core==2.41.5 # via pydantic -sniffio==1.3.0 - # via anyio +sniffio==1.3.1 # via llama-api-client -typing-extensions==4.12.2 +typing-extensions==4.15.0 + # via aiosignal # via anyio + # via exceptiongroup # via llama-api-client # via multidict # via pydantic # via pydantic-core # via typing-inspection -typing-inspection==0.4.1 +typing-inspection==0.4.2 # via pydantic -yarl==1.20.0 +yarl==1.22.0 # via aiohttp diff --git a/src/llama_api_client/_base_client.py b/src/llama_api_client/_base_client.py index 7701c0d..0a31082 100644 --- a/src/llama_api_client/_base_client.py +++ b/src/llama_api_client/_base_client.py @@ -1247,9 +1247,12 @@ def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, files=to_httpx_files(files), **options + ) return self.request(cast_to, opts) def put( @@ -1767,9 +1770,12 @@ async def patch( *, cast_to: Type[ResponseT], body: Body | None = None, + files: RequestFiles | None = None, options: RequestOptions = {}, ) -> ResponseT: - opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, files=await async_to_httpx_files(files), **options + ) return await self.request(cast_to, opts) async def put( diff --git a/src/llama_api_client/_client.py b/src/llama_api_client/_client.py index f7bb718..5b11e67 100644 --- a/src/llama_api_client/_client.py +++ b/src/llama_api_client/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping from typing_extensions import Self, override import httpx @@ -20,8 +20,8 @@ not_given, ) from ._utils import is_given, get_async_library +from ._compat import cached_property from ._version import __version__ -from .resources import models, uploads, moderations from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import APIStatusError, LlamaAPIClientError from ._base_client import ( @@ -29,7 +29,13 @@ SyncAPIClient, AsyncAPIClient, ) -from .resources.chat import chat + +if TYPE_CHECKING: + from .resources import chat, models, uploads, moderations + from .resources.models import ModelsResource, AsyncModelsResource + from .resources.uploads import UploadsResource, AsyncUploadsResource + from .resources.chat.chat import ChatResource, AsyncChatResource + from .resources.moderations import ModerationsResource, AsyncModerationsResource __all__ = [ "Timeout", @@ -44,13 +50,6 @@ class LlamaAPIClient(SyncAPIClient): - chat: chat.ChatResource - models: models.ModelsResource - uploads: uploads.UploadsResource - moderations: moderations.ModerationsResource - with_raw_response: LlamaAPIClientWithRawResponse - with_streaming_response: LlamaAPIClientWithStreamedResponse - # client options api_key: str @@ -105,12 +104,37 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.chat = chat.ChatResource(self) - self.models = models.ModelsResource(self) - self.uploads = uploads.UploadsResource(self) - self.moderations = moderations.ModerationsResource(self) - self.with_raw_response = LlamaAPIClientWithRawResponse(self) - self.with_streaming_response = LlamaAPIClientWithStreamedResponse(self) + @cached_property + def chat(self) -> ChatResource: + from .resources.chat import ChatResource + + return ChatResource(self) + + @cached_property + def models(self) -> ModelsResource: + from .resources.models import ModelsResource + + return ModelsResource(self) + + @cached_property + def uploads(self) -> UploadsResource: + from .resources.uploads import UploadsResource + + return UploadsResource(self) + + @cached_property + def moderations(self) -> ModerationsResource: + from .resources.moderations import ModerationsResource + + return ModerationsResource(self) + + @cached_property + def with_raw_response(self) -> LlamaAPIClientWithRawResponse: + return LlamaAPIClientWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> LlamaAPIClientWithStreamedResponse: + return LlamaAPIClientWithStreamedResponse(self) @property @override @@ -218,13 +242,6 @@ def _make_status_error( class AsyncLlamaAPIClient(AsyncAPIClient): - chat: chat.AsyncChatResource - models: models.AsyncModelsResource - uploads: uploads.AsyncUploadsResource - moderations: moderations.AsyncModerationsResource - with_raw_response: AsyncLlamaAPIClientWithRawResponse - with_streaming_response: AsyncLlamaAPIClientWithStreamedResponse - # client options api_key: str @@ -279,12 +296,37 @@ def __init__( _strict_response_validation=_strict_response_validation, ) - self.chat = chat.AsyncChatResource(self) - self.models = models.AsyncModelsResource(self) - self.uploads = uploads.AsyncUploadsResource(self) - self.moderations = moderations.AsyncModerationsResource(self) - self.with_raw_response = AsyncLlamaAPIClientWithRawResponse(self) - self.with_streaming_response = AsyncLlamaAPIClientWithStreamedResponse(self) + @cached_property + def chat(self) -> AsyncChatResource: + from .resources.chat import AsyncChatResource + + return AsyncChatResource(self) + + @cached_property + def models(self) -> AsyncModelsResource: + from .resources.models import AsyncModelsResource + + return AsyncModelsResource(self) + + @cached_property + def uploads(self) -> AsyncUploadsResource: + from .resources.uploads import AsyncUploadsResource + + return AsyncUploadsResource(self) + + @cached_property + def moderations(self) -> AsyncModerationsResource: + from .resources.moderations import AsyncModerationsResource + + return AsyncModerationsResource(self) + + @cached_property + def with_raw_response(self) -> AsyncLlamaAPIClientWithRawResponse: + return AsyncLlamaAPIClientWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncLlamaAPIClientWithStreamedResponse: + return AsyncLlamaAPIClientWithStreamedResponse(self) @property @override @@ -392,35 +434,127 @@ def _make_status_error( class LlamaAPIClientWithRawResponse: + _client: LlamaAPIClient + def __init__(self, client: LlamaAPIClient) -> None: - self.chat = chat.ChatResourceWithRawResponse(client.chat) - self.models = models.ModelsResourceWithRawResponse(client.models) - self.uploads = uploads.UploadsResourceWithRawResponse(client.uploads) - self.moderations = moderations.ModerationsResourceWithRawResponse(client.moderations) + self._client = client + + @cached_property + def chat(self) -> chat.ChatResourceWithRawResponse: + from .resources.chat import ChatResourceWithRawResponse + + return ChatResourceWithRawResponse(self._client.chat) + + @cached_property + def models(self) -> models.ModelsResourceWithRawResponse: + from .resources.models import ModelsResourceWithRawResponse + + return ModelsResourceWithRawResponse(self._client.models) + + @cached_property + def uploads(self) -> uploads.UploadsResourceWithRawResponse: + from .resources.uploads import UploadsResourceWithRawResponse + + return UploadsResourceWithRawResponse(self._client.uploads) + + @cached_property + def moderations(self) -> moderations.ModerationsResourceWithRawResponse: + from .resources.moderations import ModerationsResourceWithRawResponse + + return ModerationsResourceWithRawResponse(self._client.moderations) class AsyncLlamaAPIClientWithRawResponse: + _client: AsyncLlamaAPIClient + def __init__(self, client: AsyncLlamaAPIClient) -> None: - self.chat = chat.AsyncChatResourceWithRawResponse(client.chat) - self.models = models.AsyncModelsResourceWithRawResponse(client.models) - self.uploads = uploads.AsyncUploadsResourceWithRawResponse(client.uploads) - self.moderations = moderations.AsyncModerationsResourceWithRawResponse(client.moderations) + self._client = client + + @cached_property + def chat(self) -> chat.AsyncChatResourceWithRawResponse: + from .resources.chat import AsyncChatResourceWithRawResponse + + return AsyncChatResourceWithRawResponse(self._client.chat) + + @cached_property + def models(self) -> models.AsyncModelsResourceWithRawResponse: + from .resources.models import AsyncModelsResourceWithRawResponse + + return AsyncModelsResourceWithRawResponse(self._client.models) + + @cached_property + def uploads(self) -> uploads.AsyncUploadsResourceWithRawResponse: + from .resources.uploads import AsyncUploadsResourceWithRawResponse + + return AsyncUploadsResourceWithRawResponse(self._client.uploads) + + @cached_property + def moderations(self) -> moderations.AsyncModerationsResourceWithRawResponse: + from .resources.moderations import AsyncModerationsResourceWithRawResponse + + return AsyncModerationsResourceWithRawResponse(self._client.moderations) class LlamaAPIClientWithStreamedResponse: + _client: LlamaAPIClient + def __init__(self, client: LlamaAPIClient) -> None: - self.chat = chat.ChatResourceWithStreamingResponse(client.chat) - self.models = models.ModelsResourceWithStreamingResponse(client.models) - self.uploads = uploads.UploadsResourceWithStreamingResponse(client.uploads) - self.moderations = moderations.ModerationsResourceWithStreamingResponse(client.moderations) + self._client = client + + @cached_property + def chat(self) -> chat.ChatResourceWithStreamingResponse: + from .resources.chat import ChatResourceWithStreamingResponse + + return ChatResourceWithStreamingResponse(self._client.chat) + + @cached_property + def models(self) -> models.ModelsResourceWithStreamingResponse: + from .resources.models import ModelsResourceWithStreamingResponse + + return ModelsResourceWithStreamingResponse(self._client.models) + + @cached_property + def uploads(self) -> uploads.UploadsResourceWithStreamingResponse: + from .resources.uploads import UploadsResourceWithStreamingResponse + + return UploadsResourceWithStreamingResponse(self._client.uploads) + + @cached_property + def moderations(self) -> moderations.ModerationsResourceWithStreamingResponse: + from .resources.moderations import ModerationsResourceWithStreamingResponse + + return ModerationsResourceWithStreamingResponse(self._client.moderations) class AsyncLlamaAPIClientWithStreamedResponse: + _client: AsyncLlamaAPIClient + def __init__(self, client: AsyncLlamaAPIClient) -> None: - self.chat = chat.AsyncChatResourceWithStreamingResponse(client.chat) - self.models = models.AsyncModelsResourceWithStreamingResponse(client.models) - self.uploads = uploads.AsyncUploadsResourceWithStreamingResponse(client.uploads) - self.moderations = moderations.AsyncModerationsResourceWithStreamingResponse(client.moderations) + self._client = client + + @cached_property + def chat(self) -> chat.AsyncChatResourceWithStreamingResponse: + from .resources.chat import AsyncChatResourceWithStreamingResponse + + return AsyncChatResourceWithStreamingResponse(self._client.chat) + + @cached_property + def models(self) -> models.AsyncModelsResourceWithStreamingResponse: + from .resources.models import AsyncModelsResourceWithStreamingResponse + + return AsyncModelsResourceWithStreamingResponse(self._client.models) + + @cached_property + def uploads(self) -> uploads.AsyncUploadsResourceWithStreamingResponse: + from .resources.uploads import AsyncUploadsResourceWithStreamingResponse + + return AsyncUploadsResourceWithStreamingResponse(self._client.uploads) + + @cached_property + def moderations(self) -> moderations.AsyncModerationsResourceWithStreamingResponse: + from .resources.moderations import AsyncModerationsResourceWithStreamingResponse + + return AsyncModerationsResourceWithStreamingResponse(self._client.moderations) Client = LlamaAPIClient diff --git a/src/llama_api_client/_constants.py b/src/llama_api_client/_constants.py index 6ddf2c7..0d2274d 100644 --- a/src/llama_api_client/_constants.py +++ b/src/llama_api_client/_constants.py @@ -5,8 +5,8 @@ RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to" -# default timeout is 1 minute -DEFAULT_TIMEOUT = httpx.Timeout(timeout=60, connect=5.0) +# default timeout is 10 minutes +DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0) DEFAULT_MAX_RETRIES = 2 DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) diff --git a/src/llama_api_client/_models.py b/src/llama_api_client/_models.py index 6a3cd1d..ca9500b 100644 --- a/src/llama_api_client/_models.py +++ b/src/llama_api_client/_models.py @@ -2,6 +2,7 @@ import os import inspect +import weakref from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast from datetime import date, datetime from typing_extensions import ( @@ -256,15 +257,16 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, + context: Any | None = None, by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, - serialize_as_any: bool = False, fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -272,16 +274,24 @@ def model_dump( Args: mode: The mode in which `to_python` should run. - If mode is 'json', the dictionary will only contain JSON serializable types. - If mode is 'python', the dictionary may contain any Python objects. - include: A list of fields to include in the output. - exclude: A list of fields to exclude from the output. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + include: A set of fields to include in the output. + exclude: A set of fields to exclude from the output. + context: Additional context to pass to the serializer. by_alias: Whether to use the field's alias in the dictionary key if defined. - exclude_unset: Whether to exclude fields that are unset or None from the output. - exclude_defaults: Whether to exclude fields that are set to their default value from the output. - exclude_none: Whether to exclude fields that have a value of `None` from the output. - round_trip: Whether to enable serialization and deserialization round-trip support. - warnings: Whether to log warnings when invalid fields are encountered. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + exclude_computed_fields: Whether to exclude computed fields. + While this can be useful for round-tripping, it is usually recommended to use the dedicated + `round_trip` parameter instead. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, + "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + fallback: A function to call when an unknown value is encountered. If not provided, + a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. Returns: A dictionary representation of the model. @@ -298,6 +308,8 @@ def model_dump( raise ValueError("serialize_as_any is only supported in Pydantic v2") if fallback is not None: raise ValueError("fallback is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, @@ -314,15 +326,17 @@ def model_dump_json( self, *, indent: int | None = None, + ensure_ascii: bool = False, include: IncEx | None = None, exclude: IncEx | None = None, + context: Any | None = None, by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, + exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, - context: dict[str, Any] | None = None, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: @@ -354,6 +368,10 @@ def model_dump_json( raise ValueError("serialize_as_any is only supported in Pydantic v2") if fallback is not None: raise ValueError("fallback is only supported in Pydantic v2") + if ensure_ascii != False: + raise ValueError("ensure_ascii is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, @@ -573,6 +591,9 @@ class CachedDiscriminatorType(Protocol): __discriminator__: DiscriminatorDetails +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() + + class DiscriminatorDetails: field_name: str """The name of the discriminator field in the variant class, e.g. @@ -615,8 +636,9 @@ def __init__( def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ + cached = DISCRIMINATOR_CACHE.get(union) + if cached is not None: + return cached discriminator_field_name: str | None = None @@ -669,7 +691,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, discriminator_field=discriminator_field_name, discriminator_alias=discriminator_alias, ) - cast(CachedDiscriminatorType, union).__discriminator__ = details + DISCRIMINATOR_CACHE.setdefault(union, details) return details diff --git a/src/llama_api_client/_streaming.py b/src/llama_api_client/_streaming.py index 65b2083..6f96434 100644 --- a/src/llama_api_client/_streaming.py +++ b/src/llama_api_client/_streaming.py @@ -55,18 +55,18 @@ def __stream__(self) -> Iterator[_T]: process_data = self._client._process_response_data iterator = self._iter_events() - for sse in iterator: - if sse.event == "error": - raise APIError( - message=sse.data, - request=response.request, - body=sse.json(), - ) - yield process_data(data=sse.json(), cast_to=cast_to, response=response) - - # Ensure the entire stream is consumed - for _sse in iterator: - ... + try: + for sse in iterator: + if sse.event == "error": + raise APIError( + message=sse.data, + request=response.request, + body=sse.json(), + ) + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + response.close() def __enter__(self) -> Self: return self @@ -125,18 +125,18 @@ async def __stream__(self) -> AsyncIterator[_T]: process_data = self._client._process_response_data iterator = self._iter_events() - async for sse in iterator: - if sse.event == "error": - raise APIError( - message=sse.data, - request=response.request, - body=sse.json(), - ) - yield process_data(data=sse.json(), cast_to=cast_to, response=response) - - # Ensure the entire stream is consumed - async for _sse in iterator: - ... + try: + async for sse in iterator: + if sse.event == "error": + raise APIError( + message=sse.data, + request=response.request, + body=sse.json(), + ) + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + await response.aclose() async def __aenter__(self) -> Self: return self diff --git a/src/llama_api_client/_types.py b/src/llama_api_client/_types.py index 2e16e90..2772ab3 100644 --- a/src/llama_api_client/_types.py +++ b/src/llama_api_client/_types.py @@ -243,6 +243,9 @@ class HttpxSendArgs(TypedDict, total=False): if TYPE_CHECKING: # This works because str.__contains__ does not accept object (either in typeshed or at runtime) # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285 + # + # Note: index() and count() methods are intentionally omitted to allow pyright to properly + # infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr. class SequenceNotStr(Protocol[_T_co]): @overload def __getitem__(self, index: SupportsIndex, /) -> _T_co: ... @@ -251,8 +254,6 @@ def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ... def __contains__(self, value: object, /) -> bool: ... def __len__(self) -> int: ... def __iter__(self) -> Iterator[_T_co]: ... - def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ... - def count(self, value: Any, /) -> int: ... def __reversed__(self) -> Iterator[_T_co]: ... else: # just point this to a normal `Sequence` at runtime to avoid having to special case diff --git a/src/llama_api_client/_utils/_sync.py b/src/llama_api_client/_utils/_sync.py index ad7ec71..f6027c1 100644 --- a/src/llama_api_client/_utils/_sync.py +++ b/src/llama_api_client/_utils/_sync.py @@ -1,10 +1,8 @@ from __future__ import annotations -import sys import asyncio import functools -import contextvars -from typing import Any, TypeVar, Callable, Awaitable +from typing import TypeVar, Callable, Awaitable from typing_extensions import ParamSpec import anyio @@ -15,34 +13,11 @@ T_ParamSpec = ParamSpec("T_ParamSpec") -if sys.version_info >= (3, 9): - _asyncio_to_thread = asyncio.to_thread -else: - # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - # for Python 3.8 support - async def _asyncio_to_thread( - func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs - ) -> Any: - """Asynchronously run function *func* in a separate thread. - - Any *args and **kwargs supplied for this function are directly passed - to *func*. Also, the current :class:`contextvars.Context` is propagated, - allowing context variables from the main thread to be accessed in the - separate thread. - - Returns a coroutine that can be awaited to get the eventual result of *func*. - """ - loop = asyncio.events.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - async def to_thread( func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> T_Retval: if sniffio.current_async_library() == "asyncio": - return await _asyncio_to_thread(func, *args, **kwargs) + return await asyncio.to_thread(func, *args, **kwargs) return await anyio.to_thread.run_sync( functools.partial(func, *args, **kwargs), @@ -53,10 +28,7 @@ async def to_thread( def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ Take a blocking function and create an async one that receives the same - positional and keyword arguments. For python version 3.9 and above, it uses - asyncio.to_thread to run the function in a separate thread. For python version - 3.8, it uses locally defined copy of the asyncio.to_thread function which was - introduced in python 3.9. + positional and keyword arguments. Usage: diff --git a/src/llama_api_client/_utils/_utils.py b/src/llama_api_client/_utils/_utils.py index 50d5926..eec7f4a 100644 --- a/src/llama_api_client/_utils/_utils.py +++ b/src/llama_api_client/_utils/_utils.py @@ -133,7 +133,7 @@ def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: # Type safe methods for narrowing types with TypeVars. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], # however this cause Pyright to rightfully report errors. As we know we don't -# care about the contained types we can safely use `object` in it's place. +# care about the contained types we can safely use `object` in its place. # # There are two separate functions defined, `is_*` and `is_*_t` for different use cases. # `is_*` is for when you're dealing with an unknown input diff --git a/src/llama_api_client/_version.py b/src/llama_api_client/_version.py index 83f206c..662b89a 100644 --- a/src/llama_api_client/_version.py +++ b/src/llama_api_client/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "llama_api_client" -__version__ = "0.5.0" # x-release-please-version +__version__ = "0.6.0" # x-release-please-version diff --git a/src/llama_api_client/types/chat/completion_create_params.py b/src/llama_api_client/types/chat/completion_create_params.py index 16a53fd..0798ea6 100644 --- a/src/llama_api_client/types/chat/completion_create_params.py +++ b/src/llama_api_client/types/chat/completion_create_params.py @@ -84,6 +84,8 @@ class CompletionCreateParamsBase(TypedDict, total=False): class ResponseFormatJsonSchemaResponseFormatJsonSchema(TypedDict, total=False): + """The JSON schema the response should conform to.""" + name: Required[str] """The name of the response format.""" @@ -95,6 +97,8 @@ class ResponseFormatJsonSchemaResponseFormatJsonSchema(TypedDict, total=False): class ResponseFormatJsonSchemaResponseFormat(TypedDict, total=False): + """Configuration for JSON schema-guided response generation.""" + json_schema: Required[ResponseFormatJsonSchemaResponseFormatJsonSchema] """The JSON schema the response should conform to.""" @@ -103,6 +107,8 @@ class ResponseFormatJsonSchemaResponseFormat(TypedDict, total=False): class ResponseFormatTextResponseFormat(TypedDict, total=False): + """Configuration for text-guided response generation.""" + type: Required[Literal["text"]] """The type of response format being defined. Always `text`.""" @@ -116,6 +122,11 @@ class ToolChoiceChatCompletionNamedToolChoiceFunction(TypedDict, total=False): class ToolChoiceChatCompletionNamedToolChoice(TypedDict, total=False): + """Specifies a tool the model should use. + + Use to force the model to call a specific function. + """ + function: Required[ToolChoiceChatCompletionNamedToolChoiceFunction] type: Required[Literal["function"]] diff --git a/src/llama_api_client/types/completion_message.py b/src/llama_api_client/types/completion_message.py index ba0bb05..c4f22f2 100644 --- a/src/llama_api_client/types/completion_message.py +++ b/src/llama_api_client/types/completion_message.py @@ -12,6 +12,8 @@ class ToolCallFunction(BaseModel): + """The function that the model called.""" + arguments: str """ The arguments to call the function with, as generated by the model in JSON @@ -33,6 +35,8 @@ class ToolCall(BaseModel): class CompletionMessage(BaseModel): + """A message containing the model's (assistant) response in a chat conversation.""" + role: Literal["assistant"] """Must be "assistant" to identify this as the model's response""" diff --git a/src/llama_api_client/types/completion_message_param.py b/src/llama_api_client/types/completion_message_param.py index ddde292..dd3f9d0 100644 --- a/src/llama_api_client/types/completion_message_param.py +++ b/src/llama_api_client/types/completion_message_param.py @@ -13,6 +13,8 @@ class ToolCallFunction(TypedDict, total=False): + """The function that the model called.""" + arguments: Required[str] """ The arguments to call the function with, as generated by the model in JSON @@ -34,6 +36,8 @@ class ToolCall(TypedDict, total=False): class CompletionMessageParam(TypedDict, total=False): + """A message containing the model's (assistant) response in a chat conversation.""" + role: Required[Literal["assistant"]] """Must be "assistant" to identify this as the model's response""" diff --git a/src/llama_api_client/types/create_chat_completion_response.py b/src/llama_api_client/types/create_chat_completion_response.py index c79bf91..17a3410 100644 --- a/src/llama_api_client/types/create_chat_completion_response.py +++ b/src/llama_api_client/types/create_chat_completion_response.py @@ -17,6 +17,8 @@ class Metric(BaseModel): class CreateChatCompletionResponse(BaseModel): + """Response from a chat completion request.""" + completion_message: CompletionMessage """The complete response message""" diff --git a/src/llama_api_client/types/create_chat_completion_response_stream_chunk.py b/src/llama_api_client/types/create_chat_completion_response_stream_chunk.py index d9a2d54..afad319 100644 --- a/src/llama_api_client/types/create_chat_completion_response_stream_chunk.py +++ b/src/llama_api_client/types/create_chat_completion_response_stream_chunk.py @@ -59,6 +59,8 @@ class EventMetric(BaseModel): class Event(BaseModel): + """The event containing the new content""" + delta: EventDelta """Content generated since last event. @@ -80,6 +82,8 @@ class Event(BaseModel): class CreateChatCompletionResponseStreamChunk(BaseModel): + """A chunk of a streamed chat completion response.""" + event: Event """The event containing the new content""" diff --git a/src/llama_api_client/types/message_image_content_item_param.py b/src/llama_api_client/types/message_image_content_item_param.py index 35a065f..90ad9a4 100644 --- a/src/llama_api_client/types/message_image_content_item_param.py +++ b/src/llama_api_client/types/message_image_content_item_param.py @@ -8,11 +8,15 @@ class ImageURL(TypedDict, total=False): + """Contains either an image URL or a data URL for a base64 encoded image.""" + url: Required[str] """Either a URL of the image or the base64 encoded image data.""" class MessageImageContentItemParam(TypedDict, total=False): + """A image content item""" + image_url: Required[ImageURL] """Contains either an image URL or a data URL for a base64 encoded image.""" diff --git a/src/llama_api_client/types/message_text_content_item.py b/src/llama_api_client/types/message_text_content_item.py index 5ea0208..bde2ecd 100644 --- a/src/llama_api_client/types/message_text_content_item.py +++ b/src/llama_api_client/types/message_text_content_item.py @@ -8,6 +8,8 @@ class MessageTextContentItem(BaseModel): + """A text content item""" + text: str """Text content""" diff --git a/src/llama_api_client/types/message_text_content_item_param.py b/src/llama_api_client/types/message_text_content_item_param.py index 933c7bb..b93becb 100644 --- a/src/llama_api_client/types/message_text_content_item_param.py +++ b/src/llama_api_client/types/message_text_content_item_param.py @@ -8,6 +8,8 @@ class MessageTextContentItemParam(TypedDict, total=False): + """A text content item""" + text: Required[str] """Text content""" diff --git a/src/llama_api_client/types/system_message_param.py b/src/llama_api_client/types/system_message_param.py index 758b6c0..cacaa24 100644 --- a/src/llama_api_client/types/system_message_param.py +++ b/src/llama_api_client/types/system_message_param.py @@ -11,6 +11,8 @@ class SystemMessageParam(TypedDict, total=False): + """A system message providing instructions or context to the model.""" + content: Required[Union[str, Iterable[MessageTextContentItemParam]]] """The content of the system message.""" diff --git a/src/llama_api_client/types/tool_response_message_param.py b/src/llama_api_client/types/tool_response_message_param.py index 22f1609..06a6840 100644 --- a/src/llama_api_client/types/tool_response_message_param.py +++ b/src/llama_api_client/types/tool_response_message_param.py @@ -11,6 +11,8 @@ class ToolResponseMessageParam(TypedDict, total=False): + """A message representing the result of a tool invocation.""" + content: Required[Union[str, Iterable[MessageTextContentItemParam]]] """The content of the user message, which can include text and other media.""" diff --git a/src/llama_api_client/types/user_message_param.py b/src/llama_api_client/types/user_message_param.py index 99ea643..3c250b5 100644 --- a/src/llama_api_client/types/user_message_param.py +++ b/src/llama_api_client/types/user_message_param.py @@ -14,6 +14,8 @@ class UserMessageParam(TypedDict, total=False): + """A message from the user in a chat conversation.""" + content: Required[Union[str, Iterable[ContentArrayOfContentItem]]] """The content of the user message, which can include text and other media.""" diff --git a/tests/test_client.py b/tests/test_client.py index 9750c2c..057c4a4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -64,51 +64,49 @@ def _get_open_connections(client: LlamaAPIClient | AsyncLlamaAPIClient) -> int: class TestLlamaAPIClient: - client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - def test_raw_response(self, respx_mock: MockRouter) -> None: + def test_raw_response(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = self.client.post("/foo", cast_to=httpx.Response) + response = client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, client: LlamaAPIClient) -> None: + copied = client.copy() + assert id(copied) != id(client) - copied = self.client.copy(api_key="another My API Key") + copied = client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, client: LlamaAPIClient) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(client.timeout, httpx.Timeout) def test_copy_default_headers(self) -> None: client = LlamaAPIClient( @@ -143,6 +141,7 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() def test_copy_default_query(self) -> None: client = LlamaAPIClient( @@ -180,13 +179,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + client.close() + + def test_copy_signature(self, client: LlamaAPIClient) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -197,12 +198,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, client: LlamaAPIClient) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -259,14 +260,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + def test_request_timeout(self, client: LlamaAPIClient) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( - FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) - ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(100.0) @@ -279,6 +278,8 @@ def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + client.close() + def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used with httpx.Client(timeout=None) as http_client: @@ -290,6 +291,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + client.close() + # no timeout given to the httpx client should not use the httpx default with httpx.Client() as http_client: client = LlamaAPIClient( @@ -300,6 +303,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + client.close() + # explicitly passing the default timeout currently results in it being ignored with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = LlamaAPIClient( @@ -310,6 +315,8 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + client.close() + async def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): async with httpx.AsyncClient() as http_client: @@ -321,14 +328,14 @@ async def test_invalid_http_client(self) -> None: ) def test_default_headers_option(self) -> None: - client = LlamaAPIClient( + test_client = LlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = LlamaAPIClient( + test_client2 = LlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -337,10 +344,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + test_client.close() + test_client2.close() + def test_validate_headers(self) -> None: client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -369,8 +379,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + client.close() + + def test_request_extra_json(self, client: LlamaAPIClient) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -381,7 +393,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -392,7 +404,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -403,8 +415,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: LlamaAPIClient) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -414,7 +426,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -425,8 +437,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: LlamaAPIClient) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -439,7 +451,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -453,7 +465,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -496,7 +508,7 @@ def test_multipart_repeating_array(self, client: LlamaAPIClient) -> None: ] @pytest.mark.respx(base_url=base_url) - def test_basic_union_response(self, respx_mock: MockRouter) -> None: + def test_basic_union_response(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: class Model1(BaseModel): name: str @@ -505,12 +517,12 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + def test_union_response_different_types(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -521,18 +533,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, client: LlamaAPIClient + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -548,7 +562,7 @@ class Model(BaseModel): ) ) - response = self.client.get("/foo", cast_to=Model) + response = client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 @@ -562,6 +576,8 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" + client.close() + def test_base_url_env(self) -> None: with update_env(LLAMA_API_CLIENT_BASE_URL="http://localhost:5000/from/env"): client = LlamaAPIClient(api_key=api_key, _strict_response_validation=True) @@ -591,6 +607,7 @@ def test_base_url_trailing_slash(self, client: LlamaAPIClient) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -616,6 +633,7 @@ def test_base_url_no_trailing_slash(self, client: LlamaAPIClient) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + client.close() @pytest.mark.parametrize( "client", @@ -641,35 +659,36 @@ def test_absolute_request_url(self, client: LlamaAPIClient) -> None: ), ) assert request.url == "https://myapi.com/foo" + client.close() def test_copied_client_does_not_close_http(self) -> None: - client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied - assert not client.is_closed() + assert not test_client.is_closed() def test_client_context_manager(self) -> None: - client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - with client as c2: - assert c2 is client + test_client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + def test_client_response_validation_error(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - self.client.get("/foo", cast_to=Model) + client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -691,11 +710,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): strict_client.get("/foo", cast_to=Model) - client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = client.get("/foo", cast_to=Model) + response = non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + strict_client.close() + non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -718,9 +740,9 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = LlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: LlamaAPIClient + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) @@ -742,7 +764,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien model="model", ).__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -759,7 +781,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client ], model="model", ).__enter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -887,83 +909,77 @@ def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - def test_follow_redirects(self, respx_mock: MockRouter) -> None: + def test_follow_redirects(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - self.client.post( - "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response - ) + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" class TestAsyncLlamaAPIClient: - client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response(self, respx_mock: MockRouter) -> None: + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient) -> None: respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient) -> None: respx_mock.post("/foo").mock( return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') ) - response = await self.client.post("/foo", cast_to=httpx.Response) + response = await async_client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) assert response.json() == {"foo": "bar"} - def test_copy(self) -> None: - copied = self.client.copy() - assert id(copied) != id(self.client) + def test_copy(self, async_client: AsyncLlamaAPIClient) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) - copied = self.client.copy(api_key="another My API Key") + copied = async_client.copy(api_key="another My API Key") assert copied.api_key == "another My API Key" - assert self.client.api_key == "My API Key" + assert async_client.api_key == "My API Key" - def test_copy_default_options(self) -> None: + def test_copy_default_options(self, async_client: AsyncLlamaAPIClient) -> None: # options that have a default are overridden correctly - copied = self.client.copy(max_retries=7) + copied = async_client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 2 + assert async_client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 assert copied.max_retries == 7 # timeout - assert isinstance(self.client.timeout, httpx.Timeout) - copied = self.client.copy(timeout=None) + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) assert copied.timeout is None - assert isinstance(self.client.timeout, httpx.Timeout) + assert isinstance(async_client.timeout, httpx.Timeout) - def test_copy_default_headers(self) -> None: + async def test_copy_default_headers(self) -> None: client = AsyncLlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) @@ -996,8 +1012,9 @@ def test_copy_default_headers(self) -> None: match="`default_headers` and `set_default_headers` arguments are mutually exclusive", ): client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() - def test_copy_default_query(self) -> None: + async def test_copy_default_query(self) -> None: client = AsyncLlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} ) @@ -1033,13 +1050,15 @@ def test_copy_default_query(self) -> None: ): client.copy(set_default_query={}, default_query={"foo": "Bar"}) - def test_copy_signature(self) -> None: + await client.close() + + def test_copy_signature(self, async_client: AsyncLlamaAPIClient) -> None: # ensure the same parameters that can be passed to the client are defined in the `.copy()` method init_signature = inspect.signature( # mypy doesn't like that we access the `__init__` property. - self.client.__init__, # type: ignore[misc] + async_client.__init__, # type: ignore[misc] ) - copy_signature = inspect.signature(self.client.copy) + copy_signature = inspect.signature(async_client.copy) exclude_params = {"transport", "proxies", "_strict_response_validation"} for name in init_signature.parameters.keys(): @@ -1050,12 +1069,12 @@ def test_copy_signature(self) -> None: assert copy_param is not None, f"copy() signature is missing the {name} param" @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") - def test_copy_build_request(self) -> None: + def test_copy_build_request(self, async_client: AsyncLlamaAPIClient) -> None: options = FinalRequestOptions(method="get", url="/foo") def build_request(options: FinalRequestOptions) -> None: - client = self.client.copy() - client._build_request(options) + client_copy = async_client.copy() + client_copy._build_request(options) # ensure that the machinery is warmed up before tracing starts. build_request(options) @@ -1112,12 +1131,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic print(frame) raise AssertionError() - async def test_request_timeout(self) -> None: - request = self.client._build_request(FinalRequestOptions(method="get", url="/foo")) + async def test_request_timeout(self, async_client: AsyncLlamaAPIClient) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT - request = self.client._build_request( + request = async_client._build_request( FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) ) timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore @@ -1132,6 +1151,8 @@ async def test_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(0) + await client.close() + async def test_http_client_timeout_option(self) -> None: # custom timeout given to the httpx client should be used async with httpx.AsyncClient(timeout=None) as http_client: @@ -1143,6 +1164,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == httpx.Timeout(None) + await client.close() + # no timeout given to the httpx client should not use the httpx default async with httpx.AsyncClient() as http_client: client = AsyncLlamaAPIClient( @@ -1153,6 +1176,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT + await client.close() + # explicitly passing the default timeout currently results in it being ignored async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: client = AsyncLlamaAPIClient( @@ -1163,6 +1188,8 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + await client.close() + def test_invalid_http_client(self) -> None: with pytest.raises(TypeError, match="Invalid `http_client` arg"): with httpx.Client() as http_client: @@ -1173,15 +1200,15 @@ def test_invalid_http_client(self) -> None: http_client=cast(Any, http_client), ) - def test_default_headers_option(self) -> None: - client = AsyncLlamaAPIClient( + async def test_default_headers_option(self) -> None: + test_client = AsyncLlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} ) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "bar" assert request.headers.get("x-stainless-lang") == "python" - client2 = AsyncLlamaAPIClient( + test_client2 = AsyncLlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, @@ -1190,10 +1217,13 @@ def test_default_headers_option(self) -> None: "X-Stainless-Lang": "my-overriding-header", }, ) - request = client2._build_request(FinalRequestOptions(method="get", url="/foo")) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" + await test_client.close() + await test_client2.close() + def test_validate_headers(self) -> None: client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) request = client._build_request(FinalRequestOptions(method="get", url="/foo")) @@ -1204,7 +1234,7 @@ def test_validate_headers(self) -> None: client2 = AsyncLlamaAPIClient(base_url=base_url, api_key=None, _strict_response_validation=True) _ = client2 - def test_default_query_option(self) -> None: + async def test_default_query_option(self) -> None: client = AsyncLlamaAPIClient( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} ) @@ -1222,8 +1252,10 @@ def test_default_query_option(self) -> None: url = httpx.URL(request.url) assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} - def test_request_extra_json(self) -> None: - request = self.client._build_request( + await client.close() + + def test_request_extra_json(self, client: LlamaAPIClient) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1234,7 +1266,7 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": False} - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1245,7 +1277,7 @@ def test_request_extra_json(self) -> None: assert data == {"baz": False} # `extra_json` takes priority over `json_data` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1256,8 +1288,8 @@ def test_request_extra_json(self) -> None: data = json.loads(request.content.decode("utf-8")) assert data == {"foo": "bar", "baz": None} - def test_request_extra_headers(self) -> None: - request = self.client._build_request( + def test_request_extra_headers(self, client: LlamaAPIClient) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1267,7 +1299,7 @@ def test_request_extra_headers(self) -> None: assert request.headers.get("X-Foo") == "Foo" # `extra_headers` takes priority over `default_headers` when keys clash - request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request( + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1278,8 +1310,8 @@ def test_request_extra_headers(self) -> None: ) assert request.headers.get("X-Bar") == "false" - def test_request_extra_query(self) -> None: - request = self.client._build_request( + def test_request_extra_query(self, client: LlamaAPIClient) -> None: + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1292,7 +1324,7 @@ def test_request_extra_query(self) -> None: assert params == {"my_query_param": "Foo"} # if both `query` and `extra_query` are given, they are merged - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1306,7 +1338,7 @@ def test_request_extra_query(self) -> None: assert params == {"bar": "1", "foo": "2"} # `extra_query` takes priority over `query` when keys clash - request = self.client._build_request( + request = client._build_request( FinalRequestOptions( method="post", url="/foo", @@ -1349,7 +1381,7 @@ def test_multipart_repeating_array(self, async_client: AsyncLlamaAPIClient) -> N ] @pytest.mark.respx(base_url=base_url) - async def test_basic_union_response(self, respx_mock: MockRouter) -> None: + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient) -> None: class Model1(BaseModel): name: str @@ -1358,12 +1390,14 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" @pytest.mark.respx(base_url=base_url) - async def test_union_response_different_types(self, respx_mock: MockRouter) -> None: + async def test_union_response_different_types( + self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient + ) -> None: """Union of objects with the same field name using a different type""" class Model1(BaseModel): @@ -1374,18 +1408,20 @@ class Model2(BaseModel): respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model2) assert response.foo == "bar" respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) - response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) assert isinstance(response, Model1) assert response.foo == 1 @pytest.mark.respx(base_url=base_url) - async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None: + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient + ) -> None: """ Response that sets Content-Type to something other than application/json but returns json data """ @@ -1401,11 +1437,11 @@ class Model(BaseModel): ) ) - response = await self.client.get("/foo", cast_to=Model) + response = await async_client.get("/foo", cast_to=Model) assert isinstance(response, Model) assert response.foo == 2 - def test_base_url_setter(self) -> None: + async def test_base_url_setter(self) -> None: client = AsyncLlamaAPIClient( base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True ) @@ -1415,7 +1451,9 @@ def test_base_url_setter(self) -> None: assert client.base_url == "https://example.com/from_setter/" - def test_base_url_env(self) -> None: + await client.close() + + async def test_base_url_env(self) -> None: with update_env(LLAMA_API_CLIENT_BASE_URL="http://localhost:5000/from/env"): client = AsyncLlamaAPIClient(api_key=api_key, _strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" @@ -1435,7 +1473,7 @@ def test_base_url_env(self) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: + async def test_base_url_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1444,6 +1482,7 @@ def test_base_url_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1460,7 +1499,7 @@ def test_base_url_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: ], ids=["standard", "custom http client"], ) - def test_base_url_no_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: + async def test_base_url_no_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1469,6 +1508,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: ), ) assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() @pytest.mark.parametrize( "client", @@ -1485,7 +1525,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncLlamaAPIClient) -> None: ], ids=["standard", "custom http client"], ) - def test_absolute_request_url(self, client: AsyncLlamaAPIClient) -> None: + async def test_absolute_request_url(self, client: AsyncLlamaAPIClient) -> None: request = client._build_request( FinalRequestOptions( method="post", @@ -1494,37 +1534,39 @@ def test_absolute_request_url(self, client: AsyncLlamaAPIClient) -> None: ), ) assert request.url == "https://myapi.com/foo" + await client.close() async def test_copied_client_does_not_close_http(self) -> None: - client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - assert not client.is_closed() + test_client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() - copied = client.copy() - assert copied is not client + copied = test_client.copy() + assert copied is not test_client del copied await asyncio.sleep(0.2) - assert not client.is_closed() + assert not test_client.is_closed() async def test_client_context_manager(self) -> None: - client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - async with client as c2: - assert c2 is client + test_client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client assert not c2.is_closed() - assert not client.is_closed() - assert client.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio - async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None: + async def test_client_response_validation_error( + self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient + ) -> None: class Model(BaseModel): foo: str respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) with pytest.raises(APIResponseValidationError) as exc: - await self.client.get("/foo", cast_to=Model) + await async_client.get("/foo", cast_to=Model) assert isinstance(exc.value.__cause__, ValidationError) @@ -1535,7 +1577,6 @@ async def test_client_max_retries_validation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): name: str @@ -1547,11 +1588,14 @@ class Model(BaseModel): with pytest.raises(APIResponseValidationError): await strict_client.get("/foo", cast_to=Model) - client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=False) + non_strict_client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=False) - response = await client.get("/foo", cast_to=Model) + response = await non_strict_client.get("/foo", cast_to=Model) assert isinstance(response, str) # type: ignore[unreachable] + await strict_client.close() + await non_strict_client.close() + @pytest.mark.parametrize( "remaining_retries,retry_after,timeout", [ @@ -1574,13 +1618,12 @@ class Model(BaseModel): ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) - @pytest.mark.asyncio - async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None: - client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) - + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncLlamaAPIClient + ) -> None: headers = httpx.Headers({"retry-after": retry_after}) options = FinalRequestOptions(method="get", url="/foo", max_retries=3) - calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @@ -1601,7 +1644,7 @@ async def test_retrying_timeout_errors_doesnt_leak( model="model", ).__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @@ -1620,12 +1663,11 @@ async def test_retrying_status_errors_doesnt_leak( ], model="model", ).__aenter__() - assert _get_open_connections(self.client) == 0 + assert _get_open_connections(async_client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( self, @@ -1665,7 +1707,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_omit_retry_count_header( self, async_client: AsyncLlamaAPIClient, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1698,7 +1739,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - @pytest.mark.asyncio async def test_overwrite_retry_count_header( self, async_client: AsyncLlamaAPIClient, failures_before_success: int, respx_mock: MockRouter ) -> None: @@ -1755,26 +1795,26 @@ async def test_default_client_creation(self) -> None: ) @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient) -> None: # Test that the default follow_redirects=True allows following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) - response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) assert response.status_code == 200 assert response.json() == {"status": "ok"} @pytest.mark.respx(base_url=base_url) - async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient) -> None: # Test that follow_redirects=False prevents following redirects respx_mock.post("/redirect").mock( return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) ) with pytest.raises(APIStatusError) as exc_info: - await self.client.post( + await async_client.post( "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response ) diff --git a/tests/test_models.py b/tests/test_models.py index e65a5dc..7a728aa 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from llama_api_client._utils import PropertyInfo from llama_api_client._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from llama_api_client._models import BaseModel, construct_type +from llama_api_client._models import DISCRIMINATOR_CACHE, BaseModel, construct_type class BasicModel(BaseModel): @@ -809,7 +809,7 @@ class B(BaseModel): UnionType = cast(Any, Union[A, B]) - assert not hasattr(UnionType, "__discriminator__") + assert not DISCRIMINATOR_CACHE.get(UnionType) m = construct_type( value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) @@ -818,7 +818,7 @@ class B(BaseModel): assert m.type == "b" assert m.data == "foo" # type: ignore[comparison-overlap] - discriminator = UnionType.__discriminator__ + discriminator = DISCRIMINATOR_CACHE.get(UnionType) assert discriminator is not None m = construct_type( @@ -830,7 +830,7 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache - assert UnionType.__discriminator__ is discriminator + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")