@@ -1465,11 +1465,10 @@ def _find_intermediate_color(lowcolor, highcolor, intermed):
14651465 diff_1 = float (highcolor [1 ] - lowcolor [1 ])
14661466 diff_2 = float (highcolor [2 ] - lowcolor [2 ])
14671467
1468- new_tuple = (lowcolor [0 ] + intermed * diff_0 ,
1469- lowcolor [1 ] + intermed * diff_1 ,
1470- lowcolor [2 ] + intermed * diff_2 )
1471-
1472- return new_tuple
1468+ inter_colors = np .array ([lowcolor [0 ] + intermed * diff_0 ,
1469+ lowcolor [1 ] + intermed * diff_1 ,
1470+ lowcolor [2 ] + intermed * diff_2 ])
1471+ return inter_colors
14731472
14741473 @staticmethod
14751474 def _unconvert_from_RGB_255 (colors ):
@@ -1501,7 +1500,7 @@ def _unconvert_from_RGB_255(colors):
15011500 return un_rgb_colors
15021501
15031502 @staticmethod
1504- def _map_z2color (zval , colormap , vmin , vmax ):
1503+ def _map_z2color (zvals , colormap , vmin , vmax ):
15051504 """
15061505 Returns the color corresponding zval's place between vmin and vmax
15071506
@@ -1518,42 +1517,14 @@ def _map_z2color(zval, colormap, vmin, vmax):
15181517 "of vmax." )
15191518 # find distance t of zval from vmin to vmax where the distance
15201519 # is normalized to be between 0 and 1
1521- t = (zval - vmin )/ float ((vmax - vmin ))
1522-
1523- # for colormaps of more than 2 colors, find two closest colors based
1524- # on relative position between vmin and vmax
1525- if len (colormap ) == 1 :
1526- t_color = colormap [0 ]
1527- else :
1528- num_steps = len (colormap ) - 1
1529- step = 1. / num_steps
1530-
1531- if t == 1.0 :
1532- t_color = FigureFactory ._find_intermediate_color (
1533- colormap [int (t / step ) - 1 ],
1534- colormap [int (t / step )],
1535- t
1536- )
1537- else :
1538- new_t = (t - int (t / step )* step )/ float (step )
1539-
1540- t_color = FigureFactory ._find_intermediate_color (
1541- colormap [int (t / step )],
1542- colormap [int (t / step ) + 1 ],
1543- new_t
1544- )
1545-
1546- t_color = (t_color [0 ]* 255.0 , t_color [1 ]* 255.0 , t_color [2 ]* 255.0 )
1547- labelled_color = 'rgb{}' .format (t_color )
1548-
1549- return labelled_color
1550-
1551- @staticmethod
1552- def _tri_indices (simplices ):
1553- """
1554- Returns a triplet of lists containing simplex coordinates
1555- """
1556- return ([triplet [c ] for triplet in simplices ] for c in range (3 ))
1520+ t = (zvals - vmin ) / float ((vmax - vmin ))
1521+ t_colors = FigureFactory ._find_intermediate_color (colormap [0 ],
1522+ colormap [1 ],
1523+ t )
1524+ t_colors = t_colors * 255.
1525+ labelled_colors = ['rgb(%s, %s, %s)' % (i , j , k )
1526+ for i , j , k in t_colors .T ]
1527+ return labelled_colors
15571528
15581529 @staticmethod
15591530 def _trisurf (x , y , z , simplices , colormap = None , color_func = None ,
@@ -1570,11 +1541,11 @@ def _trisurf(x, y, z, simplices, colormap=None, color_func=None,
15701541 points3D = np .vstack ((x , y , z )).T
15711542
15721543 # vertices of the surface triangles
1573- tri_vertices = list ( map ( lambda index : points3D [index ], simplices ))
1544+ tri_vertices = points3D [simplices ]
15741545
15751546 if not color_func :
15761547 # mean values of z-coordinates of triangle vertices
1577- mean_dists = [ np . mean ( tri [ :, 2 ]) for tri in tri_vertices ]
1548+ mean_dists = tri_vertices [ :, :, 2 ]. mean ( - 1 )
15781549 else :
15791550 # apply user inputted function to calculate
15801551 # custom coloring for triangle vertices
@@ -1590,38 +1561,47 @@ def _trisurf(x, y, z, simplices, colormap=None, color_func=None,
15901561
15911562 min_mean_dists = np .min (mean_dists )
15921563 max_mean_dists = np .max (mean_dists )
1593- facecolor = ([ FigureFactory ._map_z2color (zz , colormap , min_mean_dists ,
1594- max_mean_dists ) for zz in mean_dists ] )
1595- ii , jj , kk = FigureFactory . _tri_indices ( simplices )
1564+ facecolor = FigureFactory ._map_z2color (mean_dists , colormap ,
1565+ min_mean_dists , max_mean_dists )
1566+ ii , jj , kk = zip ( * simplices )
15961567
15971568 triangles = graph_objs .Mesh3d (x = x , y = y , z = z , facecolor = facecolor ,
15981569 i = ii , j = jj , k = kk , name = '' )
15991570
1600- if plot_edges is None : # the triangle sides are not plotted
1571+ if plot_edges is not True : # the triangle sides are not plotted
16011572 return graph_objs .Data ([triangles ])
16021573
16031574 # define the lists x_edge, y_edge and z_edge, of x, y, resp z
16041575 # coordinates of edge end points for each triangle
16051576 # None separates data corresponding to two consecutive triangles
1606- lists_coord = ([[[T [k % 3 ][c ] for k in range (4 )]+ [None ]
1607- for T in tri_vertices ] for c in range (3 )])
1608- if x_edge is None :
1609- x_edge = []
1610- for array in lists_coord [0 ]:
1611- for item in array :
1612- x_edge .append (item )
1613-
1614- if y_edge is None :
1615- y_edge = []
1616- for array in lists_coord [1 ]:
1617- for item in array :
1618- y_edge .append (item )
1619-
1620- if z_edge is None :
1621- z_edge = []
1622- for array in lists_coord [2 ]:
1623- for item in array :
1624- z_edge .append (item )
1577+ is_none = [ii is None for ii in [x_edge , y_edge , z_edge ]]
1578+ if any (is_none ):
1579+ if not all (is_none ):
1580+ raise ValueError ("If any (x_edge, y_edge, z_edge) is None, "
1581+ "all must be None" )
1582+ else :
1583+ x_edge = []
1584+ y_edge = []
1585+ z_edge = []
1586+
1587+ # Pull indices we care about, then add a None column to separate tris
1588+ ixs_triangles = [0 , 1 , 2 , 0 ]
1589+ pull_edges = tri_vertices [:, ixs_triangles , :]
1590+ x_edge_pull = np .hstack ([pull_edges [:, :, 0 ],
1591+ np .tile (None , [pull_edges .shape [0 ], 1 ])])
1592+ y_edge_pull = np .hstack ([pull_edges [:, :, 1 ],
1593+ np .tile (None , [pull_edges .shape [0 ], 1 ])])
1594+ z_edge_pull = np .hstack ([pull_edges [:, :, 2 ],
1595+ np .tile (None , [pull_edges .shape [0 ], 1 ])])
1596+
1597+ # Now unravel the edges into a 1-d vector for plotting
1598+ x_edge = np .hstack ([x_edge , x_edge_pull .reshape ([1 , - 1 ])[0 ]])
1599+ y_edge = np .hstack ([y_edge , y_edge_pull .reshape ([1 , - 1 ])[0 ]])
1600+ z_edge = np .hstack ([z_edge , z_edge_pull .reshape ([1 , - 1 ])[0 ]])
1601+
1602+ if not (len (x_edge ) == len (y_edge ) == len (z_edge )):
1603+ raise exceptions .PlotlyError ("The lengths of x_edge, y_edge and "
1604+ "z_edge are not the same." )
16251605
16261606 # define the lines for plotting
16271607 lines = graph_objs .Scatter3d (
@@ -5865,4 +5845,3 @@ def make_table_annotations(self):
58655845 font = dict (color = font_color ),
58665846 showarrow = False ))
58675847 return annotations
5868-
0 commit comments