diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 86335875..51ddad20 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -289,36 +289,25 @@ fn shard_manager( } let mut env = vec![ - ("RANK".parse().unwrap(), rank.to_string().parse().unwrap()), - ( - "WORLD_SIZE".parse().unwrap(), - world_size.to_string().parse().unwrap(), - ), - ("MASTER_ADDR".parse().unwrap(), master_addr.parse().unwrap()), - ( - "MASTER_PORT".parse().unwrap(), - master_port.to_string().parse().unwrap(), - ), - ( - "SAFETENSORS_FAST_GPU".parse().unwrap(), - "1".to_string().parse().unwrap(), - ), + ("RANK".into(), rank.to_string().into()), + ("WORLD_SIZE".into(), world_size.to_string().into()), + ("MASTER_ADDR".into(), master_addr.into()), + ("MASTER_PORT".into(), master_port.to_string().into()), + ("SAFETENSORS_FAST_GPU".into(), "1".into()), ]; // If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard // Useful when running inside a docker container if let Ok(huggingface_hub_cache) = env::var("HUGGINGFACE_HUB_CACHE") { env.push(( - "HUGGINGFACE_HUB_CACHE".parse().unwrap(), - huggingface_hub_cache.parse().unwrap(), + "HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into(), )); }; // If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") { env.push(( - "CUDA_VISIBLE_DEVICES".parse().unwrap(), - cuda_visible_devices.parse().unwrap(), + "CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into(), )); }; diff --git a/router/src/batcher.rs b/router/src/batcher.rs index 1484434c..3f2a8668 100644 --- a/router/src/batcher.rs +++ b/router/src/batcher.rs @@ -74,10 +74,9 @@ impl Batcher { // Await on the response from the background task // We can safely unwrap as the background task will never drop the sender - match response_rx.await.unwrap() { - Ok(output) => Ok(output), - Err(err) => Err(InferError::GenerationError(err.to_string())), - } + response_rx.await.unwrap().map_err( + |err| InferError::GenerationError(err.to_string()) + ) } } diff --git a/server/text_generation/models/__init__.py b/server/text_generation/models/__init__.py index f561ad47..b615eb76 100644 --- a/server/text_generation/models/__init__.py +++ b/server/text_generation/models/__init__.py @@ -23,5 +23,5 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model: raise ValueError("sharded is not supported for AutoModel") try: return CausalLM(model_name, quantize=quantize) - except Exception as e: + except Exception: return Seq2SeqLM(model_name, quantize=quantize) diff --git a/server/text_generation/models/bloom.py b/server/text_generation/models/bloom.py index 3561a8ea..1135e565 100644 --- a/server/text_generation/models/bloom.py +++ b/server/text_generation/models/bloom.py @@ -5,7 +5,7 @@ from typing import List, Optional, Type from accelerate import init_empty_weights from safetensors import safe_open -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase from transformers.models.bloom.parallel_layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -34,7 +34,7 @@ torch.manual_seed(0) class BloomCausalLMBatch(CausalLMBatch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device ) -> "CausalLMBatch": batch = super(BloomCausalLMBatch, cls).from_pb( pb=pb, tokenizer=tokenizer, device=device @@ -203,9 +203,7 @@ class BLOOMSharded(BLOOM): def linear(input, weight, bias): size_out = input.size()[:-1] + (out_features,) input = input.view(-1, in_features) - out = torch.empty( - size_out, device=input.device, dtype=input.dtype - ) + out = input.new_empty(size_out) out = bnb.matmul( input, weight, diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index b352eb6b..6bebcc36 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -1,17 +1,17 @@ import torch from dataclasses import dataclass -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText +from text_generation.models.types import GeneratedText, Batch from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria @dataclass -class CausalLMBatch: +class CausalLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] @@ -38,7 +38,7 @@ class CausalLMBatch: # Past metadata keys_head_dim_last: bool = True - def to_pb(self): + def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( id=self.batch_id, requests=self.requests, @@ -47,7 +47,7 @@ class CausalLMBatch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -130,20 +130,14 @@ class CausalLMBatch: # input_ids is always of shape [batch_size, 1] # We do not need to pad it if input_ids is None: - input_ids = torch.empty( - (total_batch_size, 1), - dtype=batch.input_ids.dtype, - device=batch.input_ids.device, - ) + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) # Copy to correct indices input_ids[start_index:end_index] = batch.input_ids # Create padded tensor if attention_mask is None: - attention_mask = torch.zeros( + attention_mask = batch.attention_mask.new_zeros( (total_batch_size, max_sequence_length), - dtype=batch.attention_mask.dtype, - device=batch.attention_mask.device, ) # We need to slice the attention mask to remove padding from previous steps @@ -171,8 +165,8 @@ class CausalLMBatch: if batch.keys_head_dim_last: padded_past_keys_shape = padded_past_values_shape - # seq_length is last for BLOOM else: + # seq_length is last for BLOOM padded_past_keys_shape = ( total_batch_size, num_heads, @@ -182,16 +176,8 @@ class CausalLMBatch: # This will run only once per layer if j == len(past_key_values): - padded_past_keys = torch.zeros( - padded_past_keys_shape, - dtype=past_keys.dtype, - device=past_keys.device, - ) - padded_past_values = torch.zeros( - padded_past_values_shape, - dtype=past_values.dtype, - device=past_values.device, - ) + padded_past_keys = past_keys.new_zeros(padded_past_keys_shape) + padded_past_values = past_values.new_zeros(padded_past_values_shape) past_key_values.append((padded_past_keys, padded_past_values)) # We slice the past keys and values to remove the padding from previous batches diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index a713e69e..76a1b1ab 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -6,7 +6,7 @@ from typing import List, Optional, Type from accelerate import init_empty_weights from safetensors import safe_open -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase from transformers.models.opt.parallel_layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -82,7 +82,7 @@ def escape_custom_split_sequence(text): class GalacticaCausalLMBatch(CausalLMBatch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device ) -> "GalacticaCausalLMBatch": inputs = [] next_token_choosers = [] @@ -278,9 +278,7 @@ class GalacticaSharded(Galactica): def linear(input, weight, bias): size_out = input.size()[:-1] + (out_features,) input = input.view(-1, in_features) - out = torch.empty( - size_out, device=input.device, dtype=input.dtype - ) + out = input.new_empty(size_out) out = bnb.matmul( input, weight, diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index 0331e193..ef6a5682 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -2,7 +2,7 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type -from tokenizers import Tokenizer +from transformers import PreTrainedTokenizerBase from text_generation.models.types import Batch, GeneratedText @@ -10,7 +10,7 @@ B = TypeVar("B", bound=Batch) class Model(ABC): - def __init__(self, tokenizer: Tokenizer, device: torch.device): + def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): self.tokenizer = tokenizer self.device = device diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 4095db92..c561aebe 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -1,17 +1,17 @@ import torch from dataclasses import dataclass -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type from text_generation.models import Model -from text_generation.models.types import GeneratedText +from text_generation.models.types import GeneratedText, Batch from text_generation.pb import generate_pb2 from text_generation.utils import NextTokenChooser, StoppingCriteria @dataclass -class Seq2SeqLMBatch: +class Seq2SeqLMBatch(Batch): batch_id: int requests: List[generate_pb2.Request] @@ -41,7 +41,7 @@ class Seq2SeqLMBatch: max_input_length: int max_decoder_input_length: int - def to_pb(self): + def to_pb(self) -> generate_pb2.Batch: """Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf""" return generate_pb2.Batch( id=self.batch_id, @@ -51,7 +51,7 @@ class Seq2SeqLMBatch: @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device ) -> "Seq2SeqLMBatch": """Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch""" inputs = [] @@ -155,10 +155,8 @@ class Seq2SeqLMBatch: # Create padded tensor if input_ids is None: - input_ids = torch.zeros( + input_ids = batch.input_ids.new_zeros( (total_batch_size, max_input_length), - dtype=batch.input_ids.dtype, - device=batch.input_ids.device, ) # Copy to correct indices input_ids[ @@ -167,10 +165,8 @@ class Seq2SeqLMBatch: # Create padded tensor if attention_mask is None: - attention_mask = torch.zeros( + attention_mask = batch.attention_mask.new_zeros( (total_batch_size, max_input_length), - dtype=batch.attention_mask.dtype, - device=batch.attention_mask.device, ) # Copy to correct indices attention_mask[ @@ -179,10 +175,8 @@ class Seq2SeqLMBatch: # Create padded tensor if decoder_input_ids is None: - decoder_input_ids = torch.zeros( + decoder_input_ids = batch.decoder_input_ids.new_zeros( (total_batch_size, max_decoder_input_length), - dtype=batch.decoder_input_ids.dtype, - device=batch.decoder_input_ids.device, ) # Copy to correct indices decoder_input_ids[ @@ -191,10 +185,9 @@ class Seq2SeqLMBatch: # Create padded tensor if decoder_attention_mask is None: - decoder_attention_mask = torch.zeros( + # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here + decoder_attention_mask = batch.attention_mask.new_zeros( (total_batch_size, max_decoder_input_length), - dtype=batch.attention_mask.dtype, # As decoder_attention_mask might not exist, - device=batch.attention_mask.device, # we use `batch.attention_maks` for device here ) # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated # this batch. All generations are of length `batch.max_decoder_input_length`. @@ -210,14 +203,12 @@ class Seq2SeqLMBatch: # Create padded tensor if encoder_last_hidden_state is None: - encoder_last_hidden_state = torch.zeros( + encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros( ( total_batch_size, max_input_length, batch.encoder_last_hidden_state.shape[-1], ), - dtype=batch.encoder_last_hidden_state.dtype, - device=batch.encoder_last_hidden_state.device, ) # Copy to correct indices @@ -245,9 +236,7 @@ class Seq2SeqLMBatch: # Initialize tensors # This will run only once per layer and per past tensor if k == len(past_key_values[j]): - past_key_values[j].append( - torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) - ) + past_key_values[j].append(t.new_zeros(padded_t_shape)) # We slice the past keys and values to remove the padding from previous batches past_key_values[j][k][ @@ -271,9 +260,7 @@ class Seq2SeqLMBatch: # Initialize tensors # This will run only once per layer and per past tensor if idx == len(past_key_values[j]): - past_key_values[j].append( - torch.zeros(padded_t_shape, dtype=t.dtype, device=t.device) - ) + past_key_values[j].append(t.new_zeros(padded_t_shape)) past_key_values[j][idx][ start_index:end_index, :, -batch.max_input_length :, : diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index e76cf697..fa0dc9a0 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizerBase from text_generation.pb import generate_pb2 @@ -17,7 +17,7 @@ class Batch(ABC): @classmethod @abstractmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, device: torch.device ) -> "Batch": raise NotImplementedError diff --git a/server/text_generation/utils.py b/server/text_generation/utils.py index cc779736..b0b5b072 100644 --- a/server/text_generation/utils.py +++ b/server/text_generation/utils.py @@ -12,7 +12,7 @@ from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache from huggingface_hub.utils import LocalEntryNotFoundError from tqdm import tqdm from typing import List, Optional, Tuple -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizerBase from transformers.generation.logits_process import ( LogitsProcessorList, TemperatureLogitsWarper, @@ -114,7 +114,7 @@ class StoppingCriteria: @classmethod def from_pb( - cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer + cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: PreTrainedTokenizerBase ) -> "StoppingCriteria": stop_sequence_criterias = [ StopSequenceCriteria(sequence) for sequence in pb.stop_sequences