import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict

def get_model_histogram(model):
    """
    Description:
        - get norm gradients from model, and store in a OrderDict

    Args:
        - model: (torch.nn.Module), torch model

    Returns:
        - grads in OrderDict
    """

    gradshisto = OrderedDict()
    grads = OrderedDict()
    for name, params in model.named_parameters():
        grad = params.grad
        if grad is not None:
            tmp = {}
            params_np = grad.cpu().numpy()
            histogram, bins = np.histogram(params_np.flatten(),bins=20)
            tmp['histogram'] = list(histogram)
            tmp['bins'] = list(bins)
            gradshisto[name] = tmp
            grads[name] = params_np

    return gradshisto,grads


def get_model_norm_gradient(model):
    """
    Description:
        - get norm gradients from model, and store in a OrderDict

    Args:
        - model: (torch.nn.Module), torch model

    Returns:
        - grads in OrderDict
    """
    grads = OrderedDict()
    for name, params in model.named_parameters():
        grad = params.grad
        if grad is not None:
            grads[name] = grad.norm().item()
    return grads


def get_grad_histogram(grads_sum):

    gradshisto = OrderedDict()
   # grads = OrderedDict()
    for name, params in grads_sum.items():
        grad = params
        if grad is not None:
            tmp = {}
            #params_np = grad.cpu().numpy()
            params_np = grad
            histogram, bins = np.histogram(params_np.flatten(),bins=20)
            tmp['histogram'] = list(histogram)
            tmp['bins'] = list(bins)
            gradshisto[name] = tmp   #每层一个histogram （tmp中的是描述直方图的信息）
        #    grads[name] = params_np

    return gradshisto