自定义线程池支持异步操作。

3
我希望有一个自定义线程池,满足以下要求:
  1. 真实线程根据池容量预先分配。如果需要生成并发任务,则实际工作可以自由使用标准的.NET线程池。
  2. 该池必须能够返回空闲线程数。返回的数字可能小于空闲线程的实际数量,但不能超过。当然,数字越准确越好。
  3. 将工作排队到池中应返回相应的Task,该Task应与基于任务的API兼容。
  4. 新功能最大作业容量(或并行度)应动态可调整。尝试减少容量不必立即生效,但增加容量应立即生效。

第一项的原理如下所示:

  • 机器不应同时运行超过N个工作项,其中N相对较小-在10到30之间。
  • 工作从数据库中提取,如果提取了K个,则我们要确保有K个空闲线程立即开始工作。如果工作从数据库中提取,但仍在等待下一个可用线程,则是不可接受的。

最后一项也解释了具有空闲线程计数的原因-我将从数据库中获取那么多工作项。它还解释了为什么报告的空闲线程计数绝不能高于实际数量-否则我可能会获取更多可立即启动的工作。

无论如何,这是我的实现以及用于测试它的小程序(BJE代表Background Job Engine):

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

namespace TaskStartLatency
{
    public class BJEThreadPool
    {
        private sealed class InternalTaskScheduler : TaskScheduler
        {
            private int m_idleThreadCount;
            private readonly BlockingCollection<Task> m_bus;

            public InternalTaskScheduler(int threadCount, BlockingCollection<Task> bus)
            {
                m_idleThreadCount = threadCount;
                m_bus = bus;
            }

            public void RunInline(Task task)
            {
                Interlocked.Decrement(ref m_idleThreadCount);
                try
                {
                    TryExecuteTask(task);
                }
                catch
                {
                    // The action is responsible itself for the error handling, for the time being...
                }
                Interlocked.Increment(ref m_idleThreadCount);
            }

            public int IdleThreadCount
            {
                get { return m_idleThreadCount; }
            }

            #region Overrides of TaskScheduler

            protected override void QueueTask(Task task)
            {
                m_bus.Add(task);
            }

            protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued)
            {
                return TryExecuteTask(task);
            }

            protected override IEnumerable<Task> GetScheduledTasks()
            {
                throw new NotSupportedException();
            }

            #endregion

            public void DecrementIdleThreadCount()
            {
                Interlocked.Decrement(ref m_idleThreadCount);
            }
        }

        private class ThreadContext
        {
            private readonly InternalTaskScheduler m_ts;
            private readonly BlockingCollection<Task> m_bus;
            private readonly CancellationTokenSource m_cts;
            public readonly Thread Thread;

            public ThreadContext(string name, InternalTaskScheduler ts, BlockingCollection<Task> bus, CancellationTokenSource cts)
            {
                m_ts = ts;
                m_bus = bus;
                m_cts = cts;
                Thread = new Thread(Start)
                {
                    IsBackground = true,
                    Name = name
                };
                Thread.Start();
            }

            private void Start()
            {
                try
                {
                    foreach (var task in m_bus.GetConsumingEnumerable(m_cts.Token))
                    {
                        m_ts.RunInline(task);
                    }
                }
                catch (OperationCanceledException)
                {
                }
                m_ts.DecrementIdleThreadCount();
            }
        }

        private readonly InternalTaskScheduler m_ts;
        private readonly CancellationTokenSource m_cts = new CancellationTokenSource();
        private readonly BlockingCollection<Task> m_bus = new BlockingCollection<Task>();
        private readonly List<ThreadContext> m_threadCtxs = new List<ThreadContext>();

        public BJEThreadPool(int threadCount)
        {
            m_ts = new InternalTaskScheduler(threadCount, m_bus);
            for (int i = 0; i < threadCount; ++i)
            {
                m_threadCtxs.Add(new ThreadContext("BJE Thread " + i, m_ts, m_bus, m_cts));
            }
        }

        public void Terminate()
        {
            m_cts.Cancel();
            foreach (var t in m_threadCtxs)
            {
                t.Thread.Join();
            }
        }

        public Task Run(Action<CancellationToken> action)
        {
            return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
        }
        public Task Run(Action action)
        {
            return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts);
        }

        public int IdleThreadCount
        {
            get { return m_ts.IdleThreadCount; }
        }
    }

    class Program
    {
        static void Main()
        {
            const int THREAD_COUNT = 32;
            var pool = new BJEThreadPool(THREAD_COUNT);
            var tcs = new TaskCompletionSource<bool>();
            var tasks = new List<Task>();
            var allRunning = new CountdownEvent(THREAD_COUNT);

            for (int i = pool.IdleThreadCount; i > 0; --i)
            {
                var index = i;
                tasks.Add(pool.Run(cancellationToken =>
                {
                    Console.WriteLine("Started action " + index);
                    allRunning.Signal();
                    tcs.Task.Wait(cancellationToken);
                    Console.WriteLine("  Ended action " + index);
                }));
            }

            Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);

            allRunning.Wait();
            Debug.Assert(pool.IdleThreadCount == 0);

            int expectedIdleThreadCount = THREAD_COUNT;
            Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
            switch (Console.ReadKey().KeyChar)
            {
            case 'c':
                Console.WriteLine("Cancel All");
                tcs.TrySetCanceled();
                break;
            case 'e':
                Console.WriteLine("Error All");
                tcs.TrySetException(new Exception("Failed"));
                break;
            case 'a':
                Console.WriteLine("Abort All");
                pool.Terminate();
                expectedIdleThreadCount = 0;
                break;
            default:
                Console.WriteLine("Done All");
                tcs.TrySetResult(true);
                break;
            }
            try
            {
                Task.WaitAll(tasks.ToArray());
            }
            catch (AggregateException exc)
            {
                Console.WriteLine(exc.Flatten().InnerException.Message);
            }

            Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);

            pool.Terminate();
            Console.WriteLine("Press any key");
            Console.ReadKey();
        }
    }
}

这是一个非常简单的实现,看起来已经生效。然而,存在一个问题——BJEThreadPool.Run方法不允许异步方法。也就是说,我的实现不能添加以下重载:

public Task Run(Func<CancellationToken, Task> action)
{
    return Task.Factory.StartNew(() => action(m_cts.Token), m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}
public Task Run(Func<Task> action)
{
    return Task.Factory.StartNew(action, m_cts.Token, TaskCreationOptions.DenyChildAttach, m_ts).Unwrap();
}

我在InternalTaskScheduler.RunInline中使用的模式在这种情况下不起作用。

所以,我的问题是如何添加对异步工作项的支持?只要满足帖子开头概述的要求,我可以更改整个设计。

编辑

我想澄清所需池的使用意图。请查看以下代码:

if (pool.IdleThreadCount == 0)
{
  return;
}

foreach (var jobData in FetchFromDB(pool.IdleThreadCount))
{
  pool.Run(CreateJobAction(jobData));
}

注意事项:
1. 代码将会定期运行,例如每一分钟。 2. 多台机器将同时监视同一个数据库运行该代码。 3. `FetchFromDB` 将使用在 Using SQL Server as a DB queue with multiple clients 中描述的技术来原子地从数据库中获取和锁定工作。 4. `CreateJobAction` 将调用由 `jobData`(工作代码)表示的代码,并在完成该代码后关闭工作。工作代码不受我的控制,它可能是任何东西 - 重 CPU 绑定代码或轻异步 IO 绑定代码、糟糕编写的同步 IO 绑定代码或混合所有这些。它可能运行几分钟,也可能运行几个小时。关闭工作是我的代码,它将是异步 IO 绑定代码。因此,返回的作业操作的签名是异步方法的签名。
第二项强调正确识别空闲线程数量的重要性。如果有 900 个待处理工作项和 10 台代理机器,我不能允许代理机器获取 300 个工作项并将其排队到线程池中。为什么?因为,代理机器很可能无法同时运行 300 个工作项。当然,它会运行一些工作项,但其他工作项将在线程池工作队列中等待。假设它运行 100 个工作项,并让 200 个工作项等待(即使 100 个工作项可能过于乐观)。这将产生 3 个完全加载的代理机器和 7 个空闲的代理机器。但实际上只有 900 个工作项中的 300 个正在被并发处理!!!
我的目标是在可用代理机器之间最大化工作分配。理想情况下,我应该评估代理的负载和待处理工作的“重量”,但这是一个艰巨的任务,预留给未来版本。现在,我希望为每个代理分配最大作业容量,并提供动态增加/减少作业容量而无需重新启动代理的手段。
下一个观察结果是:工作可能需要很长时间才能运行,并且可能全部是同步代码。据我所知,使用线程池线程进行这种类型的工作是不可取的。
编辑 2:
有一种说法是 `TaskScheduler` 仅适用于 CPU 绑定的工作。但如果我不知道工作的性质呢?我的意思是,它是一个通用的后台作业引擎,可以运行数千种不同类型的作业。我没有办法告诉“那个工作是 CPU 绑定的”、“那个工作是同步 IO 绑定的”,还有另一个是异步 IO 绑定的。我希望我能,但我不能。
编辑 3:
最终,我没有使用SemaphoreSlim,也没有使用TaskScheduler - 我最终明白了这是不合适和错误的,而且它使代码过于复杂。
然而,我仍然没有看到SemaphoreSlim是正确的方法。所提出的模式:
public async Task Enqueue(Func<Task> taskGenerator)
{
    await semaphore.WaitAsync();
    try
    {
        await taskGenerator();
    }
    finally
    {
        semaphore.Release();
    }
}

希望taskGenerator是异步IO绑定代码或打开新线程,但我无法确定要执行的工作是哪种。此外,从SemaphoreSlim.WaitAsync continuation code中了解到,如果信号量被解锁,则在WaitAsync()后面运行的代码将在同一线程上运行,这对我来说不太好。
无论如何,以下是我的实现,如果有人喜欢的话。不幸的是,我还没有理解如何动态减少池线程数,但这是另一个问题的主题。
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

namespace TaskStartLatency
{
    public interface IBJEThreadPool
    {
        void SetThreadCount(int threadCount);
        void Terminate();
        Task Run(Action action);
        Task Run(Action<CancellationToken> action);
        Task Run(Func<Task> action);
        Task Run(Func<CancellationToken, Task> action);
        int IdleThreadCount { get; }
    }

    public class BJEThreadPool : IBJEThreadPool
    {
        private interface IActionContext
        {
            Task Run(CancellationToken ct);
            TaskCompletionSource<object> TaskCompletionSource { get; }
        }

        private class ActionContext : IActionContext
        {
            private readonly Action m_action;

            public ActionContext(Action action)
            {
                m_action = action;
                TaskCompletionSource = new TaskCompletionSource<object>();
            }

            #region Implementation of IActionContext

            public Task Run(CancellationToken ct)
            {
                m_action();
                return null;
            }

            public TaskCompletionSource<object> TaskCompletionSource { get; private set; }

            #endregion
        }
        private class CancellableActionContext : IActionContext
        {
            private readonly Action<CancellationToken> m_action;

            public CancellableActionContext(Action<CancellationToken> action)
            {
                m_action = action;
                TaskCompletionSource = new TaskCompletionSource<object>();
            }

            #region Implementation of IActionContext

            public Task Run(CancellationToken ct)
            {
                m_action(ct);
                return null;
            }

            public TaskCompletionSource<object> TaskCompletionSource { get; private set; }

            #endregion
        }
        private class AsyncActionContext : IActionContext
        {
            private readonly Func<Task> m_action;

            public AsyncActionContext(Func<Task> action)
            {
                m_action = action;
                TaskCompletionSource = new TaskCompletionSource<object>();
            }

            #region Implementation of IActionContext

            public Task Run(CancellationToken ct)
            {
                return m_action();
            }

            public TaskCompletionSource<object> TaskCompletionSource { get; private set; }

            #endregion
        }
        private class AsyncCancellableActionContext : IActionContext
        {
            private readonly Func<CancellationToken, Task> m_action;

            public AsyncCancellableActionContext(Func<CancellationToken, Task> action)
            {
                m_action = action;
                TaskCompletionSource = new TaskCompletionSource<object>();
            }

            #region Implementation of IActionContext

            public Task Run(CancellationToken ct)
            {
                return m_action(ct);
            }

            public TaskCompletionSource<object> TaskCompletionSource { get; private set; }

            #endregion
        }

        private readonly CancellationTokenSource m_ctsTerminateAll = new CancellationTokenSource();
        private readonly BlockingCollection<IActionContext> m_bus = new BlockingCollection<IActionContext>();
        private readonly LinkedList<Thread> m_threads = new LinkedList<Thread>();
        private int m_idleThreadCount;

        private static int s_threadCount;

        public BJEThreadPool(int threadCount)
        {
            ReserveAdditionalThreads(threadCount);
        }

        private void ReserveAdditionalThreads(int n)
        {
            for (int i = 0; i < n; ++i)
            {
                var index = Interlocked.Increment(ref s_threadCount) - 1;

                var t = new Thread(Start)
                {
                    IsBackground = true,
                    Name = "BJE Thread " + index
                };
                Interlocked.Increment(ref m_idleThreadCount);
                t.Start();

                m_threads.AddLast(t);
            }
        }

        private void Start()
        {
            try
            {
                foreach (var actionContext in m_bus.GetConsumingEnumerable(m_ctsTerminateAll.Token))
                {
                    RunWork(actionContext).Wait();
                }
            }
            catch (OperationCanceledException)
            {
            }
            catch
            {
                // Should never happen - log the error
            }

            Interlocked.Decrement(ref m_idleThreadCount);
        }

        private async Task RunWork(IActionContext actionContext)
        {
            Interlocked.Decrement(ref m_idleThreadCount);
            try
            {
                var task = actionContext.Run(m_ctsTerminateAll.Token);
                if (task != null)
                {
                    await task;
                }
                actionContext.TaskCompletionSource.SetResult(null);
            }
            catch (OperationCanceledException)
            {
                actionContext.TaskCompletionSource.TrySetCanceled();
            }
            catch (Exception exc)
            {
                actionContext.TaskCompletionSource.TrySetException(exc);
            }
            Interlocked.Increment(ref m_idleThreadCount);
        }

        private Task PostWork(IActionContext actionContext)
        {
            m_bus.Add(actionContext);
            return actionContext.TaskCompletionSource.Task;
        }

        #region Implementation of IBJEThreadPool

        public void SetThreadCount(int threadCount)
        {
            if (threadCount > m_threads.Count)
            {
                ReserveAdditionalThreads(threadCount - m_threads.Count);
            }
            else if (threadCount < m_threads.Count)
            {
                throw new NotSupportedException();
            }
        }
        public void Terminate()
        {
            m_ctsTerminateAll.Cancel();
            foreach (var t in m_threads)
            {
                t.Join();
            }
        }

        public Task Run(Action action)
        {
            return PostWork(new ActionContext(action));
        }
        public Task Run(Action<CancellationToken> action)
        {
            return PostWork(new CancellableActionContext(action));
        }
        public Task Run(Func<Task> action)
        {
            return PostWork(new AsyncActionContext(action));
        }
        public Task Run(Func<CancellationToken, Task> action)
        {
            return PostWork(new AsyncCancellableActionContext(action));
        }

        public int IdleThreadCount
        {
            get { return m_idleThreadCount; }
        }

        #endregion
    }

    public static class Extensions
    {
        public static Task WithCancellation(this Task task, CancellationToken token)
        {
            return task.ContinueWith(t => t.GetAwaiter().GetResult(), token);
        }
    }

    class Program
    {
        static void Main()
        {
            const int THREAD_COUNT = 16;
            var pool = new BJEThreadPool(THREAD_COUNT);
            var tcs = new TaskCompletionSource<bool>();
            var tasks = new List<Task>();
            var allRunning = new CountdownEvent(THREAD_COUNT);

            for (int i = pool.IdleThreadCount; i > 0; --i)
            {
                var index = i;
                tasks.Add(pool.Run(async ct =>
                {
                    Console.WriteLine("Started action " + index);
                    allRunning.Signal();
                    await tcs.Task.WithCancellation(ct);
                    Console.WriteLine("  Ended action " + index);
                }));
            }

            Console.WriteLine("pool.IdleThreadCount = " + pool.IdleThreadCount);

            allRunning.Wait();
            Debug.Assert(pool.IdleThreadCount == 0);

            int expectedIdleThreadCount = THREAD_COUNT;
            Console.WriteLine("Press [c]ancel, [e]rror, [a]bort or any other key");
            switch (Console.ReadKey().KeyChar)
            {
            case 'c':
                Console.WriteLine("ancel All");
                tcs.TrySetCanceled();
                break;
            case 'e':
                Console.WriteLine("rror All");
                tcs.TrySetException(new Exception("Failed"));
                break;
            case 'a':
                Console.WriteLine("bort All");
                pool.Terminate();
                expectedIdleThreadCount = 0;
                break;
            default:
                Console.WriteLine("Done All");
                tcs.TrySetResult(true);
                break;
            }

            try
            {
                Task.WaitAll(tasks.ToArray());
            }
            catch (AggregateException exc)
            {
                Console.WriteLine(exc.Flatten().InnerException.Message);
            }

            Debug.Assert(pool.IdleThreadCount == expectedIdleThreadCount);

            pool.Terminate();
            Console.WriteLine("Press any key");
            Console.ReadKey();
        }
    }
}

2
你会如何总结你的问题,例如“机器不应该同时运行超过N个工作项”? - L.B
我认为当前的标题相当准确 - 我有一个不支持异步工作项的自定义线程池,问题是如何修改它以支持它们。 - mark
由于许多“async”操作实际上根本不使用新线程运行,因此您的问题有些奇怪... - Alexei Levenkov
从你的第一条评论中,我可以看出你不愿意接受对你的代码的批评,但我可以问一下,“为什么你需要这段代码?你试图解决什么实际问题?” - L.B
@标记 我想确保我获取的所有工作都不会等待线程可用 我希望能够理解你说的话。祝你好运。 - L.B
显示剩余11条评论
2个回答

3

异步的“工作项”通常基于异步IO。在运行时,异步IO不使用线程。任务调度程序用于执行CPU工作(基于委托的任务)。概念TaskScheduler不适用。您不能使用自定义TaskScheduler来影响异步代码的操作。

让您的工作项自我限制:

static SemaphoreSlim sem = new SemaphoreSlim(maxDegreeOfParallelism); //shared object

async Task MyWorkerFunction()
{
    await sem.WaitAsync();
    try
    {
        MyWork();
    }
    finally
    {
        sem.Release();
    }
}

请注意,没有 WaitOneAsync,只有 WaitAsync。同时,释放资源的操作必须在 finally 块中执行,以免在完成后未释放资源导致死锁。希望您不介意我进行了这些修改。 - Servy
我有一些要求。细节决定成败 - 请解释您的答案如何满足我的要求。请注意,MyWork 所表示的可能是任何东西 - 同步 CPU 绑定代码、同步 IO 绑定代码、异步 IO 绑定代码或它们的混合。此外,它可能是长时间运行的工作(数小时)或短时间运行的工作(数分钟)。 - mark
请帮我将编辑内容添加到帖子中。我希望这会更清楚地阐明我的上下文。 - mark
可能你的方法是做这件事的方法,我还不确定。我一直在暗示,但现在我已经明确将其添加到要求列表中 - 动态调整并行度(要求4)。如何使用您的解决方案实现它?接下来,工作在WaitAsync的继续执行中 - 它在线程池上,对吗?如果是这样,它可能是一个长时间运行的进程,占用线程池线程。我们对此满意吗? - mark

2
如另一个答案(由usr提到)中所述,您无法使用 TaskScheduler 来限制所有类型的工作(无论是并行还是非并行),因为它仅适用于CPU绑定的工作。 他还向您展示了如何使用 SemaphoreSlim 异步限制并行度。
您可以扩展这些概念以一般化几种方式。 对您最有益的可能是创建一种特殊类型的队列,该队列接受返回 Task 的操作,并以使得达到给定的最大并行度的方式执行它们。
public class FixedParallelismQueue
{
    private SemaphoreSlim semaphore;
    public FixedParallelismQueue(int maxDegreesOfParallelism)
    {
        semaphore = new SemaphoreSlim(maxDegreesOfParallelism,
            maxDegreesOfParallelism);
    }

    public async Task<T> Enqueue<T>(Func<Task<T>> taskGenerator)
    {
        await semaphore.WaitAsync();
        try
        {
            return await taskGenerator();
        }
        finally
        {
            semaphore.Release();
        }
    }
    public async Task Enqueue(Func<Task> taskGenerator)
    {
        await semaphore.WaitAsync();
        try
        {
            await taskGenerator();
        }
        finally
        {
            semaphore.Release();
        }
    }
}

这允许您为应用程序创建一个队列(如果需要,甚至可以有几个独立的队列),该队列具有固定的并行度。然后,当操作完成并返回一个Task时,您可以提供它们,并且队列将在可能的情况下安排它并返回表示该工作单元何时完成的Task

可能可以,我还不确定。我一直在暗示,但现在我已经明确将其添加到要求列表中 - 动态调整并行度(要求4)。使用您的解决方案如何实现?接下来,工作在WaitAsync的继续执行中 - 它在线程池上,是吗?如果是这样,它可能是一个长时间运行的进程,占用线程池线程。我们对此满意吗? - mark
  1. 我不知道如何调整SemaphoreSlim的计数,除非替换它。
  2. 你是说我应该知道工作的性质吗?这将是完美的,但我不能。你看,这些工作(或我们称之为后台作业)是由不同的团队编写的。编写它们的人的技能水平是不同的。它们已经使用不同的.NET技术编写了过去4年。旧代码纯粹是同步的,新代码包含异步补丁。同样,有些很重,有些很轻。我希望我能够在每个工作上指定它的“重量”。但我没有。
- mark
  1. 仍然不明白。如果度数从20变为30,我如何在同一个SemaphoreSlim实例中完成它?
  2. 你是说一些后台作业的代码必须更改吗?
- mark
@mark 释放信号量10次。我是说,将函数传递到此队列的任何功能实际上负责确定“任务”执行什么以及在哪里执行等内容。队列所做的所有工作都是将并行度限制为固定数量。没有更多的事情了。 - Servy
  1. 我虽然慢,但我正在进步。所以,我可以轻松地增加容量。但是如何减少它呢?
  2. 我相信我已经解释过了,我目前没有办法检查获取的工作并确定是否安全运行在线程池线程上,或者必须为其生成专用线程。这是现实的悲哀。
- mark
显示剩余5条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接