feat: able to test channels now (#59)

This commit is contained in:
JustSong 2023-05-15 10:48:52 +08:00
parent b44f0519a0
commit 443a22b75d
4 changed files with 133 additions and 1 deletions

View File

@ -1,12 +1,17 @@
package controller package controller
import ( import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
"time"
) )
func GetAllChannels(c *gin.Context) { func GetAllChannels(c *gin.Context) {
@ -153,3 +158,97 @@ func UpdateChannel(c *gin.Context) {
}) })
return return
} }
func testChannel(channel *model.Channel, request *ChatRequest) error {
if request.Model == "" {
request.Model = "gpt-3.5-turbo"
if channel.Type == common.ChannelTypeAzure {
request.Model = "gpt-35-turbo"
}
}
requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.Type == common.ChannelTypeCustom {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err
}
if response.Error.Type != "" {
return errors.New(fmt.Sprintf("type %s, code %s, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
}
return nil
}
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
model_ := c.Query("model")
chatRequest := &ChatRequest{
Model: model_,
}
testMessage := Message{
Role: "user",
Content: "echo hi",
}
chatRequest.Messages = append(chatRequest.Messages, testMessage)
tik := time.Now()
err = testChannel(channel, chatRequest)
tok := time.Now()
consumedTime := float64(tok.Sub(tik).Milliseconds()) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
})
return
}

View File

@ -19,6 +19,11 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
}
type TextRequest struct { type TextRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@ -32,8 +37,16 @@ type Usage struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
} }
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code string `json:"code"`
}
type TextResponse struct { type TextResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
Error OpenAIError `json:"error"`
} }
type StreamResponse struct { type StreamResponse struct {

View File

@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.POST("/", controller.AddChannel) channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel) channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel) channelRoute.DELETE("/:id", controller.DeleteChannel)

View File

@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { API, copy, showError, showSuccess, timestamp2string } from '../helpers'; import { API, copy, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
@ -139,6 +139,16 @@ const ChannelsTable = () => {
setSearching(false); setSearching(false);
}; };
const testChannel = async (id, name) => {
const res = await API.get(`/api/channel/test/${id}/`);
const { success, message, time } = res.data;
if (success) {
showInfo(`通道 ${name} 测试成功,耗时 ${time} 秒。`);
} else {
showError(message);
}
}
const handleKeywordChange = async (e, { value }) => { const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim()); setSearchKeyword(value.trim());
}; };
@ -244,6 +254,15 @@ const ChannelsTable = () => {
<Table.Cell>{renderTimestamp(channel.accessed_time)}</Table.Cell> <Table.Cell>{renderTimestamp(channel.accessed_time)}</Table.Cell>
<Table.Cell> <Table.Cell>
<div> <div>
<Button
size={'small'}
positive
onClick={() => {
testChannel(channel.id, channel.name);
}}
>
测试
</Button>
<Popup <Popup
trigger={ trigger={
<Button size='small' negative> <Button size='small' negative>