Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
</tr>
</thead>
<tbody>
<tr>
<td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
<td style="word-wrap: break-word;">false</td>
<td>Boolean</td>
<td>Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks.</td>
</tr>
<tr>
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
<td style="word-wrap: break-word;">4</td>
<td>Integer</td>
<td>Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when <code class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code> is enabled.<br />Note, the value of the configuration option must be greater than `1`.</td>
</tr>
<tr>
<td><h5>taskmanager.network.compression.codec</h5></td>
<td style="word-wrap: break-word;">LZ4</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
<td>Boolean</td>
<td>Enable SSL support for the taskmanager data transport. This is applicable only when the global flag for internal SSL (security.ssl.internal.enabled) is set to true</td>
</tr>
<tr>
<td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
<td style="word-wrap: break-word;">false</td>
<td>Boolean</td>
<td>Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks.</td>
</tr>
<tr>
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
<td style="word-wrap: break-word;">4</td>
<td>Integer</td>
<td>Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when <code class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code> is enabled.<br />Note, the value of the configuration option must be greater than `1`.</td>
</tr>
<tr>
<td><h5>taskmanager.network.compression.codec</h5></td>
<td style="word-wrap: break-word;">LZ4</td>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,34 @@ public enum CompressionCodec {
code(NETWORK_REQUEST_BACKOFF_MAX.key()))
.build());

/** Whether to improve the rebalance and rescale partitioners to adaptive partition. */
@Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
public static final ConfigOption<Boolean> ADAPTIVE_PARTITIONER_ENABLED =
key("taskmanager.network.adaptive-partitioner.enabled")
.booleanType()
.defaultValue(false)
.withDescription(
"Whether to enable adaptive partitioner feature for rescale and rebalance partitioners based on the loading of the downstream tasks.");

/**
* Maximum number of channels to traverse when looking for the idlest channel for rescale and
* rebalance partitioners when {@link #ADAPTIVE_PARTITIONER_ENABLED} is true.
*/
@Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
public static final ConfigOption<Integer> ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE =
key("taskmanager.network.adaptive-partitioner.max-traverse-size")
.intType()
.defaultValue(4)
.withDescription(
Description.builder()
.text(
"Maximum number of channels to traverse when looking for the idlest channel for rescale and rebalance partitioners when %s is enabled.",
code(ADAPTIVE_PARTITIONER_ENABLED.key()))
.linebreak()
.text(
"Note, the value of the configuration option must be greater than `1`.")
.build());

// ------------------------------------------------------------------------

/** Not intended to be instantiated. */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.io.network.api.writer;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.core.io.IOReadableWritable;

import java.io.IOException;
import java.nio.ByteBuffer;

/** A record writer based on load of downstream tasks. */
@Internal
public final class AdaptiveLoadBasedRecordWriter<T extends IOReadableWritable>
extends RecordWriter<T> {

private final int maxTraverseSize;

private int currentChannel = -1;

private final int numberOfSubpartitions;

AdaptiveLoadBasedRecordWriter(
ResultPartitionWriter writer, long timeout, String taskName, int maxTraverseSize) {
super(writer, timeout, taskName);
this.numberOfSubpartitions = writer.getNumberOfSubpartitions();
this.maxTraverseSize = Math.min(maxTraverseSize, numberOfSubpartitions);
Copy link
Contributor

@davidradl davidradl Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why we need numberOfSubpartitions and the maxTraverseSize. why not set numberOfSubpartitions to Math.min(maxTraverseSize, numberOfSubpartitions) and remove private final int maxTraverseSize;. then you do not need to check the maxTraverseSize. in the logic as the numberOfSubpartitions will always be the minimum, accounting for the maxTraverseSize.

Also on a previous response to a review comment you said maxTraverseSize could not be 1, but it could end as one if numberOfSubpartitions == 1 due this Math.min. We should probably check for the numberOfSubpartitions == 1 case and not do adaptive processing.

}

@Override
public void emit(T record) throws IOException {
checkErroneous();

currentChannel = getIdlestChannelIndex();

ByteBuffer byteBuffer = serializeRecord(serializer, record);
targetPartition.emitRecord(byteBuffer, currentChannel);

if (flushAlways) {
targetPartition.flush(currentChannel);
}
}

@VisibleForTesting
int getIdlestChannelIndex() {
int bestChannelBuffersCount = Integer.MAX_VALUE;
long bestChannelBytesInQueue = Long.MAX_VALUE;
int bestChannel = 0;
for (int i = 1; i <= maxTraverseSize; i++) {
int candidateChannel = (currentChannel + i) % numberOfSubpartitions;
int candidateChannelBuffersCount =
targetPartition.getBuffersCountUnsafe(candidateChannel);
long candidateChannelBytesInQueue =
targetPartition.getBytesInQueueUnsafe(candidateChannel);

if (candidateChannelBuffersCount == 0) {
// If there isn't any pending data in the current channel, choose this channel
// directly.
return candidateChannel;
}

if (candidateChannelBuffersCount < bestChannelBuffersCount
|| (candidateChannelBuffersCount == bestChannelBuffersCount
&& candidateChannelBytesInQueue < bestChannelBytesInQueue)) {
bestChannel = candidateChannel;
bestChannelBuffersCount = candidateChannelBuffersCount;
bestChannelBytesInQueue = candidateChannelBytesInQueue;
}
}
return bestChannel;
}

/** Copy from {@link ChannelSelectorRecordWriter#broadcastEmit}. */
@Override
public void broadcastEmit(T record) throws IOException {
checkErroneous();

// Emitting to all channels in a for loop can be better than calling
// ResultPartitionWriter#broadcastRecord because the broadcastRecord
// method incurs extra overhead.
ByteBuffer serializedRecord = serializeRecord(serializer, record);
for (int channelIndex = 0; channelIndex < numberOfSubpartitions; channelIndex++) {
serializedRecord.rewind();
emit(record, channelIndex);
}

if (flushAlways) {
flushAll();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

package org.apache.flink.runtime.io.network.api.writer;

import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.util.Preconditions;

/** Utility class to encapsulate the logic of building a {@link RecordWriter} instance. */
public class RecordWriterBuilder<T extends IOReadableWritable> {
Expand All @@ -29,6 +31,11 @@ public class RecordWriterBuilder<T extends IOReadableWritable> {

private String taskName = "test";

private boolean enabledAdaptivePartitioner = false;

private int maxTraverseSize =
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.defaultValue();

public RecordWriterBuilder<T> setChannelSelector(ChannelSelector<T> selector) {
this.selector = selector;
return this;
Expand All @@ -44,11 +51,29 @@ public RecordWriterBuilder<T> setTaskName(String taskName) {
return this;
}

public RecordWriterBuilder<T> setEnabledAdaptivePartitioner(
boolean enabledAdaptivePartitioner) {
this.enabledAdaptivePartitioner = enabledAdaptivePartitioner;
return this;
}

public RecordWriterBuilder<T> setMaxTraverseSize(int maxTraverseSize) {
Preconditions.checkArgument(
maxTraverseSize > 1,
"The value of '%s' must be greater than 1 when '%s' is enabled.",
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(),
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED.key());
this.maxTraverseSize = maxTraverseSize;
return this;
}

public RecordWriter<T> build(ResultPartitionWriter writer) {
if (selector.isBroadcast()) {
return new BroadcastRecordWriter<>(writer, timeout, taskName);
} else {
return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName);
}
if (enabledAdaptivePartitioner) {
return new AdaptiveLoadBasedRecordWriter<>(writer, timeout, taskName, maxTraverseSize);
}
return new ChannelSelectorRecordWriter<>(writer, selector, timeout, taskName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
/** Writes the given serialized record to the target subpartition. */
void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException;

default long getBytesInQueueUnsafe(int targetSubpartition) {
return 0;
}

default int getBuffersCountUnsafe(int targetSubpartition) {
return 0;
}

/**
* Writes the given serialized record to all subpartitions. One can also achieve the same effect
* by emitting the same record to all subpartitions one by one, however, this method can have
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,9 @@ public interface BufferPool extends BufferProvider, BufferRecycler {

/** Returns the number of used buffers of this buffer pool. */
int bestEffortGetNumOfUsedBuffers();

/** Returns the requested buffer count for target channel. */
default int getBuffersCountUnsafe(int targetChannel) {
return 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -824,4 +824,9 @@ public static AvailabilityStatus from(
}
}
}

@Override
public int getBuffersCountUnsafe(int targetChannel) {
return subpartitionBuffersCount[targetChannel];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public abstract class BufferWritingResultPartition extends ResultPartition {

private TimerGauge hardBackPressuredTimeMsPerSecond = new TimerGauge();

private long totalWrittenBytes;
private final long[] subpartitionWrittenBytes;

public BufferWritingResultPartition(
String owningTaskName,
Expand All @@ -91,6 +91,7 @@ public BufferWritingResultPartition(

this.subpartitions = checkNotNull(subpartitions);
this.unicastBufferBuilders = new BufferBuilder[subpartitions.length];
this.subpartitionWrittenBytes = new long[subpartitions.length];
}

@Override
Expand All @@ -114,6 +115,11 @@ public int getNumberOfQueuedBuffers() {

@Override
public long getSizeOfQueuedBuffersUnsafe() {
long totalWrittenBytes = 0;
for (int i = 0; i < subpartitions.length; i++) {
totalWrittenBytes += subpartitionWrittenBytes[i];
}

long totalNumberOfBytes = 0;

for (ResultSubpartition subpartition : subpartitions) {
Expand All @@ -123,6 +129,12 @@ public long getSizeOfQueuedBuffersUnsafe() {
return totalWrittenBytes - totalNumberOfBytes;
}

@Override
public long getBytesInQueueUnsafe(int targetSubpartition) {
return subpartitionWrittenBytes[targetSubpartition]
- subpartitions[targetSubpartition].getTotalNumberOfBytesUnsafe();
}

@Override
public int getNumberOfQueuedBuffers(int targetSubpartition) {
checkArgument(targetSubpartition >= 0 && targetSubpartition < numSubpartitions);
Expand Down Expand Up @@ -151,7 +163,7 @@ protected void flushAllSubpartitions(boolean finishProducers) {

@Override
public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOException {
totalWrittenBytes += record.remaining();
subpartitionWrittenBytes[targetSubpartition] += record.remaining();

BufferBuilder buffer = appendUnicastDataForNewRecord(record, targetSubpartition);

Expand All @@ -171,7 +183,9 @@ public void emitRecord(ByteBuffer record, int targetSubpartition) throws IOExcep

@Override
public void broadcastRecord(ByteBuffer record) throws IOException {
totalWrittenBytes += ((long) record.remaining() * numSubpartitions);
for (int i = 0; i < subpartitions.length; i++) {
subpartitionWrittenBytes[i] += record.remaining();
}

BufferBuilder buffer = appendBroadcastDataForNewRecord(record);

Expand All @@ -197,11 +211,11 @@ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws

try (BufferConsumer eventBufferConsumer =
EventSerializer.toBufferConsumer(event, isPriorityEvent)) {
totalWrittenBytes += ((long) eventBufferConsumer.getWrittenBytes() * numSubpartitions);
for (ResultSubpartition subpartition : subpartitions) {
for (int i = 0; i < subpartitions.length; i++) {
// Retain the buffer so that it can be recycled by each subpartition of
// targetPartition
subpartition.add(eventBufferConsumer.copy(), 0);
subpartitions[i].add(eventBufferConsumer.copy(), 0);
subpartitionWrittenBytes[i] += eventBufferConsumer.getWrittenBytes();
}
}
}
Expand Down Expand Up @@ -246,8 +260,8 @@ public void finish() throws IOException {
finishBroadcastBufferBuilder();
finishUnicastBufferBuilders();

for (ResultSubpartition subpartition : subpartitions) {
totalWrittenBytes += subpartition.finish();
for (int i = 0; i < subpartitions.length; i++) {
subpartitionWrittenBytes[i] += subpartitions[i].finish();
}

super.finish();
Expand Down Expand Up @@ -340,7 +354,7 @@ private void addToSubpartition(
protected int addToSubpartition(
int targetSubpartition, BufferConsumer bufferConsumer, int partialRecordLength)
throws IOException {
totalWrittenBytes += bufferConsumer.getWrittenBytes();
subpartitionWrittenBytes[targetSubpartition] += bufferConsumer.getWrittenBytes();
return subpartitions[targetSubpartition].add(bufferConsumer, partialRecordLength);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ public boolean isNumberOfPartitionConsumerUndefined() {
/** Returns the number of queued buffers of the given target subpartition. */
public abstract int getNumberOfQueuedBuffers(int targetSubpartition);

public int getBuffersCountUnsafe(int targetSubpartition) {
return bufferPool.getBuffersCountUnsafe(targetSubpartition);
}

public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) {
this.bufferPool.setMaxOverdraftBuffersPerGate(maxOverdraftBuffersPerGate);
}
Expand Down
Loading