From 443a22b75d43d8e7e4fd1bf944accf73dd6dec82 Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 15 May 2023 10:48:52 +0800 Subject: [PATCH] feat: able to test channels now (#59) --- controller/channel.go | 99 +++++++++++++++++++++++++++++ controller/relay.go | 13 ++++ router/api-router.go | 1 + web/src/components/ChannelsTable.js | 21 +++++- 4 files changed, 133 insertions(+), 1 deletion(-) diff --git a/controller/channel.go b/controller/channel.go index 25b5d405..ed41b5ef 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,12 +1,17 @@ package controller import ( + "bytes" + "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + "time" ) func GetAllChannels(c *gin.Context) { @@ -153,3 +158,97 @@ func UpdateChannel(c *gin.Context) { }) return } + +func testChannel(channel *model.Channel, request *ChatRequest) error { + if request.Model == "" { + request.Model = "gpt-3.5-turbo" + if channel.Type == common.ChannelTypeAzure { + request.Model = "gpt-35-turbo" + } + } + requestURL := common.ChannelBaseURLs[channel.Type] + if channel.Type == common.ChannelTypeAzure { + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + } else { + if channel.Type == common.ChannelTypeCustom { + requestURL = channel.BaseURL + } + requestURL += "/v1/chat/completions" + } + + jsonData, err := json.Marshal(request) + if err != nil { + return err + } + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + return err + } + if channel.Type == common.ChannelTypeAzure { + req.Header.Set("api-key", channel.Key) + } else { + req.Header.Set("Authorization", "Bearer "+channel.Key) + } + req.Header.Set("Content-Type", "application/json") + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + var response TextResponse + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return err + } + if response.Error.Type != "" { + return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + } + return nil +} + +func TestChannel(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + model_ := c.Query("model") + chatRequest := &ChatRequest{ + Model: model_, + } + testMessage := Message{ + Role: "user", + Content: "echo hi", + } + chatRequest.Messages = append(chatRequest.Messages, testMessage) + tik := time.Now() + err = testChannel(channel, chatRequest) + tok := time.Now() + consumedTime := float64(tok.Sub(tik).Milliseconds()) / 1000.0 + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + "time": consumedTime, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "time": consumedTime, + }) + return +} diff --git a/controller/relay.go b/controller/relay.go index 81c2dc03..8a670734 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -19,6 +19,11 @@ type Message struct { Content string `json:"content"` } +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` +} + type TextRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -32,8 +37,16 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code string `json:"code"` +} + type TextResponse struct { Usage `json:"usage"` + Error OpenAIError `json:"error"` } type StreamResponse struct { diff --git a/router/api-router.go b/router/api-router.go index 09646b9e..0e249dc9 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/:id", controller.GetChannel) + channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.POST("/", controller.AddChannel) channelRoute.PUT("/", controller.UpdateChannel) channelRoute.DELETE("/:id", controller.DeleteChannel) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index e31d48d9..c04b6d8c 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; +import { API, copy, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; @@ -139,6 +139,16 @@ const ChannelsTable = () => { setSearching(false); }; + const testChannel = async (id, name) => { + const res = await API.get(`/api/channel/test/${id}/`); + const { success, message, time } = res.data; + if (success) { + showInfo(`通道 ${name} 测试成功,耗时 ${time} 秒。`); + } else { + showError(message); + } + } + const handleKeywordChange = async (e, { value }) => { setSearchKeyword(value.trim()); }; @@ -244,6 +254,15 @@ const ChannelsTable = () => { {renderTimestamp(channel.accessed_time)}
+