mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-30 11:50:19 +00:00
fix(server): Fixing T5 in case the names are mixed up. (#475)
This commit is contained in:
parent
53aa9194c8
commit
c9c65ab323
@ -1001,7 +1001,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model_dim = config.d_model
|
self.model_dim = config.d_model
|
||||||
|
|
||||||
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
try:
|
||||||
|
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
|
||||||
|
except RuntimeError:
|
||||||
|
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
encoder_config.is_decoder = False
|
encoder_config.is_decoder = False
|
||||||
|
Loading…
Reference in New Issue
Block a user