feat: return reasonable generation and add integration test

This commit is contained in:
drbh 2024-09-02 19:20:49 +00:00 committed by Daniël de Kok
parent dff1b9f795
commit 1fb9d406e7
3 changed files with 92 additions and 13 deletions

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "I'keeper services don't have real-time capabilities, however, I can guide you on how to find current weather conditions in Brooklyn, New York.\n\nTo get the most accurate",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1725304474,
"id": "",
"model": "microsoft/Phi-3.5-MoE-instruct",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 40,
"prompt_tokens": 31,
"total_tokens": 71
}
}

View File

@ -0,0 +1,42 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_chat_handle(launcher):
with launcher(
"microsoft/Phi-3.5-MoE-instruct",
num_shard=4,
cuda_graphs=[1, 2]
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_chat(flash_llama_chat_handle):
await flash_llama_chat_handle.health(300)
return flash_llama_chat_handle.client
@pytest.mark.private
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
response = await flash_llama_chat.chat(
max_tokens=40,
seed=1337,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
print(repr(response.choices[0].message.content))
assert (
response.choices[0].message.content
== "I'keeper services don't have real-time capabilities, however, I can guide you on how to find current weather conditions in Brooklyn, New York.\n\nTo get the most accurate"
)
assert response == response_snapshot

View File

@ -46,6 +46,7 @@ from text_generation_server.layers import (
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
FastLayerNorm,
)
from text_generation_server.layers import (
FastLinear,
@ -456,23 +457,33 @@ class FlashLlamaLayer(nn.Module):
weights=weights,
)
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
if self.use_moe:
if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct":
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
# with moe the layernorms are are not rmsnorms and they have bias
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
else:
self.dense = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,