From 0c40a493f2a0e62f754fcc566e3f15cb04de1330 Mon Sep 17 00:00:00 2001 From: Aditya Pratap Singh Date: Sat, 24 Jan 2026 13:39:33 +0530 Subject: [PATCH 1/2] gcg: Detach gradients post-backward() & add tensor cleanup in token_gradients() - Add .detach() after gradient extraction to break lingering computation graphs - Explicit del for loop-accumulated tensors (grads, losses) - torch.cuda.empty_cache() post-iteration to defragment CUDA allocator Prevents OOM at 1000+ steps by ensuring ~no memory growth per iter (verified via nvidia-smi/torch.cuda.memory_summary()) Fixes #961 Co-Authored-By: Claude Opus 4.5 --- .../gcg/attack/gcg/gcg_attack.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index 3f24d89f2..5ed68add6 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -68,7 +68,17 @@ def token_gradients( loss.backward() - return one_hot.grad.clone() + # Clone and detach the gradient to break the computation graph + grad = one_hot.grad.clone().detach() + + # Explicitly clear references to free memory + del one_hot, input_embeds, embeds, full_embeds, logits, targets, loss + + # Clear CUDA cache to release GPU memory + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return grad class GCGAttackPrompt(AttackPrompt): @@ -144,9 +154,11 @@ def step( j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str ) ) + del grad # Explicitly delete old grad before reassignment grad = new_grad else: grad += new_grad + del new_grad # Clean up new_grad after use with torch.no_grad(): control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii) @@ -155,6 +167,8 @@ def step( ) del grad, control_cand gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Search loss = torch.zeros(len(control_cands) * batch_size).to(main_device) @@ -192,6 +206,10 @@ def step( f"loss={loss[j * batch_size : (j + 1) * batch_size].min().item() / (i + 1):.4f}" # type: ignore[operator] ) + # Periodically clear CUDA cache during search to prevent memory buildup + if torch.cuda.is_available(): + torch.cuda.empty_cache() + min_idx = loss.argmin() model_idx = min_idx // batch_size batch_idx = min_idx % batch_size From 5c03cde629d1395251f9fae2909aee8aef4cb2c0 Mon Sep 17 00:00:00 2001 From: Aditya Pratap Singh Date: Sat, 24 Jan 2026 13:39:42 +0530 Subject: [PATCH 2/2] gcg attack_manager: Add gc.collect() post-worker & Python 3.13 annotations - gc.collect() after task completion to force Python GC on leaked refs - from __future__ import annotations for forward-ref compatibility (3.13+) - torch.cuda.empty_cache() after gradient ops in ModelWorker - Memory cleanup after test_all() in main run loop Complements per-iter cleanup; total peak mem now stable across 1000 steps Fixes #961 Co-Authored-By: Claude Opus 4.5 --- .../gcg/attack/base/attack_manager.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index b648c0dfb..290eb739a 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import gc import json import logging @@ -814,6 +816,12 @@ def control_weight_fn(_: int) -> float: self.control_str = last_control + # Clean up memory after test_all() which creates temporary PromptManagers + del model_tests + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return self.control_str, loss, steps def test( @@ -1644,7 +1652,12 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any] ob, fn, args, kwargs = task if fn == "grad": with torch.enable_grad(): # type: ignore[no-untyped-call, unused-ignore] - results.put(ob.grad(*args, **kwargs)) + result = ob.grad(*args, **kwargs) + results.put(result) + del result + # Clear CUDA cache after gradient computation to prevent memory accumulation + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: with torch.no_grad(): if fn == "logits": @@ -1657,6 +1670,9 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any] results.put(ob.test_loss(*args, **kwargs)) else: results.put(fn(*args, **kwargs)) + # Clean up the task object to free memory + del ob + gc.collect() tasks.task_done() def start(self) -> "ModelWorker":