feat: Handle errors, validate model names, and calculate quota usage (#978)

- Improved error handling in various modules for better stability and responsiveness.
- Optimized code in several files for improved efficiency and readability.
- Enhanced user experience by providing more detailed error responses in the controller.
- Strengthened security by ignoring sensitive files in `.gitignore`.
This commit is contained in:
Laisky.Cai 2024-02-12 21:35:40 +08:00 committed by GitHub
parent 2cd1a78203
commit d548a01c59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 24 additions and 18 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ build
*.db-journal *.db-journal
logs logs
data data
/web/node_modules

View File

@ -15,10 +15,7 @@ type embedFileSystem struct {
func (e embedFileSystem) Exists(prefix string, path string) bool { func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path) _, err := e.Open(path)
if err != nil { return err == nil
return false
}
return true
} }
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {

View File

@ -107,13 +107,13 @@ func Seconds2Time(num int) (time string) {
} }
func Interface2String(inter interface{}) string { func Interface2String(inter interface{}) string {
switch inter.(type) { switch inter := inter.(type) {
case string: case string:
return inter.(string) return inter
case int: case int:
return fmt.Sprintf("%d", inter.(int)) return fmt.Sprintf("%d", inter)
case float64: case float64:
return fmt.Sprintf("%f", inter.(float64)) return fmt.Sprintf("%f", inter)
} }
return "Not Implemented" return "Not Implemented"
} }

View File

@ -68,15 +68,15 @@ func Error(ctx context.Context, msg string) {
} }
func Infof(ctx context.Context, format string, a ...any) { func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a)) Info(ctx, fmt.Sprintf(format, a...))
} }
func Warnf(ctx context.Context, format string, a ...any) { func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a)) Warn(ctx, fmt.Sprintf(format, a...))
} }
func Errorf(ctx context.Context, format string, a ...any) { func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a)) Error(ctx, fmt.Sprintf(format, a...))
} }
func logHelper(ctx context.Context, level string, msg string) { func logHelper(ctx context.Context, level string, msg string) {

View File

@ -22,7 +22,9 @@ func GetSubscription(c *gin.Context) {
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId)
usedQuota, err = model.GetUserUsedQuota(userId) if err != nil {
usedQuota, err = model.GetUserUsedQuota(userId)
}
} }
if expiredTime <= 0 { if expiredTime <= 0 {
expiredTime = 0 expiredTime = 0

View File

@ -90,7 +90,7 @@ func testChannel(channel *model.Channel, request openai.ChatRequest) (err error,
if response.Error.Message == "" { if response.Error.Message == "" {
response.Error.Message = "补全 tokens 非预期返回 0" response.Error.Message = "补全 tokens 非预期返回 0"
} }
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error return fmt.Errorf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message), &response.Error
} }
return nil, nil return nil, nil
} }

View File

@ -8,7 +8,7 @@ import (
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio { for groupName := range common.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -189,5 +189,8 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }

View File

@ -27,7 +27,7 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
for model, _ := range common.ModelRatio { for model := range common.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {

View File

@ -191,6 +191,9 @@ func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode,
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse) _, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
@ -224,7 +227,7 @@ func GetSign(req ChatRequest, secretKey string) string {
messageStr = strings.TrimSuffix(messageStr, ",") messageStr = strings.TrimSuffix(messageStr, ",")
params = append(params, "messages=["+messageStr+"]") params = append(params, "messages=["+messageStr+"]")
sort.Sort(sort.StringSlice(params)) sort.Strings(params)
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
mac := hmac.New(sha1.New, []byte(secretKey)) mac := hmac.New(sha1.New, []byte(secretKey))
signURL := url signURL := url

View File

@ -84,7 +84,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
} }
// Number of generated images validation // Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false { if !isWithinRange(imageModel, imageRequest.N) {
// channel not azure // channel not azure
if channelType != common.ChannelTypeAzure { if channelType != common.ChannelTypeAzure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)