109 lines
2.5 KiB
Go
109 lines
2.5 KiB
Go
|
package embedding
|
|||
|
|
|||
|
import (
|
|||
|
"context"
|
|||
|
)
|
|||
|
|
|||
|
func (s *Service) TextEmbedding(ctx context.Context, input []string) ([][]float32, error) {
|
|||
|
|
|||
|
var r = make([][]float32, len(input)-1)
|
|||
|
|
|||
|
for _, v := range input {
|
|||
|
embedding2, err := s.OpenAI.CreateEmbedding(ctx, []string{v})
|
|||
|
if err != nil {
|
|||
|
return nil, err
|
|||
|
}
|
|||
|
|
|||
|
r = append(r, embedding2[0])
|
|||
|
|
|||
|
//embedding, err := s.getCache(ctx, v)
|
|||
|
//if err != nil {
|
|||
|
// return r, err
|
|||
|
//}
|
|||
|
//
|
|||
|
//if embedding != nil {
|
|||
|
// r = append(r, embedding)
|
|||
|
// continue
|
|||
|
//} else {
|
|||
|
// embedding2, err := s.OpenAI.CreateEmbedding(ctx, []string{v})
|
|||
|
// if err != nil {
|
|||
|
// return nil, err
|
|||
|
// }
|
|||
|
//
|
|||
|
// r = append(r, embedding2[0])
|
|||
|
//
|
|||
|
// err = s.setCache(ctx, v, embedding2[0])
|
|||
|
// if err != nil {
|
|||
|
// return nil, err
|
|||
|
// }
|
|||
|
//}
|
|||
|
}
|
|||
|
|
|||
|
return r, nil
|
|||
|
}
|
|||
|
|
|||
|
//
|
|||
|
//func (s *Service) getCache(ctx context.Context, input string) ([]float32, error) {
|
|||
|
// md5Str, err := md5.Md5(input)
|
|||
|
// if err != nil {
|
|||
|
// return nil, err
|
|||
|
// }
|
|||
|
//
|
|||
|
// c, err := s.dao.WithContext(ctx).Embedding.Where(s.dao.Embedding.TextMd5.Eq(md5Str)).
|
|||
|
// Where(s.dao.Embedding.EmbeddingModel.Eq(s.config.OpenAI.EmbeddingModel)).
|
|||
|
// Count()
|
|||
|
// if c == 0 {
|
|||
|
// return nil, err
|
|||
|
// }
|
|||
|
//
|
|||
|
// first, err := s.dao.WithContext(ctx).Embedding.Where(s.dao.Embedding.TextMd5.Eq(md5Str)).
|
|||
|
// Where(s.dao.Embedding.EmbeddingModel.Eq(s.config.OpenAI.EmbeddingModel)).
|
|||
|
// First()
|
|||
|
// if err != nil {
|
|||
|
// return nil, err
|
|||
|
// }
|
|||
|
//
|
|||
|
// // byte to float32
|
|||
|
// return first.Vector, nil
|
|||
|
//}
|
|||
|
//
|
|||
|
//func (s *Service) setCache(ctx context.Context, input string, embedding []float32) error {
|
|||
|
// md5Str, err := md5.Md5(input)
|
|||
|
// if err != nil {
|
|||
|
// return err
|
|||
|
// }
|
|||
|
//
|
|||
|
// // redis 锁
|
|||
|
// var key = "lock_" + md5Str
|
|||
|
// lock, err := s.redis.Locker.Obtain(ctx, key, 3*time.Second, nil)
|
|||
|
// if errors.Is(err, redislock.ErrNotObtained) {
|
|||
|
// s.Logger.Sugar.Warnf("redis lock %s not obtained", md5Str)
|
|||
|
// } else if err != nil {
|
|||
|
// return err
|
|||
|
// }
|
|||
|
// defer func(lock *redislock.Lock, ctx context.Context) {
|
|||
|
// err := lock.Release(ctx)
|
|||
|
// if err != nil {
|
|||
|
// s.Logger.Sugar.Error(err)
|
|||
|
// }
|
|||
|
// }(lock, ctx)
|
|||
|
//
|
|||
|
// // 如果没有 cache,则设置
|
|||
|
// c, err := s.dao.WithContext(ctx).Embedding.Where(s.dao.Embedding.TextMd5.Eq(md5Str)).
|
|||
|
// Where(s.dao.Embedding.EmbeddingModel.Eq(s.config.OpenAI.EmbeddingModel)).
|
|||
|
// Count()
|
|||
|
// if err != nil {
|
|||
|
// return err
|
|||
|
// }
|
|||
|
// if c == 0 {
|
|||
|
// return s.dao.WithContext(ctx).Embedding.Create(&entity.Embedding{
|
|||
|
// Text: input,
|
|||
|
// TextMd5: md5Str,
|
|||
|
// Vector: embedding,
|
|||
|
// EmbeddingModel: s.config.OpenAI.EmbeddingModel,
|
|||
|
// })
|
|||
|
// }
|
|||
|
//
|
|||
|
// return nil
|
|||
|
//}
|