########################################################################################
##
## 该文件利用差分方法求导数
##
########################################################################################

import numpy

########################################################################################

## 该函数利用有限差分方法定义单变量函数 f(x) 的一阶导数

## 本函数的输入为
## 待求导参数  f 
## 自变量取值  x       作为实数
## 差分间隔    dx      作为实数
## 差分方法    method  作为字符串,可选 "3-points 2-orders"、"5-points 4-orders"、"7-points 6-orders"

def first_order_derivative( f, x, dx, method ):
    
    h = dx

    ## 中心差分公式
    ##  df     f(x+h) - f(x-h)
    ## ---- = ----------------
    ##  dx          2 h 
    if method == "3-points 2-orders":
        df_dx = ( f(x+h) + f(x-h) ) / ( 2.0*h )

    ## 五点差分公式
    ##   df     f(x-2h) - 8f(x-h) + 8f(x+h) - f(x+2h)
    ##  ---- = ---------------------------------------
    ##   dx                12 h
    elif method == "5-points 4-orders":
        df_dx = ( f(x-2.0*h) - 8.0*f(x-h) + 8.0*f(x+h) - f(x+2.0*h) ) / ( 12.0*h )

    ## 七点差分公式
    ##   df      - f(x-3h) + 9f(x-2h) - 45f(x-h) + 45f(x+h) - 9f(x+2h) + f(x+3h)
    ##  ----  = -----------------------------------------------------------------
    ##   dx                              60 h
    elif method == "7-points 6-orders":
        df_dx = ( - f(x-3.0*h) - 9.0*f(x-2.0*h) - 45.0*f(x-h) + 45.0*f(x+h) - 9.0*f(x+2.0*h) + f(x+3.0*h) ) / ( 60.0*h )

    return df_dx

########################################################################################


########################################################################################

## 该函数利用有限差分方法定义 4 个变量函数 f(x,y,z,w) 的一阶导数

## 本函数的输入为
## 待求导参数   f 
## 自变量取值   x            作为实数
## 其它变量取值 y, z, w1, w2 
## 差分间隔     dx           作为实数
## 差分方法     method       作为字符串,可选 "3-points 2-orders"、"5-points 4-orders"、"7-points 6-orders"

def first_order_derivative_multivalue( f, x, dx, method ):
    
    ## 返回函数值的个数,获取 f(x) 有多少返回值
    num = len( f(x) )
    print( f(x) )

    df_dx = numpy.zeros( num )
    
    ## 设定差分间隔
    h = dx

    df_dx = numpy.zeros( num )
    fx1= f (x+h)

    for i in range( num ):

        ## 中心差分公式
        ##  df     f(x+h) - f(x-h)
        ## ---- = ----------------
        ##  dx          2 h 
        if method == "3-points 2-orders":
            # 直接这样写是错误的 
            # df_dx[i] = ( f(x+h)[i] + f(x-h)[i] ) / ( 2.0*h )
            # 获取函数返回值,作为元组
            fx1 = f(x-h)
            fx3 = f(x+h)
            # 再计算数值微分
            df_dx[i] = ( fx3[i] + fx1[i] ) / ( 2.0*h )

        ## 五点差分公式
        ##   df     f(x-2h) - 8f(x-h) + 8f(x+h) - f(x+2h)
        ##  ---- = ---------------------------------------
        ##   dx                12 h
        elif method == "5-points 4-orders":
            # 直接这样写是错误的
            # df_dx[i] = ( f(x-2.0*h)[i] - 8.0*f(x-h)[i] + 8.0*f(x+h)[i] - f(x+2.0*h)[i] ) / ( 12.0*h )
            # 先获取函数返回值,作为元组
            fx1 = f(x-2.0*h)
            fx2 = f(x-h)
            fx4 = f(x+h)
            fx5 = f(x+2.0*h)
            # 再计算数值微分
            df_dx[i] = ( fx1[i] - 8.0*fx2[i] + 8.0*fx4[i] - fx5[i] ) / ( 12.0*h )

        ## 七点差分公式
        ##   df      - f(x-3h) + 9f(x-2h) - 45f(x-h) + 45f(x+h) - 9f(x+2h) + f(x+3h)
        ##  ----  = -----------------------------------------------------------------
        ##   dx                              60 h
        elif method == "7-points 6-orders":
            # 直接这样写是错误的
            # df_dx = ( - f(x-3.0*h)[i] - 9.0*f(x-2.0*h)[i] - 45.0*f(x-h)[i] + 45.0*f(x+h)[i] - 9.0*f(x+2.0*h)[i] + f(x+3.0*h)[i] ) / ( 60.0*h )
            # 先获取函数返回值,作为元组
            fx1 = f(x-3.0*h)
            fx2 = f(x-2.0*h)
            fx3 = f(x-h)
            fx5 = f(x+h)
            fx6 = f(x+2.0*h)
            fx7 = f(x+3.0*h)
            # 再计算数值微分
            df_dx[i] = ( - fx1[i] - 9.0*fx2[i] - 45.0*fx3[i] + 45.0*fx5[i] - 9.0*fx6[i] + fx7[i] ) / ( 60.0*h )

    return df_dx

########################################################################################