#!/usr/bin/python2
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib import cm
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
vmin, vmax = -np.pi, np.pi
plot_args = {'cmap':'YlGnBu_r', 'linewidth': 0.4, 'antialiased': True}
cmap = cm.get_cmap(plot_args['cmap'])
def modpi(x, ref):
return (x + np.pi - ref) % (2*np.pi) - np.pi + ref
xa = np.linspace(-1, 1, 31)
xc = 0.5 * (xa[:-1] + xa[1:])
ya = np.linspace(-1, 1, 31)
yc = 0.5 * (ya[:-1] + ya[1:])
y, x = np.meshgrid(ya, xa)
z = np.arctan2(y, x)
zc = np.arctan2(*np.meshgrid(yc, xc))
colors = cmap((zc.flatten() - vmin) / (vmax - vmin))
verts = []
for ix in range(len(xa)-1):
for iy in range(len(ya)-1):
zm = zc[ix, iy]
polygon = [
(xa[ix], ya[iy], modpi(z[ix,iy], zm)),
(xa[ix], ya[iy+1], modpi(z[ix,iy+1], zm)),
(xa[ix+1], ya[iy+1], modpi(z[ix+1,iy+1], zm)),
(xa[ix+1], ya[iy], modpi(z[ix+1,iy], zm))]
for i, p in enumerate(polygon):
if (p[0], p[1]) == (0., 0.):
z1 = polygon[(i-1)%len(polygon)][2]
z2 = polygon[(i+1)%len(polygon)][2]
new_points = [(p[0], p[1], z1), (p[0], p[1], z2)]
polygon = polygon[:i] + new_points + polygon[i+1:]
verts.append(polygon)
ax.add_collection3d(Poly3DCollection(verts, facecolors=colors, **plot_args))
ax.view_init(azim=-70, elev=60)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(vmin, vmax)
plt.xticks([-1, -0.5, 0, 0.5, 1],
[r"$-1$", r"$-1/2$", r"$0$", r"$1/2$", r"$1$"])
plt.yticks([-1, -0.5, 0, 0.5, 1],
[r"$-1$", r"$-1/2$", r"$0$", r"$1/2$", r"$1$"])
ax.set_zticks([-np.pi, 0, np.pi])
ax.set_zticklabels([r"$-\pi$", r"$0$", r"$\pi$"])
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.set_xlabel(r"$x$", fontsize=18)
ax.set_ylabel(r"$y$", fontsize=18)
ax.set_zlabel(r"$\operatorname{atan2}(y, x)$", fontsize=18)
plt.savefig("atan2.svg", bbox_inches="tight", transparent=True)
plt.show()