Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm backward with mode spatial seems incorrect #1109

Closed
mindest opened this issue Aug 23, 2021 · 17 comments
Closed

BatchNorm backward with mode spatial seems incorrect #1109

mindest opened this issue Aug 23, 2021 · 17 comments

Comments

@mindest
Copy link

mindest commented Aug 23, 2021

Problem: MiOpenBatchNormBackward seems incorrect in mode spatial

Hi, we are trying deploying a onnxruntime training model in a Mi100 cluster. It includes a BatchNorm layer. We use MIOpenBatchNorm with mode spatial when implementing our BatchNorm operator, whose input is of size (N, C). The result should be the same for a two-dim case under mode spatial or per_activation. But we found that when using mode spatial, the loss stops decreasing after a few steps, while changing mode to per_activation resolves it.

MIOpen version

$ apt list | grep miopen
miopen-hip/now 2.11.0.40200-21 amd64 [installed,local]
miopenkernels-gfx900-56kdb/now 1.1.0.40200-21 amd64 [installed,local]
miopenkernels-gfx900-64kdb/now 1.1.0.40200-21 amd64 [installed,local]
miopenkernels-gfx906-60kdb/now 1.1.0.40200-21 amd64 [installed,local]
miopenkernels-gfx906-64kdb/now 1.1.0.40200-21 amd64 [installed,local]
miopenkernels-gfx908-120kdb/now 1.1.0.40200-21 amd64 [installed,local]

Tested torch 1.7.0&1.8.1, same issue. When testing on CUDA, the results (mode spatial/per_activation) are both comparable with cpu.

Built a small example reproducing similar issue, as follows. Would be appreciated for any useful insights, thanks!

simple repro case
import torch
from torch import nn, optim

N = 1024
C = 64
lr = 1e-2 # BatchNorm forward seems okay (by setting lr = 0.0)
steps = 10000

class TestModel(torch.nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        # self.ln1 = nn.Linear(C, C)
        # self.ln1.weight.data.fill_(0.01)
        # self.ln1.bias.data.fill_(0.01)
        self.bn = nn.BatchNorm1d(C)
        # self.ln2 = nn.Linear(C, C)
        # self.ln2.weight.data.fill_(0.02)
        # self.ln2.bias.data.fill_(0.02)

    def forward(self, x):
        # x = self.ln1(x)
        x = self.bn(x)
        # x = self.ln2(x)
        return x

net_cpu = TestModel()
net_spatial = TestModel().cuda()
net_per_act = TestModel().cuda()

criterion = nn.MSELoss()
opt_cpu = optim.Adam(net_cpu.parameters(), lr=lr)
opt_spatial = optim.Adam(net_spatial.parameters(), lr=lr)
opt_per_act = optim.Adam(net_per_act.parameters(), lr=lr)

for i in range(steps):
    # cpu batch norm
    torch.manual_seed(i)
    x_cpu = torch.rand([N, C, 1])
    target_cpu = torch.rand([N, C, 1])
    opt_cpu.zero_grad()
    y_cpu = net_cpu(x_cpu)
    loss_cpu = criterion(y_cpu, target_cpu)
    loss_cpu.backward()
    opt_cpu.step()

    # gpu batch norm, dim size 3, will use mode spatial
    x_spatial = x_cpu.cuda()
    target_spatial = target_cpu.cuda()
    opt_spatial.zero_grad()
    y_spatial = net_spatial(x_spatial)
    loss_spatial = criterion(y_spatial, target_spatial)
    loss_spatial.backward()
    opt_spatial.step()

    # gpu batch norm, dim size 2, will use mode per_activation
    x_per_act = x_spatial.squeeze()
    target_per_act = target_spatial.squeeze()
    opt_per_act.zero_grad()
    y_per_act = net_per_act(x_per_act)
    loss_per_act = criterion(y_per_act, target_per_act)
    loss_per_act.backward()
    opt_per_act.step()

    if i % 50 == 0:
        print(f'==================== step {i} ====================')
        delta_y_cpu_spatial = torch.norm(y_cpu.cuda() - y_spatial).item()
        delta_y_cpu_per_act = torch.norm(y_cpu.cuda().squeeze() - y_per_act).item()
        print('delta_y: cpu vs. gpu(spatial)', delta_y_cpu_spatial, '; cpu vs. gpu(per_act)', delta_y_cpu_per_act)

        delta_bn_weight_cpu_spatial = torch.norm(net_cpu.bn.weight.cuda() - net_spatial.bn.weight).item()
        delta_bn_weight_cpu_per_act = torch.norm(net_cpu.bn.weight.cuda().squeeze() - net_per_act.bn.weight).item()
        print('delta_bn_weight: cpu vs. gpu(spatial)', delta_bn_weight_cpu_spatial,
              '; cpu vs. gpu(per_act)', delta_bn_weight_cpu_per_act)

        delta_bn_bias_cpu_spatial = torch.norm(net_cpu.bn.bias.cuda() - net_spatial.bn.bias).item()
        delta_bn_bias_cpu_per_act = torch.norm(net_cpu.bn.bias.cuda().squeeze() - net_per_act.bn.bias).item()
        print('delta_bn_bias: cpu vs. gpu(spatial)', delta_bn_bias_cpu_spatial,
              '; cpu vs. gpu(per_act)', delta_bn_bias_cpu_per_act)

        delta_loss_cpu_spatial = torch.abs(loss_cpu.cuda() - loss_spatial).item()
        delta_loss_cpu_per_act = torch.abs(loss_cpu.cuda() - loss_per_act).item()
        print('delta_loss: cpu vs. gpu(spatial)', delta_loss_cpu_spatial,
              '; cpu vs. gpu(per_act)', delta_loss_cpu_per_act)
@muralinr
Copy link
Contributor

muralinr commented Aug 23, 2021

Hello,
Can you please let us know batchnorm configs (N,C,H,W) used here for spatial mode? I see N = 1024 and
C = 64. What are H and W?

@muralinr
Copy link
Contributor

I ran the following batchnorm forward and backward spatial mode configs with N = 1024 and C = 64. I don't see any problem.

mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 52 -W 52 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 52 -W 52 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 2.185834 ms
stats: bnormf, 0, 708837888, 708837888, 0, 972.861504, 2.185834
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 52 -W 52 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 52 -W 52 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 708837888, 708837888, 0, 433.094368, 4.910047
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 4.910047 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 104 -W 104 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 104 -W 104 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 2835350016, 2835350016, 0, 457.785312, 18.580872
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 18.580872 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 104 -W 104 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 104 -W 104 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 5.556102 ms
stats: bnormf, 0, 2835350016, 2835350016, 0, 1530.938496, 5.556102
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 14 -W 14 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 14 -W 14 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.298707 ms
stats: bnormf, 0, 51380736, 51380736, 0, 516.031456, 0.298707
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 14 -W 14 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 14 -W 14 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 51380736, 51380736, 0, 407.370848, 0.378383
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.378383 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.

@junliume
Copy link
Collaborator

Hi @muralinr, have you used or examed @mindest's reproduce code attached above?

@muralinr
Copy link
Contributor

Hi @muralinr, have you used or examed @mindest's reproduce code attached above?

yes Jun. I am using same N and C in NCHW mentioned in code. However, H and W is not clear to me and hence I asked him if he can dump NCHW config info.

@junliume
Copy link
Collaborator

MIOpenDriver might have different tolerance level, so it's better to run at application level provided by the user.
@mindest could you share more details? email communication is also fine, my contact is listed under my profile :)

@mindest
Copy link
Author

mindest commented Aug 23, 2021

@muralinr @junliume Thanks for the quick response! In the repro example code, the input dimension is (N, C, 1) for spatial mode. I think I also tried similar checks (only forward, was not clear how to trigger a backward check), which did pass.
@muralinr Have you tried the python code directly? With input of dimension (N, C, 1), should both modes (spatial/per_activation) have the same results?

@muralinr
Copy link
Contributor

I tried NCHW (1024,64,1,1) and all forward and backward passes tests are passed for spatial and per_activation modes.

mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.023998 ms
stats: bnormf, 0, 262656, 262656, 0, 32.834736, 0.023998
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 262656, 262656, 0, 25.000570, 0.031518
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.031518 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.018079 ms
stats: bnormf, 0, 131328, 131328, 0, 21.792356, 0.018079
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 131328, 131328, 0, 12.374259, 0.031839
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.031839 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.

mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.183192 ms
stats: bnormf, 0, 262656, 262656, 0, 4.301323, 0.183192
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 262656, 262656, 0, 1.077687, 0.731166
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.731166 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:/miopen/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.220790 ms
stats: bnormf, 0, 131328, 131328, 0, 1.784429, 0.220790
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
mnandhim@x1001c4s1b1n0:
/miopen/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 131328, 131328, 0, 0.507316, 0.776605
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.776605 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.

@mindest
Copy link
Author

mindest commented Aug 23, 2021

@muralinr Thanks I will also try a backward check later on the machine. Can you try the above python code I provided? If it gives the correct answer on your side, I think the reason would be me using the kernel in a wrong way or using the outdated version.

@muralinr
Copy link
Contributor

I used same configs mentioned in python code. I will ask our customer support team to run this python model.

@mindest
Copy link
Author

mindest commented Aug 25, 2021

I used same configs mentioned in python code. I will ask our customer support team to run this python model.

Thanks @muralinr.

Tried the above forward/backward checks:

./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
./bin/MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0

They all passed.

check log
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.015520 ms
stats: bnormf, 0, 262656, 262656, 0, 50.771136, 0.015520
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 262656, 262656, 0, 35.686956, 0.022080
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.022080 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.013440 ms
stats: bnormf, 0, 131328, 131328, 0, 29.314286, 0.013440
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 1 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 131328, 131328, 0, 17.973724, 0.021920
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.021920 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.168320 ms
stats: bnormf, 0, 262656, 262656, 0, 4.681369, 0.168320
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnorm -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 262656, 262656, 0, 1.101252, 0.715520
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.715520 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 1 -b 0 -i 1 -t 1 -w 0
GPU Kernel Min Time Forward Batch Normalization Elapsed: 0.185920 ms
stats: bnormf, 0, 131328, 131328, 0, 2.119105, 0.185920
Forward train batch norm verification passed on saved mean
Forward train batch norm verification passed on saved inverse variance.
Forward batch norm verification passed on output
Forward Batch Norm Verifies on CPU and GPU.
aiscuser@node-0:~/MIOpen/build$ ./bin/MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
MIOpenDriver bnormfp16 -m 0 -s 1 -n 1024 -c 64 -H 1 -W 1 -V 1 -F 0 -b 1 -i 1 -t 1 -w 0
stats: bnormb, 0, 131328, 131328, 0, 0.485585, 0.811359
GPU Kernel Min Time Backwards Batch Normalization Elapsed: 0.811359 ms
Backwards prop batch norm verification passed on dx.
Backwards prop batch norm verification passed on dscale.
Backwards prop batch norm verification passed on dbias.
Backwards Prop Batch Norm Verifies on CPU and GPU.

@muralinr
Copy link
Contributor

Thank you @mindest . We are able to run model and Both forward and backward tests are passed with these NCHW (1024, 64,1,1) tensor parameters. I did experiments on this spatial mode failures. Batchnorm spatial mode fails with batch size (N > 768) for this model. I will review code and work with teams to resolve this issue.

@mindest
Copy link
Author

mindest commented Aug 25, 2021

Thanks @muralinr. The number 768 is consistent with my previous findings -- Tried batch size 512 (fine) and further used binary search to find that batch size up to 768 seemed to work fine. Dumped some intermediate gradient values, doubted it might be related to tensor stride or something (just an assumption).
Also, NCHW (1024, 1, 1, 1) also works fine, while cases with C starting from 2 would fail. Hope these help.

@junliume
Copy link
Collaborator

junliume commented Sep 9, 2021

@mindest has the issue been resolved? If so please mark this issue as closed. Thanks!

@mindest
Copy link
Author

mindest commented Sep 9, 2021

@junliume I thought Murali was working on the fix (spatial mode fails in some cases)? For my case I just had a workaround with per_activation mode (HW are both 1).

@muralinr
Copy link
Contributor

muralinr commented Sep 9, 2021

@mindest I released fix for this issue. Can you please test and mark it closed? Here is PR. #1109

@mindest
Copy link
Author

mindest commented Sep 9, 2021

@muralinr Thanks! Did you mean this PR? Closing this.

@mindest mindest closed this as completed Sep 9, 2021
@muralinr
Copy link
Contributor

muralinr commented Sep 9, 2021

@muralinr Thanks! Did you mean this PR? Closing this.

yes @mindest

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants