Understand torch.scatter_()

Yu Yang
5 min readJul 15, 2020

1. Official Documentation

First, note that scatter_() is an inplace function, meaning that it will change the value of input tensor.

The official document scatter_(dim, index, src) → Tensor tells us that parameters include the dim, index tensor, and the source tensor. dim specifies where the index tensor is functioning, and we will keep the other dimensions unchanged. And as the function name suggests, the goal is to scatter values in the source tensor to the input tensor self. What we are going to do is to loop through the values in the source tensor, find its position in the input tensor, and replace the old one.

Note that src can also just be a scalar. In this case, we would just scatter this single value according to the index tensor.

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

2. Graphical Diagram for dim=0

For simplicity, let us consider two-dimensional matrices here. Let us first understand dim.

When dim=0, the index of rows will be based on the index tensor, and the index of columns will not change, and this means the jth column of the source tensor will only be scattered to the jth column of the input tensor. Let us try to manually update the input tensor step by step using the following example.

import torch
import numpy as np
src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
print(src)
> tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.]])
input_tensor = torch.zeros(3, 5)
print(input_tensor)
> tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
index_tensor = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index_tensor)
> tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
## try to manually work out the result
dim = 0
print(input_tensor.scatter_(dim, index_tensor, src))
> ...

Step 1: scatter the 1st column of src to the 1st column of input_tensor. Matching with the 1st column of index tensor. We would scatter 1 to row 0, scatter 6 to row 2.

Step 2: scatter the 2nd column of src to the 2nd column of input_tensor. Matching with the 2nd column of index tensor. We would scatter 2 to row 1, scatter 7 to row 0.

Step 3/4/5: do scattering in a similar way. In the end, we would get the following diagram.

Check it in python. Well done!

> tensor([[ 1.,  7.,  8.,  4.,  5.],
[ 0., 2., 0., 9., 0.],
[ 6., 0., 3., 0., 10.]])

Note that the values in the index tensor represent the row indices when dim=0, so it implicitly suggests that the max value of the index tensor should be smaller the number of rows in the input. Generally speaking, the following should be True.

input_tensor.shape[dim] > index_tensor.max().item()

3. Graphical Diagram for dim = 1

Similarly, we can work it out when dim=1. Let us try the following example.

src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
input_tensor = torch.zeros(3, 5)
index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]])
dim = 1
print(input_tensor.scatter_(dim, index_tensor, src))

Step 1: scatter the 1st row of src to the 1st row of input_tensor. 1 to col3, 2 to col0, 3 to col2, 4 to col1, 5 to col4.

Step 2: scatter the 2nd row of src to the 2nd row of input_tensor.

Note that there are two 1’s in the 2nd row of index_tensor. To make the updation clearer, I would split this step into two substeps.

Step 2.1: scatter 6 to col2, 7 to col0, 8 to col1, 9 to col3.

Step 2.2: scatter 10. The corresponding index is 1, but 8 has already been there. What we would do is to replace 8 with 10.

Done! Let’s check the results. Correct! 😄

> tensor([[ 2.,  4.,  3.,  1.,  5.],
[ 7., 10., 6., 9., 0.],
[ 0., 0., 0., 0., 0.]])

4. Graphical Diagram for a Trickier Example

Finally, let’s try a trickier example where the src is a value and the size of the index tensor is smaller than the input tensor for dim != dim.

Note that the dimension of the input tensor and the index tensor should always be the same, and this is why you may sometimes see unsqueeze() in others’ code. Also, note that the index tensor and the input tensor should have the same size on the specified dim.

input_tensor = torch.from_numpy(np.arange(1, 16)).float().view(3, 5) # dim is 2# unsqueeze to have dim = 2
index_tensor = torch.tensor([4, 0, 1]).unsqueeze(1)
src = 0
dim = 1
print(input_tensor.scatter_(dim, index_tensor, src))

Note that when src is a scalar, we are actually using the broadcasted version which has the same size as the index tensor.

dim = 1, so we do scattering row by row. For row1, we would scatter 0 to col4; for row2, we would scatter 0 to col0; for row3, we scatter 0 to col1.

Checking the result — great job! 🌟

> tensor([[ 1.,  2.,  3.,  4.,  0.],
[ 0., 7., 8., 9., 10.],
[11., 0., 13., 14., 15.]])

Hope this tutorial will help you better understand torch.scatter_()!

References:

[1] https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_

--

--