Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 12 additions & 67 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,9 @@ jobs:
- "3.9"
- "3.10"
steps:
- name: Free Disk Space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker image prune --all --force
df -h
- uses: actions/checkout@v3
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Set up Python
uses: actions/setup-python@v4
with:
Expand All @@ -80,11 +69,11 @@ jobs:
- name: Install dependencies
run: |
poetry check --lock
poetry install --sync --with dev
poetry install --with dev
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
pip install huggingface_hub==0.33.0
huggingface-cli login --token "$HF_TOKEN"
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -101,14 +90,6 @@ jobs:
name: Code Checks
runs-on: ubuntu-latest
steps:
- name: Free Disk Space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker image prune --all --force
df -h
- uses: actions/checkout@v3
- name: Install Poetry
uses: snok/install-poetry@v1
Expand All @@ -117,9 +98,6 @@ jobs:
virtualenvs-in-project: true
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "poetry"
- name: Cache Models used with Tests
uses: actions/cache@v3
with:
Expand All @@ -131,7 +109,7 @@ jobs:
- name: Install dependencies
run: |
poetry check --lock
poetry install --sync --with dev
poetry install --with dev
- name: Check format
run: make check-format
- name: Docstring test
Expand All @@ -141,7 +119,7 @@ jobs:
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
pip install huggingface_hub==0.33.0
huggingface-cli login --token "$HF_TOKEN"
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand All @@ -161,7 +139,6 @@ jobs:
name: Notebook Checks
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
notebook:
# - "Activation_Patching_in_TL_Demo"
Expand All @@ -181,60 +158,28 @@ jobs:
- "Patchscopes_Generation_Demo"
# - "T5"
steps:
- name: Free Disk Space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker image prune --all --force
df -h
- uses: actions/checkout@v3
- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
# NOTE: Poetry cache disabled - causes huggingface-hub version conflicts
cache: "poetry"
- name: Re-use HuggingFace models cache
uses: actions/cache/restore@v3
with:
path: ~/.cache/huggingface/hub
key: ${{ runner.os }}-huggingface-models
- name: Install dependencies
run: |
poetry check --lock
poetry install --sync --with dev,jupyter
- name: Verify huggingface-hub version after install
run: |
VERSION=$(poetry run python -c "import huggingface_hub; print(huggingface_hub.__version__)")
echo "huggingface-hub version after poetry install: $VERSION"
poetry install --with dev,jupyter
- name: Install pandoc
uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: pandoc
version: 1.0
- name: Register Poetry venv as Jupyter kernel
run: |
poetry run python -m ipykernel install --user --name=poetry-env
- name: Ensure correct huggingface-hub version
run: |
# Force install the exact version from poetry.lock (0.33.0)
# transformers 4.46.3 requires huggingface-hub>=0.23.2,<1.0
poetry run pip install --force-reinstall --no-deps huggingface-hub==0.33.0
- name: Verify huggingface-hub version
run: |
VERSION=$(poetry run python -c "import huggingface_hub; print(huggingface_hub.__version__)")
echo "huggingface-hub version: $VERSION"
if [[ "$VERSION" == 1.* ]]; then
echo "ERROR: huggingface-hub version 1.x detected, but <1.0 is required!"
exit 1
fi
- name: Final version check before pytest
run: |
echo "=== Environment check ==="
poetry run which python
poetry run pip show huggingface-hub | grep Version
poetry run python -c "import transformers; print('transformers OK')"
- name: Check Notebook Output Consistency
# Note: currently only checks notebooks we have specifically setup for this
run: poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/${{ matrix.notebook }}.ipynb
Expand Down Expand Up @@ -272,7 +217,7 @@ jobs:
- name: Authenticate HuggingFace CLI
if: env.HF_TOKEN != ''
run: |
pip install huggingface_hub
pip install huggingface_hub==0.33.0
huggingface-cli login --token "$HF_TOKEN"
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/test_activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_logit_attrs_works_for_all_input_shapes():
tokens=answer_tokens[:, 0],
incorrect_tokens=answer_tokens[:, 1],
)
assert torch.isclose(ref_logit_diffs, logit_diffs).all()
assert torch.isclose(ref_logit_diffs, logit_diffs, atol=1.1e-7).all()

# Single token
batch = -1
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/test_hooked_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_cross_attention(our_model, huggingface_model, hello_world_tokens, decod
huggingface_cross_attn_out = huggingface_cross_attn(
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-3, atol=1e-4)


def test_cross_attention_layer(our_model, huggingface_model, hello_world_tokens, decoder_input_ids):
Expand Down
30 changes: 19 additions & 11 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,26 +524,34 @@ def logit_attrs(
if not isinstance(batch_slice, Slice):
batch_slice = Slice(batch_slice)

if isinstance(tokens, str):
tokens = torch.as_tensor(self.model.to_single_token(tokens))
# Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
tokens_for_shape_check = tokens

elif isinstance(tokens, int):
tokens = torch.as_tensor(tokens)
if isinstance(tokens_for_shape_check, str):
tokens_for_shape_check = torch.as_tensor(
self.model.to_single_token(tokens_for_shape_check)
)
elif isinstance(tokens_for_shape_check, int):
tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check)

logit_directions = self.model.tokens_to_residual_directions(tokens)

if incorrect_tokens is not None:
if isinstance(incorrect_tokens, str):
incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens))
# Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
incorrect_tokens_for_shape_check = incorrect_tokens

elif isinstance(incorrect_tokens, int):
incorrect_tokens = torch.as_tensor(incorrect_tokens)
if isinstance(incorrect_tokens_for_shape_check, str):
incorrect_tokens_for_shape_check = torch.as_tensor(
self.model.to_single_token(incorrect_tokens_for_shape_check)
)
elif isinstance(incorrect_tokens_for_shape_check, int):
incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check)

if tokens.shape != incorrect_tokens.shape:
if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape:
raise ValueError(
f"tokens and incorrect_tokens must have the same shape! \
(tokens.shape={tokens.shape}, \
incorrect_tokens.shape={incorrect_tokens.shape})"
(tokens.shape={tokens_for_shape_check.shape}, \
incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})"
)

# If incorrect_tokens was provided, take the logit difference
Expand Down
7 changes: 4 additions & 3 deletions transformer_lens/utilities/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import nn

import transformer_lens
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig

AvailableDeviceMemory = list[tuple[int, int]]
"""
Expand Down Expand Up @@ -83,11 +84,11 @@ def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.d
return torch.device("cuda", sorted_devices[0][0])


def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -> torch.device:
def get_best_available_device(cfg: HookedTransformerConfig) -> torch.device:
"""Gets the best available device to be used based on the passed in arguments

Args:
device (Union[torch.device, str]): Either the existing torch device or the string identifier
cfg (HookedTransformerConfig): Model and device configuration.

Returns:
torch.device: The best available device
Expand All @@ -103,7 +104,7 @@ def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -

def get_device_for_block_index(
index: int,
cfg: "transformer_lens.HookedTransformerConfig",
cfg: HookedTransformerConfig,
device: Optional[Union[torch.device, str]] = None,
):
"""
Expand Down
Loading