mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Using both value from config as they might not be correct.
This commit is contained in:
parent
a2d878fa0f
commit
b91f0c02c6
@ -1304,6 +1304,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
self.num_heads = config.num_attention_heads // self.process_group.size()
|
self.num_heads = config.num_attention_heads // self.process_group.size()
|
||||||
|
self.config = config
|
||||||
# Validation is done in the model itself
|
# Validation is done in the model itself
|
||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||||
@ -1594,7 +1595,10 @@ class FlashCausalLM(Model):
|
|||||||
if max_total_tokens is None:
|
if max_total_tokens is None:
|
||||||
if get_support_chunking():
|
if get_support_chunking():
|
||||||
model_max_length = self.tokenizer.model_max_length
|
model_max_length = self.tokenizer.model_max_length
|
||||||
max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length)
|
max_position_embeddings = self.config.max_position_embeddings
|
||||||
|
max_total_tokens = min(
|
||||||
|
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
max_total_tokens = sum(batch.cache_lengths)
|
max_total_tokens = sum(batch.cache_lengths)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user