This commit is contained in:
OlivierDehaene 2023-03-07 13:33:38 +01:00
parent 6d1ad06471
commit feddbbc998
44 changed files with 362 additions and 214 deletions

View File

@ -0,0 +1 @@
# Text Generation

View File

@ -1,5 +1,5 @@
[tool.poetry]
name = "text_generation_inference"
name = "text-generation"
version = "0.3.2"
description = "Text Generation Inference Python Client"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

View File

@ -0,0 +1,2 @@
from text_generation.client import Client, AsyncClient
from text_generation.api_inference import APIInferenceAsyncClient

View File

@ -0,0 +1,64 @@
import os
import requests
import base64
import json
import warnings
from typing import List, Optional
from text_generation import Client, AsyncClient
from text_generation.errors import NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
SUPPORTED_MODELS = None
def get_supported_models() -> Optional[List[str]]:
global SUPPORTED_MODELS
if SUPPORTED_MODELS is not None:
return SUPPORTED_MODELS
response = requests.get(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json",
timeout=5,
)
if response.status_code == 200:
file_content = response.json()["content"]
SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8"))
return SUPPORTED_MODELS
warnings.warn("Could not retrieve list of supported models.")
return None
class APIInferenceClient(Client):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and model_id not in supported_models:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceClient, self).__init__(base_url, headers, timeout)
class APIInferenceAsyncClient(AsyncClient):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
supported_models = get_supported_models()
if supported_models is not None and model_id not in supported_models:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)

View File

@ -0,0 +1,210 @@
import json
import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator
from text_generation.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation.errors import parse_error
class Client:
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = timeout
def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return Response(**payload[0])
def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Iterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
resp = requests.post(
self.base_url,
json=request.dict(),
headers=self.headers,
timeout=self.timeout,
stream=True,
)
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
try:
response = StreamResponse(**json_payload)
except ValidationError:
raise parse_error(resp.status_code, json_payload)
yield response
class AsyncClient:
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
try:
response = StreamResponse(**json_payload)
except ValidationError:
raise parse_error(resp.status, json_payload)
yield response

View File

@ -2,7 +2,7 @@ from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from text_generation_inference.errors import ValidationError
from text_generation.errors import ValidationError
class Parameters(BaseModel):

View File

@ -1 +0,0 @@
from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient

View File

@ -1,142 +0,0 @@
import json
import os
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator
from text_generation_inference import SUPPORTED_MODELS
from text_generation_inference.types import (
StreamResponse,
Response,
Request,
Parameters,
)
from text_generation_inference.errors import parse_error, NotSupportedError
INFERENCE_ENDPOINT = os.environ.get(
"HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co"
)
class AsyncClient:
def __init__(
self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10
):
self.base_url = base_url
self.headers = headers
self.timeout = ClientTimeout(timeout * 60)
async def generate(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> Response:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return Response(**payload[0])
async def generate_stream(
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
watermark: bool = False,
) -> AsyncIterator[StreamResponse]:
parameters = Parameters(
details=True,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop=stop if stop is not None else [],
temperature=temperature,
top_k=top_k,
top_p=top_p,
watermark=watermark,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200:
raise parse_error(resp.status, await resp.json())
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
try:
response = StreamResponse(**json_payload)
except ValidationError:
raise parse_error(resp.status, json_payload)
yield response
class APIInferenceAsyncClient(AsyncClient):
def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10):
# Text Generation Inference client only supports a subset of the available hub models
if model_id not in SUPPORTED_MODELS:
raise NotSupportedError(model_id)
headers = {}
if token is not None:
headers = {"Authorization": f"Bearer {token}"}
base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}"
super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout)
if __name__ == "__main__":
async def main():
client = APIInferenceAsyncClient(
"bigscience/bloomz", token="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"
)
async for token in client.generate_stream("test"):
print(token)
print(await client.generate("test"))
import asyncio
asyncio.run(main())

4
server/.gitignore vendored
View File

@ -1,7 +1,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
text_generation/__pycache__/
text_generation/pb/__pycache__/
text_generation_server/__pycache__/
text_generation_server/pb/__pycache__/
*.py[cod]
*$py.class

View File

@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e
gen-server:
# Compile protos
pip install grpcio-tools==1.51.1 --no-cache-dir
mkdir text_generation/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto
find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation/pb/__init__.py
mkdir text_generation_server/pb || true
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py
install-transformers:
# Install specific version of transformers with custom cuda kernels
@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers
pip install -e . --no-cache-dir
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

View File

@ -1,11 +1,11 @@
[tool.poetry]
name = "text-generation"
name = "text-generation-server"
version = "0.3.2"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
[tool.poetry.scripts]
text-generation-server = 'text_generation.cli:app'
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"

View File

@ -1,6 +1,6 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation_server.pb import generate_pb2
@pytest.fixture

View File

@ -4,9 +4,9 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
@pytest.fixture(scope="session")

View File

@ -4,8 +4,8 @@ import torch
from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")

View File

@ -1,8 +1,8 @@
import pytest
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.models.santacoder import SantaCoder
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder
@pytest.fixture(scope="session")

View File

@ -5,8 +5,8 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation.pb import generate_pb2
from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@pytest.fixture(scope="session")

View File

@ -1,6 +1,10 @@
from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation_server.utils.hub import (
download_weights,
weight_hub_files,
weight_files,
)
from text_generation.utils.convert import convert_files
from text_generation_server.utils.convert import convert_files
def test_convert_files():

View File

@ -1,6 +1,6 @@
import pytest
from text_generation.utils.hub import (
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
weight_files,

View File

@ -1,4 +1,4 @@
from text_generation.utils.tokens import (
from text_generation_server.utils.tokens import (
StopSequenceCriteria,
StoppingCriteria,
FinishReason,

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, TypeVar
from text_generation.models.types import Batch
from text_generation_server.models.types import Batch
B = TypeVar("B", bound=Batch)

View File

@ -6,8 +6,8 @@ from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation import server, utils
from text_generation.tracing import setup_tracing
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer()
@ -42,7 +42,7 @@ def serve(
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
@ -68,7 +68,7 @@ def download_weights(
logger.add(
sys.stdout,
format="{message}",
filter="text_generation",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,

View File

@ -3,14 +3,14 @@ import torch
from transformers import AutoConfig
from typing import Optional
from text_generation.models.model import Model
from text_generation.models.causal_lm import CausalLM
from text_generation.models.bloom import BLOOM, BLOOMSharded
from text_generation.models.seq2seq_lm import Seq2SeqLM
from text_generation.models.galactica import Galactica, GalacticaSharded
from text_generation.models.santacoder import SantaCoder
from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation.models.t5 import T5Sharded
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
__all__ = [
"Model",

View File

@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.pb import generate_pb2
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
Batch,
PrefillTokens,
Generation,
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)

View File

@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.pb import generate_pb2
from text_generation.models.causal_lm import CausalLMBatch
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,

View File

@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import CausalLM
from text_generation.utils import (
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from text_generation.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, GeneratedText
B = TypeVar("B", bound=Batch)

View File

@ -4,7 +4,7 @@ import torch.distributed
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation.models import CausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"

View File

@ -5,10 +5,15 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type
from text_generation.models import Model
from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens
from text_generation.pb import generate_pb2
from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
Batch,
Generation,
PrefillTokens,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
tracer = trace.get_tracer(__name__)
@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch):
padding_right_offset: int
def to_pb(self) -> generate_pb2.Batch:
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
"""Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf"""
return generate_pb2.Batch(
id=self.batch_id,
requests=self.requests,
@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch):
tokenizer: PreTrainedTokenizerBase,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs = []
next_token_choosers = []
stopping_criterias = []

View File

@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
)
from text_generation.models import Seq2SeqLM
from text_generation.utils import (
from text_generation_server.models import Seq2SeqLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
)

View File

@ -6,8 +6,8 @@ from typing import List, Optional
from transformers import PreTrainedTokenizerBase
from text_generation.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
class Batch(ABC):

View File

@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation.cache import Cache
from text_generation.interceptor import ExceptionInterceptor
from text_generation.models import Model, get_model
from text_generation.pb import generate_pb2_grpc, generate_pb2
from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):

View File

@ -11,7 +11,7 @@ from transformers import (
)
from typing import List, Tuple, Optional
from text_generation.pb import generate_pb2
from text_generation_server.pb import generate_pb2
from text_generation.pb.generate_pb2 import FinishReason
from text_generation.utils.watermark import WatermarkLogitsProcessor