你可以透過 Serverless GPU Python API 在多個 GPU 上啟動分散式工作負載——無論是在同一節點內還是跨多個節點。 API 提供簡單、統一的介面,可抽象化 GPU 佈建、環境設定和工作負載分配的詳細資料。 只需進行最少的程式碼變更,您就可以從同一筆記本電腦從單一 GPU 訓練無縫遷移到跨遠端 GPU 的分散式執行。
快速入門
用於分散式訓練的無伺服器 GPU API 預裝於 Databricks 筆記型電腦的無伺服器 GPU 運算環境中。 我們建議使用 GPU 環境 4 及以上。 要用它做分散式訓練,匯入並使用 distributed Decorator 來分散你的訓練函數。
以下程式碼片段展示了基本 @distributed使用方式:
# Import the distributed decorator
from serverless_gpu import distributed
# Decorate your training function with @distributed and specify the number of GPUs, the GPU type,
# and whether or not the GPUs are remote
@distributed(gpus=8, gpu_type='A10', remote=True)
def run_train():
...
以下是一個完整範例,將多層感知器(MLP)模型訓練於筆記型電腦中的 8 個 A10 GPU 節點上:
建立你的模型並定義效用函數。
# Define the model import os import torch import torch.distributed as dist import torch.nn as nn def setup(): dist.init_process_group("nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) def cleanup(): dist.destroy_process_group() class SimpleMLP(nn.Module): def __init__(self, input_dim=10, hidden_dim=64, output_dim=1): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, output_dim) ) def forward(self, x): return self.net(x)匯入 serverless_gpu 函式庫與 分散 式模組。
import serverless_gpu from serverless_gpu import distributed把模型訓練程式碼包裝成一個函式,然後用
@distributed裝飾器裝飾這個函式。@distributed(gpus=8, gpu_type='A10', remote=True) def run_train(num_epochs: int, batch_size: int) -> None: import mlflow import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler, TensorDataset # 1. Set up multi node environment setup() device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") # 2. Apply the Torch distributed data parallel (DDP) library for data-parellel training. model = SimpleMLP().to(device) model = DDP(model, device_ids=[device]) # 3. Create and load dataset. x = torch.randn(5000, 10) y = torch.randn(5000, 1) dataset = TensorDataset(x, y) sampler = DistributedSampler(dataset) dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) # 4. Define the training loop. optimizer = optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() for epoch in range(num_epochs): sampler.set_epoch(epoch) model.train() total_loss = 0.0 for step, (xb, yb) in enumerate(dataloader): xb, yb = xb.to(device), yb.to(device) optimizer.zero_grad() loss = loss_fn(model(xb), yb) # Log loss to MLflow metric mlflow.log_metric("loss", loss.item(), step=step) loss.backward() optimizer.step() total_loss += loss.item() * xb.size(0) mlflow.log_metric("total_loss", total_loss) print(f"Total loss for epoch {epoch}: {total_loss}") cleanup()透過呼叫分散式函式並使用使用者定義的參數來執行分散式訓練。
run_train.distributed(num_epochs=3, batch_size=1)執行時,會在筆記本儲存格輸出中產生 MLflow 執行連結。 點擊 MLflow 跑動連結,或在 實驗 面板中找到它,查看跑動結果。
分散式執行細節
無伺服器 GPU API 包含幾個關鍵元件:
- 計算管理器:處理資源分配與管理
- 執行環境:管理 Python 環境與相依關係
- 啟動器:協調作業執行與監控
在分散式模式下運行時:
- 該函式會序列化並分配到指定數量的 GPU 上
- 每個 GPU 都會執行一個參數相同的函式副本
- 環境在所有節點間同步
- 所有顯示卡都會收集並回傳結果
若 remote 設為 True,則工作負載分配至遠端 GPU。 如果 remote 設為 False,則工作負載會運行在連接於目前筆記本的單一 GPU 節點上。 如果節點有多顆 GPU 晶片,所有晶片都會被利用。
此 API 支援熱門的平行訓練程式庫,例如 分散式資料平行 (DDP)、 全分片資料平行 (FSDP)、 DeepSpeed 和 Ray。
你可以在 筆記本範例中找到更多真實的分散式訓練場景,利用各種函式庫。
使用 Ray 啟動
無伺服器 GPU API 也支援使用 Ray 來啟動分散式訓練,並透過 @ray_launch 裝飾器進行操作,該裝飾器是疊加在 @distributed 之上。
每個ray_launch任務首先啟動一個 PyTorch 分散式會合,以決定 Ray 的主工作者並收集 IP 地址。 Rank-0 會啟動 ray start --head (如果啟用則匯出指標)、設定 RAY_ADDRESS,並執行你裝飾的函式作為 Ray 驅動程式。 其他節點會透過 ray start --address 加入,並等待驅動程式寫入完成標記。
更多配置細節:
- 要在每個節點啟用 Ray 系統的度量收集,請使用
RayMetricsMonitorremote=True。 - 在裝飾函式 中 用標準 Ray API 定義 Ray 執行時選項(actors、datasets、placement groups 和 scheduleling)。
- 在裝飾器參數或筆記本環境 外 ,管理叢集範圍的控制項(GPU 數量與類型、遠端與本地模式、非同步行為,以及 Databricks 池環境變數)。
以下範例展示了如何使用 @ray_launch:
from serverless_gpu.ray import ray_launch
@ray_launch(gpus=16, remote=True, gpu_type='A10')
def foo():
import os
import ray
print(ray.state.available_resources_per_node())
return 1
foo.distributed()
完整範例請參考這本 筆記本,它啟動 Ray 以訓練多台 A10 GPU 上的 Resnet18 神經網路。
FAQs
資料載入碼應該放在哪裡?
使用 無伺服器 GPU API 進行分散式訓練時,請將資料載入程式碼移入 @distributed 裝飾器內部。 資料集大小可能會超過 pickle 所允許的最大限制,因此建議在裝飾器中生成資料集,如下所示:
from serverless_gpu import distributed
# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
# good practice
dataset = get_dataset(file_path)
....
我可以使用保留的 GPU 池嗎?
如果你的工作區有可用的保留GPU池(請向管理員確認),並且你在@distributed 裝飾器中將remote 指定為True,則工作負載會預設在保留的GPU池上啟動。 如果你想使用隨選 GPU 池,請在呼叫分散式函式前先設定環境變數 DATABRICKS_USE_RESERVED_GPU_POOL 為 , False 如下所示:
import os
os.environ['DATABRICKS_USE_RESERVED_GPU_POOL'] = 'False'
@distributed(gpus=8, remote=True)
def run_train():
...
瞭解詳情
關於 API 參考,請參考 Serverless GPU Python API 文件。