Using both value from config as they might not be correct.

This commit is contained in:
Nicolas Patry 2024-12-10 10:53:33 +01:00
parent a2d878fa0f
commit b91f0c02c6
No known key found for this signature in database
GPG Key ID: D2920555C90F704C

View File

@ -1304,6 +1304,7 @@ class FlashCausalLM(Model):
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size() self.num_heads = config.num_attention_heads // self.process_group.size()
self.config = config
# Validation is done in the model itself # Validation is done in the model itself
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None) num_kv_heads = getattr(config, "num_key_value_heads", None)
@ -1594,7 +1595,10 @@ class FlashCausalLM(Model):
if max_total_tokens is None: if max_total_tokens is None:
if get_support_chunking(): if get_support_chunking():
model_max_length = self.tokenizer.model_max_length model_max_length = self.tokenizer.model_max_length
max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length) max_position_embeddings = self.config.max_position_embeddings
max_total_tokens = min(
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
)
else: else:
max_total_tokens = sum(batch.cache_lengths) max_total_tokens = sum(batch.cache_lengths)