mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: improve get_model_with_lora_adapters naming
This commit is contained in:
parent
59022c22b4
commit
1f3b2aeee4
@ -60,7 +60,7 @@ __all__ = [
|
|||||||
"Model",
|
"Model",
|
||||||
"CausalLM",
|
"CausalLM",
|
||||||
"Seq2SeqLM",
|
"Seq2SeqLM",
|
||||||
"get_model",
|
"get_model_with_lora_adapters",
|
||||||
]
|
]
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
@ -304,7 +304,7 @@ for data in ModelType:
|
|||||||
__GLOBALS[data.name] = data.value["type"]
|
__GLOBALS[data.name] = data.value["type"]
|
||||||
|
|
||||||
|
|
||||||
def _get_model(
|
def get_model(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapter_ids: Optional[List[str]],
|
lora_adapter_ids: Optional[List[str]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -1124,7 +1124,7 @@ def _get_model(
|
|||||||
|
|
||||||
# get_model wraps the internal _get_model function and adds support for loading adapters
|
# get_model wraps the internal _get_model function and adds support for loading adapters
|
||||||
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
||||||
def get_model(
|
def get_model_with_lora_adapters(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
lora_adapters: Optional[List[AdapterInfo]],
|
lora_adapters: Optional[List[AdapterInfo]],
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
@ -1137,7 +1137,7 @@ def get_model(
|
|||||||
adapter_to_index: Dict[str, int],
|
adapter_to_index: Dict[str, int],
|
||||||
):
|
):
|
||||||
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
||||||
model = _get_model(
|
model = get_model(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapter_ids,
|
lora_adapter_ids,
|
||||||
revision,
|
revision,
|
||||||
|
@ -13,7 +13,7 @@ from typing import List, Optional, Dict
|
|||||||
|
|
||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -226,7 +226,7 @@ def serve(
|
|||||||
server_urls = [local_url]
|
server_urls = [local_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model_with_lora_adapters(
|
||||||
model_id,
|
model_id,
|
||||||
lora_adapters,
|
lora_adapters,
|
||||||
revision,
|
revision,
|
||||||
|
Loading…
Reference in New Issue
Block a user