Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
feabe5a
refactor: removed unnecessary all-reduce ops and improved accuracy of…
le1nux Dec 6, 2025
a4775ad
chore: added documentation and renamed pytorch rms norm key
le1nux Dec 6, 2025
719e35e
feat: added timestamp and dtype to debugged model for input/output ac…
le1nux Dec 6, 2025
03354c1
feat: steppable component can now perform backward pass and optimizer…
le1nux Dec 6, 2025
9e661dc
feat: added fused and foreach options to Adam and AdamW optimizers
le1nux Dec 13, 2025
3518d00
refactor: profilers are now components
le1nux Dec 19, 2025
02e8fdd
feat: logger outputs now rank info
le1nux Dec 19, 2025
37b25d8
refactor: step information in profiling now part of the config instea…
le1nux Dec 19, 2025
52924ea
refactor: added new profiling setup to the profiling tutorial's config
le1nux Dec 19, 2025
3cfa305
refactor: experiments_root_path now passed in from outside
le1nux Dec 21, 2025
361ddc5
feat: profiling now available also in training loop
le1nux Dec 21, 2025
68dd9d2
feat: added memory profiling to kernel profiler
le1nux Dec 22, 2025
e4fe4b0
refactor: added experiments_root_path to warmstart API and improved e…
le1nux Dec 29, 2025
fbab937
refactor: refactored wamstart tutorial scripts
le1nux Dec 29, 2025
6b359c9
chore: Merge remote-tracking branch 'refs/remotes/origin/main'
le1nux Dec 30, 2025
93bd721
chore: Merge branch 'main' into monitoring_improvements
le1nux Dec 30, 2025
6bef8b0
fix: HSDP was not applied at all due to wrong condition check
le1nux Jan 4, 2026
a400a7c
refactor: allow data_parallel_replicate_degree to be -1 for auto-calc…
le1nux Jan 6, 2026
6d0e864
chore: improved device mesh logging
le1nux Jan 16, 2026
b893532
fix: in case of tp, we DP_SHARD > 1. Fixed the validation logic accor…
le1nux Jan 16, 2026
85388cd
fix: tp can now be used with dp_shard or dp_replicate
le1nux Jan 17, 2026
eb747fd
chore: improved tokenizer vocabulary warning
le1nux Jan 23, 2026
ba15c23
feat: added einsum transformer starter scripts
le1nux Jan 23, 2026
5692d6e
feat: added tokenizer to einsum example
le1nux Jan 23, 2026
840ece0
feat: added einsum transformer implementation
le1nux Jan 23, 2026
a6ed891
feat: added einsum transformer collate fn and fsdp wrapping
le1nux Jan 23, 2026
c1825b8
feat: added einsum transformer config
le1nux Jan 23, 2026
81392e3
chore: added example dataset
le1nux Jan 23, 2026
6b058ae
chore: Merge remote-tracking branch 'refs/remotes/origin/main'
le1nux Jan 23, 2026
9420c0b
chore: removed fixme since it was invalid
le1nux Jan 23, 2026
0e153ef
chore: Merge branch 'main' into monitoring_improvements
le1nux Jan 23, 2026
212bd12
fix: fixed merge conflict bug
le1nux Jan 25, 2026
8b9134f
chore: added maybe_model_list to compiled model
le1nux Feb 2, 2026
e07e7d7
refactor: optimized training loop by detaching compute graphs and onl…
le1nux Feb 2, 2026
c3f03e4
feat: added grouped sharding of fsdp units as blocks
le1nux Feb 2, 2026
da6c2d2
chore: added einsum tranformer tutorial documentation
le1nux Feb 5, 2026
b2016ff
chore: fixed paths in einsum transformer tutorial
le1nux Feb 5, 2026
71a40e8
Merge pull request #428 from Modalities/einsum_transformer
le1nux Feb 5, 2026
ade0097
chore: updated test configs to latest component changes
le1nux Feb 10, 2026
e4e7d53
refactor: reverted back to allowing dp_shard only with TP
le1nux Feb 10, 2026
447d8c9
chore: updated optional config parameters
le1nux Feb 10, 2026
e4439cb
fix: all unit and e2e tests running through again
le1nux Feb 10, 2026
94019ca
refactor: all tutorials are running through again
le1nux Feb 11, 2026
76b07cb
chore: fixed checkpointing test
le1nux Feb 11, 2026
4234f55
chore: referencing now modalities preprint in README
le1nux Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,7 @@ config_files/instruction_tuning
data/lorem_ipsum_instruct.jsonl
tutorials/scaling_up/logs*
tutorials/scaling_up/experiments_old/*
results/*
results/*
tutorials/einsum_transformer/experiments/*
tutorials/warmstart/experiments/*

25 changes: 15 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

Modalities is a PyTorch-native framework for distributed training of Large Language Models (LLMs) and Foundation Models (FMs) at scale. Given the complexity of distributed training and rapid advancements in the field, we aim to provide a flexible and easy-to-use framework that enables researchers and practitioners to train and evaluate LLMs and FMs efficiently. Modalities is built on top of PyTorch and leverages the latest advancements in distributed training, such as Fully Sharded Data Parallel (FSDP), mixed precision training, Flash Attention and many more, to achieve state-of-the-art performance and throughput.

For a technical report on the archictecture and latest benchmarks, check out our [Modalities pre-print](https://arxiv.org/abs/2602.08387).

We successfully scaled Modalities up to 2048 GPUs on two HPC centers, namely [Leonardo Booster](https://leonardo-supercomputer.cineca.eu/hpc-system/) and [MareNostrum 5](https://www.bsc.es/ca/marenostrum/marenostrum-5), featuring Nvidia A100 and H100 GPUs, respectively. The results of our scaling experiments can be found [here](#scaling-experiments).

Besides its scalabilty, Modalities allows to seamlessly integrate new components and features, such as custom attention mechanisms, loss functions, optimizers or models. We provide a series of tutorials to help you get started with training and evaluating models using Modalities. We achieve this level of extensibility by having clear interfaces for each component type (e.g., model, optimizer, etc.), that a component must implement to be registered within Modalities at runtime.
Expand Down Expand Up @@ -277,7 +279,7 @@ In the following, we list the most important features of Modalities.
| Flash Attention | supported | A highly optimized attention mechanism that significantly reduces the computational burden and memory footprint of attention calculations, enabling faster training and inference on large models. |
| Tensor Parallelism | supported | Implementing vertical model sharding, as an efficient model parallelism technique |
| Sequence Parallelism | supported | Variant of Tensor Parallelism that shard on the sequence dimension |
| Pipeline Parallelism | supported | Support for GPipe. Alternative schedules such as (interleaved) 1F1B are being implemented. |
| Pipeline Parallelism | supported | Beta-level support for schedules such as GPipe, (interleaved) 1F1B and DualPipe. |
| FSDP 2 | supported | Improved version of the original FSDP |
| Torch Compile | supported | Speeds up tensor operations by JIT compiling tensor operations into optimized kernels |
| Deferred Initialisation | supported | Instead of instantiating the model in CPU RAM, the modules are instantiated as fake tensors and operations are recorded. Once sharded (e.g., via FSDP), each rank only instantiates the local tensors by replaying the tensor operations. |
Expand Down Expand Up @@ -390,19 +392,22 @@ Further scaling results can be found at [MareNostrum5 Scaling Experiments](https
Modalities welcomes your contributions! Please check out our
[contributing](CONTRIBUTING.md) guidelines regarding the details on formatting, testing,
etc.<br/><br/><br/>
Thanks so much to all of our amazing contributors!
Thanks so much to all of our contributors and collaborators!

<a href="https://github.com/modalities/modalities/graphs/contributors">
<img src="https://contrib.rocks/image?repo=modalities/modalities&r=" width="800px"/>
</a>

## Citation

@misc{modalities,
title={Modalities: A PyTorch-native framework for distributed and reproducible foundation model training.},
author={Lübbering, Max and Ali, Mehdi and Stollenwerk, Felix and Fromm, Michael and Weber, Alexander Arno and Rutmann, Richard},
year={2024},
howpublished={\url{https://github.com/Modalities/modalities}},
url="https://github.com/Modalities/modalities",
}

```
@misc{luebbering2026modalitiespytorchnativeframeworklargescale,
title={Modalities, a PyTorch-native Framework For Large-scale LLM Training and Research},
author={Max Lübbering and Timm Ruland and Richard Rutmann and Felix Stollenwerk and David Fitzek and Michael Fromm and Alexander Weber and Rafet Sifa and Nicolas Flores-Herr and Joachim Köhler and Mehdi Ali},
year={2026},
eprint={2602.08387},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2602.08387},
}
```
9 changes: 6 additions & 3 deletions config_files/training/config_lorem_ipsum_long_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ train_dataset:
config:
raw_data_path: ${settings.paths.train_dataset_path}
sequence_length: ${settings.step_profile.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}
sample_key: ${settings.referencing_keys.sample_key}

train_dataloader:
component_key: data_loader
Expand Down Expand Up @@ -195,7 +195,7 @@ app_state:
component_key: app_state
variant_key: raw
config:
model:
model:
instance_key: initialized_model
pass_type: BY_REFERENCE
optimizer:
Expand Down Expand Up @@ -305,7 +305,7 @@ optimizer:
eps: 1e-8
weight_decay: 1e-1
weight_decay_groups_excluded: [embedding, layernorm]
wrapped_model:
wrapped_model:
instance_key: initialized_model
pass_type: BY_REFERENCE

Expand All @@ -318,6 +318,9 @@ gradient_clipper:
pass_type: BY_REFERENCE
norm_type: P2_NORM
max_norm: 1.0
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE

progress_subscriber:
component_key: progress_subscriber
Expand Down
145 changes: 70 additions & 75 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def main() -> None:
help="Path to the YAML training config file.",
)
@click.option(
"--test_comm",
is_flag=True,
default=False,
help="If set, run a communication test before training.",
"--experiments_root_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the root directory where experiment folders will be created.",
)
@click.option(
"--experiment_id",
Expand All @@ -71,61 +71,51 @@ def main() -> None:
default=None,
help="Optional path to a folder where error logs will be written.",
)
@click.option(
"--test_comm",
is_flag=True,
default=False,
help="If set, run a communication test before training.",
)
def CMD_entry_point_run_modalities(
config_file_path: Path,
test_comm: bool = False,
experiments_root_path: Path,
experiment_id: Optional[str] = None,
error_log_folder: Optional[Path] = None,
test_comm: bool = False,
):
"""Entrypoint to run the model training.

Args:
config_file_path (Path): Path to the YAML training config file.
test_comm (bool): If set, run a communication test before training.
experiments_root_path (Path): Path to the root directory where experiment folders will be created.
experiment_id (Optional[str]): Optional experiment ID to use for this run.
If not provided it will be generated. Default is None.
error_log_folder (Optional[Path]): Optional path to a folder where error logs will be written.
test_comm (bool): If set, run a communication test before training.
"""

def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
# Format an exception into a structured JSON string with error message, type, and stack trace.
error = {
"error": str(e),
"type": type(e).__name__,
"stacktrace": traceback.format_exception(type(e), e, e.__traceback__),
}

return json.dumps({"environment": environment, "error": error}, indent=2)

try:
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
if test_comm:
print_rank_0("Running communication test...")
run_communication_test()
print_rank_0("Communication test succeeded.")

main_obj = Main(config_file_path, experiment_id=experiment_id)
main_obj = Main(config_file_path, experiments_root_path=experiments_root_path, experiment_id=experiment_id)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
except Exception as e:
if error_log_folder is not None:
environment = {
"rank": int(os.environ["RANK"] if "RANK" in os.environ else -1),
"local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1),
"world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1),
"hostname": socket.gethostname(),
}
error_log_folder = (
error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log"
)
error_log_folder.parent.mkdir(parents=True, exist_ok=True)
with open(error_log_folder, "w", encoding="utf-8") as f:
f.write(_format_exception_as_json(e, environment))

raise RuntimeError(f"An error occurred while running the training: {e}. ") from e
_exception_handling(e, error_log_folder)


@main.command(name="warmstart")
@click.option(
"--experiments_root_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the root directory where experiment folders will be created.",
)
@click.option(
"--config_file_path",
type=click_pathlib.Path(exists=True),
Expand All @@ -138,10 +128,22 @@ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
required=True,
help="Path to the file containing the model and optimizer checkpoint paths from the last successful checkpoint.",
)
def CMD_entry_point_warmstart_modalities(config_file_path: Path, last_checkpoint_info_file_path: Path):
@click.option(
"--error_log_folder",
type=click_pathlib.Path(),
default=None,
help="Optional path to a folder where error logs will be written.",
)
def CMD_entry_point_warmstart_modalities(
experiments_root_path: Path,
config_file_path: Path,
last_checkpoint_info_file_path: Path,
error_log_folder: Optional[Path] = None,
):
"""Entrypoint to run the model warmstart.

Args:
experiments_root_path (Path): Path to the root directory where experiment folders will be created.
config_file_path (Path): Path to the YAML warmstart config file.
last_checkpoint_info_file_path (Path): Path to the file containing the model and
optimizer checkpoint paths from the last successful checkpoint.
Expand All @@ -159,10 +161,15 @@ def get_last_checkpoint_resolver_fun(var_name: str, last_checkpoint_info_file_pa
get_last_checkpoint_resolver_fun, last_checkpoint_info_file_path=last_checkpoint_info_file_path
)
}
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
main_obj = Main(config_file_path, additional_resolver_funs=resolver_funs)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
try:
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
main_obj = Main(
config_file_path, experiments_root_path=experiments_root_path, additional_resolver_funs=resolver_funs
)
components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel)
main_obj.run(components)
except Exception as e:
_exception_handling(e, error_log_folder)


@main.command(name="generate_text")
Expand Down Expand Up @@ -705,54 +712,42 @@ def profile():
required=True,
help="Path to the experiment output directory.",
)
@click.option(
"--num_wait_steps",
type=int,
default=1,
show_default=True,
help="Number of wait steps to skip in profiling.",
)
@click.option(
"--num_warmup_steps",
type=int,
default=1,
show_default=True,
help="Number of warmup steps to skip in profiling. Already recording but dropping the data.",
)
@click.option(
"--num_measurement_steps",
type=int,
default=3,
show_default=True,
help="Number of steps to measure during profiling.",
)
@click.option(
"--profiled_ranks",
type=str,
default="0",
help="Comma-separated list of profiled ranks (must not have spaces), e.g. --profiled_ranks '2,4,8'",
)
def CMD_entry_point_run_train_step_profiler(
config_file_path: Path,
experiment_root_path: Path,
num_wait_steps: int,
num_warmup_steps: int,
num_measurement_steps: int,
profiled_ranks: str,
):
"""Run train step profiler and write result to JSON if RANK=0."""
profiled_ranks_list = [int(i) for i in profiled_ranks.split(",")] if profiled_ranks != "" else [0]
logger.info(f"Running distributed profiling on ranks {profiled_ranks_list}")

ModalitiesProfilerStarter.run_distributed(
config_file_path=config_file_path,
num_measurement_steps=num_measurement_steps,
num_wait_steps=num_wait_steps,
num_warmup_steps=num_warmup_steps,
experiment_root_path=experiment_root_path,
profiled_ranks=profiled_ranks_list,
)


def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str:
# Format an exception into a structured JSON string with error message, type, and stack trace.
error = {
"error": str(e),
"type": type(e).__name__,
"stacktrace": traceback.format_exception(type(e), e, e.__traceback__),
}
return json.dumps({"environment": environment, "error": error}, indent=2)


def _exception_handling(e: Exception, error_log_folder: Path | None):
if error_log_folder is not None:
environment = {
"rank": int(os.environ["RANK"] if "RANK" in os.environ else -1),
"local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1),
"world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1),
"hostname": socket.gethostname(),
}
error_log_folder = error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log"
error_log_folder.parent.mkdir(parents=True, exist_ok=True)
with open(error_log_folder, "w", encoding="utf-8") as f:
f.write(_format_exception_as_json(e, environment))

raise RuntimeError(f"An error occurred while running the training: {e}. ") from e


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_checkpointing_path(
num_target_tokens=str(num_target_tokens),
)

full_path = Path(self.checkpoint_path, experiment_id, entity_file_name)
full_path = Path(self.checkpoint_path, entity_file_name)
return full_path

@torch.no_grad()
Expand Down Expand Up @@ -224,7 +224,7 @@ def _get_checkpointing_folder_path(
num_target_steps=str(num_target_steps),
num_target_tokens=str(num_target_tokens),
)
full_path = Path(self.checkpoint_path, experiment_id, entity_file_name)
full_path = Path(self.checkpoint_path, entity_file_name)
return full_path

@torch.no_grad()
Expand Down
Loading