diff --git a/controller/channel-test.go b/controller/channel-test.go index 908ab669..ec865b23 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -67,8 +67,8 @@ func testChannel(channel *model.Channel, request *ChatRequest) error { func buildTestRequest(c *gin.Context) *ChatRequest { model_ := c.Query("model") testRequest := &ChatRequest{ - Model: model_, - //MaxTokens: 1, + Model: model_, + MaxTokens: 1, } testMessage := Message{ Role: "user", diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a4b75432..ae739ca0 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -51,7 +51,7 @@ func relayPaLM(openAIRequest GeneralOpenAIRequest, c *gin.Context) *OpenAIErrorW Temperature: openAIRequest.Temperature, CandidateCount: openAIRequest.N, TopP: openAIRequest.TopP, - //TopK: openAIRequest.MaxTokens, + TopK: openAIRequest.MaxTokens, } // TODO: forward request to PaLM & convert response fmt.Print(request) diff --git a/controller/relay.go b/controller/relay.go index f25e6bea..6e248d67 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -30,28 +30,28 @@ const ( // https://platform.openai.com/docs/api-reference/chat type GeneralOpenAIRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt any `json:"prompt"` - Stream bool `json:"stream"` - //MaxTokens int `json:"max_tokens"` - Temperature float64 `json:"temperature"` - TopP float64 `json:"top_p"` - N int `json:"n"` - Input any `json:"input"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt any `json:"prompt"` + Stream bool `json:"stream"` + MaxTokens int `json:"max_tokens"` + Temperature float64 `json:"temperature"` + TopP float64 `json:"top_p"` + N int `json:"n"` + Input any `json:"input"` } type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - //MaxTokens int `json:"max_tokens"` + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` } type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - //MaxTokens int `json:"max_tokens"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` //Stream bool `json:"stream"` } @@ -183,7 +183,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { err := relayPaLM(textRequest, c) return err } else { - // 强制使用0613模型 + // 强制使用 0613 模型 textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0301") textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0314") textRequest.Model = strings.TrimSuffix(textRequest.Model, "-0613") @@ -199,9 +199,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { promptTokens = countTokenInput(textRequest.Input, textRequest.Model) } preConsumedTokens := common.PreConsumedQuota - //if textRequest.MaxTokens != 0 { - // preConsumedTokens = promptTokens + textRequest.MaxTokens - //} + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + textRequest.MaxTokens + } modelRatio := common.GetModelRatio(textRequest.Model) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio