如何在ForkJoinPool中使用MDC?

14

如何在线程池中使用MDC?的基础上,如何在ForkJoinPool中使用MDC?具体来说,如何包装ForkJoinTask以便在执行任务之前设置MDC值?

4个回答

10
以下内容对我来说似乎有效:

以下是我试过的:

import java.lang.Thread.UncaughtExceptionHandler;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.MDC;

/**
 * A {@link ForkJoinPool} that inherits MDC contexts from the thread that queues a task.
 *
 * @author Gili Tzabari
 */
public final class MdcForkJoinPool extends ForkJoinPool
{
    /**
     * Creates a new MdcForkJoinPool.
     *
     * @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
     * @param factory     the factory for creating new threads. For default value, use
     *                    {@link #defaultForkJoinWorkerThreadFactory}.
     * @param handler     the handler for internal worker threads that terminate due to unrecoverable errors encountered
     *                    while executing tasks. For default value, use {@code null}.
     * @param asyncMode   if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
     *                    joined. This mode may be more appropriate than default locally stack-based mode in applications
     *                    in which worker threads only process event-style asynchronous tasks. For default value, use
     *                    {@code false}.
     * @throws IllegalArgumentException if parallelism less than or equal to zero, or greater than implementation limit
     * @throws NullPointerException     if the factory is null
     * @throws SecurityException        if a security manager exists and the caller is not permitted to modify threads
     *                                  because it does not hold
     *                                  {@link java.lang.RuntimePermission}{@code ("modifyThread")}
     */
    public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler,
        boolean asyncMode)
    {
        super(parallelism, factory, handler, asyncMode);
    }

    @Override
    public void execute(ForkJoinTask<?> task)
    {
        // See https://dev59.com/QW025IYBdhLWcg3wVkYR#19329668
        super.execute(wrap(task, MDC.getCopyOfContextMap()));
    }

    @Override
    public void execute(Runnable task)
    {
        // See https://dev59.com/QW025IYBdhLWcg3wVkYR#19329668
        super.execute(wrap(task, MDC.getCopyOfContextMap()));
    }

    private <T> ForkJoinTask<T> wrap(ForkJoinTask<T> task, Map<String, String> newContext)
    {
        return new ForkJoinTask<T>()
        {
            private static final long serialVersionUID = 1L;
            /**
             * If non-null, overrides the value returned by the underlying task.
             */
            private final AtomicReference<T> override = new AtomicReference<>();

            @Override
            public T getRawResult()
            {
                T result = override.get();
                if (result != null)
                    return result;
                return task.getRawResult();
            }

            @Override
            protected void setRawResult(T value)
            {
                override.set(value);
            }

            @Override
            protected boolean exec()
            {
                // According to ForkJoinTask.fork() "it is a usage error to fork a task more than once unless it has completed
                // and been reinitialized". We therefore assume that this method does not have to be thread-safe.
                Map<String, String> oldContext = beforeExecution(newContext);
                try
                {
                    task.invoke();
                    return true;
                }
                finally
                {
                    afterExecution(oldContext);
                }
            }
        };
    }

    private Runnable wrap(Runnable task, Map<String, String> newContext)
    {
        return () ->
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                task.run();
            }
            finally
            {
                afterExecution(oldContext);
            }
        };
    }

    /**
     * Invoked before running a task.
     *
     * @param newValue the new MDC context
     * @return the old MDC context
     */
    private Map<String, String> beforeExecution(Map<String, String> newValue)
    {
        Map<String, String> previous = MDC.getCopyOfContextMap();
        if (newValue == null)
            MDC.clear();
        else
            MDC.setContextMap(newValue);
        return previous;
    }

    /**
     * Invoked after running a task.
     *
     * @param oldValue the old MDC context
     */
    private void afterExecution(Map<String, String> oldValue)
    {
        if (oldValue == null)
            MDC.clear();
        else
            MDC.setContextMap(oldValue);
    }
}

并且

import java.util.Map;
import java.util.concurrent.CountedCompleter;
import org.slf4j.MDC;

/**
 * A {@link CountedCompleter} that inherits MDC contexts from the thread that queues a task.
 *
 * @author Gili Tzabari
 * @param <T> The result type returned by this task's {@code get} method
 */
public abstract class MdcCountedCompleter<T> extends CountedCompleter<T>
{
    private static final long serialVersionUID = 1L;
    private final Map<String, String> newContext;

    /**
     * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
     */
    protected MdcCountedCompleter()
    {
        this(null);
    }

    /**
     * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
     *
     * @param completer this task's completer; {@code null} if none
     */
    protected MdcCountedCompleter(CountedCompleter<?> completer)
    {
        super(completer);
        this.newContext = MDC.getCopyOfContextMap();
    }

    /**
     * The main computation performed by this task.
     */
    protected abstract void computeWithContext();

    @Override
    public final void compute()
    {
        Map<String, String> oldContext = beforeExecution(newContext);
        try
        {
            computeWithContext();
        }
        finally
        {
            afterExecution(oldContext);
        }
    }

    /**
     * Invoked before running a task.
     *
     * @param newValue the new MDC context
     * @return the old MDC context
     */
    private Map<String, String> beforeExecution(Map<String, String> newValue)
    {
        Map<String, String> previous = MDC.getCopyOfContextMap();
        if (newValue == null)
            MDC.clear();
        else
            MDC.setContextMap(newValue);
        return previous;
    }

    /**
     * Invoked after running a task.
     *
     * @param oldValue the old MDC context
     */
    private void afterExecution(Map<String, String> oldValue)
    {
        if (oldValue == null)
            MDC.clear();
        else
            MDC.setContextMap(oldValue);
    }
}
  1. 使用MdcForkJoinPool而不是普通的ForkJoinPool来运行您的任务。
  2. 继承MdcCountedCompleter而不是CountedCompleter

1
有没有办法覆盖默认的ForkJoinPool实现,使用像您发布的自定义实现一样的方式? 我不想将自己的执行器服务注入到每个CompletableFuture异步调用中。 - Ihor M.
1
我可以使用自己的ForkJoinWorkerThreadFactory覆盖ForkJoinPool,但显然这还不够,因为我在一个线程上设置了MDC上下文,但是当新任务到达时,该线程对象似乎没有被回收(任务被添加到工作队列中,ForkJoinWorkerThread一次处理一个)。因此,我需要为ForkJoinTask而不是ForkJoinWorkerThread设置/取消设置MDC上下文。 - Ihor M.
@IhorM。你是否应用了上面答案中的所有部分?具体来说,MdcForkJoinPool包装了你的任务,并在每次执行前/后设置MDC。 - Gili
1
@Gill 问题:为什么你不覆盖 ForkJoinPoolsubmit() 方法? - Ihor M.
1
@Gili 我使用 new MdcFJPool().submit(() -> IntStream.range(1, 10).parallel().peek(() -> System.out.println(Thread.currentThread().getName() + " mdc: " + MDC.get("blah"))).sum())。第一个 FJ 工作线程具有 MDC 值,但是因为 .sum() 使用了 ReduceTask,它继承了 AbstractTask > CountedCompleter,所以所有其他线程都没有 MDC 值。我看不到在这里插入 MdcCountedCompleter 的方法。 - martin-g
显示剩余10条评论

1

以下是一些与@Gili的答案相关的附加信息。

测试表明解决方案可行(请注意,可能会有没有上下文的行,但至少它们不会是错误的上下文,这是正常ForkJoinPool发生的情况)。

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertThat;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import org.junit.Test;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.OutputStreamAppender;

public class MDCForkJoinPoolTest {

    private static final Logger log = (Logger) LoggerFactory.getLogger("mdc-test");

    // you can demonstrate the problem I'm trying to fix by changing the below to a normal ForkJoinPool and then running the test
    private ForkJoinPool threads = new MDCForkJoinPool(16);
    private Semaphore threadsRunning = new Semaphore(-99);
    private ByteArrayOutputStream bio = new ByteArrayOutputStream();

    @Test
    public void shouldCopyManagedDiagnosticContextWhenUsingForkJoinPool() throws Exception {
        for (int i = 0 ; i < 100; i++) {
            Thread t = new Thread(simulatedRequest(), "MDC-Test-"+i);
            t.setDaemon(true);
            t.start();
        }

        // set up the appender to grab the output
        LoggerContext lc = (LoggerContext) LoggerFactory.getILoggerFactory();
        OutputStreamAppender<ILoggingEvent> appender = new OutputStreamAppender<>();
        LogbackEncoder encoder = new LogbackEncoder();
        encoder.setPattern("%X{mdc_val:-}=%m%n");
        encoder.setContext(lc);
        encoder.start();
        appender.setEncoder(encoder);
        appender.setImmediateFlush(true);
        appender.setContext(lc);
        appender.setOutputStream(bio);
        appender.start();
        log.addAppender(appender);
        log.setAdditive(false);
        log.setLevel(Level.INFO);

        assertThat("timed out waiting for threads to complete.", threadsRunning.tryAcquire(300, TimeUnit.SECONDS), is(true));

        Set<String> ids = new HashSet<>();
        try (BufferedReader r = new BufferedReader(new InputStreamReader(new ByteArrayInputStream(bio.toByteArray()), Charset.forName("utf8")))) {
            r.lines().forEach(line->{
                System.out.println(line);
               String[] vals = line.split("=");
               if (!vals[0].isEmpty()) {
                   ids.add(vals[0]);
                   assertThat(vals[1], startsWith(vals[0]));
               }
            });
        }

        assertThat(ids.size(), is(100));
    }

    private Runnable simulatedRequest() {
        return () -> {
            String id = UUID.randomUUID().toString();
            MDC.put("mdc_val", id);
            Map<String, String> context = MDC.getCopyOfContextMap();
            threads.submit(()->{
                MDC.setContextMap(context);
                IntStream.range(0, 100).parallel().forEach((i)->{
                   log.info("{} - {}", id, i); 
                });
            }).join();
            threadsRunning.release();
        };
    }
}

此外,以下是原始答案中应该重写的其他方法。
    @Override
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }

    @Override
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }

    @Override
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()), result);
    }

    @Override
    public ForkJoinTask<?> submit(Runnable task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }

    private <T> Callable<T> wrap(Callable<T> task, Map<String, String> newContext)
    {
        return () ->
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                return task.call();
            }
            finally
            {
                afterExecution(oldContext);
            }
        };
    }

@BenL. - 我试过了你的测试(谢谢),但正如你所指出的,有些情况是没有上下文的。通过一些基本的调试,似乎需要使用MdcCountedCompleter,但我不知道如何使用它。有什么想法吗? - Alain P

0

我不熟悉ForkJoinPool,但您可以将感兴趣的MDC键/值传递给在提交到ForkJoinPool之前实例化的ForkJoinTask实例。

鉴于从logback版本1.1.5开始,MDC值不会被子线程继承,因此选择不多。它们是:

  1. 在实例化ForkJoinTask时传递相关的MDC键/值
  2. 扩展ForkJoinPool,以便将MDC键/值传递给新创建的线程
  3. 创建自己的ThreadFactory,将MDC键/值设置为新创建的线程

请注意,我实际上没有实现选项2或3。


Ceki,我在寻求一种自动继承排队任务线程MDC的方法。选项1的问题在于用户经常忘记手动继承MDC。选项2和3的问题在于MDC值应从排队线程而非执行线程中获取。单个执行线程将运行多个任务,每个任务可能具有不同的MDC值。希望这解释清楚了我的意思。无论如何感谢你。 - Gili
关于您的评论: 项目1:它必须是一些自定义的ForkJoinTask实现。由于您无法将MDC上下文注入到例如CompletableFuture$AsyncRun中,因此无法实现。 项目2:您可以扩展ForkJoinPool,但是当新任务添加到工作队列时,不应该传递MDC上下文,而不是在创建线程时。由于相同的线程被重复使用来处理多个任务。 项目3:这不是可行的解决方案,因为您将为仅构造一次但处理多个任务的线程设置MDC上下文。 - Ihor M.

0

我被同一个问题困扰着。 显然,每次需要运行并行Java流时都使用自定义的ForkJoinPool并不理想,因为它需要大量的代码。

然而,我认为我找到了一个比主题创建者提出的更小的解决方案:

@Slf4j
public class MdcTest {

    public static void main(String[] args) {
        List<Integer> list = new ArrayList<>();
        for (int i = 0; i < 10; i++) {
            list.add(i);
        }
        
        MDC.put("someKey", "iter");
        
        list.stream()
            .parallel()
            .peek(mdcParallelStreamKeeper())
            .forEach(i -> log.info("List item={} with MDC={}", i, MDC.getCopyOfContextMap()));
    }

    private static Consumer<? super Integer> mdcParallelStreamKeeper() {
        Map<String, String> contextMap = MDC.getCopyOfContextMap();
        return i -> {
            MDC.clear();
            MDC.setContextMap(contextMap);
        };
    }
}

基本上,你只需要在某个地方拥有mdcParallelStreamKeeper方法并且仅使用它。 更新 #1 这种方法存在MDC清理的问题。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接