mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix: setting the rotary base from the config for the grouped query models.
This commit is contained in:
parent
46ada47963
commit
c41573c67c
@ -481,6 +481,7 @@ class FlashRWLayer(nn.Module):
|
|||||||
|
|
||||||
return mlp_output, residual
|
return mlp_output, residual
|
||||||
|
|
||||||
|
|
||||||
class FlashRWLayerNorm(nn.Module):
|
class FlashRWLayerNorm(nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -602,7 +603,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"transformer.h.{layer_id}"
|
prefix = f"transformer.h.{layer_id}"
|
||||||
|
|
||||||
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||||
|
|
||||||
self.self_attention = FlashRWLargeAttention(
|
self.self_attention = FlashRWLargeAttention(
|
||||||
|
Loading…
Reference in New Issue
Block a user