diff --git a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
index 0036d781c128e..4a0d72e2273e3 100644
--- a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
+++ b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
@@ -8,6 +8,18 @@
+
+ taskmanager.network.adaptive-partitioner.enabled |
+ false |
+ Boolean |
+ Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks. |
+
+
+ taskmanager.network.adaptive-partitioner.max-traverse-size |
+ 4 |
+ Integer |
+ Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when taskmanager.network.adaptive-partitioner.enabled is enabled. Note, the value of the configuration option must be greater than `1`. |
+
taskmanager.network.compression.codec |
LZ4 |
diff --git a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
index 3e6012bea1d91..adbb3c2adb24b 100644
--- a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
+++ b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
@@ -26,6 +26,18 @@
Boolean |
Enable SSL support for the taskmanager data transport. This is applicable only when the global flag for internal SSL (security.ssl.internal.enabled) is set to true |
+
+ taskmanager.network.adaptive-partitioner.enabled |
+ false |
+ Boolean |
+ Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks. |
+
+
+ taskmanager.network.adaptive-partitioner.max-traverse-size |
+ 4 |
+ Integer |
+ Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when taskmanager.network.adaptive-partitioner.enabled is enabled. Note, the value of the configuration option must be greater than `1`. |
+
taskmanager.network.compression.codec |
LZ4 |
diff --git a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
index 018b11e5ecd7a..e3e0d42d73e28 100644
--- a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
+++ b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
@@ -325,6 +325,34 @@ public enum CompressionCodec {
code(NETWORK_REQUEST_BACKOFF_MAX.key()))
.build());
+ /** Whether to improve the rebalance and rescale partitioners to adaptive partition. */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption ADAPTIVE_PARTITIONER_ENABLED =
+ key("taskmanager.network.adaptive-partitioner.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks.");
+
+ /**
+ * Maximum number of channels to traverse when looking for the idlest channel for rescale and
+ * rebalance partitioners when {@link #ADAPTIVE_PARTITIONER_ENABLED} is true.
+ */
+ @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+ public static final ConfigOption ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE =
+ key("taskmanager.network.adaptive-partitioner.max-traverse-size")
+ .intType()
+ .defaultValue(4)
+ .withDescription(
+ Description.builder()
+ .text(
+ "Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when %s is enabled.",
+ code(ADAPTIVE_PARTITIONER_ENABLED.key()))
+ .linebreak()
+ .text(
+ "Note, the value of the configuration option must be greater than `1`.")
+ .build());
+
// ------------------------------------------------------------------------
/** Not intended to be instantiated. */
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
new file mode 100644
index 0000000000000..d71aaebb85326
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api.writer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.io.IOReadableWritable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/** A record writer based on load of downstream tasks. */
+@Internal
+public final class AdaptiveLoadBasedRecordWriter
+ extends RecordWriter {
+
+ private final int maxTraverseSize;
+
+ private int currentChannel = -1;
+
+ private final int numberOfSubpartitions;
+
+ AdaptiveLoadBasedRecordWriter(
+ ResultPartitionWriter writer, long timeout, String taskName, int maxTraverseSize) {
+ super(writer, timeout, taskName);
+ this.numberOfSubpartitions = writer.getNumberOfSubpartitions();
+ this.maxTraverseSize = Math.min(maxTraverseSize, numberOfSubpartitions);
+ }
+
+ @Override
+ public void emit(T record) throws IOException {
+ checkErroneous();
+
+ currentChannel = getIdlestChannelIndex();
+
+ ByteBuffer byteBuffer = serializeRecord(serializer, record);
+ targetPartition.emitRecord(byteBuffer, currentChannel);
+
+ if (flushAlways) {
+ targetPartition.flush(currentChannel);
+ }
+ }
+
+ @VisibleForTesting
+ int getIdlestChannelIndex() {
+ int bestChannelBuffersCount = Integer.MAX_VALUE;
+ long bestChannelBytesInQueue = Long.MAX_VALUE;
+ int bestChannel = 0;
+ for (int i = 1; i <= maxTraverseSize; i++) {
+ int candidateChannel = (currentChannel + i) % numberOfSubpartitions;
+ int candidateChannelBuffersCount =
+ targetPartition.getBuffersCountUnsafe(candidateChannel);
+ long candidateChannelBytesInQueue =
+ targetPartition.getBytesInQueueUnsafe(candidateChannel);
+
+ if (candidateChannelBuffersCount == 0) {
+ // If there isn't any pending data in the current channel, choose this channel
+ // directly.
+ return candidateChannel;
+ }
+
+ if (candidateChannelBuffersCount < bestChannelBuffersCount
+ || (candidateChannelBuffersCount == bestChannelBuffersCount
+ && candidateChannelBytesInQueue < bestChannelBytesInQueue)) {
+ bestChannel = candidateChannel;
+ bestChannelBuffersCount = candidateChannelBuffersCount;
+ bestChannelBytesInQueue = candidateChannelBytesInQueue;
+ }
+ }
+ return bestChannel;
+ }
+
+ /** Copy from {@link ChannelSelectorRecordWriter#broadcastEmit}. */
+ @Override
+ public void broadcastEmit(T record) throws IOException {
+ checkErroneous();
+
+ // Emitting to all channels in a for loop can be better than calling
+ // ResultPartitionWriter#broadcastRecord because the broadcastRecord
+ // method incurs extra overhead.
+ ByteBuffer serializedRecord = serializeRecord(serializer, record);
+ for (int channelIndex = 0; channelIndex < numberOfSubpartitions; channelIndex++) {
+ serializedRecord.rewind();
+ emit(record, channelIndex);
+ }
+
+ if (flushAlways) {
+ flushAll();
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
index 78e6424844d86..0740f208a03ae 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
@@ -18,7 +18,9 @@
package org.apache.flink.runtime.io.network.api.writer;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.util.Preconditions;
/** Utility class to encapsulate the logic of building a {@link RecordWriter} instance. */
public class RecordWriterBuilder {
@@ -29,6 +31,11 @@ public class RecordWriterBuilder {
private String taskName = "test";
+ private boolean enabledAdaptivePartitioner = false;
+
+ private int maxTraverseSize =
+ NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.defaultValue();
+
public RecordWriterBuilder setChannelSelector(ChannelSelector selector) {
this.selector = selector;
return this;
@@ -44,11 +51,29 @@ public RecordWriterBuilder setTaskName(String taskName) {
return this;
}
+ public RecordWriterBuilder setEnabledAdaptivePartitioner(
+ boolean enabledAdaptivePartitioner) {
+ this.enabledAdaptivePartitioner = enabledAdaptivePartitioner;
+ return this;
+ }
+
+ public RecordWriterBuilder setMaxTraverseSize(int maxTraverseSize) {
+ Preconditions.checkArgument(
+ maxTraverseSize > 1,
+ "The value of '%s' must be greater than 1 when '%s' is enabled.",
+ NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(),
+ NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED.key());
+ this.maxTraverseSize = maxTraverseSize;
+ return this;
+ }
+
public RecordWriter build(ResultPartitionWriter writer) {
if (selector.isBroadcast()) {
return new BroadcastRecordWriter<>(writer, timeout, taskName);
- } else {
- return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName);
}
+ if (enabledAdaptivePartitioner) {
+ return new AdaptiveLoadBasedRecordWriter<>(writer, timeout, taskName, maxTraverseSize);
+ }
+ return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName);
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index e283fac596fbc..04cfa0ad33d22 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -60,6 +60,14 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
/** Writes the given serialized record to the target subpartition. */
void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException;
+ default long getBytesInQueueUnsafe(int targetSubpartition) {
+ return 0;
+ }
+
+ default int getBuffersCountUnsafe(int targetSubpartition) {
+ return 0;
+ }
+
/**
* Writes the given serialized record to all subpartitions. One can also achieve the same effect
* by emitting the same record to all subpartitions one by one, however, this method can have
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
index c574607e28e16..5061d08353d1a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
@@ -75,4 +75,9 @@ public interface BufferPool extends BufferProvider, BufferRecycler {
/** Returns the number of used buffers of this buffer pool. */
int bestEffortGetNumOfUsedBuffers();
+
+ /** Returns the requested buffer count for target channel. */
+ default int getBuffersCountUnsafe(int targetChannel) {
+ return 0;
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
index 873414c6fe2b8..f31bd95f1a15d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
@@ -824,4 +824,9 @@ public static AvailabilityStatus from(
}
}
}
+
+ @Override
+ public int getBuffersCountUnsafe(int targetChannel) {
+ return subpartitionBuffersCount[targetChannel];
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
index eae9260642a5a..a39575973c00d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
@@ -65,7 +65,7 @@ public abstract class BufferWritingResultPartition extends ResultPartition {
private TimerGauge hardBackPressuredTimeMsPerSecond = new TimerGauge();
- private long totalWrittenBytes;
+ private final long[] subpartitionWrittenBytes;
public BufferWritingResultPartition(
String owningTaskName,
@@ -91,6 +91,7 @@ public BufferWritingResultPartition(
this.subpartitions = checkNotNull(subpartitions);
this.unicastBufferBuilders = new BufferBuilder[subpartitions.length];
+ this.subpartitionWrittenBytes = new long[subpartitions.length];
}
@Override
@@ -114,6 +115,11 @@ public int getNumberOfQueuedBuffers() {
@Override
public long getSizeOfQueuedBuffersUnsafe() {
+ long totalWrittenBytes = 0;
+ for (int i = 0; i < subpartitions.length; i++) {
+ totalWrittenBytes += subpartitionWrittenBytes[i];
+ }
+
long totalNumberOfBytes = 0;
for (ResultSubpartition subpartition : subpartitions) {
@@ -123,6 +129,12 @@ public long getSizeOfQueuedBuffersUnsafe() {
return totalWrittenBytes - totalNumberOfBytes;
}
+ @Override
+ public long getBytesInQueueUnsafe(int targetSubpartition) {
+ return subpartitionWrittenBytes[targetSubpartition]
+ - subpartitions[targetSubpartition].getTotalNumberOfBytesUnsafe();
+ }
+
@Override
public int getNumberOfQueuedBuffers(int targetSubpartition) {
checkArgument(targetSubpartition >= 0 && targetSubpartition < numSubpartitions);
@@ -151,7 +163,7 @@ protected void flushAllSubpartitions(boolean finishProducers) {
@Override
public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException {
- totalWrittenBytes += record.remaining();
+ subpartitionWrittenBytes[targetSubpartition] += record.remaining();
BufferBuilder buffer = appendUnicastDataForNewRecord(record, targetSubpartition);
@@ -171,7 +183,9 @@ public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOExcep
@Override
public void broadcastRecord(ByteBuffer record) throws IOException {
- totalWrittenBytes += ((long) record.remaining() * numSubpartitions);
+ for (int i = 0; i < subpartitions.length; i++) {
+ subpartitionWrittenBytes[i] += record.remaining();
+ }
BufferBuilder buffer = appendBroadcastDataForNewRecord(record);
@@ -197,11 +211,11 @@ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws
try (BufferConsumer eventBufferConsumer =
EventSerializer.toBufferConsumer(event, isPriorityEvent)) {
- totalWrittenBytes += ((long) eventBufferConsumer.getWrittenBytes() * numSubpartitions);
- for (ResultSubpartition subpartition : subpartitions) {
+ for (int i = 0; i < subpartitions.length; i++) {
// Retain the buffer so that it can be recycled by each subpartition of
// targetPartition
- subpartition.add(eventBufferConsumer.copy(), 0);
+ subpartitions[i].add(eventBufferConsumer.copy(), 0);
+ subpartitionWrittenBytes[i] += eventBufferConsumer.getWrittenBytes();
}
}
}
@@ -246,8 +260,8 @@ public void finish() throws IOException {
finishBroadcastBufferBuilder();
finishUnicastBufferBuilders();
- for (ResultSubpartition subpartition : subpartitions) {
- totalWrittenBytes += subpartition.finish();
+ for (int i = 0; i < subpartitions.length; i++) {
+ subpartitionWrittenBytes[i] += subpartitions[i].finish();
}
super.finish();
@@ -340,7 +354,7 @@ private void addToSubpartition(
protected int addToSubpartition(
int targetSubpartition, BufferConsumer bufferConsumer, int partialRecordLength)
throws IOException {
- totalWrittenBytes += bufferConsumer.getWrittenBytes();
+ subpartitionWrittenBytes[targetSubpartition] += bufferConsumer.getWrittenBytes();
return subpartitions[targetSubpartition].add(bufferConsumer, partialRecordLength);
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index 6cbcfc0c598e3..47b52caa8d6e6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -202,6 +202,10 @@ public boolean isNumberOfPartitionConsumerUndefined() {
/** Returns the number of queued buffers of the given target subpartition. */
public abstract int getNumberOfQueuedBuffers(int targetSubpartition);
+ public int getBuffersCountUnsafe(int targetSubpartition) {
+ return bufferPool.getBuffersCountUnsafe(targetSubpartition);
+ }
+
public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) {
this.bufferPool.setMaxOverdraftBuffersPerGate(maxOverdraftBuffersPerGate);
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 505e67a30d78d..2a34487b3d6ad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,6 +23,7 @@
import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.execution.RecoveryClaimMode;
import org.apache.flink.core.fs.AutoCloseableRegistry;
@@ -95,6 +96,7 @@
import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer;
@@ -1830,12 +1832,22 @@ private static RecordWriter>> crea
((ConfigurableStreamPartitioner) outputPartitioner).configure(numKeyGroups);
}
}
+ Configuration conf = environment.getJobConfiguration();
+ final boolean enabledAdaptivePartitioner =
+ (outputPartitioner instanceof RebalancePartitioner
+ || outputPartitioner instanceof RescalePartitioner)
+ && conf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED)
+ && bufferWriter.getNumberOfSubpartitions() > 1;
+ final int maxTraverseSize =
+ conf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE);
RecordWriter>> output =
new RecordWriterBuilder>>()
.setChannelSelector(outputPartitioner)
.setTimeout(bufferTimeout)
.setTaskName(taskNameWithSubtask)
+ .setEnabledAdaptivePartitioner(enabledAdaptivePartitioner)
+ .setMaxTraverseSize(maxTraverseSize)
.build(bufferWriter);
output.setMetricGroup(environment.getMetricGroup().getIOMetricGroup());
return output;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
new file mode 100644
index 0000000000000..b1b21c4585e03
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api.writer;
+
+import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
+import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Stream;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for {@link AdaptiveLoadBasedRecordWriter}. */
+class AdaptiveLoadBasedRecordWriterTest {
+
+ @Test
+ void testValidateMaxTraverseSize() {
+ // Test invalid value.
+ final RecordWriterBuilder builder = new RecordWriterBuilder<>();
+ assertThatThrownBy(() -> builder.setMaxTraverseSize(-1))
+ .isInstanceOf(IllegalArgumentException.class);
+ assertThatThrownBy(() -> builder.setMaxTraverseSize(0))
+ .isInstanceOf(IllegalArgumentException.class);
+ assertThatThrownBy(() -> builder.setMaxTraverseSize(1))
+ .isInstanceOf(IllegalArgumentException.class);
+
+ // Test for valid value.
+ builder.setMaxTraverseSize(2);
+ }
+
+ static Stream getTestingParams() {
+ return Stream.of(
+ // maxTraverseSize, bytesPerPartition, bufferPerPartition,
+ // targetResultPartitionIndex
+ Arguments.of(2, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4}, 0),
+ Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4}, 0),
+ Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0}, 0),
+ Arguments.of(3, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4}, 0),
+ Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4}, 0),
+ Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0}, 0),
+ Arguments.of(
+ 2, new long[] {1L, 2L, 3L, 1L, 2L, 3L}, new int[] {2, 3, 4, 2, 3, 4}, 0),
+ Arguments.of(
+ 2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {3, 2, 4, 2, 3, 4}, 1),
+ Arguments.of(
+ 2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {0, 0, 4, 2, 3, 4}, 0),
+ Arguments.of(
+ 4, new long[] {1L, 2L, 3L, 0L, 2L, 3L}, new int[] {2, 3, 4, 2, 3, 4}, 3),
+ Arguments.of(
+ 4, new long[] {1L, 1L, 1L, 1L, 2L, 3L}, new int[] {2, 3, 4, 0, 3, 4}, 3),
+ Arguments.of(
+ 4, new long[] {0L, 0L, 0L, 0L, 2L, 3L}, new int[] {2, 3, 0, 2, 3, 4}, 2));
+ }
+
+ @ParameterizedTest(
+ name =
+ "maxTraverseSize: {0}, bytesPerPartition: {1}, bufferPerPartition: {2}, targetResultPartitionIndex: {3}")
+ @MethodSource("getTestingParams")
+ void testGetIdlestChannelIndex(
+ int maxTraverseSize,
+ long[] bytesPerPartition,
+ int[] buffersPerPartition,
+ int targetResultPartitionIndex) {
+ TestingResultPartitionWriter resultPartitionWriter =
+ getTestingResultPartitionWriter(bytesPerPartition, buffersPerPartition);
+
+ AdaptiveLoadBasedRecordWriter adaptiveLoadBasedRecordWriter =
+ new AdaptiveLoadBasedRecordWriter<>(
+ resultPartitionWriter, 5L, "testingTask", maxTraverseSize);
+ assertThat(adaptiveLoadBasedRecordWriter.getIdlestChannelIndex())
+ .isEqualTo(targetResultPartitionIndex);
+ }
+
+ private static TestingResultPartitionWriter getTestingResultPartitionWriter(
+ long[] bytesPerPartition, int[] buffersPerPartition) {
+ final Map bytesPerPartitionMap = new HashMap<>();
+ final Map bufferPerPartitionMap = new HashMap<>();
+ for (int i = 0; i < bytesPerPartition.length; i++) {
+ bytesPerPartitionMap.put(i, bytesPerPartition[i]);
+ bufferPerPartitionMap.put(i, buffersPerPartition[i]);
+ }
+
+ return new TestingResultPartitionWriter(
+ buffersPerPartition.length, bytesPerPartitionMap, bufferPerPartitionMap);
+ }
+
+ /** Test utils class to simulate {@link ResultPartitionWriter}. */
+ static final class TestingResultPartitionWriter implements ResultPartitionWriter {
+
+ private final int numberOfSubpartitions;
+ private final Map bytesPerPartition;
+ private final Map bufferPerPartition;
+
+ TestingResultPartitionWriter(
+ int numberOfSubpartitions,
+ Map bytesPerPartition,
+ Map bufferPerPartition) {
+ this.numberOfSubpartitions = numberOfSubpartitions;
+ this.bytesPerPartition = bytesPerPartition;
+ this.bufferPerPartition = bufferPerPartition;
+ }
+
+ // The methods that are used in the testing.
+
+ @Override
+ public long getBytesInQueueUnsafe(int targetSubpartition) {
+ return bytesPerPartition.getOrDefault(targetSubpartition, 0L);
+ }
+
+ @Override
+ public int getBuffersCountUnsafe(int targetSubpartition) {
+ return bufferPerPartition.getOrDefault(targetSubpartition, 0);
+ }
+
+ @Override
+ public int getNumberOfSubpartitions() {
+ return numberOfSubpartitions;
+ }
+
+ // The methods that are not used.
+
+ @Override
+ public void setup() throws IOException {}
+
+ @Override
+ public ResultPartitionID getPartitionId() {
+ return null;
+ }
+
+ @Override
+ public int getNumTargetKeyGroups() {
+ return 0;
+ }
+
+ @Override
+ public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) {}
+
+ @Override
+ public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException {}
+
+ @Override
+ public void broadcastRecord(ByteBuffer record) throws IOException {}
+
+ @Override
+ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent)
+ throws IOException {}
+
+ @Override
+ public void alignedBarrierTimeout(long checkpointId) throws IOException {}
+
+ @Override
+ public void abortCheckpoint(long checkpointId, CheckpointException cause) {}
+
+ @Override
+ public void notifyEndOfData(StopMode mode) throws IOException {}
+
+ @Override
+ public CompletableFuture getAllDataProcessedFuture() {
+ return null;
+ }
+
+ @Override
+ public void setMetricGroup(TaskIOMetricGroup metrics) {}
+
+ @Override
+ public ResultSubpartitionView createSubpartitionView(
+ ResultSubpartitionIndexSet indexSet,
+ BufferAvailabilityListener availabilityListener)
+ throws IOException {
+ return null;
+ }
+
+ @Override
+ public void flushAll() {}
+
+ @Override
+ public void flush(int subpartitionIndex) {}
+
+ @Override
+ public void fail(@Nullable Throwable throwable) {}
+
+ @Override
+ public void finish() throws IOException {}
+
+ @Override
+ public boolean isFinished() {
+ return false;
+ }
+
+ @Override
+ public void release(Throwable cause) {}
+
+ @Override
+ public boolean isReleased() {
+ return false;
+ }
+
+ @Override
+ public void close() throws Exception {}
+
+ @Override
+ public CompletableFuture> getAvailableFuture() {
+ return null;
+ }
+ }
+}