Add option cert param to client

This commit is contained in:
Matt Haynes 2023-07-13 14:23:28 +01:00
parent 3628559516
commit ae6256a17a

View File

@ -3,7 +3,7 @@ import requests
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation.types import ( from text_generation.types import (
StreamResponse, StreamResponse,
@ -41,6 +41,7 @@ class Client:
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None,
timeout: int = 10, timeout: int = 10,
cert: Optional[Union[str, tuple[str, str]]] = None,
): ):
""" """
Args: Args:
@ -52,11 +53,15 @@ class Client:
Cookies to include in the requests Cookies to include in the requests
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
cert (`Optional[Union[str, tuple[str, str]]]`):
If String, path to ssl client cert file (.pem).
If Tuple, ('cert', 'key') pair.
""" """
self.base_url = base_url self.base_url = base_url
self.headers = headers self.headers = headers
self.cookies = cookies self.cookies = cookies
self.timeout = timeout self.timeout = timeout
self.cert = cert
def generate( def generate(
self, self,
@ -143,6 +148,7 @@ class Client:
headers=self.headers, headers=self.headers,
cookies=self.cookies, cookies=self.cookies,
timeout=self.timeout, timeout=self.timeout,
cert=self.cert,
) )
payload = resp.json() payload = resp.json()
if resp.status_code != 200: if resp.status_code != 200:
@ -229,6 +235,7 @@ class Client:
cookies=self.cookies, cookies=self.cookies,
timeout=self.timeout, timeout=self.timeout,
stream=True, stream=True,
cert=self.cert,
) )
if resp.status_code != 200: if resp.status_code != 200:
@ -283,6 +290,7 @@ class AsyncClient:
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None, cookies: Optional[Dict[str, str]] = None,
timeout: int = 10, timeout: int = 10,
cert: Optional[Union[str, tuple[str, str]]] = None,
): ):
""" """
Args: Args:
@ -294,11 +302,15 @@ class AsyncClient:
Cookies to include in the requests Cookies to include in the requests
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
cert (`Optional[Union[str, tuple[str, str]]]`):
If String, path to ssl client cert file (.pem).
If Tuple, ('cert', 'key') pair.
""" """
self.base_url = base_url self.base_url = base_url
self.headers = headers self.headers = headers
self.cookies = cookies self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60) self.timeout = ClientTimeout(timeout * 60)
self.cert = cert
async def generate( async def generate(
self, self,