PyTorch FSDP(Fully Sharded Data Parallel)の完全実装ガイド:大規模言語モデル時代の分散学習戦略

序論

現代のAI開発において、大規模言語モデル(LLM)の学習は単一GPUでは実現不可能な規模に達しています。GPT-3の175B パラメータ、PaLMの540B パラメータ、そしてGPT-4クラスのモデルでは、従来のData Parallel(DP)やDistributed Data Parallel(DDP)では対応できないメモリ制約が存在します。

PyTorch Fully Sharded Data Parallel(FSDP)は、この課題を解決するために設計された分散学習フレームワークです。FSDPは、Facebookの研究チームがFairScale プロジェクトで開発したFully Sharded Data Parallelism技術をPyTorch 1.11以降で公式サポートしたものです。

本記事では、FSDPの理論的背景から実装の詳細、運用時の最適化戦略まで、実際のプロダクション環境で活用するために必要な全ての知識を体系的に解説します。

第1章:FSDP の理論的基盤とアーキテクチャ

1.1 分散学習パラダイムの進化

従来の分散学習手法には以下の制約がありました:

手法メモリ使用量通信オーバーヘッドスケーラビリティ実装複雑度
Data ParallelN×モデルサイズ低(GPU数に制限)
Distributed Data ParallelN×モデルサイズ中(All-Reduce通信)
Model Parallelモデルサイズ/N高(パイプライン化必要)
FSDPモデルサイズ/N + α非常に高

FSDPは、各GPU上でモデルパラメータを分割(shard)して保持し、必要な計算時のみ他のGPUからパラメータを収集する(gather)することで、メモリ効率性とスケーラビリティを両立しています。

1.2 FSDPの数学的基盤

FSDPの動作原理を数学的に表現すると以下のようになります:

モデルパラメータを $\theta = {\theta_1, \theta_2, …, \theta_N}$ とし、これを$P$個のGPUに分散する場合:

GPU_i が保持するパラメータ: θ_i = {θ_{i×k}, θ_{i×k+1}, ..., θ_{(i+1)×k-1}}
ここで k = N/P (Nはパラメータ数、Pは GPU数)

フォワードパス時の計算:

1. All-Gather: θ_complete = gather(θ_1, θ_2, ..., θ_P)
2. Forward: y = f(x, θ_complete)
3. Parameter Discard: θ_complete を破棄(メモリ解放)

バックワードパス時の計算:

1. All-Gather: θ_complete = gather(θ_1, θ_2, ..., θ_P)
2. Backward: ∇θ_complete = backward(loss, θ_complete)
3. Reduce-Scatter: ∇θ_i = reduce_scatter(∇θ_complete)
4. Parameter Update: θ_i ← θ_i - α × ∇θ_i

この仕組みにより、各GPUは全パラメータの1/P分のメモリ使用量でモデル学習が可能となります。

1.3 FSDPの内部実装メカニズム

FSDPは以下の4つの主要コンポーネントで構成されています:

1. Parameter Sharding Manager 各レイヤーのパラメータを複数のGPU間で分割し、メタデータを管理します。

2. Communication Backend NCCL(NVIDIA Collective Communications Library)を使用したAll-GatherとReduce-Scatter操作を最適化します。

3. Memory Manager 動的なメモリ割り当てと解放を管理し、メモリフラグメンテーションを最小化します。

4. Gradient Synchronization Engine 勾配の同期と精度保持を効率的に実行します。

第2章:FSDP の基本実装

2.1 環境構築と依存関係

FSDPを使用するための基本的な環境設定:

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
    ShardingStrategy,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import os

# 分散学習の初期化
def setup_distributed():
    """分散学習環境の初期化"""
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ['LOCAL_RANK'])
    else:
        # 単一ノード環境での設定
        rank = 0
        world_size = torch.cuda.device_count()
        local_rank = 0
    
    # NCCL バックエンドを使用
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        rank=rank,
        world_size=world_size
    )
    
    torch.cuda.set_device(local_rank)
    return rank, world_size, local_rank

rank, world_size, local_rank = setup_distributed()

2.2 基本的なFSDPモデルの構築

簡単なTransformerモデルでFSDPの基本的な使用方法を示します:

class SimpleTransformerBlock(nn.Module):
    """単純なTransformerブロック"""
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim),
        )
    
    def forward(self, x):
        # Self-Attention with residual connection
        norm_x = self.norm1(x)
        attn_out, _ = self.attn(norm_x, norm_x, norm_x)
        x = x + attn_out
        
        # MLP with residual connection
        x = x + self.mlp(self.norm2(x))
        return x

class SimpleTransformer(nn.Module):
    """単純なTransformerモデル"""
    def __init__(self, vocab_size, seq_len, dim, num_layers, num_heads):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.pos_embedding = nn.Parameter(torch.randn(seq_len, dim))
        
        self.layers = nn.ModuleList([
            SimpleTransformerBlock(dim, num_heads)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size)
    
    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_embedding[:seq_len]
        
        for layer in self.layers:
            x = layer(x)
        
        x = self.norm(x)
        return self.head(x)

# モデルの初期化
model = SimpleTransformer(
    vocab_size=50000,
    seq_len=2048,
    dim=1024,
    num_layers=24,
    num_heads=16
)

# FSDPでラップ
fsdp_model = FSDP(
    model,
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    device_id=local_rank,
    mixed_precision=MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float32,
        buffer_dtype=torch.float16,
    ),
)

2.3 学習ループの実装

FSDPを使用した実際の学習ループ:

def train_fsdp_model(model, train_dataloader, epochs=10):
    """FSDP モデルの学習関数"""
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay=0.01
    )
    
    # 学習率スケジューラ
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs
    )
    
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        
        for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
            # データをGPUに移動
            input_ids = input_ids.cuda(local_rank, non_blocking=True)
            labels = labels.cuda(local_rank, non_blocking=True)
            
            # 勾配をゼロクリア
            optimizer.zero_grad()
            
            # フォワードパス
            logits = model(input_ids)
            
            # 損失計算
            loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            )
            
            # バックワードパス
            loss.backward()
            
            # 勾配クリッピング(重要)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # パラメータ更新
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # ログ出力(rank 0のみ)
            if rank == 0 and batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        scheduler.step()
        
        # エポック終了時の統計情報
        avg_loss = total_loss / num_batches
        if rank == 0:
            print(f"Epoch {epoch} finished. Average Loss: {avg_loss:.4f}")

# ダミーデータローダーの作成(実際の実装では適切なデータセットを使用)
def create_dummy_dataloader(batch_size=8, seq_len=2048, vocab_size=50000):
    """学習用のダミーデータローダー"""
    dataset_size = 1000
    
    input_ids = torch.randint(0, vocab_size, (dataset_size, seq_len))
    labels = torch.randint(0, vocab_size, (dataset_size, seq_len))
    
    dataset = torch.utils.data.TensorDataset(input_ids, labels)
    
    # DistributedSampler を使用してデータを分散
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
    )
    
    return dataloader

# 学習実行
train_dataloader = create_dummy_dataloader()
train_fsdp_model(fsdp_model, train_dataloader)

第3章:高度な設定とカスタマイゼーション

3.1 シャーディング戦略の詳細

FSDPでは複数のシャーディング戦略が利用可能です:

from torch.distributed.fsdp import ShardingStrategy

# 1. FULL_SHARD: 完全なパラメータシャーディング(最もメモリ効率的)
full_shard_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
)

# 2. SHARD_GRAD_OP: 勾配とオプティマイザ状態のみシャーディング
shard_grad_op_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
)

# 3. NO_SHARD: シャーディングなし(DDPと同等)
no_shard_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.NO_SHARD,
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
)

# 4. HYBRID_SHARD: ノード内でのみシャーディング
hybrid_shard_model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
)

各戦略のメモリ使用量とパフォーマンス特性:

戦略メモリ削減率通信量計算効率適用場面
FULL_SHARD最大(1/N)超大規模モデル
SHARD_GRAD_OP中(約1/2)中規模モデル
NO_SHARDなし最高小規模モデル・デバッグ
HYBRID_SHARD高(1/K, Kはノード内GPU数)マルチノード環境

3.2 CPUオフロード機能

CPU オフロードを使用することで、さらなるメモリ削減が可能です:

from torch.distributed.fsdp import CPUOffload

# パラメータのみCPUオフロード
cpu_offload_params = FSDP(
    model,
    cpu_offload=CPUOffload(offload_params=True),
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
)

# パラメータと勾配の両方をCPUオフロード
cpu_offload_full = FSDP(
    model,
    cpu_offload=CPUOffload(offload_params=True),
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
)

# カスタム CPUオフロード設定
class CustomCPUOffload:
    """カスタム CPUオフロード設定"""
    def __init__(self, offload_threshold_mb=1000):
        self.offload_threshold_mb = offload_threshold_mb
    
    def should_offload(self, module):
        """モジュールサイズに基づいてオフロード判定"""
        param_size_mb = sum(p.numel() * p.element_size() for p in module.parameters()) / (1024 * 1024)
        return param_size_mb > self.offload_threshold_mb

# 実装例:選択的CPUオフロード
def selective_cpu_offload_wrapper(model, threshold_mb=1000):
    """大きなレイヤーのみCPUオフロードを適用"""
    custom_offload = CustomCPUOffload(threshold_mb)
    
    for name, module in model.named_modules():
        if custom_offload.should_offload(module):
            print(f"Applying CPU offload to {name}")
            module = FSDP(
                module,
                cpu_offload=CPUOffload(offload_params=True),
            )
    
    return model

3.3 混合精度設定の最適化

FSDPでは細かい混合精度制御が可能です:

from torch.distributed.fsdp import MixedPrecision

# 基本的な混合精度設定
basic_mixed_precision = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float16,
)

# より安定した学習のための設定
stable_mixed_precision = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float32,  # 勾配集約は float32 で実行
    buffer_dtype=torch.float32,  # バッファは float32 で保持
    cast_forward_inputs=True,    # 入力の自動キャスト
    cast_root_forward_inputs=True,  # ルートモジュール入力のキャスト
)

# BFloat16 を使用した設定(A100などでサポート)
bfloat16_mixed_precision = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)

# モデルに適用
fsdp_model_mixed = FSDP(
    model,
    mixed_precision=stable_mixed_precision,
    auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=1e6),
    sharding_strategy=ShardingStrategy.FULL_SHARD,
)

3.4 カスタムラップポリシー

効率的なパフォーマンスを得るために、カスタムラップポリシーを定義できます:

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial

# Transformer ベースモデル用の特化ポリシー
transformer_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={
        SimpleTransformerBlock,  # 自定義のTransformerブロック
    },
)

# サイズベースの詳細設定
size_wrap_policy = partial(
    size_based_auto_wrap_policy,
    min_num_params=1e6,  # 100万パラメータ以上のモジュールをラップ
)

# カスタムラップ条件
def custom_wrap_policy(module, recurse, nonwrapped_numel):
    """カスタムラップポリシー"""
    # LayerNorm は常にラップしない
    if isinstance(module, nn.LayerNorm):
        return False
    
    # 埋め込み層は常にラップ
    if isinstance(module, nn.Embedding):
        return True
    
    # その他はサイズベースで判定
    return nonwrapped_numel >= 1e6

# 適用例
custom_fsdp_model = FSDP(
    model,
    auto_wrap_policy=custom_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
)

第4章:パフォーマンス最適化とチューニング

4.1 通信効率の最適化

FSDPのパフォーマンスは通信効率に大きく依存します:

from torch.distributed.fsdp import BackwardPrefetch

# バックワードプリフェッチの設定
def create_optimized_fsdp_model(model, strategy="backward_pre"):
    """最適化されたFSDPモデルの作成"""
    
    # バックワードプリフェッチ戦略
    if strategy == "backward_pre":
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE
    elif strategy == "backward_post":
        backward_prefetch = BackwardPrefetch.BACKWARD_POST
    else:
        backward_prefetch = None
    
    # 最適化されたFSDP設定
    fsdp_config = {
        "sharding_strategy": ShardingStrategy.FULL_SHARD,
        "auto_wrap_policy": size_based_auto_wrap_policy(min_num_params=1e6),
        "backward_prefetch": backward_prefetch,
        "forward_prefetch": True,  # フォワードプリフェッチ有効化
        "limit_all_gathers": True,  # All-Gather制限
        "use_orig_params": False,   # 元パラメータ使用を無効化
        "mixed_precision": MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.float16,
        ),
    }
    
    return FSDP(model, **fsdp_config)

# 通信パターンの分析とプロファイリング
def profile_fsdp_communication(model, input_data, output_file="fsdp_profile.txt"):
    """FSDP通信パターンのプロファイリング"""
    
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        # フォワードパス
        output = model(input_data)
        loss = output.sum()
        
        # バックワードパス
        loss.backward()
    
    # プロファイル結果の保存
    prof.export_chrome_trace(output_file)
    
    # 通信統計の出力
    print("=== Communication Statistics ===")
    key_averages = prof.key_averages()
    for evt in key_averages:
        if "all_gather" in evt.key.lower() or "reduce_scatter" in evt.key.lower():
            print(f"{evt.key}: {evt.cpu_time_total/1000:.2f}ms CPU, "
                  f"{evt.cuda_time_total/1000:.2f}ms CUDA")

4.2 メモリ使用量の最適化

メモリ使用量を最小化するための戦略:

import gc
from torch.distributed.fsdp import StateDictType

class MemoryOptimizedFSDP:
    """メモリ最適化されたFSDPラッパー"""
    
    def __init__(self, model, max_memory_gb=80):
        self.max_memory_gb = max_memory_gb
        self.model = self._create_optimized_model(model)
        self.memory_tracker = []
    
    def _create_optimized_model(self, model):
        """メモリ最適化されたFSDPモデルの作成"""
        return FSDP(
            model,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            cpu_offload=CPUOffload(offload_params=True),
            auto_wrap_policy=size_based_auto_wrap_policy(min_num_params=5e5),
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            mixed_precision=MixedPrecision(
                param_dtype=torch.float16,
                reduce_dtype=torch.float32,
                buffer_dtype=torch.float16,
            ),
            limit_all_gathers=True,
        )
    
    def track_memory_usage(self, step_name):
        """メモリ使用量の追跡"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / (1024**3)  # GB
            reserved = torch.cuda.memory_reserved() / (1024**3)    # GB
            
            self.memory_tracker.append({
                'step': step_name,
                'allocated_gb': allocated,
                'reserved_gb': reserved,
            })
            
            print(f"{step_name}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")
    
    def train_step_with_memory_management(self, batch, optimizer):
        """メモリ管理を考慮した学習ステップ"""
        self.track_memory_usage("before_forward")
        
        # フォワードパス
        output = self.model(batch['input_ids'])
        loss = torch.nn.functional.cross_entropy(
            output.view(-1, output.size(-1)),
            batch['labels'].view(-1)
        )
        
        self.track_memory_usage("after_forward")
        
        # バックワードパス
        loss.backward()
        
        self.track_memory_usage("after_backward")
        
        # オプティマイザーステップ
        optimizer.step()
        optimizer.zero_grad()
        
        # 明示的なメモリクリーンアップ
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        self.track_memory_usage("after_cleanup")
        
        return loss.item()
    
    def save_checkpoint(self, path, optimizer=None):
        """チェックポイントの保存"""
        # CPU への統合状態辞書の保存
        with FSDP.state_dict_type(
            self.model,
            StateDictType.FULL_STATE_DICT,
            state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        ):
            state_dict = self.model.state_dict()
        
        if dist.get_rank() == 0:
            checkpoint = {
                'model_state_dict': state_dict,
                'memory_stats': self.memory_tracker,
            }
            
            if optimizer is not None:
                checkpoint['optimizer_state_dict'] = optimizer.state_dict()
            
            torch.save(checkpoint, path)
            print(f"Checkpoint saved to {path}")

# 使用例
def memory_efficient_training():
    """メモリ効率的な学習の実行例"""
    # ダミーモデルとデータ
    model = SimpleTransformer(vocab_size=50000, seq_len=2048, dim=1024, num_layers=24, num_heads=16)
    
    # メモリ最適化FSDPでラップ
    memory_optimized_model = MemoryOptimizedFSDP(model, max_memory_gb=80)
    
    # オプティマイザー
    optimizer = torch.optim.AdamW(
        memory_optimized_model.model.parameters(),
        lr=1e-4,
        weight_decay=0.01
    )
    
    # 学習データ(ダミー)
    batch = {
        'input_ids': torch.randint(0, 50000, (4, 2048)).cuda(),
        'labels': torch.randint(0, 50000, (4, 2048)).cuda(),
    }
    
    # 学習ステップの実行
    for step in range(100):
        loss = memory_optimized_model.train_step_with_memory_management(batch, optimizer)
        
        if step % 10 == 0:
            print(f"Step {step}: Loss = {loss:.4f}")
        
        # 定期的なチェックポイント保存
        if step % 50 == 0 and step > 0:
            memory_optimized_model.save_checkpoint(f"checkpoint_step_{step}.pt", optimizer)

4.3 動的バッチサイズ調整

メモリ制約に応じた動的バッチサイズ調整:

class AdaptiveBatchSizeFSDP:
    """適応的バッチサイズ調整機能付きFSDP"""
    
    def __init__(self, model, initial_batch_size=8, max_memory_gb=80):
        self.model = model
        self.current_batch_size = initial_batch_size
        self.max_memory_gb = max_memory_gb
        self.memory_threshold = max_memory_gb * 0.9  # 90%で警告
        self.batch_size_history = []
    
    def adjust_batch_size_based_on_memory(self):
        """メモリ使用量に基づいてバッチサイズを調整"""
        if not torch.cuda.is_available():
            return self.current_batch_size
        
        current_memory_gb = torch.cuda.memory_allocated() / (1024**3)
        
        if current_memory_gb > self.memory_threshold:
            # メモリ使用量が高い場合、バッチサイズを減らす
            new_batch_size = max(1, int(self.current_batch_size * 0.75))
            print(f"Memory usage high ({current_memory_gb:.2f}GB), "
                  f"reducing batch size from {self.current_batch_size} to {new_batch_size}")
            self.current_batch_size = new_batch_size
            
        elif current_memory_gb < self.memory_threshold * 0.6:
            # メモリ使用量が低い場合、バッチサイズを増やす
            new_batch_size = int(self.current_batch_size * 1.25)
            print(f"Memory usage low ({current_memory_gb:.2f}GB), "
                  f"increasing batch size from {self.current_batch_size} to {new_batch_size}")
            self.current_batch_size = new_batch_size
        
        self.batch_size_history.append({
            'step': len(self.batch_size_history),
            'batch_size': self.current_batch_size,
            'memory_gb': current_memory_gb,
        })
        
        return self.current_batch_size
    
    def create_dynamic_dataloader(self, dataset):
        """動的バッチサイズ対応データローダー"""
        class DynamicBatchSampler:
            def __init__(self, dataset, fsdp_wrapper):
                self.dataset = dataset
                self.fsdp_wrapper = fsdp_wrapper
                self.indices = list(range(len(dataset)))
            
            def __iter__(self):
                random.shuffle(self.indices)
                i = 0
                while i < len(self.indices):
                    # 現在のバッチサイズを取得
                    current_bs = self.fsdp_wrapper.adjust_batch_size_based_on_memory()
                    batch_indices = self.indices[i:i+current_bs]
                    yield batch_indices
                    i += current_bs
            
            def __len__(self):
                return len(self.indices) // self.current_batch_size
        
        sampler = DynamicBatchSampler(dataset, self)
        
        def collate_fn(batch):
            # バッチデータの処理
            input_ids = torch.stack([item[0] for item in batch])
            labels = torch.stack([item[1] for item in batch])
            return {'input_ids': input_ids, 'labels': labels}
        
        return torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=collate_fn,
            num_workers=2,
            pin_memory=True,
        )

第5章:実際の大規模モデル実装例

5.1 GPT-style モデルの完全実装

実際のプロダクション環境で使用可能なGPTスタイルモデルの実装:

import math
from dataclasses import dataclass
from typing import Optional

@dataclass
class GPTConfig:
    """GPT モデルの設定"""
    vocab_size: int = 50304  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
    block_size: int = 1024

class CausalSelfAttention(nn.Module):
    """Causal Self-Attention with Flash Attention optimization"""
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        # Query, Key, Value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        
        # Regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        # Flash attention compatibility
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # Causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # Calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # Attention (materializes the large (T,T) matrix for all the queries and keys)
        if self.flash:
            # Efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True
            )
        else:
            # Manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # Output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):
    """Multi-Layer Perceptron with GELU activation"""
    
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    """Transformer Block"""
    
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    """GPT Language Model"""
    
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),  # Word Token Embeddings
            wpe = nn.Embedding(config.block_size, config.n_embd),  # Word Position Embeddings
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),  # Final Layer Norm
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying
        self.transformer.wte.weight = self.lm_head.weight

        # Initialize weights
        self.apply(self._init_weights)
        
        # Apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # Report number of parameters
        print("Number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        """Initialize weights according to GPT-2 paper"""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # Forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # If we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # Inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

# FSDP でラップされた GPT モデルの作成
def create_fsdp_gpt_model(config):
    """FSDP でラップされた GPT モデルを作成"""
    
    # モデル初期化
    model = GPT(config)
    
    # カスタムラップポリシー(Blockレベルでラップ)
    def gpt_wrap_policy(module, recurse, nonwrapped_numel):
        """GPT用のカスタムラップポリシー"""
        if isinstance(module, Block):
            return True
        return nonwrapped_numel >= 1e6
    
    # FSDP設定
    fsdp_config = {
        "sharding_strategy": ShardingStrategy.FULL_SHARD,
        "auto_wrap_policy": gpt_wrap_policy,
        "backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
        "mixed_precision": MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.bfloat16,
        ),
        "device_id": local_rank,
        "limit_all_gathers": True,
    }
    
    # FSDPでラップ
    fsdp_model = FSDP(model, **fsdp_config)
    
    return fsdp_model

# 使用例
def train_large_gpt_model():
    """大規模GPTモデルの学習例"""
    
    # 大規模モデル設定
    config = GPTConfig(
        vocab_size=50304,
        n_layer=48,      # 48 layers
        n_head=32,       # 32 attention heads
        n_embd=2048,     # 2048 embedding dimension
        dropout=0.1,
        bias=False,
        block_size=2048, # 2048 context length
    )
    
    # FSDPモデル作成
    fsdp_gpt = create_fsdp_gpt_model(config)
    
    print(f"Model created with {fsdp_gpt.get_num_params()/1e6:.1f}M parameters")
    
    # オプティマイザー設定
    optimizer = torch.optim.AdamW(
        fsdp_gpt.parameters(),
        lr=3e-4,
        weight_decay=0.1,
        betas=(0.9, 0.95),
    )
    
    # 学習率スケジューラー
    def get_lr(step, warmup_steps=2000, max_steps=100000):
        """コサインアニーリング学習率スケジューラー"""
        if step < warmup_steps:
            return 3e-4 * step / warmup_steps
        if step > max_steps:
            return 3e-4 * 0.1
        decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
        return 3e-4 * 0.1 + coeff * (3e-4 - 3e-4 * 0.1)
    
    # 学習ループ
    fsdp_gpt.train()
    for step in range(10000):
        # ダミーデータ
        batch_size = 4
        seq_len = 2048
        x = torch.randint(0, config.vocab_size, (batch_size, seq_len)).cuda()
        y = torch.randint(0, config.vocab_size, (batch_size, seq_len)).cuda()
        
        # 学習率更新
        lr = get_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        # フォワード・バックワード
        optimizer.zero_grad()
        logits, loss = fsdp_gpt(x, y)
        loss.backward()
        
        # 勾配クリッピング
        torch.nn.utils.clip_grad_norm_(fsdp_gpt.parameters(), 1.0)
        
        optimizer.step()
        
        # ログ出力
        if rank == 0 and step % 100 == 0:
            print(f"Step {step}: Loss = {loss.item():.4f}, LR = {lr:.2e}")

5.2 チェックポイント管理とモデル保存

大規模モデルのチェックポイント管理は重要な課題です:

from torch.distributed.fsdp import (
    FullStateDictConfig,
    LocalStateDictConfig,
    ShardedStateDictConfig,
    StateDictType,
)
from torch.distributed.checkpoint import save_state_dict, load_state_dict
import os
from pathlib import Path

class FSDPCheckpointManager:
    """FSDP モデルのチェックポイント管理"""
    
    def __init__(self, model, optimizer, checkpoint_dir="./checkpoints"):
        self.model = model
        self.optimizer = optimizer
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        
    def save_full_checkpoint(self, step, loss, save_optimizer=True):
        """完全なチェックポイントの保存(CPU統合)"""
        checkpoint_path = self.checkpoint_dir / f"checkpoint_step_{step}"
        checkpoint_path.mkdir(exist_ok=True)
        
        # CPU統合状態辞書の設定
        full_state_dict_config = FullStateDictConfig(
            offload_to_cpu=True,
            rank0_only=True
        )
        
        # モデル状態辞書の保存
        with FSDP.state_dict_type(
            self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config
        ):
            model_state_dict = self.model.state_dict()
        
        # rank 0 のみが保存を実行
        if dist.get_rank() == 0:
            checkpoint_data = {
                'step': step,
                'model_state_dict': model_state_dict,
                'loss': loss,
                'config': getattr(self.model, 'config', None),
            }
            
            if save_optimizer:
                checkpoint_data['optimizer_state_dict'] = self.optimizer.state_dict()
            
            torch.save(checkpoint_data, checkpoint_path / "pytorch_model.bin")
            
            # メタデータの保存
            metadata = {
                'step': step,
                'loss': loss,
                'timestamp': str(datetime.now()),
                'world_size': dist.get_world_size(),
            }
            
            with open(checkpoint_path / "metadata.json", 'w') as f:
                json.dump(metadata, f, indent=2)
            
            print(f"Full checkpoint saved at step {step}")
    
    def save_sharded_checkpoint(self, step, loss):
        """シャード化チェックポイントの保存(分散)"""
        checkpoint_path = self.checkpoint_dir / f"sharded_checkpoint_step_{step}"
        checkpoint_path.mkdir(exist_ok=True)
        
        # シャード化状態辞書の設定
        sharded_state_dict_config = ShardedStateDictConfig(
            offload_to_cpu=True
        )
        
        # 分散チェックポイント保存
        with FSDP.state_dict_type(
            self.model, StateDictType.SHARDED_STATE_DICT, sharded_state_dict_config
        ):
            model_state_dict = self.model.state_dict()
            optimizer_state_dict = FSDP.optim_state_dict(self.model, self.optimizer)
            
            state_dict = {
                "model": model_state_dict,
                "optimizer": optimizer_state_dict,
                "step": step,
                "loss": loss,
            }
            
            save_state_dict(
                state_dict=state_dict,
                storage_writer=checkpoint_path,
            )
        
        print(f"Sharded checkpoint saved at step {step}")
    
    def load_full_checkpoint(self, checkpoint_path):
        """完全チェックポイントの読み込み"""
        checkpoint_path = Path(checkpoint_path)
        
        if dist.get_rank() == 0:
            checkpoint = torch.load(checkpoint_path / "pytorch_model.bin", map_location='cpu')
        else:
            checkpoint = None
        
        # 全rankに broadcast
        checkpoint = [checkpoint]
        dist.broadcast_object_list(checkpoint, src=0)
        checkpoint = checkpoint[0]
        
        # モデル状態の復元
        with FSDP.state_dict_type(
            self.model, StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(rank0_only=False)
        ):
            self.model.load_state_dict(checkpoint['model_state_dict'])
        
        # オプティマイザー状態の復元
        if 'optimizer_state_dict' in checkpoint:
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        return checkpoint['step'], checkpoint['loss']
    
    def load_sharded_checkpoint(self, checkpoint_path):
        """シャード化チェックポイントの読み込み"""
        checkpoint_path = Path(checkpoint_path)
        
        with FSDP.state_dict_type(
            self.model, StateDictType.SHARDED_STATE_DICT
        ):
            model_state_dict = self.model.state_dict()
            optimizer_state_dict = FSDP.optim_state_dict(self.model, self.optimizer)
            
            state_dict = {
                "model": model_state_dict,
                "optimizer": optimizer_state_dict,
            }
            
            load_state_dict(
                state_dict=state_dict,
                storage_reader=checkpoint_path,
            )
            
            # 状態の復元
            self.model.load_state_dict(state_dict["model"])
            FSDP.optim_state_dict_to_load(
                self.model, self.optimizer, state_dict["optimizer"]
            )
        
        # メタデータの読み込み
        with open(checkpoint_path / "metadata.json", 'r') as f:
            metadata = json.load(f)
        
        return metadata['step'], metadata['loss']
    
    def cleanup_old_checkpoints(self, keep_last_n=5):
        """古いチェックポイントのクリーンアップ"""
        if dist.get_rank() != 0:
            return
        
        # チェックポイントディレクトリの一覧取得
        checkpoint_dirs = [d for d in self.checkpoint_dir.iterdir() if d.is_dir()]
        checkpoint_dirs.sort(key=lambda x: x.stat().st_mtime, reverse=True)
        
        # 古いチェックポイントを削除
        for old_checkpoint in checkpoint_dirs[keep_last_n:]:
            import shutil
            shutil.rmtree(old_checkpoint)
            print(f"Removed old checkpoint: {old_checkpoint}")

# 使用例
def training_with_checkpointing():
    """チェックポイント機能付きの学習"""
    
    # モデルとオプティマイザーの初期化
    config = GPTConfig(n_layer=24, n_head=16, n_embd=1024)
    model = create_fsdp_gpt_model(config)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    # チェックポイントマネージャー
    checkpoint_manager = FSDPCheckpointManager(model, optimizer)
    
    # 学習開始ステップ
    start_step = 0
    
    # 既存チェックポイントからの復旧(オプション)
    latest_checkpoint = checkpoint_manager.checkpoint_dir / "latest_checkpoint"
    if latest_checkpoint.exists():
        try:
            start_step, last_loss = checkpoint_manager.load_sharded_checkpoint(latest_checkpoint)
            print(f"Resumed from step {start_step}, loss: {last_loss:.4f}")
        except Exception as e:
            print(f"Failed to load checkpoint: {e}")
            start_step = 0
    
    # 学習ループ
    model.train()
    for step in range(start_step, 10000):
        # ダミー学習データ
        batch_size = 4
        seq_len = 1024
        x = torch.randint(0, config.vocab_size, (batch_size, seq_len)).cuda()
        y = torch.randint(0, config.vocab_size, (batch_size, seq_len)).cuda()
        
        # 学習ステップ
        optimizer.zero_grad()
        logits, loss = model(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # 定期的なチェックポイント保存
        if step % 1000 == 0 and step > 0:
            # シャード化チェックポイント(高速)
            checkpoint_manager.save_sharded_checkpoint(step, loss.item())
            
            # latest_checkpoint へのシンボリックリンク更新
            latest_path = checkpoint_manager.checkpoint_dir / "latest_checkpoint"
            if latest_path.exists():
                latest_path.unlink()
            latest_path.symlink_to(f"sharded_checkpoint_step_{step}")
        
        if step % 5000 == 0 and step > 0:
            # 完全チェックポイント(長期保存用)
            checkpoint_manager.save_full_checkpoint(step, loss.item())
            
            # 古いチェックポイントのクリーンアップ
            checkpoint_manager.cleanup_old_checkpoints(keep_last_n=3)
        
        # ログ出力
        if rank == 0 and step % 100 == 0:
            print(f"Step {step}: Loss = {loss.item():.4f}")

第6章:限界とリスク

6.1 技術的限界

FSDPには以下の技術的限界が存在します:

通信オーバーヘッド All-GatherとReduce-Scatter操作により、通信量がDDPと比較して増加します。特に小規模なモデルでは、通信コストが計算コストを上回る場合があります。

メモリアクセスパターン パラメータの動的な収集と破棄により、メモリアクセスパターンが複雑化し、キャッシュ効率が低下する可能性があります。

デバッグの困難性 分散環境での動的なパラメータ管理により、従来のデバッグツールでは問題の特定が困難になる場合があります。

6.2 パフォーマンスリスク

リスク要因影響度対策
通信帯域幅不足InfiniBandなど高速ネットワークの使用
GPU間の非対称性同一スペックGPUの使用
メモリフラグメンテーション定期的なメモリクリーンアップ
勾配同期の遅延バックワードプリフェッチの活用

6.3 運用上のリスク

スケーラビリティの限界 GPU数の増加に伴い、通信オーバーヘッドが指数的に増加する場合があります。特に数千GPU規模では、ネットワークトポロジーの最適化が不可欠です。

障害耐性 単一GPUの障害により、全体の学習プロセスが停止するリスクがあります。チェックポイント戦略と自動復旧機能の実装が重要です。

6.4 不適切なユースケース

以下の場合、FSDPの使用は推奨されません:

小規模モデル(< 1億パラメータ) 通信オーバーヘッドが学習効率を大幅に低下させる可能性があります。

単一GPU環境 FSDPの利点を活用できず、実装の複雑性のみが増加します。

リアルタイム推論 パラメータの動的収集により、推論レイテンシが増加します。推論時は統合モデルの使用を推奨します。

結論

PyTorch FSDPは、大規模言語モデル時代の分散学習において、メモリ効率性とスケーラビリティの課題を解決する重要な技術です。本記事で解説した実装例と最適化戦略を活用することで、従来では不可能だった規模のモデル学習が可能になります。

しかし、FSDPの導入には適切な設計と運用戦略が不可欠です。特に、通信効率の最適化、メモリ管理、チェックポイント戦略は、プロダクション環境での成功に直結する要素です。

今後のAI開発において、FSDPはより大規模で高性能なモデルの実現を支える基盤技術として、その重要性がさらに高まることが予想されます。継続的な技術動向の把握と