mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 07:22:07 +00:00
Don't set rope_scaling for unsupported models (#115)
Co-authored-by: Karol Damaszke <kdamaszke@habana.ai>
This commit is contained in:
parent
3e28d7aa42
commit
b0de25a285
@ -243,6 +243,8 @@ class CausalLMBatch(Batch):
|
|||||||
logits = None
|
logits = None
|
||||||
past = None
|
past = None
|
||||||
|
|
||||||
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
@ -286,7 +288,7 @@ class CausalLMBatch(Batch):
|
|||||||
def get_tensor_groups(self):
|
def get_tensor_groups(self):
|
||||||
past_keys, past_values = self.detach_kv_cache()
|
past_keys, past_values = self.detach_kv_cache()
|
||||||
seq_dim = -1
|
seq_dim = -1
|
||||||
key_dim = -2 # TODO: Add case for Bloom and other models
|
key_dim = -2 if self.keys_head_dim_last else -1
|
||||||
value_dim = -2
|
value_dim = -2
|
||||||
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
tensors = [[self.input_ids], [self.attention_mask], [self.position_ids], past_keys, past_values]
|
||||||
# We don't need to align position_ids
|
# We don't need to align position_ids
|
||||||
@ -592,12 +594,20 @@ class CausalLM(Model):
|
|||||||
model = self.prepare_model_for_quantization(model)
|
model = self.prepare_model_for_quantization(model)
|
||||||
else:
|
else:
|
||||||
get_repo_root(model_id)
|
get_repo_root(model_id)
|
||||||
rope_scaling = self.get_rope_scaling()
|
|
||||||
|
# Check support for rope scaling
|
||||||
|
model_kwargs = {}
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_id
|
||||||
|
)
|
||||||
|
if hasattr(config, "rope_scaling"):
|
||||||
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
rope_scaling=rope_scaling
|
**model_kwargs
|
||||||
)
|
)
|
||||||
model = self.prepare_model_for_quantization(model)
|
model = self.prepare_model_for_quantization(model)
|
||||||
model = model.eval().to(device)
|
model = model.eval().to(device)
|
||||||
@ -673,8 +683,7 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
world_size, rank, local_rank = initialize_distributed_hpu()
|
world_size, rank, local_rank = initialize_distributed_hpu()
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision
|
||||||
'rope_scaling': self.get_rope_scaling()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize process(es) for DeepSpeed
|
# Initialize process(es) for DeepSpeed
|
||||||
@ -685,6 +694,11 @@ class CausalLM(Model):
|
|||||||
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||||
load_to_meta = model_on_meta(config)
|
load_to_meta = model_on_meta(config)
|
||||||
|
|
||||||
|
# Check support for rope scaling
|
||||||
|
if hasattr(config, "rope_scaling"):
|
||||||
|
config.rope_scaling = self.get_rope_scaling()
|
||||||
|
model_kwargs["rope_scaling"] = self.get_rope_scaling()
|
||||||
|
|
||||||
if load_to_meta:
|
if load_to_meta:
|
||||||
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
# Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load
|
||||||
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
with deepspeed.OnDevice(dtype=dtype, device="meta"):
|
||||||
|
Loading…
Reference in New Issue
Block a user