2025-02-24 08:10:05 +00:00
import copy
import logging
import time
from abc import ABC
from enum import Enum
from typing import List , Optional , Tuple
import torch
from loguru import logger
from transformers import AutoConfig , AutoTokenizer , PreTrainedTokenizerBase
from transformers . generation import GenerationConfig
from optimum . neuron import NeuronModelForCausalLM
from optimum . neuron . generation import TokenSelector
from . model import get_export_kwargs_from_env
from . pb . generate_pb2 import (
Batch ,
CachedBatch ,
FinishReason ,
GeneratedText ,
Generation ,
InfoResponse ,
Request ,
Tokens ,
)
# Disable optimum-neuron warnings as it seems to block the server after a while
optimum_logger = logging . getLogger ( " optimum.neuron " )
optimum_logger . setLevel ( " CRITICAL " )
class Generator ( ABC ) :
""" An abstract class to represent the workhorse behind TextGenerationService.
Ideally , it should not rely on protobuf constructs , but in a first step it does .
Implementations would typically need a model and a tokenizer to implement the Generator methods .
"""
@property
def info ( self ) - > InfoResponse :
""" This should simply return the expected InfoResponse """
raise NotImplementedError
def warmup ( self , batch : Batch ) - > int :
""" Verify if the hardware can support the target load.
Args :
batch ( ` Batch ` ) :
A batch corresponding to the maximum number of concurrent requests .
Return :
The maximum number of tokens the model supports .
"""
raise NotImplementedError
def prefill ( self , batch : Batch ) - > Tuple [ List [ Generation ] , CachedBatch ] :
""" Prefill is called whenever new requests need to be added.
When this method returns successfully , a decode method will follow
with both the current and newly prefilled batch ( es ) .
Args :
batch ( ` Batch ` ) :
A batch containing the new requests .
Return :
A list of ` Generation ` for each request and a ` CachedBatch ` containing all pending requests .
"""
raise NotImplementedError
def decode ( self , batches : List [ Batch ] ) - > Tuple [ List [ Generation ] , CachedBatch ] :
""" Decode after a prefill or another decode. """
raise NotImplementedError
def filter ( self , batch_id : int , request_ids : List [ int ] ) - > CachedBatch :
""" Remove requests that are not listed from the specified batch """
raise NotImplementedError
def clear ( self ) :
""" Remove all requests from the generator """
raise NotImplementedError
@classmethod
def from_pretrained ( cls , model_id : str , revision : Optional [ str ] ) :
""" Factory method " a la transformers " """
raise NotImplementedError
class Slot :
""" Represents a slot in a static batch """
class State ( Enum ) :
EMPTY = 0
PAUSE = 1
READY = 2
def __init__ ( self , id : int , tokenizer : PreTrainedTokenizerBase ) :
self . _id = id
self . _tokenizer = tokenizer
self . clear ( )
def clear ( self ) :
""" Clear the slot and mark it as available. """
self . _state = Slot . State . EMPTY
self . _batch_id = None
self . _request_id = None
self . _inputs = " "
self . _truncate = 0
self . _generation_config = None
self . _tokens = [ ]
self . _mask = torch . tensor ( [ ] )
self . _selector = None
self . _generated_tokens = 0
self . _next_text_token_start = 0
self . _next_text_token_end = 0
self . _generated_text = " "
self . _next_text = " "
@property
def id ( self ) - > int :
return self . _id
@property
def state ( self ) - > " Slot.State " :
return self . _state
@property
def batch_id ( self ) - > int :
return self . _batch_id
@property
def request_id ( self ) - > int :
return self . _request_id
@property
def cached_text ( self ) - > str :
return self . _inputs + self . _generated_text
@property
def generation_config ( self ) - > GenerationConfig :
return self . _generation_config
@property
def generated_tokens ( self ) - > int :
return self . _generated_tokens
2025-02-25 21:11:34 +00:00
def assign (
self , batch_id : int , request : Request , generation_config : GenerationConfig
) :
2025-02-24 08:10:05 +00:00
""" Assign a request to a slot.
Args :
request ( ` Request ` ) :
The request to be assigned . Contains the inputs and tokens selection parameters .
generation_config ( ` transformers . GenerationConfig ` ) :
The base generation config ( might be modified by the request generation parameters ) .
"""
self . _state = Slot . State . READY
self . _batch_id = batch_id
self . _request_id = request . id
self . _inputs = request . inputs
if request . truncate :
self . _truncate = request . truncate
self . _generation_config = copy . deepcopy ( generation_config )
# Update generation config with request parameters
self . _generation_config . do_sample = request . parameters . do_sample
if self . _generation_config . do_sample :
if request . parameters . temperature != 0 :
self . _generation_config . temperature = request . parameters . temperature
if request . parameters . top_k != 0 :
self . _generation_config . top_k = request . parameters . top_k
if request . parameters . top_p != 0 :
self . _generation_config . top_p = request . parameters . top_p
if request . parameters . typical_p != 0 :
self . _generation_config . typical_p = request . parameters . typical_p
if request . parameters . repetition_penalty != 0 :
2025-02-25 21:11:34 +00:00
self . _generation_config . repetition_penalty = (
request . parameters . repetition_penalty
)
2025-02-24 08:10:05 +00:00
self . seed = request . parameters . seed
2025-02-25 21:11:34 +00:00
self . _generation_config . max_new_tokens = (
request . stopping_parameters . max_new_tokens
)
2025-02-24 08:10:05 +00:00
self . _max_new_tokens = self . _generation_config . max_new_tokens
stop_strings = request . stopping_parameters . stop_sequences
if stop_strings :
self . _generation_config . stop_strings = stop_strings
2025-02-25 21:11:34 +00:00
def reset (
self ,
input_ids : torch . LongTensor ,
attention_mask : torch . LongTensor ,
selector : TokenSelector ,
) :
2025-02-24 08:10:05 +00:00
""" Reset the slot for the next generation.
Args :
input_ids : ( ` torch . LongTensor ` ) :
The new input_ids to use to generate the next token .
attention_mask : ( ` torch . LongTensor ` ) :
The new attention_mask to use to generate the next token .
selector : ( ` optimum . neuron . generation . TokenSelector ` ) :
An object implementing the updated token selection logic .
"""
self . _tokens = input_ids . clone ( )
self . _next_text_token_start = 0
self . _next_text_token_end = torch . numel ( self . _tokens )
self . _next_text = " "
self . _mask = attention_mask . clone ( )
self . _selector = selector
def pause ( self , reset_on_pause : bool ) :
""" Mark the current slot as paused for generation.
Note that the KV cache for this slot will still be filled .
"""
if reset_on_pause :
# Drop the last token as it will be added back when resuming the slot
self . _generated_tokens - = 1
# Since generated tokens are now part of the prefill, we need to reevaluate
# max_new_tokens for the next generation
2025-02-25 21:11:34 +00:00
self . _generation_config . max_new_tokens = (
self . _max_new_tokens - self . _generated_tokens
)
2025-02-24 08:10:05 +00:00
self . _state = Slot . State . PAUSE
def resume ( self ) :
""" Mark the slot as ready for generation. """
self . _state = Slot . State . READY
def _decode_next_tokens (
self ,
) - > str :
""" Hack to hopefully support generate_stream for the maximum number of tokenizers """
# We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
2025-02-25 21:11:34 +00:00
new_text = self . _tokenizer . decode (
self . _tokens [ self . _next_text_token_start : ] , skip_special_tokens = False
)
2025-02-24 08:10:05 +00:00
if new_text . endswith ( " <EFBFBD> " ) :
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
return " "
# Compare the generated text with the one using only the tokens producing the last one
last_text = self . _tokenizer . decode (
self . _tokens [ self . _next_text_token_start : self . _next_text_token_end ] ,
skip_special_tokens = False ,
)
if len ( new_text ) == len ( last_text ) :
# Nothing new was actually generated
return " "
# Return the decoded text and store its token offsets
self . _next_text_token_start = self . _next_text_token_end
self . _next_text_token_end = torch . numel ( self . _tokens )
return new_text [ len ( last_text ) : ]
def append ( self , next_token : int ) - > str :
""" Append a new generated token to this slot
The new token is added to the list of generated tokens , which impacts
directly the generated_text and stopped property .
The new token is however not added immediately to the slot inputs : it will
be added later on when it has effectively been used to produce the next token .
Args :
next_token ( ` int ` ) :
The newly generated token .
Return :
The corresponding decoded text ( if any ) .
"""
self . _tokens = torch . cat ( [ self . _tokens , torch . LongTensor ( [ next_token ] ) ] )
self . _mask = torch . cat ( [ self . _mask , torch . LongTensor ( [ 1 ] ) ] )
self . _generated_tokens + = 1
next_text = self . _decode_next_tokens ( )
# Now that a new token has been generated, we can append the previous one to the generated text
self . _generated_text + = self . _next_text
self . _next_text = next_text
return next_text
2025-02-25 21:11:34 +00:00
def select (
self , input_ids : torch . LongTensor , logits : torch . Tensor
) - > torch . LongTensor :
2025-02-24 08:10:05 +00:00
""" Select the next token from the candidate logits.
Args :
input_ids ( ` torch . LongTensor ` of shape ` ( batch_size , sequence_length ) ` ) :
The sequence used as a prompt for the generation ( not used in all generation modes ) .
logits ( ` torch . Tensor ` of shape ` ( batch_size , sequence_length ) ` ) :
The logits corresponding to the generated tokens .
Return :
` torch . LongTensor ` : A scalar torch . LongTensor ` containing the selected token .
"""
return self . _selector . select ( input_ids , logits ) [ 0 ]
@property
def stopped ( self ) - > bool :
# Transformers stopping criteria expects a batch of input ids
input_ids = torch . unsqueeze ( self . _tokens , dim = 0 )
return self . _selector . stopping_criteria ( input_ids , None )
@property
def generated_text ( self ) - > str :
return self . _generated_text + self . _next_text
@property
def next_token ( self ) - > int :
return None if len ( self . _tokens ) == 0 else self . _tokens [ - 1 ]
@property
def attention_mask ( self ) - > torch . LongTensor :
return self . _mask
@property
def max_token ( self ) - > int :
return self . _generation_config . max_length
@property
def max_new_tokens ( self ) - > int :
# The current value of max_new_tokens: might be different of the target max_new_tokens
# if the slot has been paused and resumed.
return self . _generation_config . max_new_tokens
@property
def truncate ( self ) - > int :
return self . _truncate
class NeuronGenerator ( Generator ) :
""" A Generator for Neuron models. """
def __init__ (
self ,
model : NeuronModelForCausalLM ,
tokenizer : PreTrainedTokenizerBase ,
) :
self . model = model
self . rebuild_cache_on_prefill = not self . model . continuous_batching
# Specify padding and truncation options for decoder-only architecture
tokenizer . pad_token_id = tokenizer . eos_token_id
tokenizer . padding_side = " left "
tokenizer . truncation_side = " left "
self . tokenizer = tokenizer
self . special_tokens = self . tokenizer . all_special_ids
self . slots = [ Slot ( i , tokenizer ) for i in range ( self . model . batch_size ) ]
self . batch_id = 0
@property
def info ( self ) - > InfoResponse :
""" Returns the expected InfoResponse. """
dtype = getattr ( self . model . config , " torch_dtype " , " float32 " )
return InfoResponse (
requires_padding = True ,
dtype = str ( dtype ) ,
device_type = " xla " ,
)
def warmup ( self , batch : Batch ) - > int :
""" Verify if the hardware can support the target load.
Args :
batch ( ` Batch ` ) :
A batch corresponding to the maximum number of concurrent requests .
Return :
The maximum number of tokens the model supports .
"""
# Just check that the warmup request parameters match the model capacity
batch_size = self . model . batch_size
if len ( batch . requests ) > batch_size :
raise ValueError (
f " Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently { batch_size } ) matches the batch_size passed to TGI. The compiled model batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process. The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE. "
)
self . prefill ( batch )
self . clear ( )
return self . model . batch_size * self . model . max_length
def prefill ( self , batch : Batch ) - > Tuple [ List [ Generation ] , CachedBatch ] :
""" Prefill new requests.
Args :
batch ( ` Batch ` ) :
A batch containing the new requests .
Return :
A list of ` Generation ` for each request and a ` CachedBatch ` containing all pending requests .
"""
slots = { state : [ ] for state in Slot . State }
for slot in self . slots :
slots [ slot . state ] . append ( slot )
active_slots = slots [ Slot . State . READY ]
empty_slots = slots [ Slot . State . EMPTY ]
if len ( empty_slots ) < len ( batch . requests ) :
raise ValueError (
f " Cannot prefill { len ( batch . requests ) } new request(s) with only { len ( empty_slots ) } empty slots. "
f " Please align max_batch_size with the static batch size: { self . model . batch_size } . "
)
# Assign each request to an empty slot
2025-02-25 21:11:34 +00:00
logger . debug (
f " Prefilling { len ( batch . requests ) } new request(s) with { len ( empty_slots ) } empty slot(s) "
)
2025-02-24 08:10:05 +00:00
new_slots = [ ]
for request in batch . requests :
slot = empty_slots . pop ( )
slot . assign ( self . batch_id , request , self . model . generation_config )
new_slots . append ( slot )
logger . debug (
f " Request { slot . request_id } assigned to slot { slot . id } with and max_new_tokens { slot . max_new_tokens } "
)
if self . rebuild_cache_on_prefill :
# We will clear pending slots and prefill all slots
prefill_slots = self . slots
seq_ids = None
else :
# We only need to pass inputs for the new requests
prefill_slots = new_slots
seq_ids = torch . tensor ( [ slot . id for slot in prefill_slots ] )
# Reconstruct the full inputs (without padding) as seen by the model.
# This comprises:
# - the inputs for new requests,
# - only when rebuilding the cache, the inputs and the generated text that has already
# been cached (i.e. excluding the last generated token) for unfinished requests.
inputs = [ ]
max_length = 0
for slot in prefill_slots :
inputs . append ( slot . cached_text )
# Apply truncation, making sure we fit into static dimensions
if slot . truncate == 0 :
max_length = self . model . max_length
elif slot . truncate > max_length and slot . truncate < self . model . max_length :
max_length = slot . truncate
# Tokenize with padding and truncation
padded_inputs = self . tokenizer (
2025-02-25 21:11:34 +00:00
inputs ,
return_tensors = " pt " ,
padding = True ,
truncation = True ,
max_length = max_length ,
2025-02-24 08:10:05 +00:00
)
input_ids = padded_inputs . input_ids
attention_mask = padded_inputs . attention_mask
# Pause previously active slots during generation
next_tokens = [ ]
for slot in active_slots :
slot . pause ( reset_on_pause = self . rebuild_cache_on_prefill )
if self . rebuild_cache_on_prefill :
# The slot will be reset, so we need to store its next token
next_tokens . append ( slot . next_token )
# Each slot must be reset with the padded inputs and masks
for i , slot in enumerate ( prefill_slots ) :
if slot . state != slot . state . EMPTY :
if slot . truncate > 0 and slot . truncate < input_ids . shape [ - 1 ] :
# Apply per-request truncation
input_ids [ i , : - slot . truncate ] = self . tokenizer . pad_token_id
attention_mask [ i , : - slot . truncate ] = 0
slot_input_ids = input_ids [ i : i + 1 , : ]
# Padded input ids are also required to set logits processors and stopping criterias
selector = TokenSelector . create (
slot_input_ids ,
slot . generation_config ,
self . model ,
self . model . max_length ,
tokenizer = self . tokenizer ,
seed = slot . seed ,
)
slot_input_ids = slot_input_ids . squeeze ( dim = 0 ) . type ( torch . int64 )
slot_attention_mask = attention_mask [ i ]
slot . reset ( slot_input_ids , slot_attention_mask , selector )
# Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,
# as they have already been generated and sent back in the last decode.
2025-02-25 21:11:34 +00:00
model_inputs = self . model . prepare_inputs_for_prefill (
input_ids , attention_mask , seq_ids
)
2025-02-24 08:10:05 +00:00
logits = self . model ( * * model_inputs ) [ 0 ]
2025-02-25 21:11:34 +00:00
generation , next_batch = self . _generate_token (
prefill_slots , self . batch_id , logits , input_ids
)
2025-02-24 08:10:05 +00:00
self . batch_id + = 1
# Reactivate previously active slots for the next decode
for i , slot in enumerate ( active_slots ) :
slot . resume ( )
if self . rebuild_cache_on_prefill :
# Append back the next token
slot . append ( next_tokens [ i ] )
logger . debug ( " Model ready for decoding " )
if next_batch is not None :
2025-02-25 21:11:34 +00:00
logger . debug (
f " Next batch is { next_batch . id } with requests: { next_batch . request_ids } "
)
2025-02-24 08:10:05 +00:00
return generation , next_batch
2025-02-25 21:11:34 +00:00
def decode (
self , batches : List [ CachedBatch ]
) - > Tuple [ List [ Generation ] , CachedBatch ] :
2025-02-24 08:10:05 +00:00
""" Decode the specified prefilled requests.
Args :
batches ( ` List [ CachedBatch ] ` ) :
A list of previous batches containing the prefilled requests .
Return :
A list of ` Generation ` for each request and a ` CachedBatch ` containing all pending requests .
"""
# batches contains a list composed of:
# - the batch id returned by the last decode,
# - the batch id(s) returned by the last prefill(s)
# Batches are always concatenated during prefill, so we can
# just carry on with decoding. We adopt the id of the first
# batch in the list as our next batch id.
next_batch_id = batches [ 0 ] . id
request_ids = [ ]
for batch in batches :
request_ids + = batch . request_ids
cleared_request_ids = [ ]
for slot in self . slots :
if slot . state == slot . State . READY and slot . request_id not in request_ids :
cleared_request_ids . append ( slot . request_id )
slot . clear ( )
if len ( cleared_request_ids ) > 0 :
2025-02-25 21:11:34 +00:00
logger . info (
f " Clearing slot for requests { cleared_request_ids } as they are not requested. "
)
2025-02-24 08:10:05 +00:00
active_slots = [ slot for slot in self . slots if slot . state == slot . State . READY ]
if len ( active_slots ) < len ( request_ids ) :
2025-02-25 21:11:34 +00:00
raise ValueError (
" Unable to decode tokens for non-prefilled batches (probably due to a previous failure) "
)
2025-02-24 08:10:05 +00:00
if self . model . continuous_batching :
decode_slots = active_slots
seq_ids = torch . tensor ( [ slot . id for slot in decode_slots ] )
else :
decode_slots = self . slots
seq_ids = None
# Reconstruct input_ids and attention_mask from decode slots
n_slots = len ( decode_slots )
2025-02-25 21:11:34 +00:00
input_ids = torch . full (
[ n_slots , 1 ] , fill_value = self . tokenizer . eos_token_id , dtype = torch . int64
)
2025-02-24 08:10:05 +00:00
max_length = 0
for slot in decode_slots :
max_length = max ( max_length , slot . attention_mask . size ( - 1 ) )
attention_mask = torch . zeros ( [ n_slots , max_length ] , dtype = torch . int64 )
for i , slot in enumerate ( decode_slots ) :
if slot . state != Slot . State . EMPTY :
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids [ i , 0 ] = slot . next_token
attention_mask [ i , : slot . attention_mask . size ( - 1 ) ] = slot . attention_mask
2025-02-25 21:11:34 +00:00
model_inputs = self . model . prepare_inputs_for_decode (
input_ids , attention_mask , seq_ids
)
2025-02-24 08:10:05 +00:00
logits = self . model ( * * model_inputs ) [ 0 ]
return self . _generate_token ( decode_slots , next_batch_id , logits , input_ids )
def _generate_token (
2025-02-25 21:11:34 +00:00
self ,
slots : List [ Slot ] ,
next_batch_id : int ,
logits : torch . Tensor ,
input_ids : torch . LongTensor ,
2025-02-24 08:10:05 +00:00
) - > Tuple [ List [ Generation ] , CachedBatch ] :
generations = [ ]
active_slots = False
for i , slot in enumerate ( slots ) :
if slot . state != Slot . State . READY :
continue
request_id = slot . request_id
next_token_logits = logits [ i : i + 1 , - 1 , : ]
slot_input_ids = input_ids [ i : i + 1 , : ]
next_token = slot . select ( slot_input_ids , next_token_logits )
next_token_text = slot . append ( next_token )
generated_text = None
finish_reason = None
if next_token == self . tokenizer . eos_token_id :
finish_reason = FinishReason . FINISH_REASON_EOS_TOKEN
elif slot . stopped :
if slot . generated_tokens == slot . max_new_tokens :
finish_reason = FinishReason . FINISH_REASON_LENGTH
else :
finish_reason = FinishReason . FINISH_REASON_STOP_SEQUENCE
if finish_reason is not None :
# We must include the generated text for each finished sequence in the response
generated_text = GeneratedText (
2025-02-25 21:11:34 +00:00
text = slot . generated_text ,
generated_tokens = slot . generated_tokens ,
finish_reason = finish_reason ,
)
logger . debug (
f " Decode complete for request { request_id } with { slot . generated_tokens } tokens "
2025-02-24 08:10:05 +00:00
)
# mark the slot as available
slot . clear ( )
else :
active_slots = True
generations . append (
Generation (
request_id = request_id ,
prefill_tokens = None ,
tokens = Tokens (
ids = [ next_token ] ,
logprobs = [ 0 ] ,
texts = [ next_token_text ] ,
is_special = [ next_token in self . special_tokens ] ,
) ,
generated_text = generated_text ,
)
)
batch = None
if active_slots :
# Whatever initial batch these requests came from, we always return all pending requests in a single batch
2025-02-25 21:11:34 +00:00
request_ids = [
slot . request_id for slot in self . slots if slot . state == Slot . State . READY
]
2025-02-24 08:10:05 +00:00
batch = self . _cached_batch ( next_batch_id , request_ids )
else :
logger . debug ( " No more pending requests " )
return generations , batch
def _cached_batch ( self , batch_id : int , request_ids : List ) :
size = len ( request_ids )
max_tokens = size * self . model . max_length
2025-02-25 21:11:34 +00:00
return CachedBatch (
id = batch_id , request_ids = request_ids , size = size , max_tokens = max_tokens
)
2025-02-24 08:10:05 +00:00
def filter ( self , batch_id : int , keep_request_ids : List [ int ] ) - > CachedBatch :
""" Remove requests that are not listed from the specified batch
Args :
batch_id ( ` int ` ) :
The id of a cached batch .
keep_ids ( ` List [ int ] ` ) :
The list of requests that must be kept .
Return :
A ` CachedBatch ` containing the pending requests .
"""
2025-02-25 21:11:34 +00:00
keep_slot_ids = [
slot . id for slot in self . slots if slot . request_id in keep_request_ids
]
2025-02-24 08:10:05 +00:00
self . _clear ( keep_slot_ids )
return self . _cached_batch ( batch_id , keep_request_ids )
def clear ( self , batch_id : Optional [ int ] = None ) :
""" Remove a subset or all requests from the generator """
keep_ids = [ ]
if batch_id is not None :
keep_ids = [ slot . id for slot in self . slots if slot . batch_id != batch_id ]
return self . _clear ( keep_ids )
def _clear ( self , keep_slot_ids : List ) :
for slot in self . slots :
if slot . state != Slot . State . EMPTY and slot . id not in keep_slot_ids :
logger . debug ( f " Removing slot { slot . id } with request { slot . request_id } " )
slot . clear ( )
@classmethod
def from_pretrained ( cls , model_id : str , revision : str = None ) :
""" Instantiate a NeuronGenerator.
Args :
model_id ( ` str ` ) :
A hub model id or the path to a local model . This path must also contain a Tokenizer .
revision ( ` Optional [ str ] ` , defaults to ` None ` ) :
The revision of the model on the HuggingFace hub .
Returns :
A NeuronGenerator .
"""
config = AutoConfig . from_pretrained ( model_id )
neuron_config = getattr ( config , " neuron " , None )
start = time . time ( )
if neuron_config is None :
export_kwargs = get_export_kwargs_from_env ( )
logger . info ( f " Exporting model to neuron with config: { export_kwargs } . " )
model = NeuronModelForCausalLM . from_pretrained (
2025-02-25 21:11:34 +00:00
model_id ,
revision = revision ,
low_cpu_mem_usage = True ,
export = True ,
* * export_kwargs ,
2025-02-24 08:10:05 +00:00
)
else :
2025-02-25 21:11:34 +00:00
logger . info (
" Loading model on neuron devices (this can take a few minutes). "
)
model = NeuronModelForCausalLM . from_pretrained (
model_id , low_cpu_mem_usage = True , revision = revision
)
2025-02-24 08:10:05 +00:00
end = time . time ( )
logger . info ( f " Model successfully loaded in { end - start : .2f } s. " )
tokenizer = AutoTokenizer . from_pretrained ( model_id , revision = revision )
return cls ( model , tokenizer )