recommender/internal/service/user/suggest.go
2024-11-10 03:49:53 +08:00

71 lines
1.7 KiB
Go

package user
import (
"context"
"fmt"
"github.com/milvus-io/milvus-sdk-go/v2/client"
entity2 "github.com/milvus-io/milvus-sdk-go/v2/entity"
"leafdev.top/Ecosystem/recommender/internal/entity"
)
func (s *Service) SuggestPosts(c context.Context, externalUserEntity *entity.ExternalUser, categoryEntity *entity.Category) ([]*entity.Post, error) {
emb, err := s.embedding.TextEmbedding(c, []string{externalUserEntity.Summary})
if err != nil {
return nil, err
}
var filter = fmt.Sprintf("application_id == %d && category_id == %s", externalUserEntity.ApplicationId, categoryEntity.Id)
sp, err := entity2.NewIndexAUTOINDEXSearchParam(1)
if err != nil {
return nil, err
}
vector := entity2.FloatVector(emb[0])
postResults, err := s.milvus.Search(c, s.config.Milvus.PostCollection,
[]string{},
filter,
[]string{"post_id", "category_id"},
[]entity2.Vector{vector},
"vector",
entity2.L2,
3,
sp, client.WithLimit(7))
var ids []uint
for _, res := range postResults {
// 没找到,直接返回空的
if res.ResultCount == 0 {
return make([]*entity.Post, 0), nil
}
var blockIdColumn *entity2.ColumnInt64
for _, field := range res.Fields {
if field.Name() == "post_id" {
c, ok := field.(*entity2.ColumnInt64)
if ok {
blockIdColumn = c
}
}
}
// 没有记录
if blockIdColumn == nil {
return make([]*entity.Post, 0), nil
//return nil, fmt.Errorf("block_id column not found")
}
for i := 0; i < res.ResultCount; i++ {
id, err := blockIdColumn.ValueByIdx(i)
if err != nil {
return nil, err
}
ids = append(ids, uint(id))
}
}
posts, err := s.dao.Post.Where(s.dao.Post.Where(s.dao.Post.Id.In(ids...))).Find()
return posts, err
}