"""
based on:
- http://www.leovanrijn-sediment.com/papers/Formulaesandtransport.pdf
- van Rijn, 1993

van Rijn Ch.2:
$$u=\frac{u_{*,c}}{\kappa}ln\left(\frac{z}{z_0}\right)$$
$$\tau_{b, c} = \rho g \frac{u_{avg}^2}{C^2}$$
$$C = 18 ^{10} log\left(\frac{12h}{k_s}\right)$$

Dijk, 2005: $k_s=2.5D_{50}$
VanRijn Ch.4 p21: $k_s=2D_{50}=D_{90}$
"""

from __future__ import print_function, division
import numpy as np
from .linwaves import tau_waves



def D_nondim(d50, s=2.6, v_kin=1.3e-6):
    return d50 * ((s-1)*9.81/v_kin**2)**(1/3)


def shields_critical(d50, **kwargs):
    D = D_nondim(d50, **kwargs)
    return _shields_critical(D)


def _shields_critical(D):
    """soulsby 1997"""
    return 0.3 / (1+1.2*D) + 0.055 * (1-np.exp(-0.02*D))


def shields_critical_susp(d50, **kwargs):
    D = D_nondim(d50, **kwargs)
    return 0.3 / (1+D) + 0.1 * (1-np.exp(-0.05*D))


def u_critical(h, d50, s=2.6, susp=False, **kwargs):
    if susp:
        Sh = shields_critical_susp(d50, s=s, **kwargs)
    else:
        Sh = shields_critical(d50, s=s, **kwargs)
    return 5.75*(np.log10(2*h/d50))*(Sh*(s-1)*9.81*d50)**.5


def shearstress(u, h, d50):
    """van Dijk and Kleinhans, 2005"""
    return 1020*9.81*(u/(18*np.log10(12*h/(2.5*d50))))**2


def shields(u, h, d50, rho_sed=2650, rho_wat=1020):
    return shearstress(u, h, d50) / ((rho_sed - rho_wat)*9.81*d50)


def fall_velocity(d50, v_kin=1.3e-6, **kwargs):
    D = D_nondim(d50, v_kin=v_kin, **kwargs)
    return (10*v_kin/d50)*(np.sqrt(1+.01*D**3) - 1)**2


def Rouse(tau_tide, tau_waves, d50, rho_wat=1020, **kwargs):
    return fall_velocity(d50, **kwargs) / (0.41*np.sqrt((tau_tide+tau_waves)/rho_wat))


def Rouse2(u, hsig, h, d50, T=5):
    tauc = shearstress(u, h, d50)
    tauw = tau_waves(hsig, T, h, d50)
    return Rouse(tauc, tauw, d50)


def bed_load_transport(u, h, d50, C=.029, **kwargs):
    """
    approximate sediment transport as c * u**3 accounting for critical velocity
    :param u:
    :param h:
    :param d50:
    :param C:
    :param kwargs:
    :return:
    """
    u_cr = u_critical(h, d50, **kwargs)
    um = u.copy()
    um[np.absolute(u) < u_cr] = 0
    return C*um*np.absolute(um)**2


def bed_load_vRijn(u, h, d50, rho_sed=2650., rho_wat=1020):
    """
    vanRijn, 2007
    :param u: velocity (may be complex array)
    :param h: depth
    :param d50: grain size
    :param rho_sed: sediment density
    :param rho_wat: water density
    :return: transport signal
    """
    a = 0.015
    s = rho_sed / rho_wat
    u_cr = u_critical(h, d50, s=s)
    Me = (np.absolute(u) - u_cr) / ((s-1)*9.81*d50)**.5
    return a * rho_sed * u * h * (d50/h)**1.2 * Me**1.5


def plot_shields_diagram(ax):
    D = 10**np.linspace(0.2, 3.2, 100)
    ax.loglog(D, np.vectorize(_shields_critical)(D))
    ax.set_xlim(1, 10000)
    ax.set_ylim(0.01, 1)
    ax.set_xticklabels(map('{:.0f}'.format, ax.get_xticks()))
    ax.set_yticklabels(map('{:.2f}'.format, ax.get_yticks()))
    ax.grid(True)
    ax.set_xlabel('$D^*$ (-)')
    ax.set_ylabel('$\Theta_{cr}$ (-)')

