显存使用分析(PyTorch)
我们一直使用 PyTorch 进行模型训练,有时会出现显存不足的情况。除了找到对应的解决办法,比如:累加梯度、使用自动混合精度,还应该了解训练时,显存究竟在哪些环节被大量占用。主要有以下四个环节: 1. CUDA 运行内存 2. 模型的固定参数 3. 模型的前向计算 4. 模型的反向计算 5. 优化方法统计量 1. CUDA 运行内存 {#title-0} ========...
我们一直使用 PyTorch 进行模型训练,有时会出现显存不足的情况。除了找到对应的解决办法,比如:累加梯度、使用自动混合精度,还应该了解训练时,显存究竟在哪些环节被大量占用。主要有以下四个环节: 1. CUDA 运行内存 2. 模型的固定参数 3. 模型的前向计算 4. 模型的反向计算 5. 优化方法统计量 1. CUDA 运行内存 {#title-0} ========...