51工具盒子

依楼听风雨
笑看云卷云舒,淡观潮起潮落

PyTorch 初始化 API

网络参数初始化的优劣在极大程度上决定了网络的最终性能。比较推荐的初始化方式有 He 初始化是,将参数初始化为服从高斯分布或均匀分布的较小随机整数,同时对参数方差加以规范化。

处于稳定状态下的神经网络,其参数和数据均值为 0。

PyTorch 中实现了较多的参数初始化 API, 主要如下:

  1. 均匀分布初始化
  2. 正态分布初始化
  3. 全0初始化
  4. 全1初始化
  5. 固定值初始化
  6. kaiming 初始化
  7. xavier 初始化

示例代码:

import torch
import torch.nn.functional as F
import torch.nn as nn


# 1. 均匀分布随机初始化
def test01():

    linear = nn.Linear(5, 3)
    # 从0-1均匀分布产生参数
    nn.init.uniform_(linear.weight)
    print(linear.weight.data)


# 2. 固定初始化
def test02():

    linear = nn.Linear(5, 3)
    nn.init.constant_(linear.weight, 5)
    print(linear.weight.data)


# 3. 全0初始化
def test03():

    linear = nn.Linear(5, 3)
    nn.init.zeros_(linear.weight)
    print(linear.weight.data)


# 4. 全1初始化
def test04():

    linear = nn.Linear(5, 3)
    nn.init.ones_(linear.weight)
    print(linear.weight.data)


# 5. 正态分布随机初始化
def test05():

    linear = nn.Linear(5, 3)
    nn.init.normal_(linear.weight, mean=0, std=1)
    print(linear.weight.data)


# 6. kaiming 初始化
def test06():

    # kaiming 正态分布初始化
    linear = nn.Linear(5, 3)
    nn.init.kaiming_normal_(linear.weight)
    print(linear.weight.data)

    # kaiming 均匀分布初始化
    linear = nn.Linear(5, 3)
    nn.init.kaiming_uniform_(linear.weight)
    print(linear.weight.data)


# 7. xavier 初始化
def test07():

    # xavier 正态分布初始化
    linear = nn.Linear(5, 3)
    nn.init.xavier_normal_(linear.weight)
    print(linear.weight.data)

    # xavier 均匀分布初始化
    linear = nn.Linear(5, 3)
    nn.init.xavier_uniform_(linear.weight)
    print(linear.weight.data)


if __name__ == '__main__':
    test07()

赞(2)
未经允许不得转载:工具盒子 » PyTorch 初始化 API