在无服务器 GPU 计算上加载数据

本部分介绍有关在无服务器 GPU 计算上加载数据的信息,具体适用于 ML 和 DL 应用程序。 请查看 本教程 ,详细了解如何使用 Spark Python API 加载和转换数据。

加载表格数据

使用 Spark Connect 从 Delta 表加载表格机器学习数据。

对于单节点训练,可以使用 PySpark 方法toPandas()将 Apache Spark 数据帧转换为 pandas 数据帧,然后使用 PySpark 方法to_numpy()选择性地转换为 NumPy 格式。

注释

Spark Connect 将分析和名称解析推迟到执行时间,这可能会更改代码的行为。 请参阅 “将 Spark 连接与 Spark 经典版进行比较”。

Spark Connect 支持大多数 PySpark API,包括 Spark SQL、Spark 上的 Pandas API、结构化流式处理和 MLlib(基于数据帧)。 有关最新支持的 API,请参阅 PySpark API 参考文档

有关其他限制,请参阅 无服务器计算限制

在修饰器中 @distributed 加载数据

使用 无服务器 GPU API 进行分布式训练时,在 @distributed 修饰器内移动数据加载代码。 数据集大小可以超出 pickle 允许的最大大小,因此建议在装饰器内生成数据集,如下所示:

from serverless_gpu import distributed

# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
  # good practice
  dataset = get_dataset(file_path)
  ....

数据加载性能

/Workspace/Volumes 目录托管在远程 Unity 目录存储上。 如果数据集存储在 Unity 目录中,则数据加载速度受可用网络带宽的限制。 如果要训练多个轮次,建议先将数据复制到本地,特别是复制到位于超快速存储(NVMe SSD 磁盘)上的/tmp目录。

如果数据集很大,我们还建议使用以下技术来并行化训练和数据加载:

  • 训练多个时期时,请在读取每个文件之前将数据集更新为在本地 /tmp 目录中缓存文件。 在后续的时期,使用缓存的版本。
  • 通过在 torch DataLoader API 中启用工作线程来并行提取数据。 设置 num_workers 至少为 2。 默认情况下,每个工作线程会预获取两个工作项。 为了提高性能,可以增加num_workers(这将增加并行读取的数量)或prefetch_factor(这将增加每个工作线程预提取的项数)。