diff --git a/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_llama_simple.json b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_llama_simple.json new file mode 100644 index 00000000..41e5f02d --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_llama_simple.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "I'keeper services don't have real-time capabilities, however, I can guide you on how to find current weather conditions in Brooklyn, New York.\n\nTo get the most accurate", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1725304474, + "id": "", + "model": "microsoft/Phi-3.5-MoE-instruct", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", + "usage": { + "completion_tokens": 40, + "prompt_tokens": 31, + "total_tokens": 71 + } +} diff --git a/integration-tests/models/test_flash_phi35_moe.py b/integration-tests/models/test_flash_phi35_moe.py new file mode 100644 index 00000000..e3a9eff3 --- /dev/null +++ b/integration-tests/models/test_flash_phi35_moe.py @@ -0,0 +1,42 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_chat_handle(launcher): + with launcher( + "microsoft/Phi-3.5-MoE-instruct", + num_shard=4, + cuda_graphs=[1, 2] + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_chat(flash_llama_chat_handle): + await flash_llama_chat_handle.health(300) + return flash_llama_chat_handle.client + + +@pytest.mark.private +async def test_flash_llama_simple(flash_llama_chat, response_snapshot): + response = await flash_llama_chat.chat( + max_tokens=40, + seed=1337, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + print(repr(response.choices[0].message.content)) + assert ( + response.choices[0].message.content + == "I'keeper services don't have real-time capabilities, however, I can guide you on how to find current weather conditions in Brooklyn, New York.\n\nTo get the most accurate" + ) + assert response == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d6eb8080..522d9b43 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -46,6 +46,7 @@ from text_generation_server.layers import ( from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, + FastLayerNorm, ) from text_generation_server.layers import ( FastLinear, @@ -456,23 +457,33 @@ class FlashLlamaLayer(nn.Module): weights=weights, ) - self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct" - - if self.use_moe: + if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct": self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights) + # with moe the layernorms are are not rmsnorms and they have bias + self.input_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastLayerNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) else: self.dense = LlamaMLP( - prefix=f"{prefix}.mlp", config=config, weights=weights + prefix=f"{prefix}.mlp", config=config, weights=weights, index=index + ) + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, ) - - self.input_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = FastRMSNorm.load( - prefix=f"{prefix}.post_attention_layernorm", - weights=weights, - eps=config.rms_norm_eps, - ) def forward( self,