Don't set rope_scaling for unsupported models (#115)

Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
Karol Damaszke 2024-04-02 12:12:02 +02:00 committed by GitHub
parent 3e28d7aa42
commit b0de25a285
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -243,6 +243,8 @@ class CausalLMBatch(Batch):
logits = None logits = None
past = None past = None
keys_head_dim_last: bool = True
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
@ -286,7 +288,7 @@ class CausalLMBatch(Batch):
def get_tensor_groups(self): def get_tensor_groups(self):
past_keys, past_values = self.detach_kv_cache() past_keys, past_values = self.detach_kv_cache()
seq_dim = -1 seq_dim = -1
key_dim = -2 # TODO: Add case for Bloom and other models key_dim = -2 if self.keys_head_dim_last else -1
value_dim = -2 value_dim = -2
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values] tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
# We don't need to align position_ids # We don't need to align position_ids
@ -592,12 +594,20 @@ class CausalLM(Model):
model = self.prepare_model_for_quantization(model) model = self.prepare_model_for_quantization(model)
else: else:
get_repo_root(model_id) get_repo_root(model_id)
rope_scaling = self.get_rope_scaling()
# Check support for rope scaling
model_kwargs = {}
config = AutoConfig.from_pretrained(
model_id
)
if hasattr(config, "rope_scaling"):
model_kwargs["rope_scaling"] = self.get_rope_scaling()
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
rope_scaling=rope_scaling **model_kwargs
) )
model = self.prepare_model_for_quantization(model) model = self.prepare_model_for_quantization(model)
model = model.eval().to(device) model = model.eval().to(device)
@ -673,8 +683,7 @@ class CausalLM(Model):
world_size, rank, local_rank = initialize_distributed_hpu() world_size, rank, local_rank = initialize_distributed_hpu()
model_kwargs = { model_kwargs = {
"revision": revision, "revision": revision
'rope_scaling': self.get_rope_scaling()
} }
# Initialize process(es) for DeepSpeed # Initialize process(es) for DeepSpeed
@ -685,6 +694,11 @@ class CausalLM(Model):
config = AutoConfig.from_pretrained(model_id, **model_kwargs) config = AutoConfig.from_pretrained(model_id, **model_kwargs)
load_to_meta = model_on_meta(config) load_to_meta = model_on_meta(config)
# Check support for rope scaling
if hasattr(config, "rope_scaling"):
config.rope_scaling = self.get_rope_scaling()
model_kwargs["rope_scaling"] = self.get_rope_scaling()
if load_to_meta: if load_to_meta:
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
with deepspeed.OnDevice(dtype=dtype, device="meta"): with deepspeed.OnDevice(dtype=dtype, device="meta"):