mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-16 14:22:08 +00:00
fix: read vocab size from tokenizer and add hacky patch for qwen2b
This commit is contained in:
parent
55d82d4654
commit
b32cd97b71
@ -13,7 +13,7 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1745337456,
|
"created": 1746486174,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
"usage": null
|
"usage": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1745337878,
|
"created": 1746486174,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
"logprobs": null
|
"logprobs": null
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"created": 1745337495,
|
"created": 1746486174,
|
||||||
"id": "",
|
"id": "",
|
||||||
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
"model": "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
@ -1267,6 +1267,15 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
prefix = None
|
prefix = None
|
||||||
model = model_class(prefix, config, weights)
|
model = model_class(prefix, config, weights)
|
||||||
|
|
||||||
|
if model.config.vocab_size != tokenizer.vocab_size:
|
||||||
|
logger.warning(
|
||||||
|
f"Tokenizer vocab size {tokenizer.vocab_size} does not match model vocab size {model.config.vocab_size}. Updating tokenizer vocab size."
|
||||||
|
)
|
||||||
|
# TODO: HUGE HACK! This is a workaround for the fact that Qwen2TokenizerFast
|
||||||
|
# returns the incorrect vocab size for the 2B model.
|
||||||
|
tokenizer._vocab_size = model.config.vocab_size
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
# VLM models define the config we care about in their text_config
|
# VLM models define the config we care about in their text_config
|
||||||
|
@ -641,7 +641,8 @@ class LogitBiasProcessor(LogitsProcessor):
|
|||||||
):
|
):
|
||||||
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
||||||
|
|
||||||
vocab_size = len(tokenizer)
|
# use _vocab_size or fallback to tokenizer.vocab_size if not available
|
||||||
|
self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size)
|
||||||
|
|
||||||
# Convert keys to integers and values to a list
|
# Convert keys to integers and values to a list
|
||||||
token_ids = torch.tensor(
|
token_ids = torch.tensor(
|
||||||
@ -650,7 +651,7 @@ class LogitBiasProcessor(LogitsProcessor):
|
|||||||
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
|
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
|
||||||
|
|
||||||
# Create a tensor and directly copy bias values at the corresponding indices
|
# Create a tensor and directly copy bias values at the corresponding indices
|
||||||
self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float)
|
self.bias_tensor = torch.zeros(self.vocab_size, dtype=torch.float)
|
||||||
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
|
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
@ -669,10 +670,13 @@ class HeterogeneousLogitBiasProcessor(LogitsProcessor):
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
|
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.logit_biases = logit_biases
|
self.logit_biases = logit_biases
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
self.vocab_size = len(tokenizer)
|
# use _vocab_size or fallback to tokenizer.vocab_size if not available
|
||||||
|
self.vocab_size = getattr(tokenizer, "_vocab_size", tokenizer.vocab_size)
|
||||||
|
|
||||||
# Create batch_size x vocab_size bias matrix
|
# Create batch_size x vocab_size bias matrix
|
||||||
self.bias_matrix = torch.zeros(
|
self.bias_matrix = torch.zeros(
|
||||||
|
Loading…
Reference in New Issue
Block a user