diff --git a/fishjam/_openapi_client/client.py b/fishjam/_openapi_client/client.py index eeffd00..4b2aac4 100644 --- a/fishjam/_openapi_client/client.py +++ b/fishjam/_openapi_client/client.py @@ -4,6 +4,8 @@ import httpx from attrs import define, evolve, field +from fishjam.version import get_version + @define class Client: @@ -168,6 +170,8 @@ class AuthenticatedClient: token: The token to use for authentication prefix: The prefix to use for the Authorization header auth_header_name: The name of the Authorization header + api_prefix: The prefix to use for the api version header + api_header_name: The name of the api version header """ raise_on_unexpected_status: bool = field(default=False, kw_only=True) @@ -190,6 +194,8 @@ class AuthenticatedClient: token: str prefix: str = "Bearer" auth_header_name: str = "Authorization" + api_prefix: str = "python-server" + api_header_name: str = "x-fishjam-api-client" def with_headers(self, headers: dict[str, str]) -> "AuthenticatedClient": """Get a new client matching this one with additional headers""" @@ -229,6 +235,8 @@ def get_httpx_client(self) -> httpx.Client: self._headers[self.auth_header_name] = ( f"{self.prefix} {self.token}" if self.prefix else self.token ) + self._headers[self.api_header_name] = f"{self.api_prefix}-{get_version()}" + self._client = httpx.Client( base_url=self._base_url, cookies=self._cookies, @@ -265,6 +273,8 @@ def get_async_httpx_client(self) -> httpx.AsyncClient: self._headers[self.auth_header_name] = ( f"{self.prefix} {self.token}" if self.prefix else self.token ) + self._headers[self.api_header_name] = f"{self.api_prefix}-{get_version()}" + self._async_client = httpx.AsyncClient( base_url=self._base_url, cookies=self._cookies, diff --git a/templates/openapi/client.py.jinja b/templates/openapi/client.py.jinja new file mode 100644 index 0000000..88cdd6a --- /dev/null +++ b/templates/openapi/client.py.jinja @@ -0,0 +1,203 @@ +import ssl +from typing import Any, Union, Optional + +from attrs import define, field, evolve +import httpx + +from fishjam.version import get_version + + +{% set attrs_info = { + "raise_on_unexpected_status": namespace( + type="bool", + default="field(default=False, kw_only=True)", + docstring="Whether or not to raise an errors.UnexpectedStatus if the API returns a status code" + " that was not documented in the source OpenAPI document. Can also be provided as a keyword" + " argument to the constructor." + ), + "token": namespace(type="str", default="", docstring="The token to use for authentication"), + "prefix": namespace(type="str", default='"Bearer"', docstring="The prefix to use for the Authorization header"), + "auth_header_name": namespace(type="str", default='"Authorization"', docstring="The name of the Authorization header"), + "api_prefix": namespace(type="str", default='"python-server"', docstring="The prefix to use for the api version header"), + "api_header_name": namespace(type="str", default='"x-fishjam-api-client"', docstring="The name of the api version header"), +} %} + +{% macro attr_in_class_docstring(name) %} +{{ name }}: {{ attrs_info[name].docstring }} +{%- endmacro %} + +{% macro declare_attr(name) %} +{% set attr = attrs_info[name] %} +{{ name }}: {{ attr.type }}{% if attr.default %} = {{ attr.default }}{% endif %} +{% if attr.docstring and config.docstrings_on_attributes +%} +"""{{ attr.docstring }}""" +{%- endif %} +{% endmacro %} + +@define +class Client: + """A class for keeping track of data related to the API + +{% macro httpx_args_docstring() %} + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. +{% endmacro %} +{{ httpx_args_docstring() }} +{% if not config.docstrings_on_attributes %} + + Attributes: + {{ attr_in_class_docstring("raise_on_unexpected_status") | wordwrap(101) | indent(12) }} +{% endif %} + """ +{% macro attributes() %} + {{ declare_attr("raise_on_unexpected_status") | indent(4) }} + _base_url: str = field(alias="base_url") + _cookies: dict[str, str] = field(factory=dict, kw_only=True, alias="cookies") + _headers: dict[str, str] = field(factory=dict, kw_only=True, alias="headers") + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout") + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl") + _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects") + _httpx_args: dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args") + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) +{% endmacro %}{{ attributes() }} +{% macro builders(self) %} + def with_headers(self, headers: dict[str, str]) -> "{{ self }}": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: dict[str, str]) -> "{{ self }}": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "{{ self }}": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) +{% endmacro %}{{ builders("Client") }} +{% macro httpx_stuff(name, custom_constructor=None) %} + def set_httpx_client(self, client: httpx.Client) -> "{{ name }}": + """Manually set the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + {% if custom_constructor %} + {{ custom_constructor | indent(12) }} + {% endif %} + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "{{ name }}": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "{{ name }}": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + {% if custom_constructor %} + {{ custom_constructor | indent(12) }} + {% endif %} + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "{{ name }}": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) +{% endmacro %}{{ httpx_stuff("Client") }} + +@define +class AuthenticatedClient: + """A Client which has been authenticated for use on secured endpoints + +{{ httpx_args_docstring() }} +{% if not config.docstrings_on_attributes %} + + Attributes: + {{ attr_in_class_docstring("raise_on_unexpected_status") | wordwrap(101) | indent(12) }} + {{ attr_in_class_docstring("token") | indent(8) }} + {{ attr_in_class_docstring("prefix") | indent(8) }} + {{ attr_in_class_docstring("auth_header_name") | indent(8) }} + {{ attr_in_class_docstring("api_prefix") | indent(8) }} + {{ attr_in_class_docstring("api_header_name") | indent(8) }} +{% endif %} + """ + +{{ attributes() }} + {{ declare_attr("token") | indent(4) }} + {{ declare_attr("prefix") | indent(4) }} + {{ declare_attr("auth_header_name") | indent(4) }} + {{ declare_attr("api_prefix") | indent(4) }} + {{ declare_attr("api_header_name") | indent(4) }} + +{{ builders("AuthenticatedClient") }} +{% set auth_constructor %} +self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token +self._headers[self.api_header_name] = f"{self.api_prefix}-{get_version()}" +{% endset %} +{{ httpx_stuff("AuthenticatedClient", auth_constructor) }} diff --git a/tests/test_room_api.py b/tests/test_room_api.py index 3130bbb..83ed50a 100644 --- a/tests/test_room_api.py +++ b/tests/test_room_api.py @@ -54,6 +54,36 @@ def test_valid_token(self): assert room in all_rooms +class TestApiVersionHeaders: + def test_client_sets_sdk_header_sync(self): + client = FishjamClient(FISHJAM_ID, MANAGEMENT_TOKEN) + httpx_client = client.get_httpx_client() + try: + headers = httpx_client.headers + + assert ( + headers[client.api_header_name] + == f"{client.api_prefix}-{client.get_sdk_version()}" + ) + finally: + httpx_client.close() + + def test_client_sets_sdk_header_async(self): + client = FishjamClient(FISHJAM_ID, MANAGEMENT_TOKEN) + async_client = client.get_async_httpx_client() + try: + headers = async_client.headers + + assert ( + headers[client.api_header_name] + == f"{client.api_prefix}-{client.get_sdk_version()}" + ) + finally: + import asyncio + + asyncio.run(async_client.aclose()) + + @pytest.fixture def room_api(): return FishjamClient(FISHJAM_ID, MANAGEMENT_TOKEN)