本部分介绍有关在无服务器 GPU 计算上加载数据的信息,具体适用于 ML 和 DL 应用程序。 请查看 本教程 ,详细了解如何使用 Spark Python API 加载和转换数据。
加载表格数据
使用 Spark Connect 从 Delta 表加载表格机器学习数据。
对于单节点训练,可以使用 PySpark 方法toPandas()将 Apache Spark 数据帧转换为 pandas 数据帧,然后使用 PySpark 方法to_numpy()选择性地转换为 NumPy 格式。
注释
Spark Connect 将分析和名称解析推迟到执行时间,这可能会更改代码的行为。 请参阅 “将 Spark 连接与 Spark 经典版进行比较”。
Spark Connect 支持大多数 PySpark API,包括 Spark SQL、Spark 上的 Pandas API、结构化流式处理和 MLlib(基于数据帧)。 有关最新支持的 API,请参阅 PySpark API 参考文档 。
有关其他限制,请参阅 无服务器计算限制。
在修饰器中 @distributed 加载数据
使用 无服务器 GPU API 进行分布式训练时,在 @distributed 修饰器内移动数据加载代码。 数据集大小可以超出 pickle 允许的最大大小,因此建议在装饰器内生成数据集,如下所示:
from serverless_gpu import distributed
# this may cause pickle error
dataset = get_dataset(file_path)
@distributed(gpus=8, remote=True)
def run_train():
# good practice
dataset = get_dataset(file_path)
....
数据加载性能
/Workspace 和 /Volumes 目录托管在远程 Unity 目录存储上。 如果数据集存储在 Unity 目录中,则数据加载速度受可用网络带宽的限制。 如果要训练多个轮次,建议先将数据复制到本地,特别是复制到位于超快速存储(NVMe SSD 磁盘)上的/tmp目录。
如果数据集很大,我们还建议使用以下技术来并行化训练和数据加载:
- 训练多个时期时,请在读取每个文件之前将数据集更新为在本地
/tmp目录中缓存文件。 在后续的时期,使用缓存的版本。 - 通过在 torch DataLoader API 中启用工作线程来并行提取数据。 设置
num_workers至少为 2。 默认情况下,每个工作线程会预获取两个工作项。 为了提高性能,可以增加num_workers(这将增加并行读取的数量)或prefetch_factor(这将增加每个工作线程预提取的项数)。