mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
poc
This commit is contained in:
parent
6d1ad06471
commit
feddbbc998
@ -0,0 +1 @@
|
|||||||
|
# Text Generation
|
@ -1,5 +1,5 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text_generation_inference"
|
name = "text-generation"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
description = "Text Generation Inference Python Client"
|
description = "Text Generation Inference Python Client"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
2
clients/python/text_generation/__init__.py
Normal file
2
clients/python/text_generation/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from text_generation.client import Client, AsyncClient
|
||||||
|
from text_generation.api_inference import APIInferenceAsyncClient
|
64
clients/python/text_generation/api_inference.py
Normal file
64
clients/python/text_generation/api_inference.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from text_generation import Client, AsyncClient
|
||||||
|
from text_generation.errors import NotSupportedError
|
||||||
|
|
||||||
|
INFERENCE_ENDPOINT = os.environ.get(
|
||||||
|
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORTED_MODELS = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_supported_models() -> Optional[List[str]]:
|
||||||
|
global SUPPORTED_MODELS
|
||||||
|
if SUPPORTED_MODELS is not None:
|
||||||
|
return SUPPORTED_MODELS
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
file_content = response.json()["content"]
|
||||||
|
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
||||||
|
return SUPPORTED_MODELS
|
||||||
|
|
||||||
|
warnings.warn("Could not retrieve list of supported models.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class APIInferenceClient(Client):
|
||||||
|
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
supported_models = get_supported_models()
|
||||||
|
if supported_models is not None and model_id not in supported_models:
|
||||||
|
raise NotSupportedError(model_id)
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if token is not None:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
||||||
|
|
||||||
|
super(APIInferenceClient, self).__init__(base_url, headers, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
class APIInferenceAsyncClient(AsyncClient):
|
||||||
|
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||||
|
# Text Generation Inference client only supports a subset of the available hub models
|
||||||
|
supported_models = get_supported_models()
|
||||||
|
if supported_models is not None and model_id not in supported_models:
|
||||||
|
raise NotSupportedError(model_id)
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if token is not None:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
||||||
|
|
||||||
|
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
|
210
clients/python/text_generation/client.py
Normal file
210
clients/python/text_generation/client.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||||
|
|
||||||
|
from text_generation.types import (
|
||||||
|
StreamResponse,
|
||||||
|
Response,
|
||||||
|
Request,
|
||||||
|
Parameters,
|
||||||
|
)
|
||||||
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
def __init__(
|
||||||
|
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
||||||
|
):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = headers
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop if stop is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
)
|
||||||
|
payload = resp.json()
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, payload)
|
||||||
|
return Response(**payload[0])
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Iterator[StreamResponse]:
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop if stop is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
json=request.dict(),
|
||||||
|
headers=self.headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise parse_error(resp.status_code, resp.json())
|
||||||
|
|
||||||
|
for byte_payload in resp.iter_lines():
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
|
try:
|
||||||
|
response = StreamResponse(**json_payload)
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status_code, json_payload)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncClient:
|
||||||
|
def __init__(
|
||||||
|
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
||||||
|
):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.headers = headers
|
||||||
|
self.timeout = ClientTimeout(timeout * 60)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> Response:
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop if stop is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
payload = await resp.json()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, payload)
|
||||||
|
return Response(**payload[0])
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
do_sample: bool = False,
|
||||||
|
max_new_tokens: int = 20,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
return_full_text: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
watermark: bool = False,
|
||||||
|
) -> AsyncIterator[StreamResponse]:
|
||||||
|
parameters = Parameters(
|
||||||
|
details=True,
|
||||||
|
do_sample=do_sample,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
return_full_text=return_full_text,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop if stop is not None else [],
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
watermark=watermark,
|
||||||
|
)
|
||||||
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||||
|
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise parse_error(resp.status, await resp.json())
|
||||||
|
|
||||||
|
async for byte_payload in resp.content:
|
||||||
|
if byte_payload == b"\n":
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = byte_payload.decode("utf-8")
|
||||||
|
|
||||||
|
if payload.startswith("data:"):
|
||||||
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||||
|
try:
|
||||||
|
response = StreamResponse(**json_payload)
|
||||||
|
except ValidationError:
|
||||||
|
raise parse_error(resp.status, json_payload)
|
||||||
|
yield response
|
@ -2,7 +2,7 @@ from enum import Enum
|
|||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from text_generation_inference.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
@ -1 +0,0 @@
|
|||||||
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient
|
|
@ -1,142 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from aiohttp import ClientSession, ClientTimeout
|
|
||||||
from pydantic import ValidationError
|
|
||||||
from typing import Dict, Optional, List, AsyncIterator
|
|
||||||
|
|
||||||
from text_generation_inference import SUPPORTED_MODELS
|
|
||||||
from text_generation_inference.types import (
|
|
||||||
StreamResponse,
|
|
||||||
Response,
|
|
||||||
Request,
|
|
||||||
Parameters,
|
|
||||||
)
|
|
||||||
from text_generation_inference.errors import parse_error, NotSupportedError
|
|
||||||
|
|
||||||
INFERENCE_ENDPOINT = os.environ.get(
|
|
||||||
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncClient:
|
|
||||||
def __init__(
|
|
||||||
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
|
||||||
):
|
|
||||||
self.base_url = base_url
|
|
||||||
self.headers = headers
|
|
||||||
self.timeout = ClientTimeout(timeout * 60)
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
do_sample: bool = False,
|
|
||||||
max_new_tokens: int = 20,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
return_full_text: bool = False,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
watermark: bool = False,
|
|
||||||
) -> Response:
|
|
||||||
parameters = Parameters(
|
|
||||||
details=True,
|
|
||||||
do_sample=do_sample,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
return_full_text=return_full_text,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop if stop is not None else [],
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
watermark=watermark,
|
|
||||||
)
|
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
|
||||||
|
|
||||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
|
||||||
payload = await resp.json()
|
|
||||||
if resp.status != 200:
|
|
||||||
raise parse_error(resp.status, payload)
|
|
||||||
return Response(**payload[0])
|
|
||||||
|
|
||||||
async def generate_stream(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
do_sample: bool = False,
|
|
||||||
max_new_tokens: int = 20,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
|
||||||
return_full_text: bool = False,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
watermark: bool = False,
|
|
||||||
) -> AsyncIterator[StreamResponse]:
|
|
||||||
parameters = Parameters(
|
|
||||||
details=True,
|
|
||||||
do_sample=do_sample,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
return_full_text=return_full_text,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop if stop is not None else [],
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
watermark=watermark,
|
|
||||||
)
|
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
|
||||||
|
|
||||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
|
||||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise parse_error(resp.status, await resp.json())
|
|
||||||
|
|
||||||
async for byte_payload in resp.content:
|
|
||||||
if byte_payload == b"\n":
|
|
||||||
continue
|
|
||||||
|
|
||||||
payload = byte_payload.decode("utf-8")
|
|
||||||
|
|
||||||
if payload.startswith("data:"):
|
|
||||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
|
||||||
try:
|
|
||||||
response = StreamResponse(**json_payload)
|
|
||||||
except ValidationError:
|
|
||||||
raise parse_error(resp.status, json_payload)
|
|
||||||
yield response
|
|
||||||
|
|
||||||
|
|
||||||
class APIInferenceAsyncClient(AsyncClient):
|
|
||||||
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
|
||||||
# Text Generation Inference client only supports a subset of the available hub models
|
|
||||||
if model_id not in SUPPORTED_MODELS:
|
|
||||||
raise NotSupportedError(model_id)
|
|
||||||
|
|
||||||
headers = {}
|
|
||||||
if token is not None:
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
|
||||||
|
|
||||||
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
async def main():
|
|
||||||
client = APIInferenceAsyncClient(
|
|
||||||
"bigscience/bloomz", token="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"
|
|
||||||
)
|
|
||||||
async for token in client.generate_stream("test"):
|
|
||||||
print(token)
|
|
||||||
|
|
||||||
print(await client.generate("test"))
|
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
4
server/.gitignore
vendored
4
server/.gitignore
vendored
@ -1,7 +1,7 @@
|
|||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
text_generation/__pycache__/
|
text_generation_server/__pycache__/
|
||||||
text_generation/pb/__pycache__/
|
text_generation_server/pb/__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
|
@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
|
|||||||
gen-server:
|
gen-server:
|
||||||
# Compile protos
|
# Compile protos
|
||||||
pip install grpcio-tools==1.51.1 --no-cache-dir
|
pip install grpcio-tools==1.51.1 --no-cache-dir
|
||||||
mkdir text_generation/pb || true
|
mkdir text_generation_server/pb || true
|
||||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
|
||||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||||
touch text_generation/pb/__init__.py
|
touch text_generation_server/pb/__init__.py
|
||||||
|
|
||||||
install-transformers:
|
install-transformers:
|
||||||
# Install specific version of transformers with custom cuda kernels
|
# Install specific version of transformers with custom cuda kernels
|
||||||
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
|
|||||||
pip install -e . --no-cache-dir
|
pip install -e . --no-cache-dir
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
@ -1,11 +1,11 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "text-generation"
|
name = "text-generation-server"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
description = "Text Generation Inference Python gRPC Server"
|
description = "Text Generation Inference Python gRPC Server"
|
||||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
text-generation-server = 'text_generation.cli:app'
|
text-generation-server = 'text_generation_server.cli:app'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -4,9 +4,9 @@ import torch
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -4,8 +4,8 @@ import torch
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -5,8 +5,8 @@ from copy import copy
|
|||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
from text_generation_server.utils.hub import (
|
||||||
|
download_weights,
|
||||||
|
weight_hub_files,
|
||||||
|
weight_files,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation.utils.convert import convert_files
|
from text_generation_server.utils.convert import convert_files
|
||||||
|
|
||||||
|
|
||||||
def test_convert_files():
|
def test_convert_files():
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from text_generation.utils.hub import (
|
from text_generation_server.utils.hub import (
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
download_weights,
|
download_weights,
|
||||||
weight_files,
|
weight_files,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from text_generation.utils.tokens import (
|
from text_generation_server.utils.tokens import (
|
||||||
StopSequenceCriteria,
|
StopSequenceCriteria,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Dict, Optional, TypeVar
|
from typing import Dict, Optional, TypeVar
|
||||||
|
|
||||||
from text_generation.models.types import Batch
|
from text_generation_server.models.types import Batch
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
@ -6,8 +6,8 @@ from pathlib import Path
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation import server, utils
|
from text_generation_server import server, utils
|
||||||
from text_generation.tracing import setup_tracing
|
from text_generation_server.tracing import setup_tracing
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ def serve(
|
|||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
filter="text_generation",
|
filter="text_generation_server",
|
||||||
level=logger_level,
|
level=logger_level,
|
||||||
serialize=json_output,
|
serialize=json_output,
|
||||||
backtrace=True,
|
backtrace=True,
|
||||||
@ -68,7 +68,7 @@ def download_weights(
|
|||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format="{message}",
|
format="{message}",
|
||||||
filter="text_generation",
|
filter="text_generation_server",
|
||||||
level=logger_level,
|
level=logger_level,
|
||||||
serialize=json_output,
|
serialize=json_output,
|
||||||
backtrace=True,
|
backtrace=True,
|
@ -3,14 +3,14 @@ import torch
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||||
from text_generation.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||||
from text_generation.models.t5 import T5Sharded
|
from text_generation_server.models.t5 import T5Sharded
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Model",
|
"Model",
|
@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
from text_generation_server.models.types import (
|
||||||
from text_generation.pb import generate_pb2
|
Batch,
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
PrefillTokens,
|
||||||
|
Generation,
|
||||||
|
GeneratedText,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.models.causal_lm import CausalLMBatch
|
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
NextTokenChooser,
|
NextTokenChooser,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
@ -4,7 +4,7 @@ import torch.distributed
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
from text_generation.models import CausalLM
|
from text_generation_server.models import CausalLM
|
||||||
|
|
||||||
FIM_PREFIX = "<fim-prefix>"
|
FIM_PREFIX = "<fim-prefix>"
|
||||||
FIM_MIDDLE = "<fim-middle>"
|
FIM_MIDDLE = "<fim-middle>"
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
|||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type
|
from typing import Optional, Tuple, List, Type
|
||||||
|
|
||||||
from text_generation.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
from text_generation_server.models.types import (
|
||||||
from text_generation.pb import generate_pb2
|
GeneratedText,
|
||||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
Batch,
|
||||||
|
Generation,
|
||||||
|
PrefillTokens,
|
||||||
|
)
|
||||||
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
padding_right_offset: int
|
padding_right_offset: int
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.Batch:
|
def to_pb(self) -> generate_pb2.Batch:
|
||||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||||
return generate_pb2.Batch(
|
return generate_pb2.Batch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
requests=self.requests,
|
requests=self.requests,
|
||||||
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "Seq2SeqLMBatch":
|
) -> "Seq2SeqLMBatch":
|
||||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation.models import Seq2SeqLM
|
from text_generation_server.models import Seq2SeqLM
|
||||||
from text_generation.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
)
|
)
|
@ -6,8 +6,8 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||||
|
|
||||||
|
|
||||||
class Batch(ABC):
|
class Batch(ABC):
|
@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from text_generation.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation.models import Model, get_model
|
from text_generation_server.models import Model, get_model
|
||||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
@ -11,7 +11,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from text_generation.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation.pb.generate_pb2 import FinishReason
|
from text_generation.pb.generate_pb2 import FinishReason
|
||||||
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
|
Loading…
Reference in New Issue
Block a user