Skip to content

Conditional Embedding Perturbation (CEP)#1235

Open
Koratahiu wants to merge 7 commits intoNerogar:masterfrom
Koratahiu:cep
Open

Conditional Embedding Perturbation (CEP)#1235
Koratahiu wants to merge 7 commits intoNerogar:masterfrom
Koratahiu:cep

Conversation

@Koratahiu
Copy link
Contributor

@Koratahiu Koratahiu commented Dec 30, 2025

This draft implements the Conditional Embedding Perturbation (CEP) strategy proposed in the paper:
Slight Corruption in Pre-training Data Makes Better Diffusion Models (NeurIPS 2024 spotlight)

This method aims to improve the generation quality and diversity of diffusion models by mitigating the impact of "perfect" overfitting to training pairs. The paper demonstrates theoretically that standard training can cause the generated distribution to collapse to the empirical distribution of the training data.

CEP addresses this by introducing slight, dimension-scaled noise to the conditional embeddings (e.g., text encoder outputs) during training. By optimizing the objective, the model is forced to learn a smoother conditional manifold, reducing the distance to the true data distribution and preventing memorization.

Implementation Details

  • Adds a perturbation term $\delta$ to the text embeddings before they are passed to the model.
  • The noise is sampled from a Uniform distribution and scaled by the embedding dimension, ensuring the corruption remains "slight" regardless of the model architecture (SD 1.5 vs SDXL vs Flux).
  • All models are supported with UI

Usage

  • Enable Conditional Embedding Perturbation (CEP) (below timestep shifting)
  • Set CEP Gamma to 1

TODO

  • To be tested

@Koratahiu Koratahiu marked this pull request as ready for review February 7, 2026 12:08
@Koratahiu
Copy link
Contributor Author

This has been tested with SDXL, Chroma, and Zib, and it works very well.
It is especially beneficial for Zib, which relies on semantic patterns that CEP mitigates through its perturbation noise.

Copy link
Collaborator

@dxqb dxqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting. code comments added, but I want to try it myself too. I remember that some people on Discord have tested it. Would be great if they could post their conclusions here also.

  • Flux2 was added in the meantime



# Conditional Embedding Perturbation (CEP)
cep_label = components.label(frame, 10, 0, "Conditional Embedding Perturbation (CEP)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a gamma that is a no-op?
in that case we wouldn't need an enabled switch. This is how most other parameters in OneTrainer work, that there is a 0.0 which doesn't do anything for example

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, 0 is a no-op.
1 is the paper's default value (slight noise based on the dimension of the TEs), 2 is double that, and so on.

text_encoder_dropout_probability=config.text_encoder.dropout_probability,
)

if config.cep_enabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer this call in Model.encode_text
this is where similar functionality is implemented (such as caption dropout)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would model.encode_text do it on-the-fly without caching?
One benefit of this method is that it doesn't need re-caching

Copy link
Collaborator

@dxqb dxqb Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.encode_text takes the cached output and returns it, but it can (and does) still modify the cached output before returning it. doesn't mean you have to cache the perturbation, it can applied to the cached value.

components.switch(frame, 9, 1, self.ui_state, "dynamic_timestep_shifting")


# Conditional Embedding Perturbation (CEP)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this option fits better near "Caption Dropout Probability"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it fits both: injected 'noise' applied to the TE conditioning. However, wouldn't TE settings require per-model setting application? I'm trying to avoid that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


return noise

def _apply_conditional_embedding_perturbation(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't use self
@staticmethod and no self


# gamma controls perturbation magnitude (Paper uses gamma=1.0 as default baseline)
# Calculate scaling factor: sqrt(gamma / d)
scale = math.sqrt(gamma / d)
Copy link
Collaborator

@dxqb dxqb Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be
scale = gamma / math.sqrt(d)

Image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right; I had (1/√d) in my mind when I wrote this

)

if config.cep_enabled:
text_encoder_output = self._apply_conditional_embedding_perturbation(
Copy link
Collaborator

@dxqb dxqb Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should CEP also be applied during validation? it currently is - validation uses the same predict().
theoretically I guess not, because you want validation to be deterministic and comparable across time. but the effect might be minor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this also the case with caption dropout (which is in model.encode_text)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, but that's definitely not good. I've added it here: #957 (comment)

@dxqb
Copy link
Collaborator

dxqb commented Feb 15, 2026

* The noise is sampled from a Uniform distribution and scaled by the embedding dimension, ensuring the corruption remains "slight" regardless of the model architecture (SD 1.5 vs SDXL vs Flux).

I think it might still need tuning per model, because the magnitude of embeddings are different by text encoder
the paper/PR only corrects for dimension

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants