This commit is contained in:
OlivierDehaene 2023-03-07 13:33:38 +01:00
parent 6d1ad06471
commit feddbbc998
44 changed files with 362 additions and 214 deletions

View File

@ -0,0 +1 @@
# Text Generation

View File

@ -1,5 +1,5 @@
[tool.poetry] [tool.poetry]
name = "text_generation_inference" name = "text-generation"
version = "0.3.2" version = "0.3.2"
description = "Text Generation Inference Python Client" description = "Text Generation Inference Python Client"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -0,0 +1,2 @@
from text_generation.client import Client, AsyncClient
from text_generation.api_inference import APIInferenceAsyncClient

View File

@ -0,0 +1,64 @@
import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from text_generation import Client, AsyncClient
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class APIInferenceClient(Client):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and model_id not in supported_models:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceClient, self).__init__(base_url, headers, timeout)
class APIInferenceAsyncClient(AsyncClient):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and model_id not in supported_models:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)

View File

@ -0,0 +1,210 @@
import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Iterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
stream=True,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
try:
response = StreamResponse(**json_payload)
except ValidationError:
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
try:
response = StreamResponse(**json_payload)
except ValidationError:
raise parse_error(resp.status, json_payload)
yield response

View File

@ -2,7 +2,7 @@ from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List from typing import Optional, List
from text_generation_inference.errors import ValidationError from text_generation.errors import ValidationError
class Parameters(BaseModel): class Parameters(BaseModel):

View File

@ -1 +0,0 @@
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient

View File

@ -1,142 +0,0 @@
import json
import os
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator
from text_generation_inference import SUPPORTED_MODELS
from text_generation_inference.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation_inference.errors import parse_error, NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
class AsyncClient:
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
try:
response = StreamResponse(**json_payload)
except ValidationError:
raise parse_error(resp.status, json_payload)
yield response
class APIInferenceAsyncClient(AsyncClient):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
if model_id not in SUPPORTED_MODELS:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
if __name__ == "__main__":
async def main():
client = APIInferenceAsyncClient(
"bigscience/bloomz", token="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"
)
async for token in client.generate_stream("test"):
print(token)
print(await client.generate("test"))
import asyncio
asyncio.run(main())

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
text_generation/__pycache__/ text_generation_server/__pycache__/
text_generation/pb/__pycache__/ text_generation_server/pb/__pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class

View File

@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server: gen-server:
# Compile protos # Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py touch text_generation_server/pb/__init__.py
install-transformers: install-transformers:
# Install specific version of transformers with custom cuda kernels # Install specific version of transformers with custom cuda kernels
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir pip install -e . --no-cache-dir
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

View File

@ -1,11 +1,11 @@
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation-server"
version = "0.3.2" version = "0.3.2"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts] [tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app' text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"

View File

@ -1,6 +1,6 @@
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@pytest.fixture @pytest.fixture

View File

@ -4,9 +4,9 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -4,8 +4,8 @@ import torch
from copy import copy from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,8 +1,8 @@
import pytest import pytest
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer from transformers import AutoTokenizer
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,6 +1,10 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation.utils.convert import convert_files from text_generation_server.utils.convert import convert_files
def test_convert_files(): def test_convert_files():

View File

@ -1,6 +1,6 @@
import pytest import pytest
from text_generation.utils.hub import ( from text_generation_server.utils.hub import (
weight_hub_files, weight_hub_files,
download_weights, download_weights,
weight_files, weight_files,

View File

@ -1,4 +1,4 @@
from text_generation.utils.tokens import ( from text_generation_server.utils.tokens import (
StopSequenceCriteria, StopSequenceCriteria,
StoppingCriteria, StoppingCriteria,
FinishReason, FinishReason,

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, TypeVar from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)

View File

@ -6,8 +6,8 @@ from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation import server, utils from text_generation_server import server, utils
from text_generation.tracing import setup_tracing from text_generation_server.tracing import setup_tracing
app = typer.Typer() app = typer.Typer()
@ -42,7 +42,7 @@ def serve(
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",
filter="text_generation", filter="text_generation_server",
level=logger_level, level=logger_level,
serialize=json_output, serialize=json_output,
backtrace=True, backtrace=True,
@ -68,7 +68,7 @@ def download_weights(
logger.add( logger.add(
sys.stdout, sys.stdout,
format="{message}", format="{message}",
filter="text_generation", filter="text_generation_server",
level=logger_level, level=logger_level,
serialize=json_output, serialize=json_output,
backtrace=True, backtrace=True,

View File

@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig from transformers import AutoConfig
from typing import Optional from typing import Optional
from text_generation.models.model import Model from text_generation_server.models.model import Model
from text_generation.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
__all__ = [ __all__ = [
"Model", "Model",

View File

@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 Batch,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)

View File

@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed, initialize_torch_distributed,

View File

@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)

View File

@ -4,7 +4,7 @@ import torch.distributed
from typing import Optional, List from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>" FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>" FIM_MIDDLE = "<fim-middle>"

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type from typing import Optional, Tuple, List, Type
from text_generation.models import Model from text_generation_server.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens from text_generation_server.models.types import (
from text_generation.pb import generate_pb2 GeneratedText,
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset: int padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch: def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch( return generate_pb2.Batch(
id=self.batch_id, id=self.batch_id,
requests=self.requests, requests=self.requests,
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
device: torch.device, device: torch.device,
) -> "Seq2SeqLMBatch": ) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []

View File

@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation.models import Seq2SeqLM from text_generation_server.models import Seq2SeqLM
from text_generation.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
) )

View File

@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC): class Batch(ABC):

View File

@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from text_generation.cache import Cache from text_generation_server.cache import Cache
from text_generation.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model from text_generation_server.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):

View File

@ -11,7 +11,7 @@ from transformers import (
) )
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason from text_generation.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor from text_generation.utils.watermark import WatermarkLogitsProcessor