mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Add option cert param to client
This commit is contained in:
parent
3628559516
commit
ae6256a17a
@ -3,7 +3,7 @@ import requests
|
|||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||||
|
|
||||||
from text_generation.types import (
|
from text_generation.types import (
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
@ -41,6 +41,7 @@ class Client:
|
|||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
cookies: Optional[Dict[str, str]] = None,
|
cookies: Optional[Dict[str, str]] = None,
|
||||||
timeout: int = 10,
|
timeout: int = 10,
|
||||||
|
cert: Optional[Union[str, tuple[str, str]]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -52,11 +53,15 @@ class Client:
|
|||||||
Cookies to include in the requests
|
Cookies to include in the requests
|
||||||
timeout (`int`):
|
timeout (`int`):
|
||||||
Timeout in seconds
|
Timeout in seconds
|
||||||
|
cert (`Optional[Union[str, tuple[str, str]]]`):
|
||||||
|
If String, path to ssl client cert file (.pem).
|
||||||
|
If Tuple, ('cert', 'key') pair.
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.cert = cert
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -143,6 +148,7 @@ class Client:
|
|||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
cookies=self.cookies,
|
cookies=self.cookies,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
cert=self.cert,
|
||||||
)
|
)
|
||||||
payload = resp.json()
|
payload = resp.json()
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
@ -229,6 +235,7 @@ class Client:
|
|||||||
cookies=self.cookies,
|
cookies=self.cookies,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
cert=self.cert,
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
@ -283,6 +290,7 @@ class AsyncClient:
|
|||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
cookies: Optional[Dict[str, str]] = None,
|
cookies: Optional[Dict[str, str]] = None,
|
||||||
timeout: int = 10,
|
timeout: int = 10,
|
||||||
|
cert: Optional[Union[str, tuple[str, str]]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -294,11 +302,15 @@ class AsyncClient:
|
|||||||
Cookies to include in the requests
|
Cookies to include in the requests
|
||||||
timeout (`int`):
|
timeout (`int`):
|
||||||
Timeout in seconds
|
Timeout in seconds
|
||||||
|
cert (`Optional[Union[str, tuple[str, str]]]`):
|
||||||
|
If String, path to ssl client cert file (.pem).
|
||||||
|
If Tuple, ('cert', 'key') pair.
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.cookies = cookies
|
self.cookies = cookies
|
||||||
self.timeout = ClientTimeout(timeout * 60)
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
self.cert = cert
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user