8
8
9
9
"""
10
10
11
-
11
+ import warnings
12
12
from collections import defaultdict
13
13
from functools import reduce
14
- import warnings
15
14
15
+ import matplotlib .pyplot as plt
16
16
import numpy as np
17
17
from scipy .sparse import SparseEfficiencyWarning , lil_matrix
18
18
from scipy .sparse .linalg import spsolve
19
- import matplotlib .pyplot as plt
20
-
21
19
22
20
warnings .simplefilter ("ignore" , SparseEfficiencyWarning )
23
21
warnings .filterwarnings ("ignore" , category = SparseEfficiencyWarning )
@@ -44,6 +42,7 @@ class _Chi2GradientHessian:
44
42
The contributions to the Hessian matrix
45
43
46
44
"""
45
+
47
46
def __init__ (self , dim ):
48
47
self .chi2 = 0.
49
48
self .dim = dim
@@ -59,7 +58,6 @@ def update(chi2_grad_hess, incoming):
59
58
chi2_grad_hess : _Chi2GradientHessian
60
59
The ``_Chi2GradientHessian`` that will be updated
61
60
incoming : tuple
62
- TODO
63
61
64
62
"""
65
63
chi2_grad_hess .chi2 += incoming [0 ]
@@ -100,6 +98,7 @@ class Graph(object):
100
98
A list of the vertices in the graph
101
99
102
100
"""
101
+
103
102
def __init__ (self , edges , vertices ):
104
103
# The vertices and edges lists
105
104
self ._edges = edges
@@ -117,14 +116,16 @@ def _link_edges(self):
117
116
118
117
"""
119
118
index_id_dict = {i : v .id for i , v in enumerate (self ._vertices )}
120
- id_index_dict = {v_id : v_index for v_index , v_id in index_id_dict .items ()}
119
+ id_index_dict = {v_id : v_index for v_index , v_id in
120
+ index_id_dict .items ()}
121
121
122
122
# Fill in the vertices' `index` attribute
123
123
for v in self ._vertices :
124
124
v .index = id_index_dict [v .id ]
125
125
126
126
for e in self ._edges :
127
- e .vertices = [self ._vertices [id_index_dict [v_id ]] for v_id in e .vertex_ids ]
127
+ e .vertices = [self ._vertices [id_index_dict [v_id ]] for v_id in
128
+ e .vertex_ids ]
128
129
129
130
def calc_chi2 (self ):
130
131
r"""Calculate the :math:`\chi^2` error for the ``Graph``.
@@ -144,22 +145,34 @@ def _calc_chi2_gradient_hessian(self):
144
145
"""
145
146
n = len (self ._vertices )
146
147
dim = len (self ._vertices [0 ].pose .to_compact ())
147
- chi2_gradient_hessian = reduce (_Chi2GradientHessian .update , (e .calc_chi2_gradient_hessian () for e in self ._edges ), _Chi2GradientHessian (dim ))
148
+ chi2_gradient_hessian = reduce (_Chi2GradientHessian .update ,
149
+ (e .calc_chi2_gradient_hessian ()
150
+ for e in self ._edges ),
151
+ _Chi2GradientHessian (dim ))
148
152
149
153
self ._chi2 = chi2_gradient_hessian .chi2
150
154
151
155
# Fill in the gradient vector
152
- self ._gradient = np .zeros (n * dim , dtype = np . float64 )
153
- for idx , contrib in chi2_gradient_hessian .gradient .items ():
154
- self ._gradient [idx * dim : (idx + 1 ) * dim ] += contrib
156
+ self ._gradient = np .zeros (n * dim , dtype = float )
157
+ for idx , cont in chi2_gradient_hessian .gradient .items ():
158
+ self ._gradient [idx * dim : (idx + 1 ) * dim ] += cont
155
159
156
160
# Fill in the Hessian matrix
157
- self ._hessian = lil_matrix ((n * dim , n * dim ), dtype = np .float64 )
158
- for (row_idx , col_idx ), contrib in chi2_gradient_hessian .hessian .items ():
159
- self ._hessian [row_idx * dim : (row_idx + 1 ) * dim , col_idx * dim : (col_idx + 1 ) * dim ] = contrib
161
+ self ._hessian = lil_matrix ((n * dim , n * dim ), dtype = float )
162
+ for (row_idx , col_idx ), cont in chi2_gradient_hessian .hessian .items ():
163
+ x_start = row_idx * dim
164
+ x_end = (row_idx + 1 ) * dim
165
+ y_start = col_idx * dim
166
+ y_end = (col_idx + 1 ) * dim
167
+ self ._hessian [x_start :x_end , y_start :y_end ] = cont
160
168
161
169
if row_idx != col_idx :
162
- self ._hessian [col_idx * dim : (col_idx + 1 ) * dim , row_idx * dim : (row_idx + 1 ) * dim ] = np .transpose (contrib )
170
+ x_start = col_idx * dim
171
+ x_end = (col_idx + 1 ) * dim
172
+ y_start = row_idx * dim
173
+ y_end = (row_idx + 1 ) * dim
174
+ self ._hessian [x_start :x_end , y_start :y_end ] = \
175
+ np .transpose (cont )
163
176
164
177
def optimize (self , tol = 1e-4 , max_iter = 20 , fix_first_pose = True ):
165
178
r"""Optimize the :math:`\chi^2` error for the ``Graph``.
@@ -189,8 +202,10 @@ def optimize(self, tol=1e-4, max_iter=20, fix_first_pose=True):
189
202
190
203
# Check for convergence (from the previous iteration); this avoids having to calculate chi^2 twice
191
204
if i > 0 :
192
- rel_diff = (chi2_prev - self ._chi2 ) / (chi2_prev + np .finfo (float ).eps )
193
- print ("{:9d} {:20.4f} {:18.6f}" .format (i , self ._chi2 , - rel_diff ))
205
+ rel_diff = (chi2_prev - self ._chi2 ) / (
206
+ chi2_prev + np .finfo (float ).eps )
207
+ print (
208
+ "{:9d} {:20.4f} {:18.6f}" .format (i , self ._chi2 , - rel_diff ))
194
209
if self ._chi2 < chi2_prev and rel_diff < tol :
195
210
return
196
211
else :
@@ -207,7 +222,7 @@ def optimize(self, tol=1e-4, max_iter=20, fix_first_pose=True):
207
222
self ._gradient [:dim ] = 0.
208
223
209
224
# Solve for the updates
210
- dx = spsolve (self ._hessian , - self ._gradient ) # pylint: disable=invalid-unary-operand-type
225
+ dx = spsolve (self ._hessian , - self ._gradient )
211
226
212
227
# Apply the updates
213
228
for v , dx_i in zip (self ._vertices , np .split (dx , n )):
@@ -216,7 +231,8 @@ def optimize(self, tol=1e-4, max_iter=20, fix_first_pose=True):
216
231
# If we reached the maximum number of iterations, print out the final iteration's results
217
232
self .calc_chi2 ()
218
233
rel_diff = (chi2_prev - self ._chi2 ) / (chi2_prev + np .finfo (float ).eps )
219
- print ("{:9d} {:20.4f} {:18.6f}" .format (max_iter , self ._chi2 , - rel_diff ))
234
+ print ("{:9d} {:20.4f} {:18.6f}" .format (
235
+ max_iter , self ._chi2 , - rel_diff ))
220
236
221
237
def to_g2o (self , outfile ):
222
238
"""Save the graph in .g2o format.
@@ -234,7 +250,8 @@ def to_g2o(self, outfile):
234
250
for e in self ._edges :
235
251
f .write (e .to_g2o ())
236
252
237
- def plot (self , vertex_color = 'r' , vertex_marker = 'o' , vertex_markersize = 3 , edge_color = 'b' , title = None ):
253
+ def plot (self , vertex_color = 'r' , vertex_marker = 'o' , vertex_markersize = 3 ,
254
+ edge_color = 'b' , title = None ):
238
255
"""Plot the graph.
239
256
240
257
Parameters
0 commit comments