Fix the bug

This commit is contained in:
Sadra Barikbin 2024-08-07 23:48:15 +03:30
parent 133015f408
commit 22d9249c4a
2 changed files with 28 additions and 12 deletions

View File

@ -0,0 +1,19 @@
import pytest
@pytest.fixture(scope="module")
def opt_sharded_handle(launcher):
with launcher("facebook/opt-6.7b", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def opt_sharded(opt_sharded_handle):
await opt_sharded_handle.health(300)
return opt_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio
async def test_opt(opt_sharded):
pass

View File

@ -98,7 +98,7 @@ class OPTLearnedPositionalEmbedding(nn.Module):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
weights.get_tensor(f"{prefix and prefix + '.'}decoder.embed_positions.weight")
)
def forward(
@ -315,7 +315,7 @@ class OPTDecoderLayer(nn.Module):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"{prefix}.decoder.layers.{layer_id}"
prefix = f"{prefix and prefix + '.'}decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
@ -437,15 +437,17 @@ class OPTDecoder(OPTPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
prefix = prefix and prefix + '.'
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
prefix=f"{prefix}decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config,
prefix=f"{prefix}.decoder.project_out",
prefix=f"{prefix}decoder.project_out",
weights=weights,
bias=False,
)
@ -455,7 +457,7 @@ class OPTDecoder(OPTPreTrainedModel):
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config,
prefix=f"{prefix}.decoder.project_in",
prefix=f"{prefix}decoder.project_in",
weights=weights,
bias=False,
)
@ -467,7 +469,7 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
@ -752,15 +754,10 @@ class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, prefix, config, weights):
super().__init__(config)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
config, prefix=f"{prefix and prefix + '.'}decoder.embed_tokens", weights=weights
)
def forward(