PyTorch的take_along_dim如何为?

摘要:接前面一篇take_along_axis的文章,本文主要介绍在PyTorch框架下,功能基本一样的函数take_along_dim。二者除了命名和一些关键词参数不一致之外,用法是一样的。需要注意的是,两者都要求输入的数组和索引数组维度数量一
技术背景 在此前的一篇博客中,我们介绍过take_along_axis这个算子的具体使用方法。这里针对于Pytorch的take_along_dim算子,再重新介绍一次。 Numpy版本使用 这里我们展示的案例是基于numpy-2.0.1版本实现的: $ python3 -m pip show numpy Name: numpy Version: 2.0.1 Summary: Fundamental package for array computing in Python Home-page: https://numpy.org Author: Travis E. Oliphant et al. 示例如下: In [1]: import numpy as np In [2]: a = np.arange(12).reshape((1,4,3)) In [3]: a Out[3]: array([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]]) In [4]: idx = np.array([1,2]) In [6]: b = np.take_along_axis(a, idx[None,:,None], axis=1) In [7]: b Out[7]: array([[[3, 4, 5], [6, 7, 8]]]) In [8]: b = np.take_along_axis(a, idx[None,None,:], axis=2) In [9]: b Out[9]: array([[[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]]]) 在这个基础示例中,我们分别展示了同一个索引矩阵,在不同的维度上进行索引的结果。使用take_along_axis有一个默认的要求:原始数组和索引数组的维度数量需要保持一致。但是因为这里的索引矩阵是一维的,那么我们只要用slice的方法对索引矩阵进行扩维就好了。例如,我们需要在第二个维度进行提取,那么就可以用arr[None,:,None]来进行扩维。 PyTorch版实现 这里我们使用的torch是2.5.1的稳定版: $ python3 -m pip show torch Name: torch Version: 2.5.1 Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration Home-page: https://pytorch.org/ Author: PyTorch Team Author-email: packages@pytorch.org License: BSD-3-Clause Location: /miniconda3/envs/pytorch/lib/python3.9/site-packages Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions Required-by: torchaudio, torchmetrics, torchvision 相关的API接口文档如下: 其实实现起来跟numpy的操作是非常类似的: In [1]: import torch as tc In [2]: a = tc.arange(12).reshape((1,4,3)) In [3]: idx = tc.tensor([1,2]) In [4]: b = tc.take_along_dim(a, idx[None,:,None], dim=1) In [5]: b Out[5]: tensor([[[3, 4, 5], [6, 7, 8]]]) In [6]: b = tc.take_along_dim(a, idx[None,None,:], dim=2) In [7]: b Out[7]: tensor([[[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]]]) 可以说是基本一致。那么同样的,也是要做一个扩维的处理。
阅读全文