mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Fix (flash) Gemma prefix and enable tests
This commit is contained in:
parent
d32e33bd48
commit
9231098f3a
@ -3,7 +3,7 @@ import pytest
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_gemma_handle(launcher):
|
def flash_gemma_handle(launcher):
|
||||||
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
|
with launcher("google/gemma-2b", num_shard=1) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle):
|
|||||||
return flash_gemma_handle.client
|
return flash_gemma_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma(flash_gemma, response_snapshot):
|
async def test_flash_gemma(flash_gemma, response_snapshot):
|
||||||
@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
||||||
@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
|
|||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
|
||||||
|
@ -423,7 +423,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
embed_norm = config.hidden_size**0.5
|
embed_norm = config.hidden_size**0.5
|
||||||
if prefix is None:
|
if not prefix:
|
||||||
prefix = "model"
|
prefix = "model"
|
||||||
else:
|
else:
|
||||||
prefix = f"{prefix}.model"
|
prefix = f"{prefix}.model"
|
||||||
|
@ -57,7 +57,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
# TODO hardcoded
|
# TODO hardcoded
|
||||||
prefix = "language_model"
|
prefix = ""
|
||||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
Loading…
Reference in New Issue
Block a user