ai-gateway/providers/openai/base.go

180 lines
4.7 KiB
Go
Raw Permalink Normal View History

package openai
2023-11-28 10:32:26 +00:00
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/types"
"strings"
"one-api/providers/base"
2023-11-28 10:32:26 +00:00
"github.com/gin-gonic/gin"
)
2023-12-02 10:14:48 +00:00
type OpenAIProviderFactory struct{}
// 创建 OpenAIProvider
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
return CreateOpenAIProvider(c, "")
}
2023-11-28 10:32:26 +00:00
type OpenAIProvider struct {
base.BaseProvider
IsAzure bool
2023-11-28 10:32:26 +00:00
}
// 创建 OpenAIProvider
// https://platform.openai.com/docs/api-reference/introduction
2023-11-28 10:32:26 +00:00
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com"
}
return &OpenAIProvider{
BaseProvider: base.BaseProvider{
2023-11-28 10:32:26 +00:00
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
2023-11-28 10:32:26 +00:00
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
2023-12-01 09:20:22 +00:00
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edits",
2023-12-01 09:20:22 +00:00
ImagesVariations: "/v1/images/variations",
2023-11-28 10:32:26 +00:00
Context: c,
},
IsAzure: false,
2023-11-28 10:32:26 +00:00
}
}
// 获取完整请求 URL
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
if p.IsAzure {
2023-11-28 10:32:26 +00:00
apiVersion := p.Context.GetString("api_version")
2023-12-01 09:20:22 +00:00
if modelName == "dall-e-2" {
// 因为dall-e-3需要api-version=2023-12-01-preview但是该版本
// 已经没有dall-e-2了所以暂时写死
requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL)
} else {
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
}
2023-11-28 10:32:26 +00:00
}
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
if p.IsAzure {
2023-11-28 10:32:26 +00:00
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
} else {
requestURL = strings.TrimPrefix(requestURL, "/v1")
}
}
return fmt.Sprintf("%s%s", baseURL, requestURL)
}
// 获取请求头
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
if p.IsAzure {
2023-11-28 10:32:26 +00:00
headers["api-key"] = p.Context.GetString("api_key")
} else {
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key"))
}
return headers
}
// 获取请求体
2023-12-01 09:20:22 +00:00
func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
2023-11-28 10:32:26 +00:00
if isModelMapped {
jsonStr, err := json.Marshal(request)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = p.Context.Request.Body
}
return
}
// 发送流式请求
2023-11-28 10:32:26 +00:00
func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
2023-12-01 02:54:07 +00:00
defer req.Body.Close()
2023-11-28 10:32:26 +00:00
resp, err := common.HttpClient.Do(req)
if err != nil {
2023-12-01 19:28:18 +00:00
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
2023-11-28 10:32:26 +00:00
}
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
2023-11-28 10:32:26 +00:00
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := json.Unmarshal([]byte(data), response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
responseText += response.responseStreamHandler()
2023-11-28 10:32:26 +00:00
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
2023-11-28 10:32:26 +00:00
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
p.Context.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
return nil, responseText
}