import os
import math
import numpy
import torch
import torch.nn as nn
from model import *
import argparse
import json
# input_size = 32
# num_units = 128
# num_layers = 1
# bidirectional = True
# recurrent_bias_enabled = True
# lstm1 = nn.LSTM(input_size=input_size,
#                 hidden_size=num_units,
#                 num_layers=num_layers,
#                 batch_first=False,
#                 bidirectional= bidirectional,
#                 bias= recurrent_bias_enabled)
# lstm2 = nn.LSTM(input_size=input_size,
#                 hidden_size=num_units,
#                 num_layers=num_layers + 1,
#                 batch_first=False,
#                 bidirectional= bidirectional,
#                 bias= recurrent_bias_enabled)

# print("LSTM1:")
# for name,params in lstm1.named_parameters():
#     print(f"name:{name},params:{params.shape}")

# print("=============================================")

# print("LSTM2:")
# for name,params in lstm2.named_parameters():
#     print(f"name:{name},params:{params.shape}")

class objdict(dict):
    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError("No such attribute: " + name)

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        if name in self:
            del self[name]
        else:
            raise AttributeError("No such attribute: " + name)
        
parser = argparse.ArgumentParser(description='PyTorch BiLSTM Sequential MNIST Example')
parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
args = parser.parse_args()

with open(args.params) as d:
    trainer_params = json.load(d)
            # trainer_params = json.load(d, object_hook=ascii_encode_dict)
trainer_params = objdict(trainer_params)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 加入设备选择
model = BiLSTM(trainer_params).to(device) 

model.quantize('INT',8,0)

