From feddbbc998ba39906b10e801e06d138faab23148 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 7 Mar 2023 13:33:38 +0100 Subject: [PATCH] poc --- clients/python/README.md | 1 + clients/python/pyproject.toml | 2 +- clients/python/text_generation/__init__.py | 2 + .../python/text_generation/api_inference.py | 64 ++++++ clients/python/text_generation/client.py | 210 ++++++++++++++++++ .../errors.py | 0 .../types.py | 2 +- .../text_generation_inference/__init__.py | 1 - .../text_generation_inference/async_client.py | 142 ------------ .../text_generation_inference/client.py | 0 server/.gitignore | 4 +- server/Makefile | 10 +- server/pyproject.toml | 4 +- server/tests/conftest.py | 2 +- server/tests/models/test_bloom.py | 6 +- server/tests/models/test_causal_lm.py | 4 +- server/tests/models/test_santacoder.py | 6 +- server/tests/models/test_seq2seq_lm.py | 4 +- server/tests/utils/test_convert.py | 8 +- server/tests/utils/test_hub.py | 2 +- server/tests/utils/test_tokens.py | 2 +- .../__init__.py | 0 .../cache.py | 2 +- .../cli.py | 8 +- .../interceptor.py | 0 .../models/__init__.py | 16 +- .../models/bloom.py | 8 +- .../models/causal_lm.py | 13 +- .../models/galactica.py | 8 +- .../models/gpt_neox.py | 4 +- .../models/model.py | 2 +- .../models/santacoder.py | 2 +- .../models/seq2seq_lm.py | 17 +- .../models/t5.py | 4 +- .../models/types.py | 4 +- .../pb/.gitignore | 0 .../server.py | 10 +- .../tracing.py | 0 .../utils/__init__.py | 0 .../utils/convert.py | 0 .../utils/dist.py | 0 .../utils/hub.py | 0 .../utils/tokens.py | 2 +- .../utils/watermark.py | 0 44 files changed, 362 insertions(+), 214 deletions(-) create mode 100644 clients/python/text_generation/__init__.py create mode 100644 clients/python/text_generation/api_inference.py create mode 100644 clients/python/text_generation/client.py rename clients/python/{text_generation_inference => text_generation}/errors.py (100%) rename clients/python/{text_generation_inference => text_generation}/types.py (97%) delete mode 100644 clients/python/text_generation_inference/__init__.py delete mode 100644 clients/python/text_generation_inference/async_client.py delete mode 100644 clients/python/text_generation_inference/client.py rename server/{text_generation => text_generation_server}/__init__.py (100%) rename server/{text_generation => text_generation_server}/cache.py (90%) rename server/{text_generation => text_generation_server}/cli.py (94%) rename server/{text_generation => text_generation_server}/interceptor.py (100%) rename server/{text_generation => text_generation_server}/models/__init__.py (79%) rename server/{text_generation => text_generation_server}/models/bloom.py (97%) rename server/{text_generation => text_generation_server}/models/causal_lm.py (98%) rename server/{text_generation => text_generation_server}/models/galactica.py (98%) rename server/{text_generation => text_generation_server}/models/gpt_neox.py (99%) rename server/{text_generation => text_generation_server}/models/model.py (95%) rename server/{text_generation => text_generation_server}/models/santacoder.py (97%) rename server/{text_generation => text_generation_server}/models/seq2seq_lm.py (97%) rename server/{text_generation => text_generation_server}/models/t5.py (98%) rename server/{text_generation => text_generation_server}/models/types.py (95%) rename server/{text_generation => text_generation_server}/pb/.gitignore (100%) rename server/{text_generation => text_generation_server}/server.py (92%) rename server/{text_generation => text_generation_server}/tracing.py (100%) rename server/{text_generation => text_generation_server}/utils/__init__.py (100%) rename server/{text_generation => text_generation_server}/utils/convert.py (100%) rename server/{text_generation => text_generation_server}/utils/dist.py (100%) rename server/{text_generation => text_generation_server}/utils/hub.py (100%) rename server/{text_generation => text_generation_server}/utils/tokens.py (98%) rename server/{text_generation => text_generation_server}/utils/watermark.py (100%) diff --git a/clients/python/README.md b/clients/python/README.md index e69de29b..1ae11928 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -0,0 +1 @@ +# Text Generation \ No newline at end of file diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index db124125..226ab0b7 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "text_generation_inference" +name = "text-generation" version = "0.3.2" description = "Text Generation Inference Python Client" authors = ["Olivier Dehaene "] diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py new file mode 100644 index 00000000..493547a2 --- /dev/null +++ b/clients/python/text_generation/__init__.py @@ -0,0 +1,2 @@ +from text_generation.client import Client, AsyncClient +from text_generation.api_inference import APIInferenceAsyncClient diff --git a/clients/python/text_generation/api_inference.py b/clients/python/text_generation/api_inference.py new file mode 100644 index 00000000..ce7937d5 --- /dev/null +++ b/clients/python/text_generation/api_inference.py @@ -0,0 +1,64 @@ +import os +import requests +import base64 +import json +import warnings + +from typing import List, Optional + +from text_generation import Client, AsyncClient +from text_generation.errors import NotSupportedError + +INFERENCE_ENDPOINT = os.environ.get( + "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" +) + +SUPPORTED_MODELS = None + + +def get_supported_models() -> Optional[List[str]]: + global SUPPORTED_MODELS + if SUPPORTED_MODELS is not None: + return SUPPORTED_MODELS + + response = requests.get( + "https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json", + timeout=5, + ) + if response.status_code == 200: + file_content = response.json()["content"] + SUPPORTED_MODELS = json.loads(base64.b64decode(file_content).decode("utf-8")) + return SUPPORTED_MODELS + + warnings.warn("Could not retrieve list of supported models.") + return None + + +class APIInferenceClient(Client): + def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10): + # Text Generation Inference client only supports a subset of the available hub models + supported_models = get_supported_models() + if supported_models is not None and model_id not in supported_models: + raise NotSupportedError(model_id) + + headers = {} + if token is not None: + headers = {"Authorization": f"Bearer {token}"} + base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}" + + super(APIInferenceClient, self).__init__(base_url, headers, timeout) + + +class APIInferenceAsyncClient(AsyncClient): + def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10): + # Text Generation Inference client only supports a subset of the available hub models + supported_models = get_supported_models() + if supported_models is not None and model_id not in supported_models: + raise NotSupportedError(model_id) + + headers = {} + if token is not None: + headers = {"Authorization": f"Bearer {token}"} + base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}" + + super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout) diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py new file mode 100644 index 00000000..dbfd5dce --- /dev/null +++ b/clients/python/text_generation/client.py @@ -0,0 +1,210 @@ +import json +import requests + +from aiohttp import ClientSession, ClientTimeout +from pydantic import ValidationError +from typing import Dict, Optional, List, AsyncIterator, Iterator + +from text_generation.types import ( + StreamResponse, + Response, + Request, + Parameters, +) +from text_generation.errors import parse_error + + +class Client: + def __init__( + self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10 + ): + self.base_url = base_url + self.headers = headers + self.timeout = timeout + + def generate( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + watermark: bool = False, + ) -> Response: + parameters = Parameters( + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop if stop is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + watermark=watermark, + ) + request = Request(inputs=prompt, stream=False, parameters=parameters) + + resp = requests.post( + self.base_url, + json=request.dict(), + headers=self.headers, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return Response(**payload[0]) + + def generate_stream( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + watermark: bool = False, + ) -> Iterator[StreamResponse]: + parameters = Parameters( + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop if stop is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + watermark=watermark, + ) + request = Request(inputs=prompt, stream=True, parameters=parameters) + + resp = requests.post( + self.base_url, + json=request.dict(), + headers=self.headers, + timeout=self.timeout, + stream=True, + ) + + if resp.status_code != 200: + raise parse_error(resp.status_code, resp.json()) + + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + try: + response = StreamResponse(**json_payload) + except ValidationError: + raise parse_error(resp.status_code, json_payload) + yield response + + +class AsyncClient: + def __init__( + self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10 + ): + self.base_url = base_url + self.headers = headers + self.timeout = ClientTimeout(timeout * 60) + + async def generate( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + watermark: bool = False, + ) -> Response: + parameters = Parameters( + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop if stop is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + watermark=watermark, + ) + request = Request(inputs=prompt, stream=False, parameters=parameters) + + async with ClientSession(headers=self.headers, timeout=self.timeout) as session: + async with session.post(self.base_url, json=request.dict()) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return Response(**payload[0]) + + async def generate_stream( + self, + prompt: str, + do_sample: bool = False, + max_new_tokens: int = 20, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + watermark: bool = False, + ) -> AsyncIterator[StreamResponse]: + parameters = Parameters( + details=True, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop if stop is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + watermark=watermark, + ) + request = Request(inputs=prompt, stream=True, parameters=parameters) + + async with ClientSession(headers=self.headers, timeout=self.timeout) as session: + async with session.post(self.base_url, json=request.dict()) as resp: + if resp.status != 200: + raise parse_error(resp.status, await resp.json()) + + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + try: + response = StreamResponse(**json_payload) + except ValidationError: + raise parse_error(resp.status, json_payload) + yield response diff --git a/clients/python/text_generation_inference/errors.py b/clients/python/text_generation/errors.py similarity index 100% rename from clients/python/text_generation_inference/errors.py rename to clients/python/text_generation/errors.py diff --git a/clients/python/text_generation_inference/types.py b/clients/python/text_generation/types.py similarity index 97% rename from clients/python/text_generation_inference/types.py rename to clients/python/text_generation/types.py index 4ce65f29..ef5b379a 100644 --- a/clients/python/text_generation_inference/types.py +++ b/clients/python/text_generation/types.py @@ -2,7 +2,7 @@ from enum import Enum from pydantic import BaseModel, validator from typing import Optional, List -from text_generation_inference.errors import ValidationError +from text_generation.errors import ValidationError class Parameters(BaseModel): diff --git a/clients/python/text_generation_inference/__init__.py b/clients/python/text_generation_inference/__init__.py deleted file mode 100644 index 874227e2..00000000 --- a/clients/python/text_generation_inference/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from text_generation_inference.async_client import AsyncClient, APIInferenceAsyncClient diff --git a/clients/python/text_generation_inference/async_client.py b/clients/python/text_generation_inference/async_client.py deleted file mode 100644 index db106455..00000000 --- a/clients/python/text_generation_inference/async_client.py +++ /dev/null @@ -1,142 +0,0 @@ -import json -import os - -from aiohttp import ClientSession, ClientTimeout -from pydantic import ValidationError -from typing import Dict, Optional, List, AsyncIterator - -from text_generation_inference import SUPPORTED_MODELS -from text_generation_inference.types import ( - StreamResponse, - Response, - Request, - Parameters, -) -from text_generation_inference.errors import parse_error, NotSupportedError - -INFERENCE_ENDPOINT = os.environ.get( - "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" -) - - -class AsyncClient: - def __init__( - self, base_url: str, headers: Dict[str, str] = None, timeout: int = 10 - ): - self.base_url = base_url - self.headers = headers - self.timeout = ClientTimeout(timeout * 60) - - async def generate( - self, - prompt: str, - do_sample: bool = False, - max_new_tokens: int = 20, - repetition_penalty: Optional[float] = None, - return_full_text: bool = False, - seed: Optional[int] = None, - stop: Optional[List[str]] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - watermark: bool = False, - ) -> Response: - parameters = Parameters( - details=True, - do_sample=do_sample, - max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty, - return_full_text=return_full_text, - seed=seed, - stop=stop if stop is not None else [], - temperature=temperature, - top_k=top_k, - top_p=top_p, - watermark=watermark, - ) - request = Request(inputs=prompt, stream=False, parameters=parameters) - - async with ClientSession(headers=self.headers, timeout=self.timeout) as session: - async with session.post(self.base_url, json=request.dict()) as resp: - payload = await resp.json() - if resp.status != 200: - raise parse_error(resp.status, payload) - return Response(**payload[0]) - - async def generate_stream( - self, - prompt: str, - do_sample: bool = False, - max_new_tokens: int = 20, - repetition_penalty: Optional[float] = None, - return_full_text: bool = False, - seed: Optional[int] = None, - stop: Optional[List[str]] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - watermark: bool = False, - ) -> AsyncIterator[StreamResponse]: - parameters = Parameters( - details=True, - do_sample=do_sample, - max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty, - return_full_text=return_full_text, - seed=seed, - stop=stop if stop is not None else [], - temperature=temperature, - top_k=top_k, - top_p=top_p, - watermark=watermark, - ) - request = Request(inputs=prompt, stream=True, parameters=parameters) - - async with ClientSession(headers=self.headers, timeout=self.timeout) as session: - async with session.post(self.base_url, json=request.dict()) as resp: - if resp.status != 200: - raise parse_error(resp.status, await resp.json()) - - async for byte_payload in resp.content: - if byte_payload == b"\n": - continue - - payload = byte_payload.decode("utf-8") - - if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) - try: - response = StreamResponse(**json_payload) - except ValidationError: - raise parse_error(resp.status, json_payload) - yield response - - -class APIInferenceAsyncClient(AsyncClient): - def __init__(self, model_id: str, token: Optional[str] = None, timeout: int = 10): - # Text Generation Inference client only supports a subset of the available hub models - if model_id not in SUPPORTED_MODELS: - raise NotSupportedError(model_id) - - headers = {} - if token is not None: - headers = {"Authorization": f"Bearer {token}"} - base_url = f"{INFERENCE_ENDPOINT}/models/{model_id}" - - super(APIInferenceAsyncClient, self).__init__(base_url, headers, timeout) - - -if __name__ == "__main__": - async def main(): - client = APIInferenceAsyncClient( - "bigscience/bloomz", token="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" - ) - async for token in client.generate_stream("test"): - print(token) - - print(await client.generate("test")) - - - import asyncio - - asyncio.run(main()) diff --git a/clients/python/text_generation_inference/client.py b/clients/python/text_generation_inference/client.py deleted file mode 100644 index e69de29b..00000000 diff --git a/server/.gitignore b/server/.gitignore index 5758ba92..aef74bb4 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -1,7 +1,7 @@ # Byte-compiled / optimized / DLL files __pycache__/ -text_generation/__pycache__/ -text_generation/pb/__pycache__/ +text_generation_server/__pycache__/ +text_generation_server/pb/__pycache__/ *.py[cod] *$py.class diff --git a/server/Makefile b/server/Makefile index 3926cf5d..3de13e8e 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,10 +3,10 @@ transformers_commit := 2f87dca1ca3e5663d0637da9bb037a6956e57a5e gen-server: # Compile protos pip install grpcio-tools==1.51.1 --no-cache-dir - mkdir text_generation/pb || true - python -m grpc_tools.protoc -I../proto --python_out=text_generation/pb --grpc_python_out=text_generation/pb ../proto/generate.proto - find text_generation/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; - touch text_generation/pb/__init__.py + mkdir text_generation_server/pb || true + python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb --grpc_python_out=text_generation_server/pb ../proto/generate.proto + find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; + touch text_generation_server/pb/__init__.py install-transformers: # Install specific version of transformers with custom cuda kernels @@ -28,4 +28,4 @@ install: gen-server install-torch install-transformers pip install -e . --no-cache-dir run-dev: - SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded \ No newline at end of file diff --git a/server/pyproject.toml b/server/pyproject.toml index 21277939..b8e188b7 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -1,11 +1,11 @@ [tool.poetry] -name = "text-generation" +name = "text-generation-server" version = "0.3.2" description = "Text Generation Inference Python gRPC Server" authors = ["Olivier Dehaene "] [tool.poetry.scripts] -text-generation-server = 'text_generation.cli:app' +text-generation-server = 'text_generation_server.cli:app' [tool.poetry.dependencies] python = "^3.9" diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 9fae8ee1..04c909ef 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from text_generation.pb import generate_pb2 +from text_generation_server.pb import generate_pb2 @pytest.fixture diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index bc36276a..90239f95 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -4,9 +4,9 @@ import torch from copy import copy from transformers import AutoTokenizer -from text_generation.pb import generate_pb2 -from text_generation.models.causal_lm import CausalLMBatch -from text_generation.models.bloom import BloomCausalLMBatch, BLOOM +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM @pytest.fixture(scope="session") diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 723017cd..869022fa 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -4,8 +4,8 @@ import torch from copy import copy from transformers import AutoTokenizer -from text_generation.pb import generate_pb2 -from text_generation.models.causal_lm import CausalLM, CausalLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch @pytest.fixture(scope="session") diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index 1596e413..d089def3 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,8 +1,8 @@ import pytest -from text_generation.pb import generate_pb2 -from text_generation.models.causal_lm import CausalLMBatch -from text_generation.models.santacoder import SantaCoder +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.models.santacoder import SantaCoder @pytest.fixture(scope="session") diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index f7173392..764f8f83 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -5,8 +5,8 @@ from copy import copy from transformers import AutoTokenizer -from text_generation.pb import generate_pb2 -from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch @pytest.fixture(scope="session") diff --git a/server/tests/utils/test_convert.py b/server/tests/utils/test_convert.py index 5f284be5..7dfe6a1e 100644 --- a/server/tests/utils/test_convert.py +++ b/server/tests/utils/test_convert.py @@ -1,6 +1,10 @@ -from text_generation.utils.hub import download_weights, weight_hub_files, weight_files +from text_generation_server.utils.hub import ( + download_weights, + weight_hub_files, + weight_files, +) -from text_generation.utils.convert import convert_files +from text_generation_server.utils.convert import convert_files def test_convert_files(): diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index b3120160..fac9a64d 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -1,6 +1,6 @@ import pytest -from text_generation.utils.hub import ( +from text_generation_server.utils.hub import ( weight_hub_files, download_weights, weight_files, diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 7eca482f..3883ad97 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,4 +1,4 @@ -from text_generation.utils.tokens import ( +from text_generation_server.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, FinishReason, diff --git a/server/text_generation/__init__.py b/server/text_generation_server/__init__.py similarity index 100% rename from server/text_generation/__init__.py rename to server/text_generation_server/__init__.py diff --git a/server/text_generation/cache.py b/server/text_generation_server/cache.py similarity index 90% rename from server/text_generation/cache.py rename to server/text_generation_server/cache.py index 5a3a8d31..72dc4857 100644 --- a/server/text_generation/cache.py +++ b/server/text_generation_server/cache.py @@ -1,6 +1,6 @@ from typing import Dict, Optional, TypeVar -from text_generation.models.types import Batch +from text_generation_server.models.types import Batch B = TypeVar("B", bound=Batch) diff --git a/server/text_generation/cli.py b/server/text_generation_server/cli.py similarity index 94% rename from server/text_generation/cli.py rename to server/text_generation_server/cli.py index 678dce16..6308ef6b 100644 --- a/server/text_generation/cli.py +++ b/server/text_generation_server/cli.py @@ -6,8 +6,8 @@ from pathlib import Path from loguru import logger from typing import Optional -from text_generation import server, utils -from text_generation.tracing import setup_tracing +from text_generation_server import server, utils +from text_generation_server.tracing import setup_tracing app = typer.Typer() @@ -42,7 +42,7 @@ def serve( logger.add( sys.stdout, format="{message}", - filter="text_generation", + filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, @@ -68,7 +68,7 @@ def download_weights( logger.add( sys.stdout, format="{message}", - filter="text_generation", + filter="text_generation_server", level=logger_level, serialize=json_output, backtrace=True, diff --git a/server/text_generation/interceptor.py b/server/text_generation_server/interceptor.py similarity index 100% rename from server/text_generation/interceptor.py rename to server/text_generation_server/interceptor.py diff --git a/server/text_generation/models/__init__.py b/server/text_generation_server/models/__init__.py similarity index 79% rename from server/text_generation/models/__init__.py rename to server/text_generation_server/models/__init__.py index 386b7dc9..0035c1c6 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -3,14 +3,14 @@ import torch from transformers import AutoConfig from typing import Optional -from text_generation.models.model import Model -from text_generation.models.causal_lm import CausalLM -from text_generation.models.bloom import BLOOM, BLOOMSharded -from text_generation.models.seq2seq_lm import Seq2SeqLM -from text_generation.models.galactica import Galactica, GalacticaSharded -from text_generation.models.santacoder import SantaCoder -from text_generation.models.gpt_neox import GPTNeox, GPTNeoxSharded -from text_generation.models.t5 import T5Sharded +from text_generation_server.models.model import Model +from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.bloom import BLOOM, BLOOMSharded +from text_generation_server.models.seq2seq_lm import Seq2SeqLM +from text_generation_server.models.galactica import Galactica, GalacticaSharded +from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.gpt_neox import GPTNeox, GPTNeoxSharded +from text_generation_server.models.t5 import T5Sharded __all__ = [ "Model", diff --git a/server/text_generation/models/bloom.py b/server/text_generation_server/models/bloom.py similarity index 97% rename from server/text_generation/models/bloom.py rename to server/text_generation_server/models/bloom.py index 83a0d63e..0d83abe2 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -17,10 +17,10 @@ from transformers.models.bloom.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import CausalLM -from text_generation.models.causal_lm import CausalLMBatch -from text_generation.pb import generate_pb2 -from text_generation.utils import ( +from text_generation_server.models import CausalLM +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py similarity index 98% rename from server/text_generation/models/causal_lm.py rename to server/text_generation_server/models/causal_lm.py index 23c94ddf..4cfc15b9 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -5,10 +5,15 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type -from text_generation.models import Model -from text_generation.models.types import Batch, PrefillTokens, Generation, GeneratedText -from text_generation.pb import generate_pb2 -from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.models import Model +from text_generation_server.models.types import ( + Batch, + PrefillTokens, + Generation, + GeneratedText, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling tracer = trace.get_tracer(__name__) diff --git a/server/text_generation/models/galactica.py b/server/text_generation_server/models/galactica.py similarity index 98% rename from server/text_generation/models/galactica.py rename to server/text_generation_server/models/galactica.py index 9a71c5d3..b4d1c553 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -18,10 +18,10 @@ from transformers.models.opt.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import CausalLM -from text_generation.pb import generate_pb2 -from text_generation.models.causal_lm import CausalLMBatch -from text_generation.utils import ( +from text_generation_server.models import CausalLM +from text_generation_server.pb import generate_pb2 +from text_generation_server.models.causal_lm import CausalLMBatch +from text_generation_server.utils import ( NextTokenChooser, StoppingCriteria, initialize_torch_distributed, diff --git a/server/text_generation/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py similarity index 99% rename from server/text_generation/models/gpt_neox.py rename to server/text_generation_server/models/gpt_neox.py index 0197f976..5e1960f4 100644 --- a/server/text_generation/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -16,8 +16,8 @@ from transformers.models.gpt_neox.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import CausalLM -from text_generation.utils import ( +from text_generation_server.models import CausalLM +from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) diff --git a/server/text_generation/models/model.py b/server/text_generation_server/models/model.py similarity index 95% rename from server/text_generation/models/model.py rename to server/text_generation_server/models/model.py index 09fa6a2a..e0ce6686 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation_server/models/model.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase -from text_generation.models.types import Batch, GeneratedText +from text_generation_server.models.types import Batch, GeneratedText B = TypeVar("B", bound=Batch) diff --git a/server/text_generation/models/santacoder.py b/server/text_generation_server/models/santacoder.py similarity index 97% rename from server/text_generation/models/santacoder.py rename to server/text_generation_server/models/santacoder.py index 5d271c85..fe15cde0 100644 --- a/server/text_generation/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -4,7 +4,7 @@ import torch.distributed from typing import Optional, List from transformers import AutoTokenizer, AutoModelForCausalLM -from text_generation.models import CausalLM +from text_generation_server.models import CausalLM FIM_PREFIX = "" FIM_MIDDLE = "" diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py similarity index 97% rename from server/text_generation/models/seq2seq_lm.py rename to server/text_generation_server/models/seq2seq_lm.py index 4b88baec..73a10879 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -5,10 +5,15 @@ from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type -from text_generation.models import Model -from text_generation.models.types import GeneratedText, Batch, Generation, PrefillTokens -from text_generation.pb import generate_pb2 -from text_generation.utils import NextTokenChooser, StoppingCriteria, Sampling +from text_generation_server.models import Model +from text_generation_server.models.types import ( + GeneratedText, + Batch, + Generation, + PrefillTokens, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling tracer = trace.get_tracer(__name__) @@ -45,7 +50,7 @@ class Seq2SeqLMBatch(Batch): padding_right_offset: int def to_pb(self) -> generate_pb2.Batch: - """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" + """Convert a Seq2SeqLMBatch to a text_generation_server.v1.Batch protobuf""" return generate_pb2.Batch( id=self.batch_id, requests=self.requests, @@ -59,7 +64,7 @@ class Seq2SeqLMBatch(Batch): tokenizer: PreTrainedTokenizerBase, device: torch.device, ) -> "Seq2SeqLMBatch": - """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" + """Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch""" inputs = [] next_token_choosers = [] stopping_criterias = [] diff --git a/server/text_generation/models/t5.py b/server/text_generation_server/models/t5.py similarity index 98% rename from server/text_generation/models/t5.py rename to server/text_generation_server/models/t5.py index 55507539..cb4f7f22 100644 --- a/server/text_generation/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -16,8 +16,8 @@ from transformers.models.t5.parallel_layers import ( TensorParallelRowLinear, ) -from text_generation.models import Seq2SeqLM -from text_generation.utils import ( +from text_generation_server.models import Seq2SeqLM +from text_generation_server.utils import ( initialize_torch_distributed, weight_files, ) diff --git a/server/text_generation/models/types.py b/server/text_generation_server/models/types.py similarity index 95% rename from server/text_generation/models/types.py rename to server/text_generation_server/models/types.py index a3fbd6e8..93c3b9db 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation_server/models/types.py @@ -6,8 +6,8 @@ from typing import List, Optional from transformers import PreTrainedTokenizerBase -from text_generation.pb import generate_pb2 -from text_generation.pb.generate_pb2 import FinishReason +from text_generation_server.pb import generate_pb2 +from text_generation_server.pb.generate_pb2 import FinishReason class Batch(ABC): diff --git a/server/text_generation/pb/.gitignore b/server/text_generation_server/pb/.gitignore similarity index 100% rename from server/text_generation/pb/.gitignore rename to server/text_generation_server/pb/.gitignore diff --git a/server/text_generation/server.py b/server/text_generation_server/server.py similarity index 92% rename from server/text_generation/server.py rename to server/text_generation_server/server.py index f3129cb4..0b75c3c7 100644 --- a/server/text_generation/server.py +++ b/server/text_generation_server/server.py @@ -9,11 +9,11 @@ from grpc_reflection.v1alpha import reflection from pathlib import Path from typing import List, Optional -from text_generation.cache import Cache -from text_generation.interceptor import ExceptionInterceptor -from text_generation.models import Model, get_model -from text_generation.pb import generate_pb2_grpc, generate_pb2 -from text_generation.tracing import UDSOpenTelemetryAioServerInterceptor +from text_generation_server.cache import Cache +from text_generation_server.interceptor import ExceptionInterceptor +from text_generation_server.models import Model, get_model +from text_generation_server.pb import generate_pb2_grpc, generate_pb2 +from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): diff --git a/server/text_generation/tracing.py b/server/text_generation_server/tracing.py similarity index 100% rename from server/text_generation/tracing.py rename to server/text_generation_server/tracing.py diff --git a/server/text_generation/utils/__init__.py b/server/text_generation_server/utils/__init__.py similarity index 100% rename from server/text_generation/utils/__init__.py rename to server/text_generation_server/utils/__init__.py diff --git a/server/text_generation/utils/convert.py b/server/text_generation_server/utils/convert.py similarity index 100% rename from server/text_generation/utils/convert.py rename to server/text_generation_server/utils/convert.py diff --git a/server/text_generation/utils/dist.py b/server/text_generation_server/utils/dist.py similarity index 100% rename from server/text_generation/utils/dist.py rename to server/text_generation_server/utils/dist.py diff --git a/server/text_generation/utils/hub.py b/server/text_generation_server/utils/hub.py similarity index 100% rename from server/text_generation/utils/hub.py rename to server/text_generation_server/utils/hub.py diff --git a/server/text_generation/utils/tokens.py b/server/text_generation_server/utils/tokens.py similarity index 98% rename from server/text_generation/utils/tokens.py rename to server/text_generation_server/utils/tokens.py index 797a5634..00f4e64f 100644 --- a/server/text_generation/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -11,7 +11,7 @@ from transformers import ( ) from typing import List, Tuple, Optional -from text_generation.pb import generate_pb2 +from text_generation_server.pb import generate_pb2 from text_generation.pb.generate_pb2 import FinishReason from text_generation.utils.watermark import WatermarkLogitsProcessor diff --git a/server/text_generation/utils/watermark.py b/server/text_generation_server/utils/watermark.py similarity index 100% rename from server/text_generation/utils/watermark.py rename to server/text_generation_server/utils/watermark.py