> pytorch-fsdp2
Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
curl "https://skillshub.wtf/Orchestra-Research/AI-Research-SKILLs/pytorch-fsdp2?format=md"Skill: Use PyTorch FSDP2 (fully_shard) correctly in a training script
This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily via
torch.distributed.fsdp.fully_shardand theFSDPModulemethods it adds in-place to modules. See:references/pytorch_fully_shard_api.md,references/pytorch_fsdp2_tutorial.md.
When to use this skill
Use FSDP2 when:
- Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
- You may later compose DP with Tensor Parallel using DeviceMesh.
Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.
Alternatives (when FSDP2 is not the best fit)
- DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference: references/pytorch_ddp_notes.md, references/pytorch_fsdp1_api.md.
Contract the agent must follow
- Launch with
torchrunand set the CUDA device per process (usually viaLOCAL_RANK). - Apply
fully_shard()bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module. - Call
model(input), notmodel.forward(input), so the FSDP2 hooks run (unless you explicitlyunshard()or register the forward method). - Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-
fully_shard). - Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve
torch.save(model.state_dict())unless you deliberately gather to full tensors.
(Each of these rules is directly described in the official API docs/tutorial; see references.)
Step-by-step procedure
0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use
torchrun --nproc_per_node <gpus_per_node> ...and ensureRANK,WORLD_SIZE,LOCAL_RANKare visible.
Reference: references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).
1) Initialize distributed and set device
Minimal, correct pattern:
dist.init_process_group(backend="nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))- Optionally create a
DeviceMeshto describe the data-parallel group(s)
Reference: references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).
2) Build model on meta device (recommended for very large models)
For big models, initialize on meta, apply sharding, then materialize weights on GPU:
with torch.device("meta"): model = ...- apply
fully_shard(...)on submodules, thenfully_shard(model) model.to_empty(device="cuda")model.reset_parameters()(or your init routine)
Reference: references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).
3) Apply fully_shard() bottom-up (wrapping policy = “apply where needed”)
Do not only call fully_shard on the topmost module.
Recommended sharding pattern for transformer-like models:
- iterate modules,
if isinstance(m, TransformerBlock): fully_shard(m, ...) - then
fully_shard(model, ...)
Why:
fully_shardforms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference: references/pytorch_fully_shard_api.md (bottom-up requirement and why).
4) Configure reshard_after_forward for memory/perf trade-offs
Default behavior:
NonemeansTruefor non-root modules andFalsefor root modules (good default).
Heuristics:
- If you’re memory-bound: keep defaults or force
Trueon many blocks. - If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often
False). - Advanced: use an
intto reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference: references/pytorch_fully_shard_api.md (full semantics).
5) Mixed precision & offload (optional but common)
FSDP2 uses:
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)offload_policy=CPUOffloadPolicy()if you want CPU offload
Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep
reduce_dtypealigned with your gradient reduction expectations. - If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference: references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).
6) Optimizer, gradient clipping, accumulation
- Create the optimizer after sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
- use the FSDP2 mechanism (
set_requires_gradient_sync) instead of FSDP1’sno_sync().
- use the FSDP2 mechanism (
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference: references/pytorch_fsdp2_tutorial.md.
7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces multiple files (often at least one per rank) and operates “in place”.
B) Distributed state dict helpers
get_model_state_dict/set_model_state_dictwithStateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)- For optimizer:
get_optimizer_state_dict/set_optimizer_state_dict
Avoid:
- Saving DTensor state dicts with plain
torch.saveunless you intentionally convert withDTensor.full_tensor()and manage memory carefully.
References:
references/pytorch_dcp_overview.md(DCP behavior and caveats)references/pytorch_dcp_recipe.mdandreferences/pytorch_dcp_async_recipe.md(end-to-end usage)references/pytorch_fsdp2_tutorial.md(DTensor vs DCP state-dict flows)references/pytorch_examples_fsdp2.md(working checkpoint scripts)
Workflow checklists (copy-paste friendly)
Workflow A: Retrofit FSDP2 into an existing training script
- Launch with
torchrunand initialize the process group. - Set the CUDA device from
LOCAL_RANK; create aDeviceMeshif you need multi-dim parallelism. - Build the model (use
metaif needed), applyfully_shardbottom-up, thenfully_shard(model). - Create the optimizer after sharding so it captures DTensor parameters.
- Use
model(inputs)so hooks run; useset_requires_gradient_syncfor accumulation. - Add DCP save/load via
torch.distributed.checkpointhelpers.
Reference: references/pytorch_fsdp2_tutorial.md, references/pytorch_fully_shard_api.md, references/pytorch_device_mesh_tutorial.md, references/pytorch_dcp_recipe.md.
Workflow B: Add DCP save/load (minimal pattern)
- Wrap state in
Statefulor assemble state viaget_state_dict. - Call
dcp.save(...)from all ranks to a shared path. - Call
dcp.load(...)and restore withset_state_dict. - Validate any resharding assumptions when loading into a different mesh.
Reference: references/pytorch_dcp_recipe.md.
Debug checklist (what the agent should check first)
- All ranks on distinct GPUs?
If not, verifytorch.cuda.set_device(LOCAL_RANK)and yourtorchrunflags. - Did you accidentally call
forward()directly?
Usemodel(input)or explicitlyunshard()/ register forward. - Is
fully_shard()applied bottom-up?
If only root is sharded, expect worse memory/perf and possible confusion. - Optimizer created at the right time?
Must be built on DTensor parameters after sharding. - Checkpointing path consistent?
- If using DCP, don’t mix with ad-hoc
torch.saveunless you understand conversions. - Be mindful of PyTorch-version compatibility warnings for DCP.
- If using DCP, don’t mix with ad-hoc
Common issues and fixes
- Forward hooks not running → Call
model(inputs)(orunshard()explicitly) instead ofmodel.forward(...). - Optimizer sees non-DTensor params → Create optimizer after all
fully_shardcalls. - Only root module sharded → Apply
fully_shardbottom-up on submodules before the root. - Memory spikes after forward → Set
reshard_after_forward=Truefor more modules. - Gradient accumulation desync → Use
set_requires_gradient_syncinstead of FSDP1’sno_sync().
Reference: references/pytorch_fully_shard_api.md, references/pytorch_fsdp2_tutorial.md.
Minimal reference implementation outline (agent-friendly)
The coding agent should implement a script with these labeled blocks:
init_distributed(): init process group, set devicebuild_model_meta(): model on meta, applyfully_shard, materialize weightsbuild_optimizer(): optimizer created after shardingtrain_step(): forward/backward/step withmodel(inputs)and DTensor-aware patternscheckpoint_save/load(): DCP or distributed state dict helpers
Concrete examples live in references/pytorch_examples_fsdp2.md and the official tutorial reference.
References
references/pytorch_fsdp2_tutorial.mdreferences/pytorch_fully_shard_api.mdreferences/pytorch_ddp_notes.mdreferences/pytorch_fsdp1_api.mdreferences/pytorch_device_mesh_tutorial.mdreferences/pytorch_tp_tutorial.mdreferences/pytorch_dcp_overview.mdreferences/pytorch_dcp_recipe.mdreferences/pytorch_dcp_async_recipe.mdreferences/pytorch_examples_fsdp2.mdreferences/torchtitan_fsdp_notes.md(optional, production notes)references/ray_train_fsdp2_example.md(optional, integration example)
> related_skills --same-repo
> creative-thinking-for-research
Applies cognitive science frameworks for creative thinking to CS and AI research ideation. Use when seeking genuinely novel research directions by leveraging combinatorial creativity, analogical reasoning, constraint manipulation, and other empirically grounded creative strategies.
> brainstorming-research-ideas
Guides researchers through structured ideation frameworks to discover high-impact research directions. Use when exploring new problem spaces, pivoting between projects, or seeking novel angles on existing work.
> ml-paper-writing
Write publication-ready ML/AI/Systems papers for NeurIPS, ICML, ICLR, ACL, AAAI, COLM, OSDI, NSDI, ASPLOS, SOSP. Use when drafting papers from research repos, structuring arguments, verifying citations, or preparing camera-ready submissions. Includes LaTeX templates, reviewer guidelines, and citation verification workflows.
> speculative-decoding
Accelerate LLM inference using speculative decoding, Medusa multiple heads, and lookahead decoding techniques. Use when optimizing inference speed (1.5-3.6× speedup), reducing latency for real-time applications, or deploying models with limited compute. Covers draft models, tree-based attention, Jacobi iteration, parallel token generation, and production deployment strategies.