fix: improve get_model_with_lora_adapters naming

This commit is contained in:
drbh 2024-07-24 13:25:24 +00:00
parent 59022c22b4
commit 1f3b2aeee4
2 changed files with 6 additions and 6 deletions

View File

@ -60,7 +60,7 @@ __all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model",
"get_model_with_lora_adapters",
]
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
@ -304,7 +304,7 @@ for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
def _get_model(
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
@ -1124,7 +1124,7 @@ def _get_model(
# get_model wraps the internal _get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model(
def get_model_with_lora_adapters(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
@ -1137,7 +1137,7 @@ def get_model(
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = _get_model(
model = get_model(
model_id,
lora_adapter_ids,
revision,

View File

@ -13,7 +13,7 @@ from typing import List, Optional, Dict
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
try:
@ -226,7 +226,7 @@ def serve(
server_urls = [local_url]
try:
model = get_model(
model = get_model_with_lora_adapters(
model_id,
lora_adapters,
revision,