mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
feat: return reasonable generation and add integration test
This commit is contained in:
parent
dff1b9f795
commit
1fb9d406e7
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "I'keeper services don't have real-time capabilities, however, I can guide you on how to find current weather conditions in Brooklyn, New York.\n\nTo get the most accurate",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1725304474,
|
||||||
|
"id": "",
|
||||||
|
"model": "microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 40,
|
||||||
|
"prompt_tokens": 31,
|
||||||
|
"total_tokens": 71
|
||||||
|
}
|
||||||
|
}
|
42
integration-tests/models/test_flash_phi35_moe.py
Normal file
42
integration-tests/models/test_flash_phi35_moe.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_chat_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"microsoft/Phi-3.5-MoE-instruct",
|
||||||
|
num_shard=4,
|
||||||
|
cuda_graphs=[1, 2]
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_chat(flash_llama_chat_handle):
|
||||||
|
await flash_llama_chat_handle.health(300)
|
||||||
|
return flash_llama_chat_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_simple(flash_llama_chat, response_snapshot):
|
||||||
|
response = await flash_llama_chat.chat(
|
||||||
|
max_tokens=40,
|
||||||
|
seed=1337,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather like in Brooklyn, New York?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(repr(response.choices[0].message.content))
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "I'keeper services don't have real-time capabilities, however, I can guide you on how to find current weather conditions in Brooklyn, New York.\n\nTo get the most accurate"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
@ -46,6 +46,7 @@ from text_generation_server.layers import (
|
|||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -456,23 +457,33 @@ class FlashLlamaLayer(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_moe = config._name_or_path == "microsoft/Phi-3.5-MoE-instruct"
|
if config._name_or_path == "microsoft/Phi-3.5-MoE-instruct":
|
||||||
|
|
||||||
if self.use_moe:
|
|
||||||
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
self.dense = BlockSparseMoE(f"{prefix}.block_sparse_moe", config, weights)
|
||||||
|
# with moe the layernorms are are not rmsnorms and they have bias
|
||||||
|
self.input_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastLayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.dense = LlamaMLP(
|
self.dense = LlamaMLP(
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
|
)
|
||||||
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.input_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = FastRMSNorm.load(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
|
||||||
)
|
|
||||||
self.post_attention_layernorm = FastRMSNorm.load(
|
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
|
||||||
weights=weights,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user