2023-11-29 08:07:09 +00:00
|
|
|
|
package openai
|
2023-11-28 10:32:26 +00:00
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bufio"
|
|
|
|
|
"bytes"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"net/http"
|
|
|
|
|
"one-api/common"
|
|
|
|
|
"one-api/types"
|
|
|
|
|
"strings"
|
|
|
|
|
|
2023-11-29 08:07:09 +00:00
|
|
|
|
"one-api/providers/base"
|
|
|
|
|
|
2023-11-28 10:32:26 +00:00
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
)
|
|
|
|
|
|
2023-12-02 10:14:48 +00:00
|
|
|
|
type OpenAIProviderFactory struct{}
|
|
|
|
|
|
|
|
|
|
// 创建 OpenAIProvider
|
|
|
|
|
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
|
|
|
|
return CreateOpenAIProvider(c, "")
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-28 10:32:26 +00:00
|
|
|
|
type OpenAIProvider struct {
|
2023-11-29 08:07:09 +00:00
|
|
|
|
base.BaseProvider
|
|
|
|
|
IsAzure bool
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 创建 OpenAIProvider
|
2023-11-29 08:07:09 +00:00
|
|
|
|
// https://platform.openai.com/docs/api-reference/introduction
|
2023-11-28 10:32:26 +00:00
|
|
|
|
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
|
|
|
|
if baseURL == "" {
|
|
|
|
|
baseURL = "https://api.openai.com"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return &OpenAIProvider{
|
2023-11-29 08:07:09 +00:00
|
|
|
|
BaseProvider: base.BaseProvider{
|
2023-11-28 10:32:26 +00:00
|
|
|
|
BaseURL: baseURL,
|
|
|
|
|
Completions: "/v1/completions",
|
|
|
|
|
ChatCompletions: "/v1/chat/completions",
|
|
|
|
|
Embeddings: "/v1/embeddings",
|
2023-11-29 08:54:37 +00:00
|
|
|
|
Moderation: "/v1/moderations",
|
2023-11-28 10:32:26 +00:00
|
|
|
|
AudioSpeech: "/v1/audio/speech",
|
|
|
|
|
AudioTranscriptions: "/v1/audio/transcriptions",
|
|
|
|
|
AudioTranslations: "/v1/audio/translations",
|
2023-12-01 09:20:22 +00:00
|
|
|
|
ImagesGenerations: "/v1/images/generations",
|
2023-12-01 10:25:05 +00:00
|
|
|
|
ImagesEdit: "/v1/images/edits",
|
2023-12-01 09:20:22 +00:00
|
|
|
|
ImagesVariations: "/v1/images/variations",
|
2023-11-28 10:32:26 +00:00
|
|
|
|
Context: c,
|
|
|
|
|
},
|
2023-11-29 08:07:09 +00:00
|
|
|
|
IsAzure: false,
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取完整请求 URL
|
|
|
|
|
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
|
|
|
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
|
|
|
|
|
2023-11-29 08:07:09 +00:00
|
|
|
|
if p.IsAzure {
|
2023-12-26 08:40:50 +00:00
|
|
|
|
apiVersion := p.Channel.Other
|
2023-12-01 09:20:22 +00:00
|
|
|
|
if modelName == "dall-e-2" {
|
|
|
|
|
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
|
|
|
|
// 已经没有dall-e-2了,所以暂时写死
|
|
|
|
|
requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL)
|
|
|
|
|
} else {
|
|
|
|
|
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
|
|
|
|
}
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
2023-11-29 08:07:09 +00:00
|
|
|
|
if p.IsAzure {
|
2023-11-28 10:32:26 +00:00
|
|
|
|
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
|
|
|
|
|
} else {
|
|
|
|
|
requestURL = strings.TrimPrefix(requestURL, "/v1")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取请求头
|
|
|
|
|
func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
|
|
|
|
headers = make(map[string]string)
|
2023-11-29 08:07:09 +00:00
|
|
|
|
p.CommonRequestHeaders(headers)
|
|
|
|
|
if p.IsAzure {
|
2023-12-26 08:40:50 +00:00
|
|
|
|
headers["api-key"] = p.Channel.Key
|
2023-11-28 10:32:26 +00:00
|
|
|
|
} else {
|
2023-12-26 08:40:50 +00:00
|
|
|
|
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Channel.Key)
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return headers
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取请求体
|
2023-12-01 09:20:22 +00:00
|
|
|
|
func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
|
2023-11-28 10:32:26 +00:00
|
|
|
|
if isModelMapped {
|
|
|
|
|
jsonStr, err := json.Marshal(request)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
requestBody = bytes.NewBuffer(jsonStr)
|
|
|
|
|
} else {
|
|
|
|
|
requestBody = p.Context.Request.Body
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
2023-11-29 08:07:09 +00:00
|
|
|
|
// 发送流式请求
|
2023-12-29 07:23:05 +00:00
|
|
|
|
func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
2023-12-01 02:54:07 +00:00
|
|
|
|
defer req.Body.Close()
|
2023-11-28 10:32:26 +00:00
|
|
|
|
|
2023-12-26 10:42:39 +00:00
|
|
|
|
client := common.GetHttpClient(p.Channel.Proxy)
|
|
|
|
|
resp, err := client.Do(req)
|
2023-11-28 10:32:26 +00:00
|
|
|
|
if err != nil {
|
2023-12-01 19:28:18 +00:00
|
|
|
|
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
2023-12-26 10:42:39 +00:00
|
|
|
|
common.PutHttpClient(client)
|
2023-11-28 10:32:26 +00:00
|
|
|
|
|
|
|
|
|
if common.IsFailureStatusCode(resp) {
|
2023-11-30 05:49:35 +00:00
|
|
|
|
return common.HandleErrorResp(resp), ""
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
|
|
if atEOF && len(data) == 0 {
|
|
|
|
|
return 0, nil, nil
|
|
|
|
|
}
|
|
|
|
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
|
|
|
return i + 1, data[0:i], nil
|
|
|
|
|
}
|
|
|
|
|
if atEOF {
|
|
|
|
|
return len(data), data, nil
|
|
|
|
|
}
|
|
|
|
|
return 0, nil, nil
|
|
|
|
|
})
|
|
|
|
|
dataChan := make(chan string)
|
|
|
|
|
stopChan := make(chan bool)
|
|
|
|
|
go func() {
|
|
|
|
|
for scanner.Scan() {
|
|
|
|
|
data := scanner.Text()
|
|
|
|
|
if len(data) < 6 { // ignore blank line or wrong format
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
dataChan <- data
|
|
|
|
|
data = data[6:]
|
|
|
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
|
|
|
err := json.Unmarshal([]byte(data), response)
|
|
|
|
|
if err != nil {
|
|
|
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
|
|
continue // just ignore the error
|
|
|
|
|
}
|
2023-11-29 08:07:09 +00:00
|
|
|
|
responseText += response.responseStreamHandler()
|
2023-11-28 10:32:26 +00:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
stopChan <- true
|
|
|
|
|
}()
|
2023-11-29 08:07:09 +00:00
|
|
|
|
common.SetEventStreamHeaders(p.Context)
|
2023-11-28 10:32:26 +00:00
|
|
|
|
p.Context.Stream(func(w io.Writer) bool {
|
|
|
|
|
select {
|
|
|
|
|
case data := <-dataChan:
|
|
|
|
|
if strings.HasPrefix(data, "data: [DONE]") {
|
|
|
|
|
data = data[:12]
|
|
|
|
|
}
|
|
|
|
|
// some implementations may add \r at the end of data
|
|
|
|
|
data = strings.TrimSuffix(data, "\r")
|
|
|
|
|
p.Context.Render(-1, common.CustomEvent{Data: data})
|
|
|
|
|
return true
|
|
|
|
|
case <-stopChan:
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return nil, responseText
|
|
|
|
|
}
|