From 3e1c44839b7bd4c428e289f744e7e5606ce99630 Mon Sep 17 00:00:00 2001 From: ivamp Date: Sun, 10 Nov 2024 03:49:53 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E5=9F=BA=E6=9C=AC=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/create-collection.go | 218 ++++++++++++++++++ cmd/schedule.go | 23 +- cmd/wire.go | 2 + cmd/wire_gen.go | 18 +- cmd/worker.go | 5 + configs/config.yaml | 6 +- docs/docs.go | 97 ++++++-- docs/swagger.json | 97 ++++++-- docs/swagger.yaml | 78 +++++-- internal/base/app.go | 4 + internal/base/conf/conf.go | 14 +- internal/base/milvus/provide.go | 6 +- internal/dao/post_tags.gen.go | 9 - internal/dao/posts.gen.go | 6 +- internal/dao/tag_mappings.gen.go | 9 - internal/dao/tags.gen.go | 94 +------- internal/dao/user_tag_scores.gen.go | 9 - internal/entity/Post.go | 1 + internal/entity/Tag.go | 9 +- .../http/controller/application_v1/posts.go | 14 +- .../http/controller/application_v1/users.go | 96 +++++++- internal/handler/http/request/applications.go | 5 + internal/migrations/1_setup.sql | 15 +- internal/router/api.go | 1 + internal/service/category/categories.go | 9 + internal/service/embedding/provider.go | 10 +- internal/service/post/posts.go | 20 +- internal/service/post/provider.go | 20 +- internal/service/post/tags.go | 87 ++++++- internal/service/provider.go | 5 + internal/service/user/external.go | 14 ++ internal/service/user/posts.go | 14 +- internal/service/user/provider.go | 12 + internal/service/user/suggest.go | 70 ++++++ internal/service/user/summary.go | 59 +++++ internal/service/user/tags.go | 22 ++ pkg/consts/category.go | 8 + pkg/consts/posts.go | 5 +- 38 files changed, 945 insertions(+), 246 deletions(-) create mode 100644 cmd/create-collection.go create mode 100644 internal/service/user/suggest.go create mode 100644 internal/service/user/summary.go create mode 100644 pkg/consts/category.go diff --git a/cmd/create-collection.go b/cmd/create-collection.go new file mode 100644 index 0000000..2b16c95 --- /dev/null +++ b/cmd/create-collection.go @@ -0,0 +1,218 @@ +package cmd + +import ( + "context" + "github.com/milvus-io/milvus-sdk-go/v2/entity" + "github.com/spf13/cobra" + "leafdev.top/Ecosystem/recommender/internal/base" +) + +func init() { + RootCmd.AddCommand(createCollectionCmd) + createCollectionCmd.Flags().String("dim", "768", "模型的维度") + // 将 dim 参数标记为必填 + err := createCollectionCmd.MarkFlagRequired("dim") + if err != nil { + panic(err) + } +} + +var createCollectionCmd = &cobra.Command{ + Use: "create-collection", + Run: func(cmd *cobra.Command, args []string) { + app, err := CreateApp() + if err != nil { + panic(err) + return + } + // 获取 flag + dim, err := cmd.Flags().GetString("dim") + if err != nil { + panic(err) + } + + createMilvusTagCollection(app, dim) + createMilvusUserSummaryCollection(app, dim) + createMilvusPostCollection(app, dim) + }, +} + +func createMilvusTagCollection(app *base.Application, dim string) { + var ctx = context.Background() + var field = []*entity.Field{ + { + Name: "tag_id", + PrimaryKey: true, + AutoID: false, + DataType: entity.FieldTypeInt64, + }, + { + Name: "vector", + PrimaryKey: false, + DataType: entity.FieldTypeFloatVector, + TypeParams: map[string]string{ + "dim": dim, + }, + }, + } + + var schema = &entity.Schema{ + CollectionName: app.Config.Milvus.TagCollection, + Description: "", + AutoID: true, + Fields: field, + EnableDynamicField: true, + } + + err := app.Milvus.CreateCollection(ctx, schema, 2) + if err != nil { + panic(err) + } + + index := entity.NewGenericIndex("idx_tag_id", entity.Inverted, map[string]string{}) + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.TagCollection, "tag_id", index, false) + if err != nil { + panic(err) + } + + index, err = entity.NewIndexAUTOINDEX(entity.L2) + if err != nil { + panic(err) + } + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.TagCollection, "vector", index, false) + if err != nil { + panic(err) + } +} + +func createMilvusUserSummaryCollection(app *base.Application, dim string) { + var ctx = context.Background() + var field = []*entity.Field{ + { + Name: "external_user_id", + PrimaryKey: true, + AutoID: false, + DataType: entity.FieldTypeInt64, + }, + { + Name: "application_id", + PrimaryKey: false, + AutoID: false, + DataType: entity.FieldTypeInt64, + }, + { + Name: "vector", + PrimaryKey: false, + DataType: entity.FieldTypeFloatVector, + TypeParams: map[string]string{ + "dim": dim, + }, + }, + } + + var schema = &entity.Schema{ + CollectionName: app.Config.Milvus.UserSummaryCollection, + Description: "", + AutoID: true, + Fields: field, + EnableDynamicField: true, + } + + err := app.Milvus.CreateCollection(ctx, schema, 2) + if err != nil { + panic(err) + } + + index := entity.NewGenericIndex("idx_external_user_id", entity.Inverted, map[string]string{}) + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.UserSummaryCollection, "external_user_id", index, false) + if err != nil { + panic(err) + } + + index = entity.NewGenericIndex("idx_application_id", entity.Inverted, map[string]string{}) + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.UserSummaryCollection, "application_id", index, false) + if err != nil { + panic(err) + } + + index, err = entity.NewIndexAUTOINDEX(entity.L2) + if err != nil { + panic(err) + } + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.UserSummaryCollection, "vector", index, false) + if err != nil { + panic(err) + } +} + +func createMilvusPostCollection(app *base.Application, dim string) { + var ctx = context.Background() + var field = []*entity.Field{ + { + Name: "post_id", + PrimaryKey: true, + AutoID: false, + DataType: entity.FieldTypeInt64, + }, + { + Name: "category_id", + PrimaryKey: false, + AutoID: false, + DataType: entity.FieldTypeInt64, + }, + { + Name: "application_id", + PrimaryKey: false, + AutoID: false, + DataType: entity.FieldTypeInt64, + }, + { + Name: "vector", + PrimaryKey: false, + DataType: entity.FieldTypeFloatVector, + TypeParams: map[string]string{ + "dim": dim, + }, + }, + } + + var schema = &entity.Schema{ + CollectionName: app.Config.Milvus.PostCollection, + Description: "", + AutoID: true, + Fields: field, + EnableDynamicField: true, + } + + err := app.Milvus.CreateCollection(ctx, schema, 2) + if err != nil { + panic(err) + } + + index := entity.NewGenericIndex("idx_post_id", entity.Inverted, map[string]string{}) + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.PostCollection, "post_id", index, false) + if err != nil { + panic(err) + } + + index = entity.NewGenericIndex("idx_category_id", entity.Inverted, map[string]string{}) + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.PostCollection, "category_id", index, false) + if err != nil { + panic(err) + } + + index = entity.NewGenericIndex("idx_application_id", entity.Inverted, map[string]string{}) + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.PostCollection, "application_id", index, false) + if err != nil { + panic(err) + } + + index, err = entity.NewIndexAUTOINDEX(entity.L2) + if err != nil { + panic(err) + } + err = app.Milvus.CreateIndex(ctx, app.Config.Milvus.PostCollection, "vector", index, false) + if err != nil { + panic(err) + } +} diff --git a/cmd/schedule.go b/cmd/schedule.go index ea9f51c..a8ea004 100644 --- a/cmd/schedule.go +++ b/cmd/schedule.go @@ -1,7 +1,9 @@ package cmd import ( + "context" "leafdev.top/Ecosystem/recommender/internal/base" + "sync" "github.com/spf13/cobra" ) @@ -25,15 +27,22 @@ var scheduleCmd = &cobra.Command{ } func runSchedule(app *base.Application) { - // var wg sync.WaitGroup - // var ctx = context.Background() + var wg sync.WaitGroup - // wg.Add(1) - // // 启动一个定时器 - // go func() { + var ctx = context.Background() - // }() + wg.Add(1) + // 启动一个定时器 + go func() { - // wg.Wait() + // defer cancel() + + // run embedding + + ctx.Done() + defer wg.Done() + }() + + wg.Wait() } diff --git a/cmd/wire.go b/cmd/wire.go index 24f27fa..b29679f 100644 --- a/cmd/wire.go +++ b/cmd/wire.go @@ -7,6 +7,7 @@ import ( "leafdev.top/Ecosystem/recommender/internal/base" "leafdev.top/Ecosystem/recommender/internal/base/conf" "leafdev.top/Ecosystem/recommender/internal/base/logger" + "leafdev.top/Ecosystem/recommender/internal/base/milvus" "leafdev.top/Ecosystem/recommender/internal/base/orm" "leafdev.top/Ecosystem/recommender/internal/base/redis" "leafdev.top/Ecosystem/recommender/internal/base/s3" @@ -28,6 +29,7 @@ var ProviderSet = wire.NewSet( redis.NewRedis, s3.NewS3, batch.NewBatch, + milvus.NewService, service.Provider, handler.ProviderSet, router.ProviderSetRouter, diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index b41b9e0..0069697 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -11,6 +11,7 @@ import ( "leafdev.top/Ecosystem/recommender/internal/base" "leafdev.top/Ecosystem/recommender/internal/base/conf" "leafdev.top/Ecosystem/recommender/internal/base/logger" + "leafdev.top/Ecosystem/recommender/internal/base/milvus" "leafdev.top/Ecosystem/recommender/internal/base/orm" "leafdev.top/Ecosystem/recommender/internal/base/redis" "leafdev.top/Ecosystem/recommender/internal/base/s3" @@ -30,6 +31,7 @@ import ( "leafdev.top/Ecosystem/recommender/internal/service/application" "leafdev.top/Ecosystem/recommender/internal/service/auth" "leafdev.top/Ecosystem/recommender/internal/service/category" + "leafdev.top/Ecosystem/recommender/internal/service/embedding" "leafdev.top/Ecosystem/recommender/internal/service/jwks" "leafdev.top/Ecosystem/recommender/internal/service/post" "leafdev.top/Ecosystem/recommender/internal/service/stream" @@ -49,13 +51,15 @@ func CreateApp() (*base.Application, error) { applicationController := v1.NewApplicationController(authService, applicationService) application_v1ApplicationController := application_v1.NewApplicationController(authService, applicationService) streamService := stream.NewService(config) - postService := post.NewService(query, config, streamService) + client := milvus.NewService(config, loggerLogger) + redisRedis := redis.NewRedis(config) + embeddingService := embedding.NewService(config, loggerLogger, query, redisRedis) + postService := post.NewService(query, config, streamService, client, embeddingService) categoryService := category.NewService(query) postController := application_v1.NewPostController(authService, applicationService, postService, categoryService) categoryController := application_v1.NewCategoryController(authService, applicationService, postService, categoryService) - userService := user.NewService(query, postService, loggerLogger) - redisRedis := redis.NewRedis(config) - userController := application_v1.NewUserController(authService, applicationService, userService, postService, loggerLogger, redisRedis) + userService := user.NewService(query, postService, loggerLogger, client, embeddingService, config) + userController := application_v1.NewUserController(authService, applicationService, userService, postService, loggerLogger, redisRedis, categoryService) handlers := http.NewHandler(applicationController, application_v1ApplicationController, postController, categoryController, userController) api := router.NewApiRoute(handlers) swaggerRouter := router.NewSwaggerRoute() @@ -71,13 +75,13 @@ func CreateApp() (*base.Application, error) { grpcInterceptor := grpc.NewInterceptor(interceptorAuth, interceptorLogger) grpcHandlers := grpc.NewHandler(documentService, grpcInterceptor) handlerHandler := handler.NewHandler(grpcHandlers, handlers) - serviceService := service.NewService(loggerLogger, jwksJWKS, streamService, authService, applicationService, postService, categoryService, userService) + serviceService := service.NewService(loggerLogger, jwksJWKS, streamService, authService, applicationService, postService, categoryService, userService, embeddingService) batchBatch := batch.NewBatch(loggerLogger) s3S3 := s3.NewS3(config) - baseApplication := base.NewApplication(config, httpServer, handlerHandler, loggerLogger, serviceService, httpMiddleware, redisRedis, batchBatch, s3S3, db, query) + baseApplication := base.NewApplication(config, httpServer, handlerHandler, loggerLogger, serviceService, httpMiddleware, redisRedis, batchBatch, s3S3, db, query, client) return baseApplication, nil } // wire.go: -var ProviderSet = wire.NewSet(conf.ProviderConfig, logger.NewZapLogger, orm.NewGORM, dao.NewQuery, redis.NewRedis, s3.NewS3, batch.NewBatch, service.Provider, handler.ProviderSet, router.ProviderSetRouter, server.NewHTTPServer, base.NewApplication) +var ProviderSet = wire.NewSet(conf.ProviderConfig, logger.NewZapLogger, orm.NewGORM, dao.NewQuery, redis.NewRedis, s3.NewS3, batch.NewBatch, milvus.NewService, service.Provider, handler.ProviderSet, router.ProviderSetRouter, server.NewHTTPServer, base.NewApplication) diff --git a/cmd/worker.go b/cmd/worker.go index 1880510..4a6ca2b 100644 --- a/cmd/worker.go +++ b/cmd/worker.go @@ -81,6 +81,11 @@ func runWorker(app *base.Application) { } if processError == nil { + err = app.Service.Post.SavePostEmbedding(ctx, postEntity) + if err != nil { + app.Logger.Sugar.Error(err) + } + err = app.Service.Post.MarkAsProcessed(ctx, postEntity) if err != nil { app.Logger.Sugar.Error(err) diff --git a/configs/config.yaml b/configs/config.yaml index cb1935e..0eebb4c 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -44,10 +44,12 @@ s3: milvus: host: 127.0.0.1 port: 19530 - db_name: library + db_name: recommender # 由于 Milvus 不支持新增列, 如果更换了 Embedding Model,建议新建一个 Collection # 或者可以扩展张量 - document_collection: documents + user_summary_collection: user_summaries + tag_collection: tags + post_collection: posts user: password: diff --git a/docs/docs.go b/docs/docs.go index 259706d..5fa1cd7 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -718,34 +718,16 @@ const docTemplate = `{ "summary": "Dislike", "parameters": [ { - "description": "UserLikePost", - "name": "UserLikePost", + "description": "UserDislikePost", + "name": "UserDislikePost", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/request.UserLikePost" + "$ref": "#/definitions/request.UserDislikePost" } } ], "responses": { - "200": { - "description": "OK", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/response.ResponseBody" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/entity.Category" - } - } - } - ] - } - }, "400": { "description": "Bad Request", "schema": { @@ -784,6 +766,45 @@ const docTemplate = `{ } } ], + "responses": { + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/response.ResponseBody" + } + } + } + } + }, + "/applications/v1/users/_suggest": { + "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "推荐资源", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "application_api" + ], + "summary": "Suggest", + "parameters": [ + { + "description": "UserSuggestsRequest", + "name": "UserSuggestsRequest", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/request.UserSuggestsRequest" + } + } + ], "responses": { "200": { "description": "OK", @@ -796,7 +817,10 @@ const docTemplate = `{ "type": "object", "properties": { "data": { - "$ref": "#/definitions/entity.Category" + "type": "array", + "items": { + "$ref": "#/definitions/entity.Post" + } } } } @@ -909,6 +933,9 @@ const docTemplate = `{ }, "updated_at": { "type": "string" + }, + "vectorized": { + "type": "boolean" } } }, @@ -1014,6 +1041,17 @@ const docTemplate = `{ } } }, + "request.UserDislikePost": { + "type": "object", + "properties": { + "external_user_id": { + "type": "string" + }, + "post_id": { + "type": "integer" + } + } + }, "request.UserLikePost": { "type": "object", "properties": { @@ -1025,6 +1063,21 @@ const docTemplate = `{ } } }, + "request.UserSuggestsRequest": { + "type": "object", + "required": [ + "category_id", + "external_user_id" + ], + "properties": { + "category_id": { + "type": "integer" + }, + "external_user_id": { + "type": "string" + } + } + }, "response.ResponseBody": { "type": "object", "properties": { diff --git a/docs/swagger.json b/docs/swagger.json index 71f0963..6abffb4 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -709,34 +709,16 @@ "summary": "Dislike", "parameters": [ { - "description": "UserLikePost", - "name": "UserLikePost", + "description": "UserDislikePost", + "name": "UserDislikePost", "in": "body", "required": true, "schema": { - "$ref": "#/definitions/request.UserLikePost" + "$ref": "#/definitions/request.UserDislikePost" } } ], "responses": { - "200": { - "description": "OK", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/response.ResponseBody" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/entity.Category" - } - } - } - ] - } - }, "400": { "description": "Bad Request", "schema": { @@ -775,6 +757,45 @@ } } ], + "responses": { + "400": { + "description": "Bad Request", + "schema": { + "$ref": "#/definitions/response.ResponseBody" + } + } + } + } + }, + "/applications/v1/users/_suggest": { + "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "推荐资源", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "application_api" + ], + "summary": "Suggest", + "parameters": [ + { + "description": "UserSuggestsRequest", + "name": "UserSuggestsRequest", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/request.UserSuggestsRequest" + } + } + ], "responses": { "200": { "description": "OK", @@ -787,7 +808,10 @@ "type": "object", "properties": { "data": { - "$ref": "#/definitions/entity.Category" + "type": "array", + "items": { + "$ref": "#/definitions/entity.Post" + } } } } @@ -900,6 +924,9 @@ }, "updated_at": { "type": "string" + }, + "vectorized": { + "type": "boolean" } } }, @@ -1005,6 +1032,17 @@ } } }, + "request.UserDislikePost": { + "type": "object", + "properties": { + "external_user_id": { + "type": "string" + }, + "post_id": { + "type": "integer" + } + } + }, "request.UserLikePost": { "type": "object", "properties": { @@ -1016,6 +1054,21 @@ } } }, + "request.UserSuggestsRequest": { + "type": "object", + "required": [ + "category_id", + "external_user_id" + ], + "properties": { + "category_id": { + "type": "integer" + }, + "external_user_id": { + "type": "string" + } + } + }, "response.ResponseBody": { "type": "object", "properties": { diff --git a/docs/swagger.yaml b/docs/swagger.yaml index 0c30b05..58f58c7 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -62,6 +62,8 @@ definitions: type: string updated_at: type: string + vectorized: + type: boolean type: object page.PagedResult-entity_Category: properties: @@ -133,6 +135,13 @@ definitions: - target_id - title type: object + request.UserDislikePost: + properties: + external_user_id: + type: string + post_id: + type: integer + type: object request.UserLikePost: properties: external_user_id: @@ -140,6 +149,16 @@ definitions: post_id: type: integer type: object + request.UserSuggestsRequest: + properties: + category_id: + type: integer + external_user_id: + type: string + required: + - category_id + - external_user_id + type: object response.ResponseBody: properties: data: {} @@ -554,24 +573,15 @@ paths: - application/json description: 从用户的标签喜好中移除内容 parameters: - - description: UserLikePost + - description: UserDislikePost in: body - name: UserLikePost + name: UserDislikePost required: true schema: - $ref: '#/definitions/request.UserLikePost' + $ref: '#/definitions/request.UserDislikePost' produces: - application/json responses: - "200": - description: OK - schema: - allOf: - - $ref: '#/definitions/response.ResponseBody' - - properties: - data: - $ref: '#/definitions/entity.Category' - type: object "400": description: Bad Request schema: @@ -596,15 +606,6 @@ paths: produces: - application/json responses: - "200": - description: OK - schema: - allOf: - - $ref: '#/definitions/response.ResponseBody' - - properties: - data: - $ref: '#/definitions/entity.Category' - type: object "400": description: Bad Request schema: @@ -614,6 +615,41 @@ paths: summary: Like tags: - application_api + /applications/v1/users/_suggest: + post: + consumes: + - application/json + description: 推荐资源 + parameters: + - description: UserSuggestsRequest + in: body + name: UserSuggestsRequest + required: true + schema: + $ref: '#/definitions/request.UserSuggestsRequest' + produces: + - application/json + responses: + "200": + description: OK + schema: + allOf: + - $ref: '#/definitions/response.ResponseBody' + - properties: + data: + items: + $ref: '#/definitions/entity.Post' + type: array + type: object + "400": + description: Bad Request + schema: + $ref: '#/definitions/response.ResponseBody' + security: + - ApiKeyAuth: [] + summary: Suggest + tags: + - application_api securityDefinitions: ApiKeyAuth: in: header diff --git a/internal/base/app.go b/internal/base/app.go index 795117e..5580cb6 100644 --- a/internal/base/app.go +++ b/internal/base/app.go @@ -1,6 +1,7 @@ package base import ( + "github.com/milvus-io/milvus-sdk-go/v2/client" "gorm.io/gorm" "leafdev.top/Ecosystem/recommender/internal/base/conf" "leafdev.top/Ecosystem/recommender/internal/base/logger" @@ -26,6 +27,7 @@ type Application struct { Redis *redis.Redis Batch *batch.Batch S3 *s3.S3 + Milvus client.Client } func NewApplication( @@ -40,6 +42,7 @@ func NewApplication( S3 *s3.S3, GORM *gorm.DB, DAO *dao.Query, + Milvus client.Client, ) *Application { return &Application{ Config: config, @@ -53,5 +56,6 @@ func NewApplication( S3: S3, GORM: GORM, DAO: DAO, + Milvus: Milvus, } } diff --git a/internal/base/conf/conf.go b/internal/base/conf/conf.go index 9e069c1..2052192 100644 --- a/internal/base/conf/conf.go +++ b/internal/base/conf/conf.go @@ -100,10 +100,12 @@ type OpenAI struct { } type Milvus struct { - Host string `yaml:"host" mapstructure:"host"` - Port int `yaml:"port" mapstructure:"port"` - DBName string `yaml:"db_name" mapstructure:"db_name"` - DocumentCollection string `yaml:"document_collection" mapstructure:"document_collection"` - User string `yaml:"user" mapstructure:"user"` - Password string `yaml:"password" mapstructure:"password"` + Host string `yaml:"host" mapstructure:"host"` + Port int `yaml:"port" mapstructure:"port"` + DBName string `yaml:"db_name" mapstructure:"db_name"` + User string `yaml:"user" mapstructure:"user"` + Password string `yaml:"password" mapstructure:"password"` + UserSummaryCollection string `yaml:"user_summary_collection" mapstructure:"user_summary_collection"` + TagCollection string `yaml:"tag_collection" mapstructure:"tag_collection"` + PostCollection string `yaml:"post_collection" mapstructure:"post_collection"` } diff --git a/internal/base/milvus/provide.go b/internal/base/milvus/provide.go index 2d62e96..c42a7b9 100644 --- a/internal/base/milvus/provide.go +++ b/internal/base/milvus/provide.go @@ -3,13 +3,13 @@ package milvus import ( "context" "github.com/milvus-io/milvus-sdk-go/v2/client" - "leafdev.top/Leaf/leaf-library/internal/base/conf" - "leafdev.top/Leaf/leaf-library/internal/base/logger" + "leafdev.top/Ecosystem/recommender/internal/base/conf" + "leafdev.top/Ecosystem/recommender/internal/base/logger" "strconv" ) -func NewMilvus(config *conf.Config, logger *logger.Logger) client.Client { +func NewService(config *conf.Config, logger *logger.Logger) client.Client { var address = config.Milvus.Host + ":" + strconv.Itoa(config.Milvus.Port) logger.Sugar.Infof("Waiting for milvus, address=%s, dbname=%s", address, config.Milvus.DBName) diff --git a/internal/dao/post_tags.gen.go b/internal/dao/post_tags.gen.go index ca917c5..102d240 100644 --- a/internal/dao/post_tags.gen.go +++ b/internal/dao/post_tags.gen.go @@ -58,11 +58,6 @@ func newPostTag(db *gorm.DB, opts ...gen.DOOption) postTag { db: db.Session(&gorm.Session{}), RelationField: field.NewRelation("Tag", "entity.Tag"), - Application: struct { - field.RelationField - }{ - RelationField: field.NewRelation("Tag.Application", "entity.Application"), - }, } _postTag.fillFieldMap() @@ -217,10 +212,6 @@ type postTagBelongsToTag struct { db *gorm.DB field.RelationField - - Application struct { - field.RelationField - } } func (a postTagBelongsToTag) Where(conds ...field.Expr) *postTagBelongsToTag { diff --git a/internal/dao/posts.gen.go b/internal/dao/posts.gen.go index 72120d2..324900e 100644 --- a/internal/dao/posts.gen.go +++ b/internal/dao/posts.gen.go @@ -36,6 +36,7 @@ func newPost(db *gorm.DB, opts ...gen.DOOption) post { _post.ApplicationId = field.NewUint(tableName, "application_id") _post.CategoryId = field.NewUint(tableName, "category_id") _post.Processed = field.NewBool(tableName, "processed") + _post.Vectorized = field.NewBool(tableName, "vectorized") _post.Application = postBelongsToApplication{ db: db.Session(&gorm.Session{}), @@ -71,6 +72,7 @@ type post struct { ApplicationId field.Uint CategoryId field.Uint Processed field.Bool + Vectorized field.Bool Application postBelongsToApplication Category postBelongsToCategory @@ -99,6 +101,7 @@ func (p *post) updateTableName(table string) *post { p.ApplicationId = field.NewUint(table, "application_id") p.CategoryId = field.NewUint(table, "category_id") p.Processed = field.NewBool(table, "processed") + p.Vectorized = field.NewBool(table, "vectorized") p.fillFieldMap() @@ -115,7 +118,7 @@ func (p *post) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (p *post) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 11) + p.fieldMap = make(map[string]field.Expr, 12) p.fieldMap["id"] = p.Id p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["updated_at"] = p.UpdatedAt @@ -125,6 +128,7 @@ func (p *post) fillFieldMap() { p.fieldMap["application_id"] = p.ApplicationId p.fieldMap["category_id"] = p.CategoryId p.fieldMap["processed"] = p.Processed + p.fieldMap["vectorized"] = p.Vectorized } diff --git a/internal/dao/tag_mappings.gen.go b/internal/dao/tag_mappings.gen.go index b0c8a1b..dd9f8d1 100644 --- a/internal/dao/tag_mappings.gen.go +++ b/internal/dao/tag_mappings.gen.go @@ -35,11 +35,6 @@ func newTagMapping(db *gorm.DB, opts ...gen.DOOption) tagMapping { db: db.Session(&gorm.Session{}), RelationField: field.NewRelation("Tag", "entity.Tag"), - Application: struct { - field.RelationField - }{ - RelationField: field.NewRelation("Tag.Application", "entity.Application"), - }, } _tagMapping.Application = tagMappingBelongsToApplication{ @@ -122,10 +117,6 @@ type tagMappingBelongsToTag struct { db *gorm.DB field.RelationField - - Application struct { - field.RelationField - } } func (a tagMappingBelongsToTag) Where(conds ...field.Expr) *tagMappingBelongsToTag { diff --git a/internal/dao/tags.gen.go b/internal/dao/tags.gen.go index 855f8e1..1ba6eca 100644 --- a/internal/dao/tags.gen.go +++ b/internal/dao/tags.gen.go @@ -29,12 +29,7 @@ func newTag(db *gorm.DB, opts ...gen.DOOption) tag { _tag.ALL = field.NewAsterisk(tableName) _tag.Id = field.NewUint(tableName, "id") _tag.Name = field.NewString(tableName, "name") - _tag.ApplicationId = field.NewUint(tableName, "application_id") - _tag.Application = tagBelongsToApplication{ - db: db.Session(&gorm.Session{}), - - RelationField: field.NewRelation("Application", "entity.Application"), - } + _tag.Vectorized = field.NewBool(tableName, "vectorized") _tag.fillFieldMap() @@ -44,11 +39,10 @@ func newTag(db *gorm.DB, opts ...gen.DOOption) tag { type tag struct { tagDo - ALL field.Asterisk - Id field.Uint - Name field.String - ApplicationId field.Uint - Application tagBelongsToApplication + ALL field.Asterisk + Id field.Uint + Name field.String + Vectorized field.Bool fieldMap map[string]field.Expr } @@ -67,7 +61,7 @@ func (t *tag) updateTableName(table string) *tag { t.ALL = field.NewAsterisk(table) t.Id = field.NewUint(table, "id") t.Name = field.NewString(table, "name") - t.ApplicationId = field.NewUint(table, "application_id") + t.Vectorized = field.NewBool(table, "vectorized") t.fillFieldMap() @@ -84,11 +78,10 @@ func (t *tag) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (t *tag) fillFieldMap() { - t.fieldMap = make(map[string]field.Expr, 4) + t.fieldMap = make(map[string]field.Expr, 3) t.fieldMap["id"] = t.Id t.fieldMap["name"] = t.Name - t.fieldMap["application_id"] = t.ApplicationId - + t.fieldMap["vectorized"] = t.Vectorized } func (t tag) clone(db *gorm.DB) tag { @@ -101,77 +94,6 @@ func (t tag) replaceDB(db *gorm.DB) tag { return t } -type tagBelongsToApplication struct { - db *gorm.DB - - field.RelationField -} - -func (a tagBelongsToApplication) Where(conds ...field.Expr) *tagBelongsToApplication { - if len(conds) == 0 { - return &a - } - - exprs := make([]clause.Expression, 0, len(conds)) - for _, cond := range conds { - exprs = append(exprs, cond.BeCond().(clause.Expression)) - } - a.db = a.db.Clauses(clause.Where{Exprs: exprs}) - return &a -} - -func (a tagBelongsToApplication) WithContext(ctx context.Context) *tagBelongsToApplication { - a.db = a.db.WithContext(ctx) - return &a -} - -func (a tagBelongsToApplication) Session(session *gorm.Session) *tagBelongsToApplication { - a.db = a.db.Session(session) - return &a -} - -func (a tagBelongsToApplication) Model(m *entity.Tag) *tagBelongsToApplicationTx { - return &tagBelongsToApplicationTx{a.db.Model(m).Association(a.Name())} -} - -type tagBelongsToApplicationTx struct{ tx *gorm.Association } - -func (a tagBelongsToApplicationTx) Find() (result *entity.Application, err error) { - return result, a.tx.Find(&result) -} - -func (a tagBelongsToApplicationTx) Append(values ...*entity.Application) (err error) { - targetValues := make([]interface{}, len(values)) - for i, v := range values { - targetValues[i] = v - } - return a.tx.Append(targetValues...) -} - -func (a tagBelongsToApplicationTx) Replace(values ...*entity.Application) (err error) { - targetValues := make([]interface{}, len(values)) - for i, v := range values { - targetValues[i] = v - } - return a.tx.Replace(targetValues...) -} - -func (a tagBelongsToApplicationTx) Delete(values ...*entity.Application) (err error) { - targetValues := make([]interface{}, len(values)) - for i, v := range values { - targetValues[i] = v - } - return a.tx.Delete(targetValues...) -} - -func (a tagBelongsToApplicationTx) Clear() error { - return a.tx.Clear() -} - -func (a tagBelongsToApplicationTx) Count() int64 { - return a.tx.Count() -} - type tagDo struct{ gen.DO } type ITagDo interface { diff --git a/internal/dao/user_tag_scores.gen.go b/internal/dao/user_tag_scores.gen.go index 0c3d33b..b0c5434 100644 --- a/internal/dao/user_tag_scores.gen.go +++ b/internal/dao/user_tag_scores.gen.go @@ -35,11 +35,6 @@ func newUserTagScore(db *gorm.DB, opts ...gen.DOOption) userTagScore { db: db.Session(&gorm.Session{}), RelationField: field.NewRelation("Tag", "entity.Tag"), - Application: struct { - field.RelationField - }{ - RelationField: field.NewRelation("Tag.Application", "entity.Application"), - }, } _userTagScore.Application = userTagScoreBelongsToApplication{ @@ -122,10 +117,6 @@ type userTagScoreBelongsToTag struct { db *gorm.DB field.RelationField - - Application struct { - field.RelationField - } } func (a userTagScoreBelongsToTag) Where(conds ...field.Expr) *userTagScoreBelongsToTag { diff --git a/internal/entity/Post.go b/internal/entity/Post.go index 28f6433..ec7186d 100644 --- a/internal/entity/Post.go +++ b/internal/entity/Post.go @@ -13,6 +13,7 @@ type Post struct { Category *Category `json:"category"` CategoryId *schema.EntityId `json:"category_id"` Processed bool `json:"processed"` + Vectorized bool `json:"vectorized"` } func (u *Post) TableName() string { diff --git a/internal/entity/Tag.go b/internal/entity/Tag.go index b6dcfeb..831ccd2 100644 --- a/internal/entity/Tag.go +++ b/internal/entity/Tag.go @@ -3,10 +3,11 @@ package entity import "leafdev.top/Ecosystem/recommender/internal/schema" type Tag struct { - Id schema.EntityId `gorm:"primarykey" json:"id"` - Name string `json:"name"` - Application *Application - ApplicationId schema.EntityId `json:"application_id"` + Id schema.EntityId `gorm:"primarykey" json:"id"` + Name string `json:"name"` + Vectorized bool `json:"vectorized"` + //Application *Application + //ApplicationId schema.EntityId `json:"application_id"` } func (u *Tag) TableName() string { diff --git a/internal/handler/http/controller/application_v1/posts.go b/internal/handler/http/controller/application_v1/posts.go index 1aa4590..efadfa5 100644 --- a/internal/handler/http/controller/application_v1/posts.go +++ b/internal/handler/http/controller/application_v1/posts.go @@ -10,6 +10,7 @@ import ( "leafdev.top/Ecosystem/recommender/internal/service/auth" "leafdev.top/Ecosystem/recommender/internal/service/category" "leafdev.top/Ecosystem/recommender/internal/service/post" + "leafdev.top/Ecosystem/recommender/pkg/consts" "net/http" ) @@ -94,6 +95,17 @@ func (pc *PostController) Save(c *gin.Context) { return } + exists, err := pc.postService.TargetIdExists(c, app, postSaveRequest.TargetId) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusBadRequest).Send() + return + } + + if exists { + response.Ctx(c).Status(http.StatusConflict).Error(consts.ErrPostTargetIdExists).Send() + return + } + var postEntity = &entity.Post{ Title: postSaveRequest.Title, Content: postSaveRequest.Content, @@ -103,7 +115,7 @@ func (pc *PostController) Save(c *gin.Context) { Processed: false, } - err = pc.postService.CreatePost(c, postEntity) + err = pc.postService.CreatePost(c, app, postEntity) response.Ctx(c).Error(err).Data(postEntity).Send() return diff --git a/internal/handler/http/controller/application_v1/users.go b/internal/handler/http/controller/application_v1/users.go index e0a4a72..8b4ee35 100644 --- a/internal/handler/http/controller/application_v1/users.go +++ b/internal/handler/http/controller/application_v1/users.go @@ -8,6 +8,7 @@ import ( "leafdev.top/Ecosystem/recommender/internal/handler/http/response" "leafdev.top/Ecosystem/recommender/internal/service/application" "leafdev.top/Ecosystem/recommender/internal/service/auth" + "leafdev.top/Ecosystem/recommender/internal/service/category" "leafdev.top/Ecosystem/recommender/internal/service/post" "leafdev.top/Ecosystem/recommender/internal/service/user" "leafdev.top/Ecosystem/recommender/pkg/consts" @@ -26,6 +27,7 @@ type UserController struct { postService *post.Service logger *logger.Logger redis *redis.Redis + categoryService *category.Service } func NewUserController( @@ -35,7 +37,7 @@ func NewUserController( postService *post.Service, logger *logger.Logger, redis *redis.Redis, - + categoryService *category.Service, ) *UserController { return &UserController{ authService: authService, @@ -44,6 +46,7 @@ func NewUserController( postService: postService, logger: logger, redis: redis, + categoryService: categoryService, } } @@ -55,7 +58,6 @@ func NewUserController( // @Produce json // @Security ApiKeyAuth // @Param UserLikePost body request.UserLikePost true "UserLikePost" -// @Success 200 {object} response.ResponseBody{data=entity.Category} // @Failure 400 {object} response.ResponseBody // @Router /applications/v1/users/_like [post] func (uc *UserController) Like(c *gin.Context) { @@ -90,6 +92,11 @@ func (uc *UserController) Like(c *gin.Context) { return } + if !postEntity.Vectorized || !postEntity.Processed { + response.Ctx(c).Status(http.StatusBadRequest).Error(consts.ErrPostNotReady).Send() + return + } + // 检测是否有 var cacheKey = uc.redis.Prefix(TaskProcessing + ":" + userLikePostRequest.PostId.String()) // if exists @@ -130,8 +137,7 @@ func (uc *UserController) Like(c *gin.Context) { // @Accept json // @Produce json // @Security ApiKeyAuth -// @Param UserLikePost body request.UserLikePost true "UserLikePost" -// @Success 200 {object} response.ResponseBody{data=entity.Category} +// @Param UserDislikePost body request.UserDislikePost true "UserDislikePost" // @Failure 400 {object} response.ResponseBody // @Router /applications/v1/users/_dislike [post] func (uc *UserController) Dislike(c *gin.Context) { @@ -166,6 +172,11 @@ func (uc *UserController) Dislike(c *gin.Context) { return } + if !postEntity.Vectorized || !postEntity.Processed { + response.Ctx(c).Status(http.StatusBadRequest).Error(consts.ErrPostNotReady).Send() + return + } + // 检测是否有 var cacheKey = uc.redis.Prefix(TaskProcessing + ":" + userDislikePostRequest.PostId.String()) exists, err := uc.redis.Client.Exists(c, cacheKey).Result() @@ -197,3 +208,80 @@ func (uc *UserController) Dislike(c *gin.Context) { response.Ctx(c).Status(http.StatusNoContent).Send() return } + +// Suggest godoc +// @Summary Suggest +// @Description 推荐资源 +// @Tags application_api +// @Accept json +// @Produce json +// @Security ApiKeyAuth +// @Param UserSuggestsRequest body request.UserSuggestsRequest true "UserSuggestsRequest" +// @Success 200 {object} response.ResponseBody{data=[]entity.Post} +// @Failure 400 {object} response.ResponseBody +// @Router /applications/v1/users/_suggest [post] +func (uc *UserController) Suggest(c *gin.Context) { + app, err := uc.authService.GetApplication(c) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusBadRequest).Send() + return + } + + var userSuggestsRequest = &request.UserSuggestsRequest{} + + if err := c.ShouldBindJSON(userSuggestsRequest); err != nil { + response.Ctx(c).Error(err).Status(http.StatusBadRequest).Send() + return + } + + exists, err := uc.userService.UserExists(c, userSuggestsRequest.ExternalUserId, app) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusInternalServerError).Send() + return + } + + if !exists { + response.Ctx(c).Status(http.StatusNotFound).Send() + return + } + + externalUserEntity, err := uc.userService.GetExternalUser(c, userSuggestsRequest.ExternalUserId, app) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusInternalServerError).Send() + return + } + + // category + exists, err = uc.categoryService.CategoryExists(c, userSuggestsRequest.CategoryId, app) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusInternalServerError).Send() + return + } + + if !exists { + response.Ctx(c).Status(http.StatusNotFound).Error(consts.ErrCategoryNotExists).Send() + return + } + + categoryEntity, err := uc.categoryService.GetCategoryById(c, userSuggestsRequest.CategoryId) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusInternalServerError).Send() + return + } + + // 建议文章 + err = uc.userService.SummaryUser(c, externalUserEntity, app) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusInternalServerError).Error(err).Send() + return + } + + // 建议文章 + posts, err := uc.userService.SuggestPosts(c, externalUserEntity, categoryEntity) + if err != nil { + response.Ctx(c).Error(err).Status(http.StatusInternalServerError).Send() + return + } + + response.Ctx(c).Data(posts).Send() +} diff --git a/internal/handler/http/request/applications.go b/internal/handler/http/request/applications.go index 4caabe1..93e554f 100644 --- a/internal/handler/http/request/applications.go +++ b/internal/handler/http/request/applications.go @@ -19,3 +19,8 @@ type UserDislikePost struct { PostId schema.EntityId `json:"post_id" uri:"post_id"` ExternalUserId string `json:"external_user_id"` } + +type UserSuggestsRequest struct { + ExternalUserId string `json:"external_user_id" binding:"required"` + CategoryId schema.EntityId `json:"category_id" binding:"required"` +} diff --git a/internal/migrations/1_setup.sql b/internal/migrations/1_setup.sql index a5a5d7f..07d9e9f 100644 --- a/internal/migrations/1_setup.sql +++ b/internal/migrations/1_setup.sql @@ -51,11 +51,13 @@ CREATE TABLE `categories` CREATE TABLE `tags` ( - id serial NOT NULL, - name varchar(255) NOT NULL, - application_id bigint unsigned NOT NULL, + id serial NOT NULL, + name varchar(255) NOT NULL, + vectorized bool DEFAULT FALSE, +# application_id bigint unsigned NOT NULL, primary key (id), - foreign key (application_id) references applications (id) on delete cascade + index (vectorized) +# foreign key (application_id) references applications (id) on delete cascade ); CREATE TABLE `tag_mappings` @@ -79,10 +81,11 @@ CREATE TABLE `posts` content LONGTEXT NOT NULL, application_id bigint unsigned NOT NULL, processed BOOLEAN DEFAULT FALSE, + vectorized bool DEFAULT FALSE, category_id bigint unsigned, created_at timestamp DEFAULT CURRENT_TIMESTAMP, updated_at timestamp DEFAULT CURRENT_TIMESTAMP, - index (target_id, processed, application_id), + index (target_id, processed, vectorized, application_id), primary key (id), foreign key (application_id) references applications (id) on delete cascade, foreign key (category_id) references categories (id) @@ -138,6 +141,6 @@ DROP TABLE IF EXISTS `user_likes`; DROP TABLE IF EXISTS `posts`; DROP TABLE IF EXISTS `tags`; DROP TABLE IF EXISTS `categories`; +DROP TABLE IF EXISTS `external_users`; DROP TABLE IF EXISTS `application_tokens`; DROP TABLE IF EXISTS `applications`; -DROP TABLE IF EXISTS `external_users`; diff --git a/internal/router/api.go b/internal/router/api.go index 8bce702..75815f9 100644 --- a/internal/router/api.go +++ b/internal/router/api.go @@ -48,4 +48,5 @@ func (a *Api) InitApplicationApi(r *gin.RouterGroup) { r.POST("/users/_like", a.h.ApplicationUserApi.Like) r.POST("/users/_dislike", a.h.ApplicationUserApi.Dislike) + r.POST("/users/_suggest", a.h.ApplicationUserApi.Suggest) } diff --git a/internal/service/category/categories.go b/internal/service/category/categories.go index 5dfa49d..22038c9 100644 --- a/internal/service/category/categories.go +++ b/internal/service/category/categories.go @@ -40,3 +40,12 @@ func (s *Service) DeleteCategory(c context.Context, category *entity.Category) e func (s *Service) GetCategoryById(c context.Context, categoryId schema.EntityId) (*entity.Category, error) { return s.dao.WithContext(c).Category.Where(s.dao.Category.Id.Eq(categoryId.Uint())).First() } + +func (s *Service) CategoryExists(c context.Context, categoryId schema.EntityId, applicationEntity *entity.Application) (bool, error) { + count, err := s.dao.WithContext(c).Category. + Where(s.dao.Category.Id.Eq(categoryId.Uint())). + Where(s.dao.Category.ApplicationId.Eq(applicationEntity.Id.Uint())). + Count() + + return count > 0, err +} diff --git a/internal/service/embedding/provider.go b/internal/service/embedding/provider.go index b7ca8fc..fc33da7 100644 --- a/internal/service/embedding/provider.go +++ b/internal/service/embedding/provider.go @@ -2,10 +2,10 @@ package embedding import ( "github.com/tmc/langchaingo/llms/openai" - "leafdev.top/Leaf/leaf-library/internal/base/conf" - "leafdev.top/Leaf/leaf-library/internal/base/logger" - "leafdev.top/Leaf/leaf-library/internal/base/redis" - "leafdev.top/Leaf/leaf-library/internal/dao" + "leafdev.top/Ecosystem/recommender/internal/base/conf" + "leafdev.top/Ecosystem/recommender/internal/base/logger" + "leafdev.top/Ecosystem/recommender/internal/base/redis" + "leafdev.top/Ecosystem/recommender/internal/dao" ) type Service struct { @@ -19,7 +19,7 @@ type Service struct { func NewService(config *conf.Config, logger *logger.Logger, dao *dao.Query, redis *redis.Redis) *Service { llm, err := openai.New( openai.WithToken(config.OpenAI.ApiKey), - openai.WithBaseURL(config.OpenAI.InternalBaseUrl), + openai.WithBaseURL(config.OpenAI.BaseUrl), openai.WithEmbeddingModel(config.OpenAI.EmbeddingModel), ) diff --git a/internal/service/post/posts.go b/internal/service/post/posts.go index 5916a0e..f26f7ac 100644 --- a/internal/service/post/posts.go +++ b/internal/service/post/posts.go @@ -5,6 +5,7 @@ import ( "github.com/iVampireSP/pkg/page" "leafdev.top/Ecosystem/recommender/internal/entity" "leafdev.top/Ecosystem/recommender/internal/schema" + "leafdev.top/Ecosystem/recommender/pkg/consts" ) func (s *Service) ListPosts(c context.Context, pagedResult *page.PagedResult[*entity.Post], category *entity.Category, application *entity.Application) error { @@ -29,8 +30,23 @@ func (s *Service) ListPosts(c context.Context, pagedResult *page.PagedResult[*en return nil } -func (s *Service) CreatePost(c context.Context, post *entity.Post) error { - err := s.dao.WithContext(c).Post.Create(post) +func (s *Service) TargetIdExists(c context.Context, application *entity.Application, targetId string) (bool, error) { + count, err := s.dao.WithContext(c).Post.Where(s.dao.Post.TargetId.Eq(targetId)). + Where(s.dao.Post.ApplicationId.Eq(application.Id.Uint())). + Count() + + return count > 0, err +} + +func (s *Service) CreatePost(c context.Context, applicationEntity *entity.Application, post *entity.Post) error { + // 检测 target_id 是否存在 + exists, err := s.TargetIdExists(c, applicationEntity, post.TargetId) + if exists { + return consts.ErrPostTargetIdExists + } + + post.ApplicationId = applicationEntity.Id + err = s.dao.WithContext(c).Post.Create(post) if err != nil { return err diff --git a/internal/service/post/provider.go b/internal/service/post/provider.go index 423bc85..1c87247 100644 --- a/internal/service/post/provider.go +++ b/internal/service/post/provider.go @@ -1,21 +1,27 @@ package post import ( + "github.com/milvus-io/milvus-sdk-go/v2/client" "leafdev.top/Ecosystem/recommender/internal/base/conf" "leafdev.top/Ecosystem/recommender/internal/dao" + "leafdev.top/Ecosystem/recommender/internal/service/embedding" "leafdev.top/Ecosystem/recommender/internal/service/stream" ) type Service struct { - dao *dao.Query - config *conf.Config - stream *stream.Service + dao *dao.Query + config *conf.Config + stream *stream.Service + milvus client.Client + embedding *embedding.Service } -func NewService(dao *dao.Query, config *conf.Config, stream *stream.Service) *Service { +func NewService(dao *dao.Query, config *conf.Config, stream *stream.Service, milvus client.Client, embedding *embedding.Service) *Service { return &Service{ - dao: dao, - config: config, - stream: stream, + dao: dao, + config: config, + stream: stream, + milvus: milvus, + embedding: embedding, } } diff --git a/internal/service/post/tags.go b/internal/service/post/tags.go index f40dd63..a4ded50 100644 --- a/internal/service/post/tags.go +++ b/internal/service/post/tags.go @@ -2,12 +2,14 @@ package post import ( "context" + entity2 "github.com/milvus-io/milvus-sdk-go/v2/entity" "leafdev.top/Ecosystem/recommender/internal/entity" ) -func (s *Service) GetTag(c context.Context, name string, applicationEntity *entity.Application) (*entity.Tag, error) { - var tmq = s.dao.WithContext(c).TagMapping.Where(s.dao.TagMapping.Name.Eq(name)). - Where(s.dao.TagMapping.ApplicationId.Eq(applicationEntity.Id.Uint())) +func (s *Service) GetTag(c context.Context, name string) (*entity.Tag, error) { + var tmq = s.dao.WithContext(c).TagMapping.Where(s.dao.TagMapping.Name.Eq(name)) + //Where(s.dao.TagMapping.ApplicationId.Eq(applicationEntity.Id.Uint())) + tmqCount, err := tmq.Count() if err != nil { return nil, err @@ -23,13 +25,23 @@ func (s *Service) GetTag(c context.Context, name string, applicationEntity *enti return r.Tag, nil } - return s.dao.WithContext(c).Tag.Where(s.dao.Tag.Name.Eq(name)). - Where(s.dao.Tag.ApplicationId.Eq(applicationEntity.Id.Uint())). + t, err := s.dao.WithContext(c).Tag.Where(s.dao.Tag.Name.Eq(name)). + //Where(s.dao.Tag.ApplicationId.Eq(applicationEntity.Id.Uint())). FirstOrCreate() + + if err != nil { + return nil, err + } + + if !t.Vectorized { + err = s.SaveTagEmbedding(c, t) + } + + return t, err } func (s *Service) HasBindTag(c context.Context, post *entity.Post, tagName string) (bool, error) { - tag, err := s.GetTag(c, tagName, post.Application) + tag, err := s.GetTag(c, tagName) if err != nil { return false, err } @@ -44,7 +56,7 @@ func (s *Service) HasBindTag(c context.Context, post *entity.Post, tagName strin } func (s *Service) BindTag(c context.Context, post *entity.Post, tagName string) error { - tag, err := s.GetTag(c, tagName, post.Application) + tag, err := s.GetTag(c, tagName) if err != nil { return err } @@ -69,3 +81,64 @@ func (s *Service) MarkAsProcessed(c context.Context, post *entity.Post) error { return err } + +func (s *Service) SaveTagEmbedding(c context.Context, tag *entity.Tag) error { + emb, err := s.embedding.TextEmbedding(c, []string{tag.Name}) + if err != nil { + return err + } + + var entityCols = []entity2.Column{ + entity2.NewColumnFloatVector("vector", s.config.OpenAI.EmbeddingDim, emb), + entity2.NewColumnInt64("tag_id", []int64{int64(tag.Id)}), + } + + _, err = s.milvus.Upsert(c, s.config.Milvus.TagCollection, "", entityCols...) + + if err != nil { + return err + } + + _, err = s.dao.WithContext(c).Tag.Where(s.dao.Tag.Id.Eq(tag.Id.Uint())).Update(s.dao.Tag.Vectorized, true) + + return err +} + +func (s *Service) SavePostEmbedding(c context.Context, post *entity.Post) error { + tags, err := s.GetPostTags(c, post) + if err != nil { + return err + } + + var tagString = "" + for _, tag := range tags { + tagString += tag.Name + " " + } + + // 裁剪 > s.config.OpenAI.EmbeddingDim + if len(tagString) > s.config.OpenAI.EmbeddingDim { + tagString = tagString[:s.config.OpenAI.EmbeddingDim] + } + + emb, err := s.embedding.TextEmbedding(c, []string{tagString}) + if err != nil { + return err + } + + var entityCols = []entity2.Column{ + entity2.NewColumnFloatVector("vector", s.config.OpenAI.EmbeddingDim, emb), + entity2.NewColumnInt64("post_id", []int64{int64(post.Id)}), + entity2.NewColumnInt64("category_id", []int64{int64(post.CategoryId.Uint())}), + entity2.NewColumnInt64("application_id", []int64{int64(post.ApplicationId.Uint())}), + } + + _, err = s.milvus.Upsert(c, s.config.Milvus.PostCollection, "", entityCols...) + + if err != nil { + return err + } + + _, err = s.dao.WithContext(c).Post.Where(s.dao.Post.Id.Eq(post.Id.Uint())).Update(s.dao.Post.Vectorized, true) + + return err +} diff --git a/internal/service/provider.go b/internal/service/provider.go index 9dd090f..68a24b2 100644 --- a/internal/service/provider.go +++ b/internal/service/provider.go @@ -5,6 +5,7 @@ import ( "leafdev.top/Ecosystem/recommender/internal/service/application" "leafdev.top/Ecosystem/recommender/internal/service/auth" "leafdev.top/Ecosystem/recommender/internal/service/category" + "leafdev.top/Ecosystem/recommender/internal/service/embedding" "leafdev.top/Ecosystem/recommender/internal/service/jwks" "leafdev.top/Ecosystem/recommender/internal/service/post" "leafdev.top/Ecosystem/recommender/internal/service/stream" @@ -22,10 +23,12 @@ type Service struct { Post *post.Service Category *category.Service User *user.Service + Embedding *embedding.Service } var Provider = wire.NewSet( jwks.NewJWKS, + embedding.NewService, stream.NewService, auth.NewAuthService, application.NewService, @@ -44,6 +47,7 @@ func NewService( post *post.Service, category *category.Service, user *user.Service, + embedding *embedding.Service, ) *Service { return &Service{ logger, @@ -54,5 +58,6 @@ func NewService( post, category, user, + embedding, } } diff --git a/internal/service/user/external.go b/internal/service/user/external.go index bb6f004..f30228c 100644 --- a/internal/service/user/external.go +++ b/internal/service/user/external.go @@ -5,6 +5,20 @@ import ( "leafdev.top/Ecosystem/recommender/internal/entity" ) +func (s *Service) GetExternalUser(c context.Context, externalUserId string, applicationEntity *entity.Application) (*entity.ExternalUser, error) { + iu, err := s.dao.WithContext(c).ExternalUser. + Where(s.dao.ExternalUser.ApplicationId.Eq(applicationEntity.Id.Uint())). + Where(s.dao.ExternalUser.ExternalId.Eq(externalUserId)).First() + return iu, err +} + +func (s *Service) UserExists(c context.Context, externalUserId string, applicationEntity *entity.Application) (bool, error) { + iu, err := s.dao.WithContext(c).ExternalUser. + Where(s.dao.ExternalUser.ApplicationId.Eq(applicationEntity.Id.Uint())). + Where(s.dao.ExternalUser.ExternalId.Eq(externalUserId)).Count() + return iu > 0, err +} + func (s *Service) GetOrCreateExternalUser(c context.Context, externalUserId string, applicationEntity *entity.Application) (*entity.ExternalUser, error) { //Where(s.dao.UserTagScore.ExternalUserId.Eq(externalUserEntity.Id.Uint())). count, err := s.dao.WithContext(c).ExternalUser. diff --git a/internal/service/user/posts.go b/internal/service/user/posts.go index cf610e9..a77654c 100644 --- a/internal/service/user/posts.go +++ b/internal/service/user/posts.go @@ -4,13 +4,9 @@ import ( "context" "leafdev.top/Ecosystem/recommender/internal/entity" "leafdev.top/Ecosystem/recommender/internal/schema" - "time" + "leafdev.top/Ecosystem/recommender/pkg/consts" ) -const TaskProcessing = "user_likes" - -var LockTTL = time.Minute * 10 - func (s *Service) HasLiked(c context.Context, externalUserEntity *entity.ExternalUser, applicationEntity *entity.Application, postEntity *entity.Post) (bool, error) { count, err := s.dao.WithContext(c).UserLike.Where(s.dao.UserLike.ExternalUserId.Eq(externalUserEntity.Id.Uint())). Where(s.dao.UserLike.PostId.Eq(postEntity.Id.Uint())). @@ -68,6 +64,10 @@ func (s *Service) LikePost(c context.Context, externalUserEntity *entity.Externa // } //}(lock, c) + if !postEntity.Vectorized || !postEntity.Processed { + return consts.ErrPostNotReady + } + // get tags postTags, err := s.postService.GetPostTags(c, postEntity) if err != nil { @@ -111,6 +111,10 @@ func (s *Service) DislikePost(c context.Context, externalUserEntity *entity.Exte // } //}(lock, c) + if !postEntity.Vectorized || !postEntity.Processed { + return consts.ErrPostNotReady + } + // get tags postTags, err := s.postService.GetPostTags(c, postEntity) if err != nil { diff --git a/internal/service/user/provider.go b/internal/service/user/provider.go index 5776dc6..1832502 100644 --- a/internal/service/user/provider.go +++ b/internal/service/user/provider.go @@ -1,8 +1,11 @@ package user import ( + "github.com/milvus-io/milvus-sdk-go/v2/client" + "leafdev.top/Ecosystem/recommender/internal/base/conf" "leafdev.top/Ecosystem/recommender/internal/base/logger" "leafdev.top/Ecosystem/recommender/internal/dao" + "leafdev.top/Ecosystem/recommender/internal/service/embedding" "leafdev.top/Ecosystem/recommender/internal/service/post" ) @@ -10,16 +13,25 @@ type Service struct { dao *dao.Query postService *post.Service logger *logger.Logger + milvus client.Client + embedding *embedding.Service + config *conf.Config } func NewService( dao *dao.Query, postService *post.Service, logger *logger.Logger, + milvus client.Client, + embedding *embedding.Service, + config *conf.Config, ) *Service { return &Service{ dao: dao, postService: postService, logger: logger, + milvus: milvus, + embedding: embedding, + config: config, } } diff --git a/internal/service/user/suggest.go b/internal/service/user/suggest.go new file mode 100644 index 0000000..c02dfd6 --- /dev/null +++ b/internal/service/user/suggest.go @@ -0,0 +1,70 @@ +package user + +import ( + "context" + "fmt" + "github.com/milvus-io/milvus-sdk-go/v2/client" + entity2 "github.com/milvus-io/milvus-sdk-go/v2/entity" + "leafdev.top/Ecosystem/recommender/internal/entity" +) + +func (s *Service) SuggestPosts(c context.Context, externalUserEntity *entity.ExternalUser, categoryEntity *entity.Category) ([]*entity.Post, error) { + emb, err := s.embedding.TextEmbedding(c, []string{externalUserEntity.Summary}) + if err != nil { + return nil, err + } + + var filter = fmt.Sprintf("application_id == %d && category_id == %s", externalUserEntity.ApplicationId, categoryEntity.Id) + sp, err := entity2.NewIndexAUTOINDEXSearchParam(1) + if err != nil { + return nil, err + } + vector := entity2.FloatVector(emb[0]) + postResults, err := s.milvus.Search(c, s.config.Milvus.PostCollection, + []string{}, + filter, + []string{"post_id", "category_id"}, + []entity2.Vector{vector}, + "vector", + entity2.L2, + 3, + sp, client.WithLimit(7)) + + var ids []uint + + for _, res := range postResults { + // 没找到,直接返回空的 + if res.ResultCount == 0 { + return make([]*entity.Post, 0), nil + } + + var blockIdColumn *entity2.ColumnInt64 + for _, field := range res.Fields { + if field.Name() == "post_id" { + c, ok := field.(*entity2.ColumnInt64) + if ok { + blockIdColumn = c + } + } + } + + // 没有记录 + if blockIdColumn == nil { + return make([]*entity.Post, 0), nil + //return nil, fmt.Errorf("block_id column not found") + } + + for i := 0; i < res.ResultCount; i++ { + id, err := blockIdColumn.ValueByIdx(i) + if err != nil { + return nil, err + } + + ids = append(ids, uint(id)) + + } + } + + posts, err := s.dao.Post.Where(s.dao.Post.Where(s.dao.Post.Id.In(ids...))).Find() + return posts, err +} diff --git a/internal/service/user/summary.go b/internal/service/user/summary.go new file mode 100644 index 0000000..80e603c --- /dev/null +++ b/internal/service/user/summary.go @@ -0,0 +1,59 @@ +package user + +import ( + "context" + entity2 "github.com/milvus-io/milvus-sdk-go/v2/entity" + "leafdev.top/Ecosystem/recommender/internal/entity" + "leafdev.top/Ecosystem/recommender/pkg/consts" +) + +func (s *Service) SummaryUser(c context.Context, externalUserEntity *entity.ExternalUser, applicationEntity *entity.Application) error { + tags, err := s.GetHighScoreTags(c, externalUserEntity, applicationEntity) + if err != nil { + return err + } + + if len(tags) == 0 { + return consts.ErrExternalUserDoesNotLikeAnyPost + } + + var tagString = "" + for _, tag := range tags { + tagString += tag.Name + " " + } + + // 裁剪 > s.config.OpenAI.EmbeddingDim + if len(tagString) > s.config.OpenAI.EmbeddingDim { + tagString = tagString[:s.config.OpenAI.EmbeddingDim] + } + + if externalUserEntity.Summary != tagString { + _, err = s.dao.ExternalUser.Where(s.dao.ExternalUser.Id.Eq(externalUserEntity.Id.Uint())).Update( + s.dao.ExternalUser.Summary, tagString, + ) + + if err != nil { + return err + } + + emb, err := s.embedding.TextEmbedding(c, []string{tagString}) + if err != nil { + return err + } + + var entityCols = []entity2.Column{ + entity2.NewColumnFloatVector("vector", s.config.OpenAI.EmbeddingDim, emb), + entity2.NewColumnInt64("external_user_id", []int64{int64(externalUserEntity.Id)}), + entity2.NewColumnInt64("application_id", []int64{int64(applicationEntity.Id)}), + } + + _, err = s.milvus.Upsert(c, s.config.Milvus.UserSummaryCollection, "", entityCols...) + + if err != nil { + return err + } + + } + + return err +} diff --git a/internal/service/user/tags.go b/internal/service/user/tags.go index 430dd46..d8af8f3 100644 --- a/internal/service/user/tags.go +++ b/internal/service/user/tags.go @@ -112,3 +112,25 @@ func (s *Service) RemoveTags(c context.Context, externalUserEntity *entity.Exter return nil } + +func (s *Service) GetHighScoreTags(c context.Context, externalUserEntity *entity.ExternalUser, applicationEntity *entity.Application) ([]*entity.Tag, error) { + tagScores, err := s.dao.WithContext(c).UserTagScore. + Preload(s.dao.UserTagScore.Tag). + Where(s.dao.UserTagScore.ApplicationId.Eq(applicationEntity.Id.Uint())). + Where(s.dao.UserTagScore.ExternalUserId.Eq(externalUserEntity.Id.Uint())). + Order(s.dao.UserTagScore.Score.Desc()). + //Where(s.dao.UserTagScore.Score.Gt(3)). + Limit(30).Find() + + if err != nil { + return nil, err + } + + var tags []*entity.Tag + + for _, tagScore := range tagScores { + tags = append(tags, tagScore.Tag) + } + + return tags, nil +} diff --git a/pkg/consts/category.go b/pkg/consts/category.go new file mode 100644 index 0000000..3d14def --- /dev/null +++ b/pkg/consts/category.go @@ -0,0 +1,8 @@ +package consts + +import "errors" + +var ( + ErrCategoryExists = errors.New("category already exists") + ErrCategoryNotExists = errors.New("category does not exist") +) diff --git a/pkg/consts/posts.go b/pkg/consts/posts.go index 037e808..1ceb4f7 100644 --- a/pkg/consts/posts.go +++ b/pkg/consts/posts.go @@ -3,5 +3,8 @@ package consts import "errors" var ( - ErrAnotherOperationInProgress = errors.New("another operation in progress") + ErrPostNotReady = errors.New("post not ready") + ErrAnotherOperationInProgress = errors.New("another operation in progress") + ErrExternalUserDoesNotLikeAnyPost = errors.New("external user does not like any post") + ErrPostTargetIdExists = errors.New("post target id exists") )