mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
poc
This commit is contained in:
parent
6d1ad06471
commit
feddbbc998
@ -0,0 +1 @@
|
||||
# Text Generation
|
@ -1,5 +1,5 @@
|
||||
[tool.poetry]
|
||||
name = "text_generation_inference"
|
||||
name = "text-generation"
|
||||
version = "0.3.2"
|
||||
description = "Text Generation Inference Python Client"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
2
clients/python/text_generation/__init__.py
Normal file
2
clients/python/text_generation/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from text_generation.client import Client, AsyncClient
|
||||
from text_generation.api_inference import APIInferenceAsyncClient
|
64
clients/python/text_generation/api_inference.py
Normal file
64
clients/python/text_generation/api_inference.py
Normal file
@ -0,0 +1,64 @@
|
||||
import os
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from text_generation import Client, AsyncClient
|
||||
from text_generation.errors import NotSupportedError
|
||||
|
||||
INFERENCE_ENDPOINT = os.environ.get(
|
||||
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||
)
|
||||
|
||||
SUPPORTED_MODELS = None
|
||||
|
||||
|
||||
def get_supported_models() -> Optional[List[str]]:
|
||||
global SUPPORTED_MODELS
|
||||
if SUPPORTED_MODELS is not None:
|
||||
return SUPPORTED_MODELS
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
file_content = response.json()["content"]
|
||||
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
|
||||
return SUPPORTED_MODELS
|
||||
|
||||
warnings.warn("Could not retrieve list of supported models.")
|
||||
return None
|
||||
|
||||
|
||||
class APIInferenceClient(Client):
|
||||
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
supported_models = get_supported_models()
|
||||
if supported_models is not None and model_id not in supported_models:
|
||||
raise NotSupportedError(model_id)
|
||||
|
||||
headers = {}
|
||||
if token is not None:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
||||
|
||||
super(APIInferenceClient, self).__init__(base_url, headers, timeout)
|
||||
|
||||
|
||||
class APIInferenceAsyncClient(AsyncClient):
|
||||
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
supported_models = get_supported_models()
|
||||
if supported_models is not None and model_id not in supported_models:
|
||||
raise NotSupportedError(model_id)
|
||||
|
||||
headers = {}
|
||||
if token is not None:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
||||
|
||||
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
|
210
clients/python/text_generation/client.py
Normal file
210
clients/python/text_generation/client.py
Normal file
@ -0,0 +1,210 @@
|
||||
import json
|
||||
import requests
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||
|
||||
from text_generation.types import (
|
||||
StreamResponse,
|
||||
Response,
|
||||
Request,
|
||||
Parameters,
|
||||
)
|
||||
from text_generation.errors import parse_error
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(
|
||||
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.timeout = timeout
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
) -> Response:
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop if stop is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
resp = requests.post(
|
||||
self.base_url,
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return Response(**payload[0])
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
) -> Iterator[StreamResponse]:
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop if stop is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
resp = requests.post(
|
||||
self.base_url,
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, resp.json())
|
||||
|
||||
for byte_payload in resp.iter_lines():
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
|
||||
payload = byte_payload.decode("utf-8")
|
||||
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
try:
|
||||
response = StreamResponse(**json_payload)
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status_code, json_payload)
|
||||
yield response
|
||||
|
||||
|
||||
class AsyncClient:
|
||||
def __init__(
|
||||
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.timeout = ClientTimeout(timeout * 60)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
) -> Response:
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop if stop is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
payload = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return Response(**payload[0])
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop if stop is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, await resp.json())
|
||||
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
|
||||
payload = byte_payload.decode("utf-8")
|
||||
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
try:
|
||||
response = StreamResponse(**json_payload)
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
yield response
|
@ -2,7 +2,7 @@ from enum import Enum
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import Optional, List
|
||||
|
||||
from text_generation_inference.errors import ValidationError
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
@ -1 +0,0 @@
|
||||
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient
|
@ -1,142 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator
|
||||
|
||||
from text_generation_inference import SUPPORTED_MODELS
|
||||
from text_generation_inference.types import (
|
||||
StreamResponse,
|
||||
Response,
|
||||
Request,
|
||||
Parameters,
|
||||
)
|
||||
from text_generation_inference.errors import parse_error, NotSupportedError
|
||||
|
||||
INFERENCE_ENDPOINT = os.environ.get(
|
||||
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
|
||||
)
|
||||
|
||||
|
||||
class AsyncClient:
|
||||
def __init__(
|
||||
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.headers = headers
|
||||
self.timeout = ClientTimeout(timeout * 60)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
) -> Response:
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop if stop is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||
|
||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
payload = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return Response(**payload[0])
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
do_sample: bool = False,
|
||||
max_new_tokens: int = 20,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
return_full_text: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
watermark: bool = False,
|
||||
) -> AsyncIterator[StreamResponse]:
|
||||
parameters = Parameters(
|
||||
details=True,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=return_full_text,
|
||||
seed=seed,
|
||||
stop=stop if stop is not None else [],
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
watermark=watermark,
|
||||
)
|
||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||
|
||||
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
|
||||
async with session.post(self.base_url, json=request.dict()) as resp:
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, await resp.json())
|
||||
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
|
||||
payload = byte_payload.decode("utf-8")
|
||||
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
|
||||
try:
|
||||
response = StreamResponse(**json_payload)
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
yield response
|
||||
|
||||
|
||||
class APIInferenceAsyncClient(AsyncClient):
|
||||
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
|
||||
# Text Generation Inference client only supports a subset of the available hub models
|
||||
if model_id not in SUPPORTED_MODELS:
|
||||
raise NotSupportedError(model_id)
|
||||
|
||||
headers = {}
|
||||
if token is not None:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
|
||||
|
||||
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
client = APIInferenceAsyncClient(
|
||||
"bigscience/bloomz", token="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"
|
||||
)
|
||||
async for token in client.generate_stream("test"):
|
||||
print(token)
|
||||
|
||||
print(await client.generate("test"))
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
4
server/.gitignore
vendored
4
server/.gitignore
vendored
@ -1,7 +1,7 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
text_generation/__pycache__/
|
||||
text_generation/pb/__pycache__/
|
||||
text_generation_server/__pycache__/
|
||||
text_generation_server/pb/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
|
@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
|
||||
gen-server:
|
||||
# Compile protos
|
||||
pip install grpcio-tools==1.51.1 --no-cache-dir
|
||||
mkdir text_generation/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
|
||||
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation/pb/__init__.py
|
||||
mkdir text_generation_server/pb || true
|
||||
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
|
||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation_server/pb/__init__.py
|
||||
|
||||
install-transformers:
|
||||
# Install specific version of transformers with custom cuda kernels
|
||||
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
|
||||
pip install -e . --no-cache-dir
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
@ -1,11 +1,11 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation"
|
||||
name = "text-generation-server"
|
||||
version = "0.3.2"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
[tool.poetry.scripts]
|
||||
text-generation-server = 'text_generation.cli:app'
|
||||
text-generation-server = 'text_generation_server.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -4,9 +4,9 @@ import torch
|
||||
from copy import copy
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -4,8 +4,8 @@ import torch
|
||||
from copy import copy
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.models.santacoder import SantaCoder
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -5,8 +5,8 @@ from copy import copy
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -1,6 +1,10 @@
|
||||
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
|
||||
from text_generation_server.utils.hub import (
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
weight_files,
|
||||
)
|
||||
|
||||
from text_generation.utils.convert import convert_files
|
||||
from text_generation_server.utils.convert import convert_files
|
||||
|
||||
|
||||
def test_convert_files():
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from text_generation.utils.hub import (
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
weight_files,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from text_generation.utils.tokens import (
|
||||
from text_generation_server.utils.tokens import (
|
||||
StopSequenceCriteria,
|
||||
StoppingCriteria,
|
||||
FinishReason,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Dict, Optional, TypeVar
|
||||
|
||||
from text_generation.models.types import Batch
|
||||
from text_generation_server.models.types import Batch
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
@ -6,8 +6,8 @@ from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from text_generation import server, utils
|
||||
from text_generation.tracing import setup_tracing
|
||||
from text_generation_server import server, utils
|
||||
from text_generation_server.tracing import setup_tracing
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@ -42,7 +42,7 @@ def serve(
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation",
|
||||
filter="text_generation_server",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
||||
@ -68,7 +68,7 @@ def download_weights(
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format="{message}",
|
||||
filter="text_generation",
|
||||
filter="text_generation_server",
|
||||
level=logger_level,
|
||||
serialize=json_output,
|
||||
backtrace=True,
|
@ -3,14 +3,14 @@ import torch
|
||||
from transformers import AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation.models.model import Model
|
||||
from text_generation.models.causal_lm import CausalLM
|
||||
from text_generation.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation.models.santacoder import SantaCoder
|
||||
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||
from text_generation.models.t5 import T5Sharded
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
)
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
PrefillTokens,
|
||||
Generation,
|
||||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.models.causal_lm import CausalLMBatch
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
initialize_torch_distributed,
|
@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
)
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.models.types import Batch, GeneratedText
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
@ -4,7 +4,7 @@ import torch.distributed
|
||||
from typing import Optional, List
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation.models import CausalLM
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
FIM_PREFIX = "<fim-prefix>"
|
||||
FIM_MIDDLE = "<fim-middle>"
|
@ -5,10 +5,15 @@ from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type
|
||||
|
||||
from text_generation.models import Model
|
||||
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
Batch,
|
||||
Generation,
|
||||
PrefillTokens,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
padding_right_offset: int
|
||||
|
||||
def to_pb(self) -> generate_pb2.Batch:
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
|
||||
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
|
||||
return generate_pb2.Batch(
|
||||
id=self.batch_id,
|
||||
requests=self.requests,
|
||||
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
device: torch.device,
|
||||
) -> "Seq2SeqLMBatch":
|
||||
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
stopping_criterias = []
|
@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation.models import Seq2SeqLM
|
||||
from text_generation.utils import (
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
)
|
@ -6,8 +6,8 @@ from typing import List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
||||
|
||||
|
||||
class Batch(ABC):
|
@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from text_generation.cache import Cache
|
||||
from text_generation.interceptor import ExceptionInterceptor
|
||||
from text_generation.models import Model, get_model
|
||||
from text_generation.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
@ -11,7 +11,7 @@ from transformers import (
|
||||
)
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from text_generation.pb import generate_pb2
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation.pb.generate_pb2 import FinishReason
|
||||
from text_generation.utils.watermark import WatermarkLogitsProcessor
|
||||
|
Loading…
Reference in New Issue
Block a user