from pydantic import BaseModel
from openai import OpenAI
from pathlib import Path
import pypdf
import json
import os

import logging

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
)


class Author(BaseModel):
    name: str
    affiliations: list[str]


class Metadata(BaseModel):
    title: str
    authors: list[Author]


system_prompt = """
Act as an expert metadata extraction assistant.
Analyze the following text, which is extracted from the first page of a document (likely a scientific paper or report).
Your goal is to extract the document title, all authors, and their corresponding affiliations.

Extraction Guidelines:
-   **Title:** Extract the main title of the document. If ambiguous or missing, use "".
-   **Authors:**
    -   Identify all listed authors. Maintain the order presented in the text if possible.
    -   For each author:
        -   Extract their full name as accurately as possible. Use "" if a name cannot be clearly identified for an entry.
        -   Extract all associated institutions/affiliations mentioned for that specific author.
        -   If an author has no listed institution, use an empty list `[]`.
        -   If there are many authors and only one afflication, these authors all come from the same afflication. other wise find the corresponding afflication by indicator.
-   **Handling Missing Data:** If no authors can be identified in the text, the "authors" field in the JSON should be an empty list `[]`.
"""


def get_authors(content):
    response = client.responses.parse(
        model="gpt-4o",
        input=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": content},
        ],
        text_format=Metadata,
    )
    result = response.output_parsed
    return result


def extract_first_page_text(pdf_path):
    """Extracts text content from the first page of a PDF."""
    try:
        reader = pypdf.PdfReader(pdf_path)
        if len(reader.pages) > 0:
            first_page = reader.pages[0]
            text = first_page.extract_text()
            if text:
                # Basic cleaning: remove excessive whitespace
                cleaned_text = " ".join(text.split())
                cleaned_text = cleaned_text.encode("utf-8", errors="replace").decode(
                    "utf-8"
                )
                return cleaned_text
            else:
                logging.warning(f"No text found on the first page of {pdf_path.name}")
                return None
        else:
            logging.warning(f"PDF has no pages: {pdf_path.name}")
            return None
    except pypdf.errors.PdfReadError as e:
        logging.error(f"Error reading PDF file {pdf_path.name}: {e}")
        return None
    except FileNotFoundError:
        logging.error(f"PDF file not found: {pdf_path}")
        return None
    except Exception as e:
        logging.error(
            f"An unexpected error occurred while processing {pdf_path.name}: {e}"
        )
        return None


def main(pdf_directory: Path, result_path: Path):
    with open(result_path, "a", encoding="utf-8") as f:
        pdf_files = pdf_directory.rglob("*.pdf")
        for file in pdf_files:
            try:
                logging.info(f"Extract {file.name}'s authors")
                first_page_text = extract_first_page_text(file)
                logging.info(first_page_text)
                if first_page_text is not None:
                    result = get_authors(first_page_text).model_dump()
                    result["filename"] = file.name
                    f.write(json.dumps(result) + "\n")
            except Exception as e:
                logging.error(f"{file.name}: {str(e)}")


if __name__ == "__main__":
    import argparse

    argparser = argparse.ArgumentParser()
    argparser.add_argument("--paper", type=str, required=True)
    argparser.add_argument("--result", type=str, required=True)
    args = argparser.parse_args()

    main(Path(args.paper), Path(args.result))
