深度学习与PyTorch入门实战 - 14. Tensor高阶
501.85 KB
8 页
0 评论
语言 | 格式 | 评分 |
---|---|---|
中文(简体) | .pdf | 3 |
摘要 | ||
文档介绍了PyTorch框架中的Tensor高阶操作,详细讲解了torch.where和torch.gather的功能及其使用方法。通过示例代码展示了如何使用这些操作来实现特定的张量数据处理任务。torch.where根据条件选择元素值,而torch.gather则用于沿指定维度聚集张量中的元素。 | ||
AI总结 | ||
本文主要介绍了PyTorch中Tensor的高阶操作,重点围绕`torch.where`和`torch.gather`两种操作展开,以下是核心内容的总结:
### 1. `torch.where` 操作
- **功能**:根据条件选择元素,返回一个新的Tensor。其语法为 `torch.where(condition, x, y)`,其中:
- `condition` 是一个布尔型Tensor,表示选择的条件。
- `x` 和 `y` 是要选取的元素来源。
- 如果 `condition` 为True,选择 `x` 的对应元素;否则选择 `y` 的对应元素。
- **应用场景**:在生成模型中,条件选择可以用于处理标签生成问题,例如根据条件获取arginmax的相对标签。
### 2. `torch.gather` 操作
- **功能**:在指定的维度上,根据索引Tensor采集数据。其语法为 `torch.gather(input, dim, index)`,其中:
- `input` 是要采集的数据Tensor。
- `dim` 是采集的维度。
- `index` 是索引Tensor,表示采集的位置。
- **应用场景**:常用于根据特定索引提取Tensor中的特定元素。例如,在处理标签或概率分布时,可以通过`topk`获取索引,再用`gather`提取对应的数据。
### 示例代码与输出
- `prob = torch.randn(4, 10)`:生成一个形状为(4,10)的随机Tensor。
- `idx = prob.topk(dim=1, k=3)`:获取`prob`在维度1上的前3大值的索引。
- `label = torch.arange(10)`:生成一个从0到9的标签Tensor。
- `torch.gather(label.expand(4, 10), dim=1, index=idx.long())`:根据索引`idx`从`label`中提取对应的元素。
### 总结
`torch.where` 和 `torch.gather` 是PyTorch中常用的高阶Tensor操作,分别用于条件选择和索引采集。它们在深度学习中具有重要应用,特别是在处理复杂的数据操作和条件选择场景时,能够显著提升代码的简洁性和效率。通过示例代码可以清晰地看到它们的实际使用效果。 |
P1
P2
P3
P4
P5
P6
P7
下载文档到本地,方便使用
- 可预览页数已用完,剩余
1 页请下载阅读 -
文档评分