import os
import sys
import copy
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
#===============================================================
#
#===============================================================
def getGridLines(dat):
    lonsize = float(east - west) / nx
    latsize = float(south - north) / ny

    # Get meridians
    lons, lats = [], []
    lats_north = np.linspace(north, south, ny+1)[:-1]
    for ix in range(nx-1):
        lats_inter = copy.copy(lats_north)
        lats_inter[dat[:,ix] == dat[:,ix+1]] = np.nan
        lats_this = np.r_[np.c_[lats_north, lats_inter]\
                          .reshape(-1), south]
        lons_this = np.ones((ny*2+1)) * (west + (ix+1)*lonsize)
        lons.append(np.r_[lons_this, np.nan])
        lats.append(np.r_[lats_this, np.nan])
    meridians = (np.array(lons).reshape(-1),
                 np.array(lats).reshape(-1))

    # Get parallels
    lons, lats = [], []
    lons_west = np.linspace(west, east, nx+1)[:-1]
    for iy in range(ny-1):
        lons_inter = copy.copy(lons_west)
        lons_inter[dat[iy,:] == dat[iy+1,:]] = np.nan
        lons_this = np.r_[np.c_[lons_west, lons_inter]\
                          .reshape(-1), east]
        lats_this = np.ones((nx*2+1)) * (north + (iy+1)*latsize)
        lons.append(np.r_[lons_this, np.nan])
        lats.append(np.r_[lats_this, np.nan])
    parallels = (np.array(lons).reshape(-1),
                 np.array(lats).reshape(-1))

    return meridians, parallels
#===============================================================
#
#===============================================================
def draw(meridians, parallels,
         llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat,
         lonint=5, latint=5, miss=None, figname=None):
    lonsize = float(east - west) / nx
    latsize = float(south - north) / ny
    x0 = int(max(np.floor((llcrnrlon-west)/lonsize), 0))
    x1 = int(min(np.ceil((urcrnrlon-west)/lonsize), nx))
    y0 = int(max(np.floor((urcrnrlat-north)/latsize), 0))
    y1 = int(min(np.ceil((llcrnrlat-north)/latsize), ny))
    west_ = west + lonsize*x0
    east_ = west + lonsize*x1
    north_ = north + latsize*y0
    south_ = north + latsize*y1

    if miss is None or np.isnan(miss):
        dat_plt = dat[y0:y1,x0:x1]%30
    else:
        dat_plt = np.ma.masked_where(dat[y0:y1,x0:x1]==miss,
                                     dat[y0:y1,x0:x1]%30)

    m = Basemap(llcrnrlon=west_, llcrnrlat=south_,
                urcrnrlon=east_, urcrnrlat=north_)
    m.drawmeridians(np.arange(-180,181,lonint), linewidth=0.3,
                    color='silver', labels=[0,0,0,1])
    m.drawparallels(np.arange(-90,91,latint), linewidth=0.3,
                    color='silver', labels=[1,0,0,0])
    m.imshow(dat_plt, origin='upper', interpolation='nearest',
             cmap=plt.cm.Accent, alpha=0.3)
    m.plot(meridians[0], meridians[1], linewidth=1, color='k')
    m.plot(parallels[0], parallels[1], linewidth=1, color='k')
    if figname is not None:
        plt.savefig(figname, bbox_inches='tight',
                    pad_inches=0.1, dpi=300)
    plt.show()
#===============================================================
#
#===============================================================
#---------------------------------------------------------------
# Prep. data.
#---------------------------------------------------------------
# Basin (Japan)
ny, nx = 3000, 3000
west, east, south, north = 120, 150, 20, 50
f = 'Japan.merit_1k.basin.bin'
dat = np.fromfile(f, dtype=np.int32).reshape(ny,nx)
dat[dat > 2000] = -9999

# Basin (global)
#ny, nx = 1800, 3600
#west, east, south, north = -180, 180, -90, 90
#f = 'FLOW_v396.glb_06min.basin.bin'
#dat = np.fromfile(f, dtype=np.int32).reshape(ny,nx)
#dat[dat > 5000] = -9999


# catmxy
#ny, nx = 300, 300
#west, east, south, north = 135, 140, 35, 40
#f = 'FLOW_v396.glb_06min.1min.catmxy.n35e135.bin'
#catmxy = np.fromfile(f, dtype=np.int16).reshape(2,ny,nx)
# Basin index
#dat = catmxy[0,:,:] + (catmxy[1,:,:]-1)*nx
# For plotting
#dat = dat%17 + dat*1e-4
#dat[catmxy[0,:,:] <= 0] = -9999
#del(catmxy)
#---------------------------------------------------------------
# Plot.
#---------------------------------------------------------------
meridians, parallels = getGridLines(dat)
draw(meridians, parallels,
     west, south, east, north, lonint=1, latint=1,
     miss=-9999, figname=None)