
import os
import sys
os.chdir(os.path.dirname(__file__))
import re
import pandas as pd
from tqdm import tqdm
import openpyxl
from copy import copy
# from joblib import Parallel, delayed

# 定义输入和输出文件路径
input_file_path = 'info/测试大表2.xlsx'
output_file_path = 'output/论文被引用统计-陈老师-截止2025年X月XX日_企业国家匹配.xlsx'

# 相关参考文件路径
famous_companies_path = 'info/qiyeguojia/知名企业5.8.xlsx'
institution_country_path = 'info/qiyeguojia/机构国家汇总5.8.xlsx'
global_countries_path = 'info/qiyeguojia/全局国家地区5.8.xlsx'

# 确保输出目录存在
output_dir = os.path.dirname(output_file_path)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

def process_column_name(col_name):
    """处理列名，去除换行符和首尾空格"""
    return col_name.replace('\n', '').strip() if isinstance(col_name, str) else col_name

def load_reference_data():
    """加载参考数据"""
    print("正在加载知名企业列表...")
    famous_companies_df = pd.read_excel(famous_companies_path)
    company_dict = {}
    for idx, row in famous_companies_df.iterrows():
        company_name = row["企业名"]
        aliases = row.get("别名", "")
        if isinstance(aliases, str):
            aliases = aliases.split(';')
        else:
            aliases = []

        company_dict[company_name.strip()] = idx + 1
        for alias in aliases:
            alias = alias.strip()
            if alias:
                company_dict[alias] = idx + 1

    print("正在加载机构国家汇总信息...")
    institution_country_df = pd.read_excel(institution_country_path)
    institution_country_mapping = {key.split(',')[0].strip().lower(): value 
                                  for key, value in zip(institution_country_df['引文机构名'], institution_country_df['国家'])}

    print("正在加载全局国家地区信息...")
    global_countries_df = pd.read_excel(global_countries_path)
    country_index_mapping = {country: idx + 1 for idx, country in enumerate(global_countries_df['国家/地区'])}
    english_to_chinese = dict(zip(global_countries_df['国家/地区'], global_countries_df['国家中文名']))
    
    return company_dict, institution_country_mapping, country_index_mapping, english_to_chinese

def match_companies_and_countries(institution, company_dict, institution_country_mapping, country_index_mapping, english_to_chinese):
    """匹配企业名称和国家"""
    if not isinstance(institution, str) or institution.strip() == "":
        return "", "", "", ""
    
    institutions = [inst for inst in re.split(r'[;\n]', institution.strip()) if inst]
    company_names = []
    company_indexes = []
    country_names = []
    country_indexes = []

    for inst in institutions:
        inst_clean = inst.strip()
        # 避免 ，Google 导致匹配失败
        inst_clean = inst_clean.replace("，", ",")
        inst_clean = inst_clean.replace(",", " , ").strip()
        # 匹配知名企业
        matched = False

        for company in company_dict:
            if re.search(r'[a-zA-Z]', company):  # 判断企业名是否包含英文字母
                # 英文企业名前后必定不会有别的字母
                pattern = r'(^|\s)' + re.escape(company) + r'(\s|$)'
                if re.search(pattern, inst_clean):
                    company_names.append(company)
                    company_indexes.append(str(company_dict[company]))
                    matched = True
                    break
            else:
                # 中文企业名直接进行字符串包含匹配
                if company in inst_clean:
                    company_names.append(company)
                    company_indexes.append(str(company_dict[company]))
                    matched = True
                    break

        if not matched:
            company_names.append("")
            company_indexes.append("无")

        # 使用小写进行匹配机构国家匹配
        inst_clean = inst_clean.lower()    

        parts = re.split(r'[(),]', inst_clean)
        matched_country = False
        for part in parts:
            part = part.strip()
            # 去除the再匹配, 避免 the university of xxx的情况
            if part.startswith('the '):
                part = part[4:].strip()
                
            part = part.replace("univ.", "university") 
            part = part.replace("corp.", "corporation") 

            if part in institution_country_mapping:
                english_country = str(institution_country_mapping.get(part))
                for country in english_to_chinese:
                    if english_country in country:
                        # 中文国家名
                        country_names.append(english_to_chinese[country])
                        country_indexes.append(str(country_index_mapping[country]))
                        matched_country = True
                        break
            if matched_country:
                break
        if not matched_country:
            country_names.append("?")
            country_indexes.append("?")

    # 拼接结果，过滤空名称
    company_names_str = ";".join([n for n in company_names if n])
    company_indexes_str = ";".join(company_indexes)
    country_names_str = ";".join([n for n in country_names if n])
    country_indexes_str = ";".join(country_indexes)

    return company_names_str, company_indexes_str, country_names_str, country_indexes_str

def process_row(index, row, company_dict, institution_country_mapping, country_index_mapping, english_to_chinese):
    """处理每一行数据"""
    try:
        institution = row.get('引文机构', '')
        
        company_names, company_indexes, country_names, country_indexes = match_companies_and_countries(
            institution, 
            company_dict, 
            institution_country_mapping, 
            country_index_mapping, 
            english_to_chinese
        )
        
        return index, company_names, company_indexes, country_names, country_indexes
    except Exception as e:
        print(f"处理行 {index} 时发生错误: {e}")
        return index, "", "", "", ""

def main():
    # 加载参考数据
    company_dict, institution_country_mapping, country_index_mapping, english_to_chinese = load_reference_data()
    
    # 读取表头（第4行作为列名）
    print("正在读取Excel文件...")
    original_header = pd.read_excel(input_file_path, nrows=0, header=3)
    column_names = original_header.columns.tolist()
    
    # 读取数据（从第8行开始）
    input_df = pd.read_excel(input_file_path, skiprows=7, header=None, names=column_names)
    
    # 输出表头和数据的基本信息
    print("表头元素：")
    print(column_names)
    print("\n数据行数：", input_df.shape[0])
    print("数据列数：", input_df.shape[1])
    
    # 确保所需列存在
    required_columns = [
        "引文机构", 
        "知名企业名称（参考知名企业列表）", 
        "引文机构在知名企业中的索引",
        "引文机构所属国家",
        "引文机构所属国家索引"
    ]
    
    for col in required_columns:
        if col not in input_df.columns:
            input_df[col] = None
    
    print("开始并行处理数据...")
    # 并行处理，将加载的数据传递给每个进程
    # results = Parallel(n_jobs=-1)(
    #     delayed(process_row)(
    #         index, row, 
    #         company_dict, institution_country_mapping, country_index_mapping, english_to_chinese
    #     ) 
    #     for index, row in tqdm(input_df.iterrows())
    # )
    results = []
    for index, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
        result = process_row(
            index, row, 
            company_dict, institution_country_mapping, country_index_mapping, english_to_chinese
        )
        results.append(result)
    for index, company_names, company_indexes, country_names, country_indexes in results:
        input_df.at[index, "知名企业名称（参考知名企业列表）"] = company_names
        input_df.at[index, "引文机构在知名企业中的索引"] = company_indexes
        input_df.at[index, "引文机构所属国家"] = country_names
        input_df.at[index, "引文机构所属国家索引"] = country_indexes
    
    # 打开原始Excel文件
    print("正在读取原始Excel文件以保留格式...")
    wb_original = openpyxl.load_workbook(input_file_path)
    
    # 创建新工作簿
    wb_new = openpyxl.Workbook()
    # 删除默认创建的空白工作表
    if 'Sheet' in wb_new.sheetnames:
        del wb_new['Sheet']
    
    # 复制所有工作表
    for sheet_name in wb_original.sheetnames:
        ws_original = wb_original[sheet_name]
        ws_new = wb_new.create_sheet(sheet_name)
        
        # 复制工作表属性
        ws_new.sheet_properties = copy(ws_original.sheet_properties)
        ws_new.sheet_format = copy(ws_original.sheet_format)
        
        # 复制整个工作表的内容和格式
        for row in ws_original.rows:
            for cell in row:
                new_cell = ws_new.cell(row=cell.row, column=cell.column, value=cell.value)
                if cell.has_style:
                    new_cell.font = copy(cell.font)
                    new_cell.border = copy(cell.border)
                    new_cell.fill = copy(cell.fill)
                    new_cell.number_format = copy(cell.number_format)
                    new_cell.protection = copy(cell.protection)
                    new_cell.alignment = copy(cell.alignment)
        
        # 复制合并单元格
        for merged_cell_range in ws_original.merged_cells.ranges:
            ws_new.merge_cells(str(merged_cell_range))
    
    # 获取主工作表（第一个工作表）
    main_sheet_name = wb_original.sheetnames[0]
    ws_new = wb_new[main_sheet_name]
    
    # 更新主工作表中的相关数据
    print("正在更新主工作表中的企业和国家数据...")
    
    # 找出列索引（Excel中列是从1开始的）
    column_indices = {}
    
    # 获取第4行（索引从1开始）的所有单元格值
    header_row = [process_column_name(cell.value) for cell in ws_new[4]]
    
    # 在这些值中查找列名对应的索引
    for i, cell_value in enumerate(header_row, start=1):
        for col in required_columns:
            if process_column_name(col) == cell_value:
                column_indices[col] = i
    
    # 如果找不到列，添加新列
    max_col = ws_new.max_column
    
    for col in required_columns:
        if col not in column_indices:
            max_col += 1
            column_indices[col] = max_col
            ws_new.cell(row=4, column=max_col, value=col)
    
    # 更新数据（从第8行开始）
    # for i, row in input_df.iterrows():
    for i, (index, row) in enumerate(input_df.iterrows(), start=7):
        # excel_row = i + 8  # 转换为Excel的行号（从1开始）
        
        # 更新企业和国家相关列
        for col in required_columns[1:]:  # 跳过"引文机构"列，因为它已经存在
            ws_new.cell(row=i+1, column=column_indices[col], value=row[col])
    
    # 保存新Excel文件
    print(f"正在保存为Excel格式 {output_file_path} ...")
    wb_new.save(output_file_path)
    print(f"成功保存到 {output_file_path}，保留了原始格式")
    
    print("处理完成!")

if __name__ == "__main__":
    main()
