搜索

pdf文档 深度学习与PyTorch入门实战 - 14. Tensor高阶

501.85 KB 8 页 3 下载 108 浏览 0 评论 0 收藏
语言 格式 评分
中文(简体)
.pdf
3
摘要
文档主要介绍了PyTorch中两个高级张量操作函数:torch.where和torch.gather。torch.where根据条件选择性地从两个张量中选择元素,定义为:给定条件,若满足则选择x的元素,否则选择y的元素。torch.gather则用于沿着指定维度按索引gather值。文档通过示例展示了这两个函数的使用方法和应用场景,包括如何利用argmax和topk函数进行索引操作。
AI总结
### 文档总结:PyTorch 高阶张量操作 #### 1. `torch.where` 函数 - **功能**:根据条件 `condition` 从两个张量 `x` 和 `y` 中选择元素。 - **定义**: $$ out_{i}=\left\{\begin{aligned}&x_{i} & \text{if } condition_{i} \\ &y_{i} & \text{otherwise}\end{aligned}\right. $$ - **示例**: - 条件 `cond` 为一个 2x2 的张量。 - 当 `cond > 0.5` 时,选择 `a`,否则选择 `b`。 - 输出结果为一个 2x2 的张量,根据条件选择相应的元素。 #### 2. `torch.gather` 函数 - **功能**:沿着指定维度 `dim` 收集输入张量 `input` 中的元素。 - **定义**: - 根据索引 `index` 从 `input` 中选择元素。 - 输出结果的形状与 `index` 相同。 - **示例**: - 示例中使用 `topk` 函数获取预测结果的索引 `idx`。 - 使用 `torch.gather` 从扩展后的 `label` 张量中提取对应位置的值。 - 输出结果为一个 4x3 的张量,每个元素对应 `label` 中的索引位置。 #### 核心观点: - `torch.where` 和 `torch.gather` 是 PyTorch 中常用的高阶张量操作。 - `torch.where` 用于根据条件选择元素,`torch.gather` 用于沿着指定维度收集元素。 - 两个函数在深度学习模型中常用于处理预测结果和标签映射等任务。 #### 示例结果: - `torch.where` 示例输出: ```python tensor([[0., 0.], [0., 1.]]) ``` - `torch.gather` 示例输出: ```python tensor([[107, 104, 109], [107, 104, 109], [108, 101, 103], [108, 106, 100]]) ```
P1
P2
P3
P4
P5
P6
P7
下载文档到本地,方便使用
- 可预览页数已用完,剩余 1 页请下载阅读 -
文档评分
请文明评论,理性发言.