import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
#===============================================================
#     Convert model coordinates ($lon, $lat) to geographical
# coordinates ($Lon, $Lat).
#     ($plon, $plat) is geographical coordinate of the north
# pole of the model.
#===============================================================
def convert(lon, lat, plon, plat):
    d2r = np.pi / 180.0

    zpole = np.tan((90.0-plat)*d2r*0.5) * np.exp(1j*(plon*d2r))

    zq = np.tan((90.0-lat)*d2r*0.5) * np.exp(1j*(lon*d2r))
    zQ = zpole * (1.0-zq)/(1.0+zq)

    Lon = np.angle(zQ) / d2r
    if type(Lon) is float:
        if Lon < 0.0:
            Lon += 360.0
    else:
        Lon[Lon < 0.0] += 360.0
    Lat = 90.0 - 2.0 * np.arctan(abs(zQ)) / d2r

    return Lon, Lat
#===============================================================
#     Get grid lines of tripolar projection.
#     Two moved poles must have same latitudes and symmetrical
# with respect to the line of z=0.
#     $lonsize is the grid size in longitude.
#     $lats_south is latitudes of parallels in the southern
# area of the poles. $lats_south[0], [-1] must be equal to -90.0
# and $plat, respectively.
#     $ lons_north is longitudes in the model coordiante in the
# northern area of the poles.
#===============================================================
def getGridLines_tripolar(
    plon, plat, lonsize, lats_south, lons_north, smoothness=100):

    sm = int(smoothness + 1)

    Lats, Lons = [], []

    # Meridians (north)
    lat = np.linspace(-90.0, 90.0, sm)
    for lon in lons_north:
        Lon, Lat = convert(np.ones(lat.shape)*lon, lat, plon, plat)
        Lons.append(Lon)
        Lats.append(Lat)

    # Parallels (north)
    lon = np.linspace(-90.0, 90.0, sm)
    for lat in np.arange(-90.0, 90.0+lonsize*0.1, lonsize):
        Lon, Lat = convert(lon, np.ones(lon.shape)*lat, plon, plat)
        Lons.append(Lon)
        Lats.append(Lat)

    # Meridians (south)
    for lon in np.arange(0.0, 360.0, lonsize):
        lat = np.linspace(lats_south[0], lats_south[-1], sm)
        lon = np.ones(lat.shape) * lon
        Lons.append(lon)
        Lats.append(lat)

    # Parallels (south)
    for lat in lats_south[1:-1]:
        lon = np.linspace(0.0, 360.0, sm)
        lat = np.ones(lon.shape) * lat
        Lons.append(lon)
        Lats.append(lat)

    return Lons, Lats
#===============================================================
#
#===============================================================
def sample_tripolar_ortho():
    plon, plat = 60.0, 63.333700283430055
    lonsize = 3.0
    lats_south = np.linspace(-90.0, plat, 91)
    lons_north = np.linspace(-90.0, 90.0, 46)
    lons, lats = getGridLines_tripolar(
                     plon, plat, lonsize, lats_south, lons_north, 100)

    m = Basemap(projection='ortho', lon_0=120.0, lat_0=40.0,
                resolution='i')
    m.drawmapboundary(linewidth=0.4, color='k')
    m.fillcontinents(color='silver', lake_color='silver')
    for lon, lat in zip(lons, lats):
        x, y = m(lon, lat)
        x[x > 1e20] = np.nan
        y[y > 1e20] = np.nan
        m.plot(x, y, linewidth=1, color=linecolor)
    plt.show()
#===============================================================
#
#===============================================================
linecolor = '#636363'

sample_tripolar_ortho()