From b0de25a285b88a928badcf331ceb6bb7709744b9 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 2 Apr 2024 12:12:02 +0200 Subject: [PATCH] Don't set rope_scaling for unsupported models (#115) Co-authored-by: Karol Damaszke --- .../models/causal_lm.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 9083b786..c7d82711 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -243,6 +243,8 @@ class CausalLMBatch(Batch): logits = None past = None + keys_head_dim_last: bool = True + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -286,7 +288,7 @@ class CausalLMBatch(Batch): def get_tensor_groups(self): past_keys, past_values = self.detach_kv_cache() seq_dim = -1 - key_dim = -2 # TODO: Add case for Bloom and other models + key_dim = -2 if self.keys_head_dim_last else -1 value_dim = -2 tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] # We don't need to align position_ids @@ -471,7 +473,7 @@ class CausalLMBatch(Batch): if input_len < max_input_length and PAD_SEQUENCE_TO_MULTIPLE_OF != 0: assert PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length, "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length" rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF) - if rounded_seq_len <= max_input_length: + if rounded_seq_len <= max_input_length: bucket_size = rounded_seq_len - 1 else: bucket_size = max_input_length - 1 @@ -592,12 +594,20 @@ class CausalLM(Model): model = self.prepare_model_for_quantization(model) else: get_repo_root(model_id) - rope_scaling = self.get_rope_scaling() + + # Check support for rope scaling + model_kwargs = {} + config = AutoConfig.from_pretrained( + model_id + ) + if hasattr(config, "rope_scaling"): + model_kwargs["rope_scaling"] = self.get_rope_scaling() + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, - rope_scaling=rope_scaling + **model_kwargs ) model = self.prepare_model_for_quantization(model) model = model.eval().to(device) @@ -673,8 +683,7 @@ class CausalLM(Model): world_size, rank, local_rank = initialize_distributed_hpu() model_kwargs = { - "revision": revision, - 'rope_scaling': self.get_rope_scaling() + "revision": revision } # Initialize process(es) for DeepSpeed @@ -685,6 +694,11 @@ class CausalLM(Model): config = AutoConfig.from_pretrained(model_id, **model_kwargs) load_to_meta = model_on_meta(config) + # Check support for rope scaling + if hasattr(config, "rope_scaling"): + config.rope_scaling = self.get_rope_scaling() + model_kwargs["rope_scaling"] = self.get_rope_scaling() + if load_to_meta: # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load with deepspeed.OnDevice(dtype=dtype, device="meta"): @@ -1056,7 +1070,7 @@ class CausalLM(Model): # if decode bs is 1 warmup ends here if len(batches) == 0: return - + # prefill _, prefill_batch = self.generate_token([batches.pop(0)]) # concatenate and decode