Minor improvements in diamonds.py example
[linpy.git] / pypol / polyhedra.py
index a5048d1..aabe0fd 100644 (file)
@@ -291,18 +291,20 @@ class Polyhedron(Domain):
         return faces
 
     def _plot_2d(self, plot=None, **kwargs):
-        from matplotlib import pylab
         import matplotlib.pyplot as plt
-        from matplotlib.axes import Axes
         from matplotlib.patches import Polygon
         vertices = self._sort_polygon_2d(self.vertices())
         xys = [tuple(vertex.values()) for vertex in vertices]
         if plot is None:
             fig = plt.figure()
             plot = fig.add_subplot(1, 1, 1)
-            xs, ys = zip(*xys)
-            plot.set_xlim(float(min(xs)), float(max(xs)))
-            plot.set_ylim(float(min(ys)), float(max(ys)))
+        xmin, xmax = plot.get_xlim()
+        ymin, ymax = plot.get_xlim()
+        xs, ys = zip(*xys)
+        xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
+        ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
+        plot.set_xlim(xmin, xmax)
+        plot.set_ylim(ymin, ymax)
         plot.add_patch(Polygon(xys, closed=True, **kwargs))
         return plot
 
@@ -313,11 +315,11 @@ class Polyhedron(Domain):
         if plot is None:
             fig = plt.figure()
             axes = Axes3D(fig)
-            xmin, xmax = float('inf'), float('-inf')
-            ymin, ymax = float('inf'), float('-inf')
-            zmin, zmax = float('inf'), float('-inf')
         else:
             axes = plot
+        xmin, xmax = axes.get_xlim()
+        ymin, ymax = axes.get_xlim()
+        zmin, zmax = axes.get_xlim()
         poly_xyzs = []
         for vertices in self.faces():
             if len(vertices) == 0:
@@ -325,18 +327,16 @@ class Polyhedron(Domain):
             vertices = Polyhedron._sort_polygon_3d(vertices)
             vertices.append(vertices[0])
             face_xyzs = [tuple(vertex.values()) for vertex in vertices]
-            if plot is None:
-                xs, ys, zs = zip(*face_xyzs)
-                xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
-                ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
-                zmin, zmax = min(zmin, float(min(zs))), max(zmax, float(max(zs)))
+            xs, ys, zs = zip(*face_xyzs)
+            xmin, xmax = min(xmin, float(min(xs))), max(xmax, float(max(xs)))
+            ymin, ymax = min(ymin, float(min(ys))), max(ymax, float(max(ys)))
+            zmin, zmax = min(zmin, float(min(zs))), max(zmax, float(max(zs)))
             poly_xyzs.append(face_xyzs)
         collection = Poly3DCollection(poly_xyzs, **kwargs)
         axes.add_collection3d(collection)
-        if plot is None:
-            axes.set_xlim(xmin, xmax)
-            axes.set_ylim(ymin, ymax)
-            axes.set_zlim(zmin, zmax)
+        axes.set_xlim(xmin, xmax)
+        axes.set_ylim(ymin, ymax)
+        axes.set_zlim(zmin, zmax)
         return axes
 
     def plot(self, plot=None, **kwargs):