Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import gc
import json
import logging
Expand Down Expand Up @@ -814,6 +816,12 @@ def control_weight_fn(_: int) -> float:

self.control_str = last_control

# Clean up memory after test_all() which creates temporary PromptManagers
del model_tests
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

return self.control_str, loss, steps

def test(
Expand Down Expand Up @@ -1644,7 +1652,12 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any]
ob, fn, args, kwargs = task
if fn == "grad":
with torch.enable_grad(): # type: ignore[no-untyped-call, unused-ignore]
results.put(ob.grad(*args, **kwargs))
result = ob.grad(*args, **kwargs)
results.put(result)
del result
# Clear CUDA cache after gradient computation to prevent memory accumulation
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
with torch.no_grad():
if fn == "logits":
Expand All @@ -1657,6 +1670,9 @@ def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any]
results.put(ob.test_loss(*args, **kwargs))
else:
results.put(fn(*args, **kwargs))
# Clean up the task object to free memory
del ob
gc.collect()
tasks.task_done()

def start(self) -> "ModelWorker":
Expand Down
20 changes: 19 additions & 1 deletion pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,17 @@ def token_gradients(

loss.backward()

return one_hot.grad.clone()
# Clone and detach the gradient to break the computation graph
grad = one_hot.grad.clone().detach()

# Explicitly clear references to free memory
del one_hot, input_embeds, embeds, full_embeds, logits, targets, loss

# Clear CUDA cache to release GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()

return grad


class GCGAttackPrompt(AttackPrompt):
Expand Down Expand Up @@ -144,9 +154,11 @@ def step(
j - 1, control_cand, filter_cand=filter_cand, curr_control=self.control_str
)
)
del grad # Explicitly delete old grad before reassignment
grad = new_grad
else:
grad += new_grad
del new_grad # Clean up new_grad after use

with torch.no_grad():
control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
Expand All @@ -155,6 +167,8 @@ def step(
)
del grad, control_cand
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Search
loss = torch.zeros(len(control_cands) * batch_size).to(main_device)
Expand Down Expand Up @@ -192,6 +206,10 @@ def step(
f"loss={loss[j * batch_size : (j + 1) * batch_size].min().item() / (i + 1):.4f}" # type: ignore[operator]
)

# Periodically clear CUDA cache during search to prevent memory buildup
if torch.cuda.is_available():
torch.cuda.empty_cache()

min_idx = loss.argmin()
model_idx = min_idx // batch_size
batch_idx = min_idx % batch_size
Expand Down