pre-commit

This commit is contained in:
erikkaum 2024-08-01 13:37:45 +02:00
parent 98e790e32a
commit f5a6691d0e

View File

@ -1,7 +1,7 @@
from typing import List
import pytest import pytest
import requests import requests
def bloom_560_handle(launcher): def bloom_560_handle(launcher):
with launcher("bigscience/bloom-560m") as handle: with launcher("bigscience/bloom-560m") as handle:
yield handle yield handle
@ -12,6 +12,7 @@ async def bloom_560(bloom_560_handle):
await bloom_560_handle.health(240) await bloom_560_handle.health(240)
return bloom_560_handle.client return bloom_560_handle.client
@pytest.mark.release @pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m(bloom_560): async def test_bloom_560m(bloom_560):
@ -20,30 +21,34 @@ async def test_bloom_560m(bloom_560):
prompt = "The cat sat on the mat. The cat" prompt = "The cat sat on the mat. The cat"
repeated_2grams_control = await call_model(base_url, prompt, 0) repeated_2grams_control = await call_model(base_url, prompt, 0)
assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case" assert (
len(repeated_2grams_control) > 0
), "Expected to find repeated bi-grams in control case"
repeated_2grams_test = await call_model(base_url, prompt, 2) repeated_2grams_test = await call_model(base_url, prompt, 2)
assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}" assert (
len(repeated_2grams_test) == 0
), f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"
async def call_model(base_url, prompt, n_grams): async def call_model(base_url, prompt, n_grams):
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters" : { "parameters": {
"max_new_tokens": 20, "max_new_tokens": 20,
"seed": 42, "seed": 42,
"no_repeat_ngram_size": n_grams, "no_repeat_ngram_size": n_grams,
"details": True "details": True,
} },
} }
res = requests.post(f"{base_url}/generate", json=data) res = requests.post(f"{base_url}/generate", json=data)
res = res.json() res = res.json()
tokens = res['details']['tokens'] tokens = res["details"]["tokens"]
token_texts = [token['text'] for token in tokens] token_texts = [token["text"] for token in tokens]
# find repeated 2grams # find repeated 2grams
ngrams = [tuple(token_texts[i:i+2]) for i in range(len(token_texts)-2+1)] ngrams = [tuple(token_texts[i : i + 2]) for i in range(len(token_texts) - 2 + 1)]
ngram_counts = {} ngram_counts = {}
for ngram in ngrams: for ngram in ngrams:
if ngram in ngram_counts: if ngram in ngram_counts:
@ -54,4 +59,3 @@ async def call_model(base_url, prompt, n_grams):
repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1] repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]
return repeated return repeated