mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
pre-commit
This commit is contained in:
parent
98e790e32a
commit
f5a6691d0e
@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def bloom_560_handle(launcher):
|
||||
with launcher("bigscience/bloom-560m") as handle:
|
||||
yield handle
|
||||
@ -12,6 +12,7 @@ async def bloom_560(bloom_560_handle):
|
||||
await bloom_560_handle.health(240)
|
||||
return bloom_560_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_bloom_560m(bloom_560):
|
||||
@ -20,10 +21,14 @@ async def test_bloom_560m(bloom_560):
|
||||
prompt = "The cat sat on the mat. The cat"
|
||||
|
||||
repeated_2grams_control = await call_model(base_url, prompt, 0)
|
||||
assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case"
|
||||
assert (
|
||||
len(repeated_2grams_control) > 0
|
||||
), "Expected to find repeated bi-grams in control case"
|
||||
|
||||
repeated_2grams_test = await call_model(base_url, prompt, 2)
|
||||
assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"
|
||||
assert (
|
||||
len(repeated_2grams_test) == 0
|
||||
), f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"
|
||||
|
||||
|
||||
async def call_model(base_url, prompt, n_grams):
|
||||
@ -33,14 +38,14 @@ async def call_model(base_url, prompt, n_grams):
|
||||
"max_new_tokens": 20,
|
||||
"seed": 42,
|
||||
"no_repeat_ngram_size": n_grams,
|
||||
"details": True
|
||||
}
|
||||
"details": True,
|
||||
},
|
||||
}
|
||||
res = requests.post(f"{base_url}/generate", json=data)
|
||||
res = res.json()
|
||||
|
||||
tokens = res['details']['tokens']
|
||||
token_texts = [token['text'] for token in tokens]
|
||||
tokens = res["details"]["tokens"]
|
||||
token_texts = [token["text"] for token in tokens]
|
||||
|
||||
# find repeated 2grams
|
||||
ngrams = [tuple(token_texts[i : i + 2]) for i in range(len(token_texts) - 2 + 1)]
|
||||
@ -54,4 +59,3 @@ async def call_model(base_url, prompt, n_grams):
|
||||
repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]
|
||||
|
||||
return repeated
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user