diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json index c1cd24cd..3b683b43 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json @@ -2,7 +2,7 @@ "details": { "best_of_sequences": null, "finish_reason": "eos_token", - "generated_tokens": 5, + "generated_tokens": 7, "prefill": [ { "id": 0, @@ -10,29 +10,41 @@ "text": "" } ], - "seed": 0, + "seed": 1, "tokens": [ { - "id": 926, - "logprob": -4.3554688, + "id": 609, + "logprob": -4.1875, "special": false, - "text": " To" + "text": " it" }, { - "id": 18295, - "logprob": -7.7734375, + "id": 259, + "logprob": -1.9609375, "special": false, - "text": " sell" + "text": " " }, { - "id": 7868, - "logprob": -3.9257812, + "id": 277, + "logprob": -3.15625, "special": false, - "text": " things" + "text": "'" + }, + { + "id": 263, + "logprob": 0.0, + "special": false, + "text": "s" + }, + { + "id": 16017, + "logprob": -2.0, + "special": false, + "text": " blue" }, { "id": 260, - "logprob": -2.4179688, + "logprob": -1.9638672, "special": false, "text": "." }, @@ -42,7 +54,8 @@ "special": true, "text": "" } - ] + ], + "top_tokens": null }, - "generated_text": "To sell things." + "generated_text": "it's blue." } diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json index 024823d0..dd090c2b 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json @@ -10,11 +10,11 @@ "text": "" } ], - "seed": 0, + "seed": 1, "tokens": [ { "id": 16017, - "logprob": -0.30908203, + "logprob": -0.3112793, "special": false, "text": " blue" }, @@ -24,39 +24,39 @@ "special": false, "text": " sky" }, + { + "id": 305, + "logprob": -1.8779297, + "special": false, + "text": " and" + }, { "id": 259, - "logprob": -0.28271484, + "logprob": -0.91503906, "special": false, "text": " " }, { - "id": 15484, - "logprob": -1.7929688, + "id": 262, + "logprob": -1.1533203, "special": false, - "text": "appear" + "text": "a" }, { - "id": 345, - "logprob": -0.8935547, + "id": 35622, + "logprob": -0.47705078, "special": false, - "text": "ed" + "text": " cloud" }, { - "id": 281, + "id": 276, "logprob": 0.0, "special": false, - "text": " in" - }, - { - "id": 287, - "logprob": 0.0, - "special": false, - "text": " the" + "text": "y" }, { "id": 20495, - "logprob": -0.32299805, + "logprob": 0.0, "special": false, "text": " sky" }, @@ -66,7 +66,8 @@ "special": true, "text": "" } - ] + ], + "top_tokens": null }, - "generated_text": "Why is the sky blue?blue sky appeared in the sky" + "generated_text": "Why is the sky blue?blue sky and a cloudy sky" } diff --git a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json index c0834ae1..5887b193 100644 --- a/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json +++ b/integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json @@ -49,7 +49,8 @@ "special": true, "text": "" } - ] + ], + "top_tokens": null }, "generated_text": "Because it is blue" }, @@ -103,7 +104,8 @@ "special": true, "text": "" } - ] + ], + "top_tokens": null }, "generated_text": "Because it is blue" }, @@ -157,7 +159,8 @@ "special": true, "text": "" } - ] + ], + "top_tokens": null }, "generated_text": "Because it is blue" }, @@ -211,7 +214,8 @@ "special": true, "text": "" } - ] + ], + "top_tokens": null }, "generated_text": "Because it is blue" } diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index d167b813..a13ec3bf 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle): return flash_gemma_handle.client -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma(flash_gemma, response_snapshot): @@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot): assert response == response_snapshot -@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_all_params(flash_gemma, response_snapshot): @@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): assert response == response_snapshot -# @pytest.mark.skip @pytest.mark.asyncio @pytest.mark.private async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 777a55ba..203379fa 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -550,7 +550,7 @@ class Seq2SeqLM(Model): revision=revision, torch_dtype=dtype, device_map=( - "auto" + device if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None ),