From 7e810e76287a004c5a976b95838e9739d59ffb4f Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 19 Jul 2024 17:18:33 +0000 Subject: [PATCH] fix: update client exports and adjust after rebase --- clients/python/text_generation/__init__.py | 13 +++++++++++++ server/text_generation_server/layers/rotary.py | 1 - .../custom_modeling/flash_deepseek_v2_modeling.py | 6 ++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/clients/python/text_generation/__init__.py b/clients/python/text_generation/__init__.py index 79d5a0c3..57359f22 100644 --- a/clients/python/text_generation/__init__.py +++ b/clients/python/text_generation/__init__.py @@ -12,9 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from text_generation.client import Client, AsyncClient +from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient + + __version__ = "0.7.0" DEPRECATION_WARNING = ( "`text_generation` clients are deprecated and will be removed in the near future. " "Please use the `InferenceClient` from the `huggingface_hub` package instead." ) + + +__all__ = [ + "Client", + "AsyncClient", + "InferenceAPIClient", + "InferenceAPIAsyncClient", + "DEPRECATION_WARNING", +] diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index db9a1484..8a1e9261 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,7 +2,6 @@ import os import math import torch from torch import nn -from loguru import logger # Inverse dim formula to find dim based on number of rotations import math diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index f5b2ba0e..3e84b4a8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -39,6 +39,12 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + class DeepseekV2Config(PretrainedConfig): def __init__(