通过谓词限制流

219

是否有一种Java 8的流操作可以将一个(可能是无限的)Stream限制到第一个不符合谓词的元素?

在Java 9中,我们可以像下面的示例一样使用takeWhile来打印所有小于10的数字。

IntStream
    .iterate(1, n -> n + 1)
    .takeWhile(n -> n < 10)
    .forEach(System.out::println);

在Java 8中没有这样的操作,最好的一般实现方式是什么?


1
可能有用的信息在:https://dev59.com/dGIj5IYBdhLWcg3w4Ivt。 - nobeh
我在想,架构师们如何能够在没有遇到这种用例的情况下通过“我们实际上可以用这个做什么”这个问题。截至Java 8,流(Streams)只对现有数据结构真正有帮助:-/ - Thorbjørn Ravn Andersen
使用Java 9,编写以下代码将更加容易: IntStream.iterate(1, n->n<10, n->n+1).forEach(System.out::print); - Marc Dzaebel
19个回答

167

JDK 9增加了takeWhiledropWhile操作。你的示例代码:

IntStream
    .iterate(1, n -> n + 1)
    .takeWhile(n -> n < 10)
    .forEach(System.out::println);

当在JDK 9下编译和运行时,代码将会按你预期的方式执行。

JDK 9已经发布。可以在此处下载:JDK 9发布版


3
JDK9дёӯStreamзҡ„takeWhile/dropWhileйў„и§Ҳж–ҮжЎЈзҡ„зӣҙжҺҘй“ҫжҺҘпјҡhttp://download.java.net/jdk9/docs/api/java/util/stream/Stream.html - Miles
10
takeWhiledropWhile在Scala、Python、Groovy、Ruby、Haskell和Clojure中非常普遍。与skiplimit的不对称性很不幸。也许本应该将skiplimit称为droptake,但除非您已经熟悉Haskell,否则这些名称并不那么直观易懂。 - Stuart Marks
3
我理解 dropXXXtakeXXX 更为流行,但个人认为类似 SQL 的 limitXXXskipXXX 也可以接受。我发现这种新的不对称性比术语的个人选择更加令人困惑... :) (顺便说一下:Scala 也有 drop(int)take(int) - Lukas Eder
4
好的,让我在生产环境中升级到Jdk 9。许多开发人员仍在使用Jdk8,这样的功能应该从Streams开始就包含在内。 - wilmol
2
IntStream .iterate(1, n -> n + 1) .takeWhile(n -> n < 10) can be simplified to IntStream .iterate(1, n -> n < 10, n -> n + 1) - Holger
显示剩余3条评论

85

使用 Java 8 Stream 可能可以完成这样的操作,但不能保证效率——例如,您不能一定将这样的操作并行化,因为您必须按顺序查看元素。

API 没有提供一种简单的方法来实现它,但可能最简单的方法是使用 Stream.iterator(),对 Iterator 进行“take-while”实现的包装,然后返回到 Spliterator,最后返回到 Stream。 或者——也许——包装 Spliterator,尽管在此实现中它不再能够被分割。

以下是未经测试的在 Spliterator 上实现 takeWhile 的内容:

static <T> Spliterator<T> takeWhile(
    Spliterator<T> splitr, Predicate<? super T> predicate) {
  return new Spliterators.AbstractSpliterator<T>(splitr.estimateSize(), 0) {
    boolean stillGoing = true;
    @Override public boolean tryAdvance(Consumer<? super T> consumer) {
      if (stillGoing) {
        boolean hadNext = splitr.tryAdvance(elem -> {
          if (predicate.test(elem)) {
            consumer.accept(elem);
          } else {
            stillGoing = false;
          }
        });
        return hadNext && stillGoing;
      }
      return false;
    }
  };
}

static <T> Stream<T> takeWhile(Stream<T> stream, Predicate<? super T> predicate) {
   return StreamSupport.stream(takeWhile(stream.spliterator(), predicate), false);
}

8
理论上,使用无状态谓词并行化takeWhile很容易。在并行批处理中评估条件(假设谓词不会抛出异常或在执行多几次时具有副作用)。问题是在Streams使用的递归分解(fork/join框架)上下文中执行它。实际上,Streams非常低效。 - Aleksandr Dubinsky
98
如果Streams没有过多地关注自动并行处理,它们本可以更好。在能够使用Streams的地方,只有一小部分地方需要并行处理。此外,如果Oracle真的很关心性能,他们本可以使JVM JIT自动向量化,并获得更大的性能提升,而不必打扰开发人员。现在这就是正确实现自动并行处理的方式。 - Aleksandr Dubinsky
现在Java 9已经发布,您应该更新这个答案。 - Radiodef
10
不,@Radiodef。这个问题明确要求Java 8的解决方案。 - Renato Back

58

allMatch() 是一种短路函数,因此您可以使用它来停止处理。主要的缺点是你需要测试两次:一次用于确定是否应该处理它,另一次用于确定是否继续。

IntStream
    .iterate(1, n -> n + 1)
    .peek(n->{if (n<10) System.out.println(n);})
    .allMatch(n->n < 10);

5
一开始我觉得这个方法名很不直观,但是文档证实Stream.allMatch()是一种短路操作。所以,即使在一个无限流如IntStream.iterate()中,它也会完成。当然,回过头来看,这是一种明智的优化。 - Bailey Parker
3
这很不错,但我认为它并没有很好地传达出其意图是“窥视”体的信息。如果下个月我遇到了它,我会花一分钟想想为什么之前的程序员要检查allMatch然后忽略答案。 - Joshua Goldberg
14
这个解决方案的缺点是它返回一个布尔值,因此您无法像通常情况下那样收集流的结果。 - neXus

37
作为@StuartMarks的回答的后续。我的StreamEx库有与当前JDK-9实现兼容的takeWhile操作。在JDK-9下运行时,它会通过MethodHandle.invokeExact(非常快)委托给JDK实现。在JDK-8下运行时,将使用“polyfill”实现。因此,使用我的库可以这样解决问题:
IntStreamEx.iterate(1, n -> n + 1)
           .takeWhile(n -> n < 10)
           .forEach(System.out::println);

你为什么没有在StreamEx类中实现它? - th0masb
@Someguy 我已经实现了它。 - Tagir Valeev

14

takeWhileprotonpack库提供的函数之一。

Stream<Integer> infiniteInts = Stream.iterate(0, i -> i + 1);
Stream<Integer> finiteInts = StreamUtils.takeWhile(infiniteInts, i -> i < 10);

assertThat(finiteInts.collect(Collectors.toList()),
           hasSize(10));

12

更新:Java 9 Stream 现在配备了一个takeWhile 方法。

不需要使用任何hack或其他解决方案,只需使用它!


我相信这可以大大改善: (也许有人可以使它线程安全)

Stream<Integer> stream = Stream.iterate(0, n -> n + 1);

TakeWhile.stream(stream, n -> n < 10000)
         .forEach(n -> System.out.print((n == 0 ? "" + n : "," + n)));

一种肯定能行的hack...不太优雅,但它有效 ~:D

class TakeWhile<T> implements Iterator<T> {

    private final Iterator<T> iterator;
    private final Predicate<T> predicate;
    private volatile T next;
    private volatile boolean keepGoing = true;

    public TakeWhile(Stream<T> s, Predicate<T> p) {
        this.iterator = s.iterator();
        this.predicate = p;
    }

    @Override
    public boolean hasNext() {
        if (!keepGoing) {
            return false;
        }
        if (next != null) {
            return true;
        }
        if (iterator.hasNext()) {
            next = iterator.next();
            keepGoing = predicate.test(next);
            if (!keepGoing) {
                next = null;
            }
        }
        return next != null;
    }

    @Override
    public T next() {
        if (next == null) {
            if (!hasNext()) {
                throw new NoSuchElementException("Sorry. Nothing for you.");
            }
        }
        T temp = next;
        next = null;
        return temp;
    }

    public static <T> Stream<T> stream(Stream<T> s, Predicate<T> p) {
        TakeWhile tw = new TakeWhile(s, p);
        Spliterator split = Spliterators.spliterator(tw, Integer.MAX_VALUE, Spliterator.ORDERED);
        return StreamSupport.stream(split, false);
    }

}

8

你可以使用Java8 + rxjava

import java.util.stream.IntStream;
import rx.Observable;


// Example 1)
IntStream intStream  = IntStream.iterate(1, n -> n + 1);
Observable.from(() -> intStream.iterator())
    .takeWhile(n ->
          {
                System.out.println(n);
                return n < 10;
          }
    ).subscribe() ;


// Example 2
IntStream intStream  = IntStream.iterate(1, n -> n + 1);
Observable.from(() -> intStream.iterator())
    .takeWhile(n -> n < 10)
    .forEach( n -> System.out.println(n));

7

实际上,在Java 8中有两种方法可以做到这一点,而不需要任何额外的库或使用Java 9。

如果你想在控制台上打印从2到20的数字,可以这样做:

IntStream.iterate(2, (i) -> i + 2).peek(System.out::println).allMatch(i -> i < 20);

或者

IntStream.iterate(2, (i) -> i + 2).peek(System.out::println).anyMatch(i -> i >= 20);

输出结果在两种情况下都相同:
2
4
6
8
10
12
14
16
18
20

还没有人提到anyMatch。这就是这篇文章的原因。


6

这是从JDK 9的java.util.stream.Stream.takeWhile(Predicate)复制的源代码。为了兼容JDK 8,稍有不同。

static <T> Stream<T> takeWhile(Stream<T> stream, Predicate<? super T> p) {
    class Taking extends Spliterators.AbstractSpliterator<T> implements Consumer<T> {
        private static final int CANCEL_CHECK_COUNT = 63;
        private final Spliterator<T> s;
        private int count;
        private T t;
        private final AtomicBoolean cancel = new AtomicBoolean();
        private boolean takeOrDrop = true;

        Taking(Spliterator<T> s) {
            super(s.estimateSize(), s.characteristics() & ~(Spliterator.SIZED | Spliterator.SUBSIZED));
            this.s = s;
        }

        @Override
        public boolean tryAdvance(Consumer<? super T> action) {
            boolean test = true;
            if (takeOrDrop &&               // If can take
                    (count != 0 || !cancel.get()) && // and if not cancelled
                    s.tryAdvance(this) &&   // and if advanced one element
                    (test = p.test(t))) {   // and test on element passes
                action.accept(t);           // then accept element
                return true;
            } else {
                // Taking is finished
                takeOrDrop = false;
                // Cancel all further traversal and splitting operations
                // only if test of element failed (short-circuited)
                if (!test)
                    cancel.set(true);
                return false;
            }
        }

        @Override
        public Comparator<? super T> getComparator() {
            return s.getComparator();
        }

        @Override
        public void accept(T t) {
            count = (count + 1) & CANCEL_CHECK_COUNT;
            this.t = t;
        }

        @Override
        public Spliterator<T> trySplit() {
            return null;
        }
    }
    return StreamSupport.stream(new Taking(stream.spliterator()), stream.isParallel()).onClose(stream::close);
}

4

以下是使用整数完成的版本 - 如问题所要求。

用法:

StreamUtil.takeWhile(IntStream.iterate(1, n -> n + 1), n -> n < 10);

这里是 StreamUtil 的代码:
import java.util.PrimitiveIterator;
import java.util.Spliterators;
import java.util.function.IntConsumer;
import java.util.function.IntPredicate;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;

public class StreamUtil
{
    public static IntStream takeWhile(IntStream stream, IntPredicate predicate)
    {
        return StreamSupport.intStream(new PredicateIntSpliterator(stream, predicate), false);
    }

    private static class PredicateIntSpliterator extends Spliterators.AbstractIntSpliterator
    {
        private final PrimitiveIterator.OfInt iterator;
        private final IntPredicate predicate;

        public PredicateIntSpliterator(IntStream stream, IntPredicate predicate)
        {
            super(Long.MAX_VALUE, IMMUTABLE);
            this.iterator = stream.iterator();
            this.predicate = predicate;
        }

        @Override
        public boolean tryAdvance(IntConsumer action)
        {
            if (iterator.hasNext()) {
                int value = iterator.nextInt();
                if (predicate.test(value)) {
                    action.accept(value);
                    return true;
                }
            }

            return false;
        }
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接