Press "Enter" to skip to content

A Case Study of Implementing an Efficient Shuffling Stream/Spliterator in Java

Sorting a Stream instance is straightforward and involves just a single API method call – achieving the opposite is not that easy.

In this article, we’ll see how to shuffle a Stream in Java – eagerly and lazily using Stream Collectors factories, and custom Spliterators.

Eager Shuffle Collector

One of the most pragmatic solutions to the above problem was already described by Heinz in this article.

Mainly, it involves encapsulating a compound operation of collecting a whole stream to a list, Collections#shuffle’ing it, and converting to Stream:

public static <T> Collector<T, ?, Stream<T>> toEagerShuffledStream() {
    return Collectors.collectingAndThen(
      toList(),
      list -> {
          Collections.shuffle(list);
          return list.stream();
      });
}

This solution will be optimal if we want to process all stream elements in random order, but it can bite back if we want to process only a small subset of them – it’s all because the whole collection gets shuffled in advance even if we request only a single element.

Let’s have a look at a simple benchmark and the results it generated:

@State(Scope.Benchmark)
public class RandomSpliteratorBenchmark {

    private List<String> source;

    @Param({"1", "10", "100", "1000", "10000", "10000"})
    public int limit;

    @Param({"100000"})
    public int size;

    @Setup(Level.Iteration)
    public void setUp() {
        source = IntStream.range(0, size)
          .boxed()
          .map(Object::toString)
          .collect(Collectors.toList());
    }

    @Benchmark
    public List<String> eager() {
        return source.stream()
          .collect(toEagerShuffledStream())
          .limit(limit)
          .collect(Collectors.toList());
    }
            (limit)   Mode  Cnt     Score     Error  Units
eager             1  thrpt    5   467.796 ±   9.074  ops/s
eager            10  thrpt    5   467.694 ±  17.166  ops/s
eager           100  thrpt    5   459.765 ±   8.048  ops/s
eager          1000  thrpt    5   467.934 ±  43.095  ops/s
eager         10000  thrpt    5   449.471 ±   5.549  ops/s
eager        100000  thrpt    5   331.111 ±   5.626  ops/s

As we can see, it scales pretty well along with the number of elements consumed from the resulting Stream, too bad the absolute value is not so impressive for relatively low numbers – in such situations, shuffling the whole collection beforehand turns out to be quite wasteful.

Let’s see what we can do about it.

Lazy Shuffle Collector

To spare precious CPU cycles, instead of pre-shuffling the whole collection, we can just fetch the number of elements that match the upstream demand.

To achieve that, we need to implement a custom Spliterator that will allow us to iterate through objects in random order, and then we’ll be able to construct a Stream instance by using a helper method from the StreamSupport class:

public class RandomSpliterator<T> implements Spliterator<T> {

    // ...

    public static <T> Collector<T, ?, Stream<T>> toLazyShuffledStream() {
        return Collectors.collectingAndThen(
          toList(),
          list -> StreamSupport.stream(
            new ShuffledSpliterator<>(list), false));
    }
}

Implementation Details

We can’t avoid evaluating the whole Stream even if we want to pick a single random element (which means there’s no support for infinite sequences) so it’s perfectly fine to initiate our RandomSpliterator<T> with a List<T>, but there’s a catch…

If a particular List implementation doesn’t support constant-time random access, this solution can turn out to be much slower than the eager approach. To protect ourselves from this scenario, we can perform a simple check when instantiating the Spliterator:

private RandomSpliterator(
  List<T> source, Supplier<? extends Random> random) {
    if (source.isEmpty()) { ... } // throw
    this.source = source instanceof RandomAccess 
      ? source 
      : new ArrayList<>(source);
    this.random = random.get();
}

Creating a new instance of ArrayList is costly, but negligible in comparison to the cost generated by implementations that don’t provide O(1) random access.

And now we can override the most important method – tryAdvance().

In this case, it’s fairly straightforward – in each iteration, we need to randomly pick and remove a random element from the source collection.

We can not worry about mutating the source since we don’t publish the RandomSpliterator, only a Collector which is based on it:

@Override
public boolean tryAdvance(Consumer<? super T> action) {
    int remaining = source.size();
    if (remaining > 0 ) {
        action.accept(source.remove(random.nextInt(remaining)));
        return true;
    } else {
        return false;
    }
}

Besides this, we need to implement three other methods:

@Override
public Spliterator<T> trySplit() {
    return null; // to indicate that split is not possible
}

@Override
public long estimateSize() {
    return source.size();
}

@Override
public int characteristics() {
    return SIZED;
}

And now, we try it and see that it works indeed:

IntStream.range(0, 10).boxed()
  .collect(toLazyShuffledStream())
  .forEach(System.out::println);

And the result:

3
4
8
1
7
6
5
0
2
9

Performance Considerations

In this implementation, we replaced N array element swaps with M lookups/removals, where:

  • N – the collection size
  • M – the number of picked items

Generally, a single lookup/removal from ArrayList is a more expensive operation than a single element swap which makes this solution not that scalable but significantly better performing for relatively low M values.

Let’s now see how does this solution compare to the eager approach showcased at the beginning(both calculated for a collection containing 100_000 objects):

            (limit)   Mode  Cnt     Score     Error  Units
eager             1  thrpt    5   467.796 ±   9.074  ops/s
eager            10  thrpt    5   467.694 ±  17.166  ops/s
eager           100  thrpt    5   459.765 ±   8.048  ops/s
eager          1000  thrpt    5   467.934 ±  43.095  ops/s
eager         10000  thrpt    5   449.471 ±   5.549  ops/s
eager        100000  thrpt    5   331.111 ±   5.626  ops/s
lazy              1  thrpt    5  1530.763 ±  72.096  ops/s
lazy             10  thrpt    5  1462.305 ±  23.860  ops/s
lazy            100  thrpt    5   823.212 ± 119.771  ops/s
lazy           1000  thrpt    5   166.786 ±  16.306  ops/s
lazy          10000  thrpt    5    19.475 ±   4.052  ops/s
lazy         100000  thrpt    5     4.097 ±   0.416  ops/s

As we can see, this solution outperforms the former if the number of processed Stream items is relatively low, but as the processed/collection_size ratio increases, the throughput drops drastically.

That’s all because of the additional overhead generated by removing elements from the ArrayList holding remaining objects – each removal requires shifting the internal array by one using a relatively expensive System#arraycopy method.

We can notice a similar pattern for much bigger collections (1_000_000 elements):

      (limit)    (size)   Mode  Cnt  Score   Err  Units
eager       1  10000000  thrpt    5  0.915        ops/s
eager      10  10000000  thrpt    5  0.783        ops/s
eager     100  10000000  thrpt    5  0.965        ops/s
eager    1000  10000000  thrpt    5  0.936        ops/s
eager   10000  10000000  thrpt    5  0.860        ops/s
lazy        1  10000000  thrpt    5  4.338        ops/s
lazy       10  10000000  thrpt    5  3.149        ops/s
lazy      100  10000000  thrpt    5  2.060        ops/s
lazy     1000  10000000  thrpt    5  0.370        ops/s
lazy    10000  10000000  thrpt    5  0.05         ops/s

…and much smaller ones (128 elements, mind the scale!):

       (limit)    (size)   Mode  Cnt       Score   Error  Units
eager        2     128    thrpt    5  246439.459          ops/s
eager        4     128    thrpt    5  333866.936          ops/s
eager        8     128    thrpt    5  340296.188          ops/s
eager       16     128    thrpt    5  345533.673          ops/s
eager       32     128    thrpt    5  231725.156          ops/s
eager       64     128    thrpt    5  314324.265          ops/s
eager      128     128    thrpt    5  270451.992          ops/s
lazy         2     128    thrpt    5  765989.718          ops/s
lazy         4     128    thrpt    5  659421.041          ops/s
lazy         8     128    thrpt    5  652685.515          ops/s
lazy        16     128    thrpt    5  470346.570          ops/s
lazy        32     128    thrpt    5  324174.691          ops/s
lazy        64     128    thrpt    5  186472.090          ops/s
lazy       128     128    thrpt    5  108105.699          ops/s

But, could we do better than this?

Further Performance Improvements

Unfortunately, the scalability of the existing solution is quite disappointing. Let’s try to improve it, but before we do, we should measure first:

As expected, Arraylist#remove turns out to be one of the hot spots – in other words, CPU spends a noticeable amount of time removing things from an ArrayList.

Why is that? Removal from an ArrayList involves removal of an element from an underlying array. The catch is that arrays in Java can’t be resized – each removal triggers a new smaller array creation:

private void fastRemove(Object[] es, int i) {
    modCount++;
    final int newSize;
    if ((newSize = size - 1) > i)
        System.arraycopy(es, i + 1, es, i, newSize - i);
    es[size = newSize] = null;
}

What can we do about this? Avoid removing elements from an ArrayList.

In order to do that, we could avoid shrinking the list physically, and shrink it logically by tracking its size separately:

class ImprovedRandomSpliterator<T, LIST extends RandomAccess & List<T>> 
  implements Spliterator<T> {

    private final Random random;
    private final List<T> source;
    private int size;

    ImprovedRandomSpliterator(
      LIST source, Supplier<? extends Random> random) {
        Objects.requireNonNull(source, "source can't be null");
        Objects.requireNonNull(random, "random can't be null");

        this.source = source;
        this.random = random.get();
        this.size = this.source.size();
    }

Luckily, we can avoid concurrency issues since instances of this Spliterator are not supposed to be shared between threads.

And now whenever we try to remove an element, we don’t need to actually create a new shrunken list. Instead, we can decrement our size tracker and ignore the remaining part of the list.

But straight before that, we need to swap the last element with the returned element:

@Override
public boolean tryAdvance(Consumer<? super T> action) {
    if (size > 0) {
        int nextIdx = random.nextInt(size);
        int lastIdx = --size;

        T last = source.set(lastIdx, null);
        T elem = source.set(nextIdx, last);
        action.accept(elem);
        return true;
    } else {
        return false;
    }
}

If we profile it now, we can see that the expensive call is gone:

We’re ready to rerun benchmarks and compare:

           (limit)  (size)   Mode  Cnt     Score     Error  Units
eager            1  100000  thrpt    5   454.396 ±  11.738  ops/s
eager           10  100000  thrpt    5   441.602 ±  40.503  ops/s
eager          100  100000  thrpt    5   456.167 ±  11.420  ops/s
eager         1000  100000  thrpt    5   443.149 ±   7.590  ops/s
eager        10000  100000  thrpt    5   431.375 ±  12.116  ops/s
eager       100000  100000  thrpt    5   328.376 ±   4.156  ops/s
lazy             1  100000  thrpt    5  1419.514 ±  58.778  ops/s
lazy            10  100000  thrpt    5  1336.452 ±  34.525  ops/s
lazy           100  100000  thrpt    5   926.438 ±  65.923  ops/s
lazy          1000  100000  thrpt    5   165.967 ±  17.135  ops/s
lazy         10000  100000  thrpt    5    19.673 ±   0.375  ops/s
lazy        100000  100000  thrpt    5     4.002 ±   0.305  ops/s
optimized        1  100000  thrpt    5  1478.069 ±  32.923  ops/s
optimized       10  100000  thrpt    5  1477.618 ±  72.917  ops/s
optimized      100  100000  thrpt    5  1448.584 ±  42.205  ops/s
optimized     1000  100000  thrpt    5  1435.818 ±  38.505  ops/s
optimized    10000  100000  thrpt    5   1060.88 ±  15.238  ops/s
optimized   100000  100000  thrpt    5   332.096 ±   7.071  ops/s


As you can see, we ended up with an implementation which is way more resistant performance-wise to the number of elements we reach for.

Actually, the improved implementation performs slightly better than the Collections#shuffle-based one even in the pessimistic scenario! Our work here is done.

And to put a small cherry on top, notice how we can leverage intersection types to ensure that only appropriate List instances get passed to it!

The Complete Example

…can be also found on GitHub.

import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.RandomAccess;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.function.Supplier;

class ImprovedRandomSpliterator<T, LIST extends RandomAccess & List<T>> 
  implements Spliterator<T> {

    private final Random random;
    private final List<T> source;
    private int size;

    ImprovedRandomSpliterator(
      LIST source, Supplier<? extends Random> random) {
        Objects.requireNonNull(source, "source can't be null");
        Objects.requireNonNull(random, "random can't be null");

        this.source = source;
        this.random = random.get();
        this.size = this.source.size();
    }

    @Override
    public boolean tryAdvance(Consumer<? super T> action) {
        if (size > 0) {
            int nextIdx = random.nextInt(size);
            int lastIdx = --size;

            T last = source.set(lastIdx, null);
            T elem = source.set(nextIdx, last);
            action.accept(elem);
            return true;
        } else {
            return false;
        }
    }

    @Override
    public Spliterator<T> trySplit() {
        return null;
    }

    @Override
    public long estimateSize() {
        return source.size();
    }

    @Override
    public int characteristics() {
        return SIZED;
    }
}
package com.pivovarit.stream;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import static java.util.stream.Collectors.toCollection;

public final class RandomCollectors {

    private RandomCollectors() {
    }

    public static <T> Collector<T, ?, Stream<T>> 
      toOptimizedLazyShuffledStream() {
        return Collectors.collectingAndThen(
            toCollection(ArrayList::new),
            list -> StreamSupport.stream(
              new ImprovedRandomSpliterator<>(list, Random::new), false));
    }

    public static <T> Collector<T, ?, Stream<T>> 
      toLazyShuffledStream() {
        return Collectors.collectingAndThen(
            toCollection(ArrayList::new),
            list -> StreamSupport.stream(
              new RandomSpliterator<>(list, Random::new), false));
    }

    public static <T> Collector<T, ?, Stream<T>> 
      toEagerShuffledStream() {
        return Collectors.collectingAndThen(
            toCollection(ArrayList::new),
            list -> {
                Collections.shuffle(list);
                return list.stream();
            });
    }
}



If you enjoyed the content, consider supporting the site:

Support the siteSupport the site