From ae6256a17a9e7737f7e4d66b354d1b3787abba90 Mon Sep 17 00:00:00 2001 From: Matt Haynes Date: Thu, 13 Jul 2023 14:23:28 +0100 Subject: [PATCH] Add option cert param to client --- clients/python/text_generation/client.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bf045d47..947bca37 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -3,7 +3,7 @@ import requests from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Dict, Optional, List, AsyncIterator, Iterator +from typing import Dict, Optional, List, AsyncIterator, Iterator, Union from text_generation.types import ( StreamResponse, @@ -41,6 +41,7 @@ class Client: headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, timeout: int = 10, + cert: Optional[Union[str, tuple[str, str]]] = None, ): """ Args: @@ -52,11 +53,15 @@ class Client: Cookies to include in the requests timeout (`int`): Timeout in seconds + cert (`Optional[Union[str, tuple[str, str]]]`): + If String, path to ssl client cert file (.pem). + If Tuple, ('cert', 'key') pair. """ self.base_url = base_url self.headers = headers self.cookies = cookies self.timeout = timeout + self.cert = cert def generate( self, @@ -143,6 +148,7 @@ class Client: headers=self.headers, cookies=self.cookies, timeout=self.timeout, + cert=self.cert, ) payload = resp.json() if resp.status_code != 200: @@ -229,6 +235,7 @@ class Client: cookies=self.cookies, timeout=self.timeout, stream=True, + cert=self.cert, ) if resp.status_code != 200: @@ -283,6 +290,7 @@ class AsyncClient: headers: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None, timeout: int = 10, + cert: Optional[Union[str, tuple[str, str]]] = None, ): """ Args: @@ -294,11 +302,15 @@ class AsyncClient: Cookies to include in the requests timeout (`int`): Timeout in seconds + cert (`Optional[Union[str, tuple[str, str]]]`): + If String, path to ssl client cert file (.pem). + If Tuple, ('cert', 'key') pair. """ self.base_url = base_url self.headers = headers self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) + self.cert = cert async def generate( self,