2023-06-01 09:41:35 +00:00
import inspect
2022-11-04 13:22:47 +00:00
import torch
2022-11-03 15:07:54 +00:00
from abc import ABC , abstractmethod
2024-06-25 18:46:27 +00:00
from typing import List , Tuple , Optional , TypeVar , Type , Dict , DefaultDict
from collections import defaultdict
2023-07-21 08:59:00 +00:00
from transformers import PreTrainedTokenizerBase , PretrainedConfig
2022-10-28 17:24:00 +00:00
2023-07-31 12:35:14 +00:00
from text_generation_server . models . types import Batch , Generation
2023-12-11 11:46:30 +00:00
from text_generation_server . utils . speculate import get_speculate
2023-04-21 13:36:29 +00:00
from text_generation_server . pb . generate_pb2 import InfoResponse
2024-06-25 18:46:27 +00:00
from text_generation_server . adapters . weights import LayerAdapterWeights
from text_generation_server . utils . adapter import (
load_and_merge_adapters ,
AdapterParameters ,
AdapterSource ,
)
2024-07-02 12:52:55 +00:00
from text_generation_server . utils . import_utils import SYSTEM
from text_generation_server . models . globals import CUDA_GRAPHS
import os
2024-06-25 18:46:27 +00:00
from loguru import logger
BASE_MODEL_ADAPTER_ID = " __base_model__ "
2022-10-28 17:24:00 +00:00
2022-11-04 17:03:04 +00:00
B = TypeVar ( " B " , bound = Batch )
2023-07-24 09:43:58 +00:00
2022-11-03 15:07:54 +00:00
class Model ( ABC ) :
2023-04-12 10:03:10 +00:00
def __init__ (
self ,
2024-06-25 18:46:27 +00:00
model_id : str ,
2023-05-16 21:23:27 +00:00
model : torch . nn . Module ,
2023-04-12 10:03:10 +00:00
tokenizer : PreTrainedTokenizerBase ,
2023-04-21 13:36:29 +00:00
requires_padding : bool ,
dtype : torch . dtype ,
2023-04-12 10:03:10 +00:00
device : torch . device ,
2023-05-10 13:48:21 +00:00
rank : int = 0 ,
world_size : int = 1 ,
2023-09-28 07:55:47 +00:00
sliding_window : Optional [ int ] = None ,
2023-12-11 11:46:30 +00:00
speculate : Optional [ int ] = None ,
2024-06-25 18:46:27 +00:00
adapter_id : str = BASE_MODEL_ADAPTER_ID ,
2023-04-12 10:03:10 +00:00
) :
2024-06-25 18:46:27 +00:00
self . model_id = model_id
2023-05-16 21:23:27 +00:00
self . model = model . eval ( )
2022-11-04 13:22:47 +00:00
self . tokenizer = tokenizer
2024-04-25 17:41:50 +00:00
# all_special_ids is not set correctly if the rust tokenizer is unpacked
# TODO report this to transformers.
other_special_ids = {
id for id , token in tokenizer . added_tokens_decoder . items ( ) if token . special
}
2023-02-24 14:55:57 +00:00
self . all_special_ids = set ( tokenizer . all_special_ids )
2024-04-25 17:41:50 +00:00
self . all_special_ids . update ( other_special_ids )
2023-04-21 13:36:29 +00:00
self . requires_padding = requires_padding
self . dtype = dtype
2022-11-04 13:22:47 +00:00
self . device = device
2023-05-10 13:48:21 +00:00
self . rank = rank
self . world_size = world_size
2023-12-15 13:56:17 +00:00
self . sliding_window = sliding_window if sliding_window != - 1 else None
2023-06-01 09:41:35 +00:00
2024-06-25 18:46:27 +00:00
self . layer_to_adapter_weights : Dict [ str , LayerAdapterWeights ] = defaultdict (
LayerAdapterWeights
)
2024-07-05 09:25:29 +00:00
self . target_to_layer = None
2024-06-25 18:46:27 +00:00
self . loaded_adapters = set ( )
self . static_adapter_id = adapter_id
2023-12-11 11:46:30 +00:00
if speculate is None :
speculate = get_speculate ( )
self . speculate = speculate
2023-06-01 09:41:35 +00:00
self . has_position_ids = (
inspect . signature ( model . forward ) . parameters . get ( " position_ids " , None )
is not None
)
2023-05-15 09:32:25 +00:00
self . check_initialized ( )
2022-11-04 13:22:47 +00:00
2023-04-21 13:36:29 +00:00
@property
def info ( self ) - > InfoResponse :
2023-09-28 07:55:47 +00:00
if self . requires_padding and self . sliding_window is not None :
raise NotImplementedError ( " sliding_window is not implemented with padding " )
2023-04-21 13:36:29 +00:00
return InfoResponse (
requires_padding = self . requires_padding ,
dtype = str ( self . dtype ) ,
device_type = self . device . type ,
2023-09-28 07:55:47 +00:00
window_size = self . sliding_window ,
2023-12-11 13:49:52 +00:00
speculate = self . speculate ,
2023-04-21 13:36:29 +00:00
)
2022-11-04 17:03:04 +00:00
@property
2022-11-03 15:07:54 +00:00
@abstractmethod
2022-11-04 17:03:04 +00:00
def batch_type ( self ) - > Type [ B ] :
2022-11-03 15:07:54 +00:00
raise NotImplementedError
2022-10-28 17:24:00 +00:00
2022-11-04 17:03:04 +00:00
@abstractmethod
2023-12-14 14:59:38 +00:00
def generate_token (
self , batch : B
) - > Tuple [ List [ Generation ] , Optional [ B ] , Tuple [ int , int ] ] :
2022-11-04 17:03:04 +00:00
raise NotImplementedError
2023-03-06 12:22:58 +00:00
2023-07-19 07:31:25 +00:00
def warmup ( self , batch : B ) - > Optional [ int ] :
2024-07-02 12:52:55 +00:00
if SYSTEM == " rocm " and (
os . environ . get ( " PYTORCH_TUNABLEOP_ENABLED " ) is None
or os . environ . get ( " PYTORCH_TUNABLEOP_ENABLED " ) == " 1 "
) :
logger . info (
f " ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for { self . model_id } (instance of { self . __class__ . __name__ } ). Disabling TunableOp. "
)
torch . cuda . tunable . tuning_enable ( False )
torch . cuda . tunable . enable ( False )
2023-06-30 17:09:59 +00:00
self . generate_token ( batch )
2024-07-02 12:52:55 +00:00
if CUDA_GRAPHS :
logger . info (
f " Got CUDA_GRAPHS= { CUDA_GRAPHS } but cuda graphs are not supported for { self . model_id } (instance of { self . __class__ . __name__ } ). Cuda graphs will not be used. "
)
2023-07-19 07:31:25 +00:00
return None
2023-06-30 17:09:59 +00:00
2023-04-11 14:38:22 +00:00
def decode_token (
self ,
all_input_ids : List [ int ] ,
2023-05-16 21:23:27 +00:00
prefix_offset : int = 0 ,
read_offset : int = 0 ,
2023-09-27 10:13:45 +00:00
skip_special_tokens : bool = False ,
2023-05-16 21:23:27 +00:00
) - > Tuple [ str , int , int ] :
2023-03-06 12:22:58 +00:00
""" Hack to hopefully support generate_stream for the maximum number of tokenizers """
2023-04-11 14:38:22 +00:00
2023-05-16 21:23:27 +00:00
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self . tokenizer . decode (
2023-09-27 10:22:09 +00:00
all_input_ids [ prefix_offset : read_offset ] ,
skip_special_tokens = skip_special_tokens ,
2023-05-16 21:23:27 +00:00
)
new_text = self . tokenizer . decode (
2023-09-27 10:13:45 +00:00
all_input_ids [ prefix_offset : ] , skip_special_tokens = skip_special_tokens
2023-05-16 21:23:27 +00:00
)
2023-04-11 14:38:22 +00:00
2023-05-16 21:23:27 +00:00
if len ( new_text ) > len ( prefix_text ) and not new_text . endswith ( " <EFBFBD> " ) :
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text [ len ( prefix_text ) : ]
return new_text , read_offset , len ( all_input_ids )
2023-04-11 14:38:22 +00:00
else :
2023-05-16 21:23:27 +00:00
return " " , prefix_offset , read_offset
2023-05-15 09:32:25 +00:00
def check_initialized ( self ) :
uninitialized_parameters = [ ]
for n , p in self . model . named_parameters ( ) :
if p . data . device == torch . device ( " meta " ) :
uninitialized_parameters . append ( n )
if uninitialized_parameters :
raise RuntimeError (
f " found uninitialized parameters in model { self . __class__ . __name__ } : { uninitialized_parameters } "
)
2024-06-25 18:46:27 +00:00
@property
def supports_adapter_loading ( self ) - > bool :
return False
def adapter_target_to_layer ( self ) - > Dict [ str , Tuple [ str , torch . Tensor ] ] :
return { }
@property
def adapter_layers ( self ) - > List [ str ] :
return [ ]
@property
def default_traced_adapter_layers ( self ) - > List [ str ] :
return [ ]
def get_num_layers_for_type ( self , layer_type : str ) - > int :
return 0
def is_row_parallel ( self , layer_type : str ) - > bool :
return False
@property
def max_speculative_tokens ( self ) - > int :
return max (
[
weights . max_speculative_tokens
for weights in self . layer_to_adapter_weights . values ( )
] ,
default = 0 ,
)
def load_adapter (
self ,
adapter_parameters : AdapterParameters ,
adapter_source : AdapterSource ,
adapter_index : int ,
api_token : str ,
dynamic : bool = True ,
) :
""" Loads adapter weights from disk / host memory on the GPU.
adapter_id must be ` BASE_MODEL_ADAPTER_ID ` if adapter statically loaded
into model . Otherwise , the adapter weights are applied during the forward
pass and stored separately from the base model parameters .
"""
2024-07-05 09:25:29 +00:00
if self . target_to_layer is None :
self . target_to_layer = self . adapter_target_to_layer ( )
2024-06-25 18:46:27 +00:00
if adapter_index in self . loaded_adapters :
# Adapter already loaded
return
if not self . supports_adapter_loading :
raise ValueError ( " This model does not support adapter loading. " )
if dynamic and not self . dynamic_adapter_loading_enabled :
raise ValueError (
f " This model was initialized with the adapter { self . static_adapter_id } "
f " and therefore does not support dynamic adapter loading. "
f " Please initialize a new model instance from the base model in "
f " order to use the dynamic adapter loading feature. "
)
logger . info (
f " Loading adapter weights into model: { ' , ' . join ( adapter_parameters . adapter_ids ) } "
)
weight_names = tuple ( [ v [ 0 ] for v in self . target_to_layer . values ( ) ] )
(
module_map ,
adapter_config ,
adapter_weight_names ,
adapter_tokenizer ,
) = load_and_merge_adapters (
self . model_id ,
adapter_parameters ,
adapter_source ,
adapter_index ,
weight_names ,
api_token ,
False ,
)
unused_weight_names = adapter_weight_names . copy ( )
for layer_name in self . adapter_layers :
adapter_weights = adapter_config . load_batched_adapter_weights (
self ,
module_map ,
layer_name ,
unused_weight_names ,
dynamic ,
)
if adapter_weights is None :
continue
layer_weights = self . layer_to_adapter_weights [ layer_name ]
layer_weights . add_adapter ( adapter_index , adapter_weights )
if len ( unused_weight_names ) > 0 :
logger . warning (
f " { ' , ' . join ( adapter_parameters . adapter_ids ) } unused adapter weights: { unused_weight_names } "
)
if adapter_tokenizer is not None :
self . tokenizers . add_tokenizer ( adapter_index , adapter_tokenizer )
self . loaded_adapters . add ( adapter_index )
def offload_adapter (
self ,
adapter_parameters : AdapterParameters ,
adapter_source : AdapterSource ,
adapter_index : int ,
) :
""" Offloads the adapter weights from GPU to CPU or disk. """
if adapter_index not in self . loaded_adapters :
# Adapter already offloaded
return
if not self . supports_adapter_loading :
raise ValueError ( " This model does not support adapter loading. " )
if not self . dynamic_adapter_loading_enabled :
raise ValueError (
f " This model was initialized with the adapter { self . static_adapter_id } "
f " and therefore does not support dynamic adapter loading. "
f " Please initialize a new model instance from the base model in "
f " order to use the dynamic adapter loading feature. "
)
for layer_name in self . adapter_layers :
if layer_name in self . layer_to_adapter_weights :
self . layer_to_adapter_weights [ layer_name ] . remove_adapter ( adapter_index )
self . loaded_adapters . remove ( adapter_index )