From 1f3b2aeee46ee3d220580adb5503031887d7ee2b Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 24 Jul 2024 13:25:24 +0000 Subject: [PATCH] fix: improve get_model_with_lora_adapters naming --- server/text_generation_server/models/__init__.py | 8 ++++---- server/text_generation_server/server.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f1777744..fe5b5e2c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -60,7 +60,7 @@ __all__ = [ "Model", "CausalLM", "Seq2SeqLM", - "get_model", + "get_model_with_lora_adapters", ] FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." @@ -304,7 +304,7 @@ for data in ModelType: __GLOBALS[data.name] = data.value["type"] -def _get_model( +def get_model( model_id: str, lora_adapter_ids: Optional[List[str]], revision: Optional[str], @@ -1124,7 +1124,7 @@ def _get_model( # get_model wraps the internal _get_model function and adds support for loading adapters # this provides a post model loading hook to load adapters into the model after the model has been loaded -def get_model( +def get_model_with_lora_adapters( model_id: str, lora_adapters: Optional[List[AdapterInfo]], revision: Optional[str], @@ -1137,7 +1137,7 @@ def get_model( adapter_to_index: Dict[str, int], ): lora_adapter_ids = [adapter.id for adapter in lora_adapters] - model = _get_model( + model = get_model( model_id, lora_adapter_ids, revision, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index fc199e34..7ac54603 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -13,7 +13,7 @@ from typing import List, Optional, Dict from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor -from text_generation_server.models import Model, get_model +from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.utils.adapter import AdapterInfo try: @@ -226,7 +226,7 @@ def serve( server_urls = [local_url] try: - model = get_model( + model = get_model_with_lora_adapters( model_id, lora_adapters, revision,