#################################################
##
## 这个文件包含了数值相对论所的二进制数据画图
## 小曲
## 2024/10/01 --- 2024/12/06 
##
#################################################

import numpy
import matplotlib.pyplot    as     plt
from   matplotlib.colors    import LogNorm
from   mpl_toolkits.mplot3d import Axes3D
## import torch
import AMSS_NCKU_Input      as input_data

import os


#########################################################################################

def plot_binary_data( filename, binary_outdir, figure_outdir ):

    figure_title0 = filename.replace(binary_outdir + "/", "")  # 去掉路径中的前缀
    figure_title  = figure_title0.replace(".bin", "")          # 去掉路径中的.bin
    
    print(                                        )
    print( " 正在读取二进制文件 = ", figure_title0 )

###################################

    # 打开文件
    # 根据 AMSS-NCKU 输出二进制文件中的数据顺序依次读入数据
    with open(filename, 'rb') as file:

        physical_time = numpy.fromfile( file, dtype=numpy.float64, count=1 )
        nx, ny, nz    = numpy.fromfile( file, dtype=numpy.int32,   count=3 )
        xmin, xmax    = numpy.fromfile( file, dtype=numpy.float64, count=2 )
        ymin, ymax    = numpy.fromfile( file, dtype=numpy.float64, count=2 )
        zmin, zmax    = numpy.fromfile( file, dtype=numpy.float64, count=2 )
        data          = numpy.fromfile( file, dtype=numpy.float64          )
        
        # 现在 data 数组包含了文件中的二进制数据
 
    print( " 读取的数组大小 = ",     data.shape                            ) 
    print( " 读取的数组长度 = ",     data.size                             ) 
    print( " 原始设定的数组长度 = ", nx, "*", ny, "*", nz, " = ", nx*ny*nz )
    
###################################

    # 将读入的数据转化为多维数组
    data_reshape = data.reshape( (nz, ny, nx) ) ## 这样的排列方式画出来才正常
    # print(data_reshape)

    # data1 = data_reshape[0,:,:]
    # print(data1)

    Rmin = [xmin, ymin, zmin] 
    Rmax = [xmax, ymax, zmax]
    N    = [nx, ny, nz]
    print( " 格点坐标最小值 = ", Rmin )
    print( " 格点坐标最大值 = ", Rmax )
    print( " 格点数目       = ", N    )
    
    print(                                )
    print( " 数据读取完成,接下来开始画图 " )
    print(                                )

    # 利用画图函数进行画图
    figure_title0    = filename.replace(binary_outdir + "/", "") # 去掉路径中的前缀
    figure_title     = figure_title.replace(".bin", "")          # 去掉最后的".bin"
    figure_title_new = figure_title[:-6]                         # 再去掉末尾的6个字符,代表的是迭代次数
    
    get_data_xy( Rmin, Rmax, N, data_reshape, physical_time[0], figure_title_new, figure_outdir )
    # 注意 numpy 从二进制文件中读取的 physical_time 是一个数组(尽管实际上只有一个元素)
    # 因此用 physical_time[0] 代表对应的时间值
    
    # 手动删除数据以清除内存
    del data
    del data_reshape
    
    print( " 二进制文件 ", figure_title0," 画图已完成 " )
    print(                                             )

    return
    
    
#########################################################################################




####################################################################################

# 这是一个对某一二进制数据的画图函数

def get_data_xy( Rmin, Rmax, n, data0, time, figure_title, figure_outdir ):

    figure_contourplot_outdir = os.path.join(figure_outdir, "contour plot")
    figure_densityplot_outdir = os.path.join(figure_outdir, "density plot")
    figure_surfaceplot_outdir = os.path.join(figure_outdir, "surface plot")

    # 根据读到的格点信息还原格点坐标
    x = numpy.linspace(Rmin[0], Rmax[0], n[0])
    y = numpy.linspace(Rmin[1], Rmax[1], n[1])
    # z = numpy.linspace(Rmin[2], Rmax[2], n[2])
    # print(x)
    # print(y)

    # 用 meshgrid 建立二维格点坐标
    # X, Y = numpy.meshgrid(x, y)                                # 因为 numpy 中的 meshgrid 函数会将行列互换,很坑啊
    # X, Y = torch.meshgrid(torch.tensor(x), torch.tensor(y))    # 然而 torch 中的 meshgrid 函数不会将行列互换
    Y, X = numpy.meshgrid(y, x)                                  # 因为 numpy 中的 meshgrid 函数会将行列互换
    X0 = numpy.transpose(X)                                      # 必须要取转置之后才能进行下面操作,很坑啊
    Y0 = numpy.transpose(Y)
    # print(X0.shape)
    # print(Y0.shape)
    # print(X0[:,0])
    # print(Y0[0,:])

    # 获取 xy 平面上的数据
    if input_data.Symmetry == "no-symmetry":
        data_xy = data0[n[0]//2,:,:]
    else:
        data_xy = data0[0,:,:]

    # 下面画出二维等高线图
    fig, ax = plt.subplots()
    # contourf = ax.contourf(X, Y, data_xy, 8, cmap='coolwarm', norm=LogNorm(vmin=1, vmax=10), levels=numpy.logspace(-2, 2, 8))  # 使用'coolwarm'色板,并设置标准色彩映射
    contourf = ax.contourf( X0, Y0, data_xy, cmap=plt.get_cmap('RdYlGn_r') )
    contour  = ax.contour(  X0, Y0, data_xy, 8, colors='k', linewidths=0.5 )   # 添加等高线
    cbar     = plt.colorbar(contourf)                                          # 添加色条
    ax.set_title(  figure_title + "  physical time = " + str(time) )           # 设置标题和轴标签
    ax.set_xlabel( "X axis" )
    ax.set_ylabel( "Y axis" )
    # plt.show()                                                               # 显示图像
    plt.savefig( os.path.join(figure_contourplot_outdir, figure_title + " time = " + str(time) + " contour_plot.pdf") )   # 保存图像
    plt.close()
    
    # 下面画出二维热图    
    # fig1 = plt.figure()
    fig1, ax  = plt.subplots()
    imshowfig = plt.imshow( data_xy, interpolation='nearest', extent=[X.min(), X.max(), Y.min(), Y.max()] )
    cbar      = plt.colorbar(imshowfig)                                     # 添加色条
    ax.set_title(  figure_title + "  physical time = " + str(time)  )       # 设置标题和轴标签
    ax.set_xlabel( "X axis" )
    ax.set_ylabel( "Y axis" )
    # plt.show() 
    plt.savefig( os.path.join(figure_densityplot_outdir, figure_title + " time = " + str(time) + " density_plot.pdf") )
    plt.close()

    # 下面画出三维图
    fig2 = plt.figure()                                                       # 创建一个新的图像
    ax = fig2.add_subplot( 111, projection='3d' )                             # 创建一个 3D 绘图区域
    ax.plot_surface( X0, Y0, data_xy, cmap='viridis' )                        # 绘制曲面
    ax.set_title(  figure_title + "  physical time = " + str(time) )          # 设置标题和轴标签
    ax.set_xlabel( "X axis" )
    ax.set_ylabel( "Y axis" )
    # plt.show()                                                              # 显示图像
    plt.savefig( os.path.join(figure_surfaceplot_outdir, figure_title + " time = " + str(time) + " surface_plot.pdf") )   # 保存图像
    plt.close()

    return

####################################################################################