feat: support ali's llm (close #326)

This commit is contained in:
JustSong 2023-07-28 23:45:08 +08:00
parent d1b6f492b6
commit e92da7928b
7 changed files with 321 additions and 19 deletions

View File

@ -63,6 +63,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
2. 支持配置镜像以及众多第三方代理服务: 2. 支持配置镜像以及众多第三方代理服务:
+ [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj) + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)

View File

@ -156,24 +156,26 @@ const (
ChannelTypeAnthropic = 14 ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15 ChannelTypeBaidu = 15
ChannelTypeZhipu = 16 ChannelTypeZhipu = 16
ChannelTypeAli = 17
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
"", // 0 "", // 0
"https://api.openai.com", // 1 "https://api.openai.com", // 1
"https://oa.api2d.net", // 2 "https://oa.api2d.net", // 2
"", // 3 "", // 3
"https://api.closeai-proxy.xyz", // 4 "https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5 "https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6 "https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7 "https://api.ohmygpt.com", // 7
"", // 8 "", // 8
"https://api.caipacity.com", // 9 "https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10 "https://api.aiproxy.io", // 10
"", // 11 "", // 11
"https://api.api2gpt.com", // 12 "https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13 "https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14 "https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15 "https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16 "https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
} }

View File

@ -46,6 +46,8 @@ var ModelRatio = map[string]float64{
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
"qwen-plus-v1": 0.5715, // Same as above
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {

View File

@ -324,6 +324,24 @@ func init() {
Root: "chatglm_lite", Root: "chatglm_lite",
Parent: nil, Parent: nil,
}, },
{
Id: "qwen-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-v1",
Parent: nil,
},
{
Id: "qwen-plus-v1",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-plus-v1",
Parent: nil,
},
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {

240
controller/relay-ali.go Normal file
View File

@ -0,0 +1,240 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
type AliMessage struct {
User string `json:"user"`
Bot string `json:"bot"`
}
type AliInput struct {
Prompt string `json:"prompt"`
History []AliMessage `json:"history"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type AliOutput struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type AliChatResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
messages := make([]AliMessage, 0, len(request.Messages))
prompt := ""
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
if message.Role == "system" {
messages = append(messages, AliMessage{
User: message.Content,
Bot: "Okay",
})
continue
} else {
if i == len(request.Messages)-1 {
prompt = message.Content
break
}
messages = append(messages, AliMessage{
User: message.Content,
Bot: request.Messages[i+1].Content,
})
i++
}
}
return &AliChatRequest{
Model: request.Model,
Input: AliInput{
Prompt: prompt,
History: messages,
},
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
// TopP: request.TopP,
// TopK: 50,
// //Seed: 0,
// //EnableSearch: false,
//},
}
}
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []OpenAITextResponseChoice{choice},
Usage: Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
choice.FinishReason = aliResponse.Output.FinishReason
response := ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
usage.PromptTokens += aliResponse.Usage.InputTokens
usage.CompletionTokens += aliResponse.Usage.OutputTokens
usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var aliResponse AliChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@ -20,6 +20,7 @@ const (
APITypePaLM APITypePaLM
APITypeBaidu APITypeBaidu
APITypeZhipu APITypeZhipu
APITypeAli
) )
var httpClient *http.Client var httpClient *http.Client
@ -94,6 +95,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypePaLM apiType = APITypePaLM
case common.ChannelTypeZhipu: case common.ChannelTypeZhipu:
apiType = APITypeZhipu apiType = APITypeZhipu
case common.ChannelTypeAli:
apiType = APITypeAli
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
@ -153,6 +157,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
method = "sse-invoke" method = "sse-invoke"
} }
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
} }
var promptTokens int var promptTokens int
var completionTokens int var completionTokens int
@ -226,6 +232,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
} }
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
case APITypeAli:
aliRequest := requestOpenAI2Ali(textRequest)
jsonStr, err := json.Marshal(aliRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
@ -250,6 +263,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case APITypeZhipu: case APITypeZhipu:
token := getZhipuToken(apiKey) token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token) req.Header.Set("Authorization", token)
case APITypeAli:
req.Header.Set("Authorization", "Bearer "+apiKey)
if textRequest.Stream {
req.Header.Set("X-DashScope-SSE", "enable")
}
} }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@ -280,7 +298,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if strings.HasPrefix(textRequest.Model, "gpt-4") { if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2 completionRatio = 2
} }
if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu { if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu && apiType != APITypeAli {
completionTokens = countTokenText(streamResponseText, textRequest.Model) completionTokens = countTokenText(streamResponseText, textRequest.Model)
} else { } else {
promptTokens = textResponse.Usage.PromptTokens promptTokens = textResponse.Usage.PromptTokens
@ -415,6 +433,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
return nil return nil
} }
case APITypeAli:
if isStream {
err, usage := aliStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := aliHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default: default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
} }

View File

@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 2, text: '代理API2D', value: 2, color: 'blue' }, { key: 2, text: '代理API2D', value: 2, color: 'blue' },
@ -14,5 +15,5 @@ export const CHANNEL_OPTIONS = [
{ key: 6, text: '代理OpenAI Max', value: 6, color: 'violet' }, { key: 6, text: '代理OpenAI Max', value: 6, color: 'violet' },
{ key: 9, text: '代理AI.LS', value: 9, color: 'yellow' }, { key: 9, text: '代理AI.LS', value: 9, color: 'yellow' },
{ key: 12, text: '代理API2GPT', value: 12, color: 'blue' }, { key: 12, text: '代理API2GPT', value: 12, color: 'blue' },
{ key: 13, text: '代理AIGC2D', value: 13, color: 'purple' } { key: 13, text: '代理AIGC2D', value: 13, color: 'purple' },
]; ];