Skip to content

Commit b879047

Browse files
bottlerfacebook-github-bot
authored andcommitted
work with old linalg
Summary: solve and lstsq have moved around in torch. Cope with both. Reviewed By: patricklabatut Differential Revision: D29302316 fbshipit-source-id: b34f0b923e90a357f20df359635929241eba6e74
1 parent 5284de6 commit b879047

File tree

7 files changed

+65
-14
lines changed

7 files changed

+65
-14
lines changed

pytorch3d/common/compat.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
12+
"""
13+
Some functions which depend on PyTorch versions.
14+
"""
15+
16+
17+
def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
18+
"""
19+
Like torch.linalg.solve, tries to return X
20+
such that AX=B, with A square.
21+
"""
22+
if hasattr(torch.linalg, "solve"):
23+
# PyTorch version >= 1.8.0
24+
return torch.linalg.solve(A, B)
25+
26+
return torch.solve(B, A).solution
27+
28+
29+
def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
30+
"""
31+
Like torch.linalg.lstsq, tries to return X
32+
such that AX=B.
33+
"""
34+
if hasattr(torch.linalg, "lstsq"):
35+
# PyTorch version >= 1.9
36+
return torch.linalg.lstsq(A, B).solution
37+
38+
solution = torch.lstsq(B, A).solution
39+
if A.shape[1] < A.shape[0]:
40+
return solution[: A.shape[1]]
41+
return solution
42+
43+
44+
def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover
45+
"""
46+
Like torch.linalg.qr.
47+
"""
48+
if hasattr(torch.linalg, "qr"):
49+
# PyTorch version >= 1.9
50+
return torch.linalg.qr(A)
51+
return torch.qr(A)

pytorch3d/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
)
2727
from .se3 import se3_exp_map, se3_log_map
2828
from .so3 import (
29-
so3_exponential_map,
3029
so3_exp_map,
30+
so3_exponential_map,
3131
so3_log_map,
3232
so3_relative_angle,
3333
so3_rotation_angle,

pytorch3d/transforms/se3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8+
from pytorch3d.common.compat import solve
89

9-
from .so3 import hat, _so3_exp_map, so3_log_map
10+
from .so3 import _so3_exp_map, hat, so3_log_map
1011

1112

1213
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
@@ -173,7 +174,7 @@ def se3_log_map(
173174
# log_translation is V^-1 @ T
174175
T = transform[:, 3, :3]
175176
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
176-
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
177+
log_translation = solve(V, T[:, :, None])[:, :, 0]
177178

178179
return torch.cat((log_translation, log_rotation), dim=1)
179180

pytorch3d/transforms/so3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ..transforms import acos_linear_extrapolation
1313

14+
1415
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
1516

1617

tests/test_acos_linear_extrapolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import torch
1212
from common_testing import TestCaseMixin
13+
from pytorch3d.common.compat import lstsq
1314
from pytorch3d.transforms import acos_linear_extrapolation
1415

1516

@@ -64,8 +65,7 @@ def _test_acos_outside_bounds(self, x, y, dydx, bound):
6465
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
6566
# fit a line: slope * x + bias = y
6667
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
67-
solution = torch.linalg.lstsq(x_1, y[:, None]).solution
68-
slope, bias = solution.view(-1)[:2]
68+
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
6969
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2)
7070
# test that the desired slope is the same as the fitted one
7171
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)

tests/test_se3.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,10 @@
1010
import numpy as np
1111
import torch
1212
from common_testing import TestCaseMixin
13+
from pytorch3d.common.compat import qr
1314
from pytorch3d.transforms.rotation_conversions import random_rotations
1415
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
15-
from pytorch3d.transforms.so3 import (
16-
so3_exp_map,
17-
so3_log_map,
18-
so3_rotation_angle,
19-
)
16+
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
2017

2118

2219
class TestSE3(TestCaseMixin, unittest.TestCase):
@@ -201,7 +198,7 @@ def test_se3_log_singularity(self, batch_size: int = 100):
201198
r = [identity, rot180]
202199
r.extend(
203200
[
204-
torch.qr(identity + torch.randn_like(identity) * 1e-6)[0]
201+
qr(identity + torch.randn_like(identity) * 1e-6)[0]
205202
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
206203
# this adds random noise to the second half
207204
# of the random orthogonal matrices to generate

tests/test_so3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import torch
1313
from common_testing import TestCaseMixin
14+
from pytorch3d.common.compat import qr
1415
from pytorch3d.transforms.so3 import (
1516
hat,
1617
so3_exp_map,
@@ -46,7 +47,7 @@ def init_rot(batch_size: int = 10):
4647
# TODO(dnovotny): replace with random_rotation from random_rotation.py
4748
rot = []
4849
for _ in range(batch_size):
49-
r = torch.qr(torch.randn((3, 3), device=device))[0]
50+
r = qr(torch.randn((3, 3), device=device))[0]
5051
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
5152
if f.sum() % 2 == 0:
5253
f = 1 - f
@@ -142,7 +143,7 @@ def test_so3_log_singularity(self, batch_size: int = 100):
142143
# add random rotations and random almost orthonormal matrices
143144
r.extend(
144145
[
145-
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
146+
qr(identity + torch.randn_like(identity) * 1e-4)[0]
146147
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
147148
# this adds random noise to the second half
148149
# of the random orthogonal matrices to generate
@@ -242,7 +243,7 @@ def test_so3_cos_bound(self, batch_size: int = 100):
242243
r = [identity, rot180]
243244
r.extend(
244245
[
245-
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
246+
qr(identity + torch.randn_like(identity) * 1e-4)[0]
246247
for _ in range(batch_size - 2)
247248
]
248249
)

0 commit comments

Comments
 (0)