From 20a10b88413713c5bc74cb8940a92b6ac5f70933 Mon Sep 17 00:00:00 2001 From: Yuepeng Pan Date: Mon, 19 Jan 2026 20:19:06 +0800 Subject: [PATCH] [FLINK-38943][runtime] Support Adaptive Partition Selection for RescalePartitioner & RebalancePartitioner --- .../all_taskmanager_network_section.html | 12 + ...tty_shuffle_environment_configuration.html | 12 + .../NettyShuffleEnvironmentOptions.java | 28 ++ .../writer/AdaptiveLoadBasedRecordWriter.java | 107 ++++++++ .../api/writer/RecordWriterBuilder.java | 29 ++- .../api/writer/ResultPartitionWriter.java | 8 + .../runtime/io/network/buffer/BufferPool.java | 5 + .../io/network/buffer/LocalBufferPool.java | 5 + .../BufferWritingResultPartition.java | 32 ++- .../io/network/partition/ResultPartition.java | 4 + .../streaming/runtime/tasks/StreamTask.java | 12 + .../AdaptiveLoadBasedRecordWriterTest.java | 241 ++++++++++++++++++ 12 files changed, 484 insertions(+), 11 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java diff --git a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html index 0036d781c128e..4a0d72e2273e3 100644 --- a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html +++ b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html @@ -8,6 +8,18 @@ + +
taskmanager.network.adaptive-partitioner.enabled
+ false + Boolean + Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks. + + +
taskmanager.network.adaptive-partitioner.max-traverse-size
+ 4 + Integer + Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when taskmanager.network.adaptive-partitioner.enabled is enabled.
Note, the value of the configuration option must be greater than `1`. +
taskmanager.network.compression.codec
LZ4 diff --git a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html index 3e6012bea1d91..adbb3c2adb24b 100644 --- a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html +++ b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html @@ -26,6 +26,18 @@ Boolean Enable SSL support for the taskmanager data transport. This is applicable only when the global flag for internal SSL (security.ssl.internal.enabled) is set to true + +
taskmanager.network.adaptive-partitioner.enabled
+ false + Boolean + Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks. + + +
taskmanager.network.adaptive-partitioner.max-traverse-size
+ 4 + Integer + Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when taskmanager.network.adaptive-partitioner.enabled is enabled.
Note, the value of the configuration option must be greater than `1`. +
taskmanager.network.compression.codec
LZ4 diff --git a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java index 018b11e5ecd7a..e3e0d42d73e28 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java @@ -325,6 +325,34 @@ public enum CompressionCodec { code(NETWORK_REQUEST_BACKOFF_MAX.key())) .build()); + /** Whether to improve the rebalance and rescale partitioners to adaptive partition. */ + @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK) + public static final ConfigOption ADAPTIVE_PARTITIONER_ENABLED = + key("taskmanager.network.adaptive-partitioner.enabled") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks."); + + /** + * Maximum number of channels to traverse when looking for the idlest channel for rescale and + * rebalance partitioners when {@link #ADAPTIVE_PARTITIONER_ENABLED} is true. + */ + @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK) + public static final ConfigOption ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE = + key("taskmanager.network.adaptive-partitioner.max-traverse-size") + .intType() + .defaultValue(4) + .withDescription( + Description.builder() + .text( + "Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when %s is enabled.", + code(ADAPTIVE_PARTITIONER_ENABLED.key())) + .linebreak() + .text( + "Note, the value of the configuration option must be greater than `1`.") + .build()); + // ------------------------------------------------------------------------ /** Not intended to be instantiated. */ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java new file mode 100644 index 0000000000000..d71aaebb85326 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.api.writer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.io.IOReadableWritable; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** A record writer based on load of downstream tasks. */ +@Internal +public final class AdaptiveLoadBasedRecordWriter + extends RecordWriter { + + private final int maxTraverseSize; + + private int currentChannel = -1; + + private final int numberOfSubpartitions; + + AdaptiveLoadBasedRecordWriter( + ResultPartitionWriter writer, long timeout, String taskName, int maxTraverseSize) { + super(writer, timeout, taskName); + this.numberOfSubpartitions = writer.getNumberOfSubpartitions(); + this.maxTraverseSize = Math.min(maxTraverseSize, numberOfSubpartitions); + } + + @Override + public void emit(T record) throws IOException { + checkErroneous(); + + currentChannel = getIdlestChannelIndex(); + + ByteBuffer byteBuffer = serializeRecord(serializer, record); + targetPartition.emitRecord(byteBuffer, currentChannel); + + if (flushAlways) { + targetPartition.flush(currentChannel); + } + } + + @VisibleForTesting + int getIdlestChannelIndex() { + int bestChannelBuffersCount = Integer.MAX_VALUE; + long bestChannelBytesInQueue = Long.MAX_VALUE; + int bestChannel = 0; + for (int i = 1; i <= maxTraverseSize; i++) { + int candidateChannel = (currentChannel + i) % numberOfSubpartitions; + int candidateChannelBuffersCount = + targetPartition.getBuffersCountUnsafe(candidateChannel); + long candidateChannelBytesInQueue = + targetPartition.getBytesInQueueUnsafe(candidateChannel); + + if (candidateChannelBuffersCount == 0) { + // If there isn't any pending data in the current channel, choose this channel + // directly. + return candidateChannel; + } + + if (candidateChannelBuffersCount < bestChannelBuffersCount + || (candidateChannelBuffersCount == bestChannelBuffersCount + && candidateChannelBytesInQueue < bestChannelBytesInQueue)) { + bestChannel = candidateChannel; + bestChannelBuffersCount = candidateChannelBuffersCount; + bestChannelBytesInQueue = candidateChannelBytesInQueue; + } + } + return bestChannel; + } + + /** Copy from {@link ChannelSelectorRecordWriter#broadcastEmit}. */ + @Override + public void broadcastEmit(T record) throws IOException { + checkErroneous(); + + // Emitting to all channels in a for loop can be better than calling + // ResultPartitionWriter#broadcastRecord because the broadcastRecord + // method incurs extra overhead. + ByteBuffer serializedRecord = serializeRecord(serializer, record); + for (int channelIndex = 0; channelIndex < numberOfSubpartitions; channelIndex++) { + serializedRecord.rewind(); + emit(record, channelIndex); + } + + if (flushAlways) { + flushAll(); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java index 78e6424844d86..0740f208a03ae 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java @@ -18,7 +18,9 @@ package org.apache.flink.runtime.io.network.api.writer; +import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.util.Preconditions; /** Utility class to encapsulate the logic of building a {@link RecordWriter} instance. */ public class RecordWriterBuilder { @@ -29,6 +31,11 @@ public class RecordWriterBuilder { private String taskName = "test"; + private boolean enabledAdaptivePartitioner = false; + + private int maxTraverseSize = + NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.defaultValue(); + public RecordWriterBuilder setChannelSelector(ChannelSelector selector) { this.selector = selector; return this; @@ -44,11 +51,29 @@ public RecordWriterBuilder setTaskName(String taskName) { return this; } + public RecordWriterBuilder setEnabledAdaptivePartitioner( + boolean enabledAdaptivePartitioner) { + this.enabledAdaptivePartitioner = enabledAdaptivePartitioner; + return this; + } + + public RecordWriterBuilder setMaxTraverseSize(int maxTraverseSize) { + Preconditions.checkArgument( + maxTraverseSize > 1, + "The value of '%s' must be greater than 1 when '%s' is enabled.", + NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(), + NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED.key()); + this.maxTraverseSize = maxTraverseSize; + return this; + } + public RecordWriter build(ResultPartitionWriter writer) { if (selector.isBroadcast()) { return new BroadcastRecordWriter<>(writer, timeout, taskName); - } else { - return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName); } + if (enabledAdaptivePartitioner) { + return new AdaptiveLoadBasedRecordWriter<>(writer, timeout, taskName, maxTraverseSize); + } + return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java index e283fac596fbc..04cfa0ad33d22 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java @@ -60,6 +60,14 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid /** Writes the given serialized record to the target subpartition. */ void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException; + default long getBytesInQueueUnsafe(int targetSubpartition) { + return 0; + } + + default int getBuffersCountUnsafe(int targetSubpartition) { + return 0; + } + /** * Writes the given serialized record to all subpartitions. One can also achieve the same effect * by emitting the same record to all subpartitions one by one, however, this method can have diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java index c574607e28e16..5061d08353d1a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java @@ -75,4 +75,9 @@ public interface BufferPool extends BufferProvider, BufferRecycler { /** Returns the number of used buffers of this buffer pool. */ int bestEffortGetNumOfUsedBuffers(); + + /** Returns the requested buffer count for target channel. */ + default int getBuffersCountUnsafe(int targetChannel) { + return 0; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java index 873414c6fe2b8..f31bd95f1a15d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java @@ -824,4 +824,9 @@ public static AvailabilityStatus from( } } } + + @Override + public int getBuffersCountUnsafe(int targetChannel) { + return subpartitionBuffersCount[targetChannel]; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java index eae9260642a5a..a39575973c00d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java @@ -65,7 +65,7 @@ public abstract class BufferWritingResultPartition extends ResultPartition { private TimerGauge hardBackPressuredTimeMsPerSecond = new TimerGauge(); - private long totalWrittenBytes; + private final long[] subpartitionWrittenBytes; public BufferWritingResultPartition( String owningTaskName, @@ -91,6 +91,7 @@ public BufferWritingResultPartition( this.subpartitions = checkNotNull(subpartitions); this.unicastBufferBuilders = new BufferBuilder[subpartitions.length]; + this.subpartitionWrittenBytes = new long[subpartitions.length]; } @Override @@ -114,6 +115,11 @@ public int getNumberOfQueuedBuffers() { @Override public long getSizeOfQueuedBuffersUnsafe() { + long totalWrittenBytes = 0; + for (int i = 0; i < subpartitions.length; i++) { + totalWrittenBytes += subpartitionWrittenBytes[i]; + } + long totalNumberOfBytes = 0; for (ResultSubpartition subpartition : subpartitions) { @@ -123,6 +129,12 @@ public long getSizeOfQueuedBuffersUnsafe() { return totalWrittenBytes - totalNumberOfBytes; } + @Override + public long getBytesInQueueUnsafe(int targetSubpartition) { + return subpartitionWrittenBytes[targetSubpartition] + - subpartitions[targetSubpartition].getTotalNumberOfBytesUnsafe(); + } + @Override public int getNumberOfQueuedBuffers(int targetSubpartition) { checkArgument(targetSubpartition >= 0 && targetSubpartition < numSubpartitions); @@ -151,7 +163,7 @@ protected void flushAllSubpartitions(boolean finishProducers) { @Override public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException { - totalWrittenBytes += record.remaining(); + subpartitionWrittenBytes[targetSubpartition] += record.remaining(); BufferBuilder buffer = appendUnicastDataForNewRecord(record, targetSubpartition); @@ -171,7 +183,9 @@ public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOExcep @Override public void broadcastRecord(ByteBuffer record) throws IOException { - totalWrittenBytes += ((long) record.remaining() * numSubpartitions); + for (int i = 0; i < subpartitions.length; i++) { + subpartitionWrittenBytes[i] += record.remaining(); + } BufferBuilder buffer = appendBroadcastDataForNewRecord(record); @@ -197,11 +211,11 @@ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws try (BufferConsumer eventBufferConsumer = EventSerializer.toBufferConsumer(event, isPriorityEvent)) { - totalWrittenBytes += ((long) eventBufferConsumer.getWrittenBytes() * numSubpartitions); - for (ResultSubpartition subpartition : subpartitions) { + for (int i = 0; i < subpartitions.length; i++) { // Retain the buffer so that it can be recycled by each subpartition of // targetPartition - subpartition.add(eventBufferConsumer.copy(), 0); + subpartitions[i].add(eventBufferConsumer.copy(), 0); + subpartitionWrittenBytes[i] += eventBufferConsumer.getWrittenBytes(); } } } @@ -246,8 +260,8 @@ public void finish() throws IOException { finishBroadcastBufferBuilder(); finishUnicastBufferBuilders(); - for (ResultSubpartition subpartition : subpartitions) { - totalWrittenBytes += subpartition.finish(); + for (int i = 0; i < subpartitions.length; i++) { + subpartitionWrittenBytes[i] += subpartitions[i].finish(); } super.finish(); @@ -340,7 +354,7 @@ private void addToSubpartition( protected int addToSubpartition( int targetSubpartition, BufferConsumer bufferConsumer, int partialRecordLength) throws IOException { - totalWrittenBytes += bufferConsumer.getWrittenBytes(); + subpartitionWrittenBytes[targetSubpartition] += bufferConsumer.getWrittenBytes(); return subpartitions[targetSubpartition].add(bufferConsumer, partialRecordLength); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java index 6cbcfc0c598e3..47b52caa8d6e6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java @@ -202,6 +202,10 @@ public boolean isNumberOfPartitionConsumerUndefined() { /** Returns the number of queued buffers of the given target subpartition. */ public abstract int getNumberOfQueuedBuffers(int targetSubpartition); + public int getBuffersCountUnsafe(int targetSubpartition) { + return bufferPool.getBuffersCountUnsafe(targetSubpartition); + } + public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) { this.bufferPool.setMaxOverdraftBuffersPerGate(maxOverdraftBuffersPerGate); } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 505e67a30d78d..2a34487b3d6ad 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.execution.RecoveryClaimMode; import org.apache.flink.core.fs.AutoCloseableRegistry; @@ -95,6 +96,7 @@ import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer; @@ -1830,12 +1832,22 @@ private static RecordWriter>> crea ((ConfigurableStreamPartitioner) outputPartitioner).configure(numKeyGroups); } } + Configuration conf = environment.getJobConfiguration(); + final boolean enabledAdaptivePartitioner = + (outputPartitioner instanceof RebalancePartitioner + || outputPartitioner instanceof RescalePartitioner) + && conf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED) + && bufferWriter.getNumberOfSubpartitions() > 1; + final int maxTraverseSize = + conf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE); RecordWriter>> output = new RecordWriterBuilder>>() .setChannelSelector(outputPartitioner) .setTimeout(bufferTimeout) .setTaskName(taskNameWithSubtask) + .setEnabledAdaptivePartitioner(enabledAdaptivePartitioner) + .setMaxTraverseSize(maxTraverseSize) .build(bufferWriter); output.setMetricGroup(environment.getMetricGroup().getIOMetricGroup()); return output; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java new file mode 100644 index 0000000000000..b1b21c4585e03 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.api.writer; + +import org.apache.flink.core.io.IOReadableWritable; +import org.apache.flink.runtime.checkpoint.CheckpointException; +import org.apache.flink.runtime.event.AbstractEvent; +import org.apache.flink.runtime.io.network.api.StopMode; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Test for {@link AdaptiveLoadBasedRecordWriter}. */ +class AdaptiveLoadBasedRecordWriterTest { + + @Test + void testValidateMaxTraverseSize() { + // Test invalid value. + final RecordWriterBuilder builder = new RecordWriterBuilder<>(); + assertThatThrownBy(() -> builder.setMaxTraverseSize(-1)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.setMaxTraverseSize(0)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> builder.setMaxTraverseSize(1)) + .isInstanceOf(IllegalArgumentException.class); + + // Test for valid value. + builder.setMaxTraverseSize(2); + } + + static Stream getTestingParams() { + return Stream.of( + // maxTraverseSize, bytesPerPartition, bufferPerPartition, + // targetResultPartitionIndex + Arguments.of(2, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4}, 0), + Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4}, 0), + Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0}, 0), + Arguments.of(3, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4}, 0), + Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4}, 0), + Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0}, 0), + Arguments.of( + 2, new long[] {1L, 2L, 3L, 1L, 2L, 3L}, new int[] {2, 3, 4, 2, 3, 4}, 0), + Arguments.of( + 2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {3, 2, 4, 2, 3, 4}, 1), + Arguments.of( + 2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {0, 0, 4, 2, 3, 4}, 0), + Arguments.of( + 4, new long[] {1L, 2L, 3L, 0L, 2L, 3L}, new int[] {2, 3, 4, 2, 3, 4}, 3), + Arguments.of( + 4, new long[] {1L, 1L, 1L, 1L, 2L, 3L}, new int[] {2, 3, 4, 0, 3, 4}, 3), + Arguments.of( + 4, new long[] {0L, 0L, 0L, 0L, 2L, 3L}, new int[] {2, 3, 0, 2, 3, 4}, 2)); + } + + @ParameterizedTest( + name = + "maxTraverseSize: {0}, bytesPerPartition: {1}, bufferPerPartition: {2}, targetResultPartitionIndex: {3}") + @MethodSource("getTestingParams") + void testGetIdlestChannelIndex( + int maxTraverseSize, + long[] bytesPerPartition, + int[] buffersPerPartition, + int targetResultPartitionIndex) { + TestingResultPartitionWriter resultPartitionWriter = + getTestingResultPartitionWriter(bytesPerPartition, buffersPerPartition); + + AdaptiveLoadBasedRecordWriter adaptiveLoadBasedRecordWriter = + new AdaptiveLoadBasedRecordWriter<>( + resultPartitionWriter, 5L, "testingTask", maxTraverseSize); + assertThat(adaptiveLoadBasedRecordWriter.getIdlestChannelIndex()) + .isEqualTo(targetResultPartitionIndex); + } + + private static TestingResultPartitionWriter getTestingResultPartitionWriter( + long[] bytesPerPartition, int[] buffersPerPartition) { + final Map bytesPerPartitionMap = new HashMap<>(); + final Map bufferPerPartitionMap = new HashMap<>(); + for (int i = 0; i < bytesPerPartition.length; i++) { + bytesPerPartitionMap.put(i, bytesPerPartition[i]); + bufferPerPartitionMap.put(i, buffersPerPartition[i]); + } + + return new TestingResultPartitionWriter( + buffersPerPartition.length, bytesPerPartitionMap, bufferPerPartitionMap); + } + + /** Test utils class to simulate {@link ResultPartitionWriter}. */ + static final class TestingResultPartitionWriter implements ResultPartitionWriter { + + private final int numberOfSubpartitions; + private final Map bytesPerPartition; + private final Map bufferPerPartition; + + TestingResultPartitionWriter( + int numberOfSubpartitions, + Map bytesPerPartition, + Map bufferPerPartition) { + this.numberOfSubpartitions = numberOfSubpartitions; + this.bytesPerPartition = bytesPerPartition; + this.bufferPerPartition = bufferPerPartition; + } + + // The methods that are used in the testing. + + @Override + public long getBytesInQueueUnsafe(int targetSubpartition) { + return bytesPerPartition.getOrDefault(targetSubpartition, 0L); + } + + @Override + public int getBuffersCountUnsafe(int targetSubpartition) { + return bufferPerPartition.getOrDefault(targetSubpartition, 0); + } + + @Override + public int getNumberOfSubpartitions() { + return numberOfSubpartitions; + } + + // The methods that are not used. + + @Override + public void setup() throws IOException {} + + @Override + public ResultPartitionID getPartitionId() { + return null; + } + + @Override + public int getNumTargetKeyGroups() { + return 0; + } + + @Override + public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) {} + + @Override + public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException {} + + @Override + public void broadcastRecord(ByteBuffer record) throws IOException {} + + @Override + public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) + throws IOException {} + + @Override + public void alignedBarrierTimeout(long checkpointId) throws IOException {} + + @Override + public void abortCheckpoint(long checkpointId, CheckpointException cause) {} + + @Override + public void notifyEndOfData(StopMode mode) throws IOException {} + + @Override + public CompletableFuture getAllDataProcessedFuture() { + return null; + } + + @Override + public void setMetricGroup(TaskIOMetricGroup metrics) {} + + @Override + public ResultSubpartitionView createSubpartitionView( + ResultSubpartitionIndexSet indexSet, + BufferAvailabilityListener availabilityListener) + throws IOException { + return null; + } + + @Override + public void flushAll() {} + + @Override + public void flush(int subpartitionIndex) {} + + @Override + public void fail(@Nullable Throwable throwable) {} + + @Override + public void finish() throws IOException {} + + @Override + public boolean isFinished() { + return false; + } + + @Override + public void release(Throwable cause) {} + + @Override + public boolean isReleased() { + return false; + } + + @Override + public void close() throws Exception {} + + @Override + public CompletableFuture getAvailableFuture() { + return null; + } + } +}