diff --git a/lib/util/io_channel.c b/lib/util/io_channel.c index 1e14b8869..b91549451 100644 --- a/lib/util/io_channel.c +++ b/lib/util/io_channel.c @@ -52,6 +52,7 @@ struct io_device { spdk_io_channel_destroy_cb destroy_cb; spdk_io_device_unregister_cb unregister_cb; uint32_t ctx_size; + uint32_t for_each_count; TAILQ_ENTRY(io_device) tailq; bool unregistered; @@ -218,6 +219,7 @@ spdk_io_device_register(void *io_device, spdk_io_channel_create_cb create_cb, dev->destroy_cb = destroy_cb; dev->unregister_cb = NULL; dev->ctx_size = ctx_size; + dev->for_each_count = 0; dev->unregistered = false; pthread_mutex_lock(&g_devlist_mutex); @@ -276,6 +278,12 @@ spdk_io_device_unregister(void *io_device, spdk_io_device_unregister_cb unregist return; } + if (dev->for_each_count > 0) { + SPDK_ERRLOG("io_device %p has %u for_each calls outstanding\n", io_device, dev->for_each_count); + pthread_mutex_unlock(&g_devlist_mutex); + return; + } + dev->unregister_cb = unregister_cb; dev->unregistered = true; TAILQ_REMOVE(&g_io_devices, dev, tailq); @@ -400,6 +408,7 @@ spdk_io_channel_get_thread(struct spdk_io_channel *ch) struct call_channel { void *io_device; + struct io_device *dev; spdk_channel_msg fn; void *ctx; @@ -459,6 +468,7 @@ _call_channel(void *ctx) thread = TAILQ_NEXT(thread, tailq); } + ch_ctx->dev->for_each_count--; pthread_mutex_unlock(&g_devlist_mutex); spdk_thread_send_msg(ch_ctx->orig_thread, _call_completion, ch_ctx); @@ -489,6 +499,8 @@ spdk_for_each_channel(void *io_device, spdk_channel_msg fn, void *ctx, TAILQ_FOREACH(thread, &g_threads, tailq) { TAILQ_FOREACH(ch, &thread->io_channels, tailq) { if (ch->dev->io_device == io_device) { + ch->dev->for_each_count++; + ch_ctx->dev = ch->dev; ch_ctx->cur_thread = thread; pthread_mutex_unlock(&g_devlist_mutex); spdk_thread_send_msg(thread, _call_channel, ch_ctx); diff --git a/test/unit/lib/util/io_channel.c/io_channel_ut.c b/test/unit/lib/util/io_channel.c/io_channel_ut.c index 5afd0b12a..efccf0e42 100644 --- a/test/unit/lib/util/io_channel.c/io_channel_ut.c +++ b/test/unit/lib/util/io_channel.c/io_channel_ut.c @@ -170,6 +170,78 @@ for_each_channel_remove(void) free_threads(); } +struct unreg_ctx { + bool ch_done; + bool foreach_done; +}; + +static void +unreg_ch_done(void *io_device, struct spdk_io_channel *_ch, void *_ctx) +{ + struct unreg_ctx *ctx = _ctx; + + ctx->ch_done = true; +} + +static void +unreg_foreach_done(void *io_device, void *_ctx) +{ + struct unreg_ctx *ctx = _ctx; + + ctx->foreach_done = true; +} + +static void +for_each_channel_unreg(void) +{ + struct spdk_io_channel *ch0; + struct io_device *dev; + struct unreg_ctx ctx = {}; + int io_target; + + allocate_threads(1); + CU_ASSERT(TAILQ_EMPTY(&g_io_devices)); + spdk_io_device_register(&io_target, channel_create, channel_destroy, sizeof(int)); + CU_ASSERT(!TAILQ_EMPTY(&g_io_devices)); + dev = TAILQ_FIRST(&g_io_devices); + SPDK_CU_ASSERT_FATAL(dev != NULL); + CU_ASSERT(TAILQ_NEXT(dev, tailq) == NULL); + set_thread(0); + ch0 = spdk_get_io_channel(&io_target); + spdk_for_each_channel(&io_target, unreg_ch_done, &ctx, unreg_foreach_done); + + spdk_io_device_unregister(&io_target, NULL); + /* + * There is an outstanding foreach call on the io_device, so the unregister should not + * have removed the device. + */ + CU_ASSERT(dev == TAILQ_FIRST(&g_io_devices)); + spdk_io_device_register(&io_target, channel_create, channel_destroy, sizeof(int)); + /* + * There is already a device registered at &io_target, so a new io_device should not + * have been added to g_io_devices. + */ + CU_ASSERT(dev == TAILQ_FIRST(&g_io_devices)); + CU_ASSERT(TAILQ_NEXT(dev, tailq) == NULL); + + poll_thread(0); + CU_ASSERT(ctx.ch_done == true); + CU_ASSERT(ctx.foreach_done == true); + /* + * There are no more foreach operations outstanding, so we can unregister the device, + * even though a channel still exists for the device. + */ + spdk_io_device_unregister(&io_target, NULL); + CU_ASSERT(TAILQ_EMPTY(&g_io_devices)); + + set_thread(0); + spdk_put_io_channel(ch0); + + poll_threads(); + + free_threads(); +} + static void thread_name(void) { @@ -318,6 +390,7 @@ main(int argc, char **argv) CU_add_test(suite, "thread_alloc", thread_alloc) == NULL || CU_add_test(suite, "thread_send_msg", thread_send_msg) == NULL || CU_add_test(suite, "for_each_channel_remove", for_each_channel_remove) == NULL || + CU_add_test(suite, "for_each_channel_unreg", for_each_channel_unreg) == NULL || CU_add_test(suite, "thread_name", thread_name) == NULL || CU_add_test(suite, "channel", channel) == NULL ) {