You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

269 lines
12 KiB

  1. import ssl
  2. from typing import Any, Optional, Union
  3. import httpx
  4. from attrs import define, evolve, field
  5. @define
  6. class Client:
  7. """A class for keeping track of data related to the API
  8. The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
  9. ``base_url``: The base URL for the API, all requests are made to a relative path to this URL
  10. ``cookies``: A dictionary of cookies to be sent with every request
  11. ``headers``: A dictionary of headers to be sent with every request
  12. ``timeout``: The maximum amount of a time a request can take. API functions will raise
  13. httpx.TimeoutException if this is exceeded.
  14. ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
  15. but can be set to False for testing purposes.
  16. ``follow_redirects``: Whether or not to follow redirects. Default value is False.
  17. ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
  18. Attributes:
  19. raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
  20. status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
  21. argument to the constructor.
  22. """
  23. raise_on_unexpected_status: bool = field(default=False, kw_only=True)
  24. _base_url: str = field(alias="base_url")
  25. _cookies: dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
  26. _headers: dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
  27. _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout")
  28. _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl")
  29. _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects")
  30. _httpx_args: dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
  31. _client: Optional[httpx.Client] = field(default=None, init=False)
  32. _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
  33. def with_headers(self, headers: dict[str, str]) -> "Client":
  34. """Get a new client matching this one with additional headers"""
  35. if self._client is not None:
  36. self._client.headers.update(headers)
  37. if self._async_client is not None:
  38. self._async_client.headers.update(headers)
  39. return evolve(self, headers={**self._headers, **headers})
  40. def with_cookies(self, cookies: dict[str, str]) -> "Client":
  41. """Get a new client matching this one with additional cookies"""
  42. if self._client is not None:
  43. self._client.cookies.update(cookies)
  44. if self._async_client is not None:
  45. self._async_client.cookies.update(cookies)
  46. return evolve(self, cookies={**self._cookies, **cookies})
  47. def with_timeout(self, timeout: httpx.Timeout) -> "Client":
  48. """Get a new client matching this one with a new timeout (in seconds)"""
  49. if self._client is not None:
  50. self._client.timeout = timeout
  51. if self._async_client is not None:
  52. self._async_client.timeout = timeout
  53. return evolve(self, timeout=timeout)
  54. def set_httpx_client(self, client: httpx.Client) -> "Client":
  55. """Manually set the underlying httpx.Client
  56. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  57. """
  58. self._client = client
  59. return self
  60. def get_httpx_client(self) -> httpx.Client:
  61. """Get the underlying httpx.Client, constructing a new one if not previously set"""
  62. if self._client is None:
  63. self._client = httpx.Client(
  64. base_url=self._base_url,
  65. cookies=self._cookies,
  66. headers=self._headers,
  67. timeout=self._timeout,
  68. verify=self._verify_ssl,
  69. follow_redirects=self._follow_redirects,
  70. **self._httpx_args,
  71. )
  72. return self._client
  73. def __enter__(self) -> "Client":
  74. """Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
  75. self.get_httpx_client().__enter__()
  76. return self
  77. def __exit__(self, *args: Any, **kwargs: Any) -> None:
  78. """Exit a context manager for internal httpx.Client (see httpx docs)"""
  79. self.get_httpx_client().__exit__(*args, **kwargs)
  80. def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "Client":
  81. """Manually the underlying httpx.AsyncClient
  82. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  83. """
  84. self._async_client = async_client
  85. return self
  86. def get_async_httpx_client(self) -> httpx.AsyncClient:
  87. """Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
  88. if self._async_client is None:
  89. self._async_client = httpx.AsyncClient(
  90. base_url=self._base_url,
  91. cookies=self._cookies,
  92. headers=self._headers,
  93. timeout=self._timeout,
  94. verify=self._verify_ssl,
  95. follow_redirects=self._follow_redirects,
  96. **self._httpx_args,
  97. )
  98. return self._async_client
  99. async def __aenter__(self) -> "Client":
  100. """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
  101. await self.get_async_httpx_client().__aenter__()
  102. return self
  103. async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
  104. """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
  105. await self.get_async_httpx_client().__aexit__(*args, **kwargs)
  106. @define
  107. class AuthenticatedClient:
  108. """A Client which has been authenticated for use on secured endpoints
  109. The following are accepted as keyword arguments and will be used to construct httpx Clients internally:
  110. ``base_url``: The base URL for the API, all requests are made to a relative path to this URL
  111. ``cookies``: A dictionary of cookies to be sent with every request
  112. ``headers``: A dictionary of headers to be sent with every request
  113. ``timeout``: The maximum amount of a time a request can take. API functions will raise
  114. httpx.TimeoutException if this is exceeded.
  115. ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production,
  116. but can be set to False for testing purposes.
  117. ``follow_redirects``: Whether or not to follow redirects. Default value is False.
  118. ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor.
  119. Attributes:
  120. raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a
  121. status code that was not documented in the source OpenAPI document. Can also be provided as a keyword
  122. argument to the constructor.
  123. token: The token to use for authentication
  124. prefix: The prefix to use for the Authorization header
  125. auth_header_name: The name of the Authorization header
  126. """
  127. raise_on_unexpected_status: bool = field(default=False, kw_only=True)
  128. _base_url: str = field(alias="base_url")
  129. _cookies: dict[str, str] = field(factory=dict, kw_only=True, alias="cookies")
  130. _headers: dict[str, str] = field(factory=dict, kw_only=True, alias="headers")
  131. _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout")
  132. _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl")
  133. _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects")
  134. _httpx_args: dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args")
  135. _client: Optional[httpx.Client] = field(default=None, init=False)
  136. _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False)
  137. token: str
  138. prefix: str = "Bearer"
  139. auth_header_name: str = "Authorization"
  140. def with_headers(self, headers: dict[str, str]) -> "AuthenticatedClient":
  141. """Get a new client matching this one with additional headers"""
  142. if self._client is not None:
  143. self._client.headers.update(headers)
  144. if self._async_client is not None:
  145. self._async_client.headers.update(headers)
  146. return evolve(self, headers={**self._headers, **headers})
  147. def with_cookies(self, cookies: dict[str, str]) -> "AuthenticatedClient":
  148. """Get a new client matching this one with additional cookies"""
  149. if self._client is not None:
  150. self._client.cookies.update(cookies)
  151. if self._async_client is not None:
  152. self._async_client.cookies.update(cookies)
  153. return evolve(self, cookies={**self._cookies, **cookies})
  154. def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient":
  155. """Get a new client matching this one with a new timeout (in seconds)"""
  156. if self._client is not None:
  157. self._client.timeout = timeout
  158. if self._async_client is not None:
  159. self._async_client.timeout = timeout
  160. return evolve(self, timeout=timeout)
  161. def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient":
  162. """Manually set the underlying httpx.Client
  163. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  164. """
  165. self._client = client
  166. return self
  167. def get_httpx_client(self) -> httpx.Client:
  168. """Get the underlying httpx.Client, constructing a new one if not previously set"""
  169. if self._client is None:
  170. self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token
  171. self._client = httpx.Client(
  172. base_url=self._base_url,
  173. cookies=self._cookies,
  174. headers=self._headers,
  175. timeout=self._timeout,
  176. verify=self._verify_ssl,
  177. follow_redirects=self._follow_redirects,
  178. **self._httpx_args,
  179. )
  180. return self._client
  181. def __enter__(self) -> "AuthenticatedClient":
  182. """Enter a context manager for self.client—you cannot enter twice (see httpx docs)"""
  183. self.get_httpx_client().__enter__()
  184. return self
  185. def __exit__(self, *args: Any, **kwargs: Any) -> None:
  186. """Exit a context manager for internal httpx.Client (see httpx docs)"""
  187. self.get_httpx_client().__exit__(*args, **kwargs)
  188. def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "AuthenticatedClient":
  189. """Manually the underlying httpx.AsyncClient
  190. **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout.
  191. """
  192. self._async_client = async_client
  193. return self
  194. def get_async_httpx_client(self) -> httpx.AsyncClient:
  195. """Get the underlying httpx.AsyncClient, constructing a new one if not previously set"""
  196. if self._async_client is None:
  197. self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token
  198. self._async_client = httpx.AsyncClient(
  199. base_url=self._base_url,
  200. cookies=self._cookies,
  201. headers=self._headers,
  202. timeout=self._timeout,
  203. verify=self._verify_ssl,
  204. follow_redirects=self._follow_redirects,
  205. **self._httpx_args,
  206. )
  207. return self._async_client
  208. async def __aenter__(self) -> "AuthenticatedClient":
  209. """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)"""
  210. await self.get_async_httpx_client().__aenter__()
  211. return self
  212. async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
  213. """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)"""
  214. await self.get_async_httpx_client().__aexit__(*args, **kwargs)