mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Don't set rope_scaling for unsupported models (#115)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
3e28d7aa42
commit
b0de25a285
@ -243,6 +243,8 @@ class CausalLMBatch(Batch):
|
||||
logits = None
|
||||
past = None
|
||||
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
@ -286,7 +288,7 @@ class CausalLMBatch(Batch):
|
||||
def get_tensor_groups(self):
|
||||
past_keys, past_values = self.detach_kv_cache()
|
||||
seq_dim = -1
|
||||
key_dim = -2 # TODO: Add case for Bloom and other models
|
||||
key_dim = -2 if self.keys_head_dim_last else -1
|
||||
value_dim = -2
|
||||
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
||||
# We don't need to align position_ids
|
||||
@ -592,12 +594,20 @@ class CausalLM(Model):
|
||||
model = self.prepare_model_for_quantization(model)
|
||||
else:
|
||||
get_repo_root(model_id)
|
||||
rope_scaling = self.get_rope_scaling()
|
||||
|
||||
# Check support for rope scaling
|
||||
model_kwargs = {}
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id
|
||||
)
|
||||
if hasattr(config, "rope_scaling"):
|
||||
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
rope_scaling=rope_scaling
|
||||
**model_kwargs
|
||||
)
|
||||
model = self.prepare_model_for_quantization(model)
|
||||
model = model.eval().to(device)
|
||||
@ -673,8 +683,7 @@ class CausalLM(Model):
|
||||
|
||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||
model_kwargs = {
|
||||
"revision": revision,
|
||||
'rope_scaling': self.get_rope_scaling()
|
||||
"revision": revision
|
||||
}
|
||||
|
||||
# Initialize process(es) for DeepSpeed
|
||||
@ -685,6 +694,11 @@ class CausalLM(Model):
|
||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||
load_to_meta = model_on_meta(config)
|
||||
|
||||
# Check support for rope scaling
|
||||
if hasattr(config, "rope_scaling"):
|
||||
config.rope_scaling = self.get_rope_scaling()
|
||||
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||
|
||||
if load_to_meta:
|
||||
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
||||
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||
|
Loading…
Reference in New Issue
Block a user