import chromadb
import openai
from chromadb.utils import embedding_functions

example = {'node_id': '3', 'node_type':'chat', 'description': 'test',
           'summary':'tst', 'code': 'NULL', 'last_accessed': '2023-1-10',
           'created':'2023-1-0', 'embeddings':[1, 2, 3], 'poignancy': 10}

example1 = {'node_id': '2', 'node_type':'chat', 'description': 'test',
           'summary':'ppp', 'code': 'NULL', 'last_accessed': '2023-1-10',
           'created':'2023-1-0', 'embeddings':[1, 2, 3], 'poignancy': 10}

example2 = {'node_id': '10', 'node_type':'chat', 'description': 'test',
           'summary':'tst', 'code': 'NULL', 'last_accessed': '2023-1-10',
           'created':'2023-1-0', 'embeddings':[1, 2, 3], 'poignancy': 10}
def get_openai_key(key_path:str):
    with open(key_path, 'r') as f:
        res = f.read().strip()
    return res

def get_collection(db_file_name:str, collection_name:str):
    """
    Get the collection object according to
    the name of the database and the name of the collection
    """
    # openai_ef = embedding_functions.OpenAIEmbeddingFunction(
    #     api_key=get_openai_key("key.txt"),
    #     model_name="text-embedding-ada-002"
    # )
    default_ef = embedding_functions.DefaultEmbeddingFunction()
    client = chromadb.PersistentClient(path=db_file_name)
    collection = client.get_or_create_collection(name=collection_name,
                                                 metadata={"hnsw:space": "cosine"},
                                                 embedding_function=default_ef)
    return collection

def add_record(collection:chromadb.Collection, records:list[dict]):
    """
    Insert these records in the collection.
    We use *summary* (or *description* of the skill) as the documents and 
    other fields as the metadata. We calculate the embeddings of the records.
    """
    ids = []
    docs = []
    metadatas = []
    
    for record in records:
        ids.append(record['node_id'])

        metadata = {}
        metadata['node_type'] = record['node_type']
        metadata['last_accessed'] = record['last_accessed']
        metadata['created'] = record['created']
        metadata['poignancy'] = record['poignancy']
        metadata['code'] = record['code']
        metadata['description'] = record['description']
        metadatas.append(metadata)

        if record['node_type'] == 'skill':
            docs.append(record['description'])
        else:
            docs.append(record['summary'])

    collection.add(ids=ids, metadatas=metadatas, documents=docs)


def delete_record(collection:chromadb.Collection, records:list[dict]):
    """
    Delete record whose node_id == record['node_id']
    If there is no this record, do nothing.
    """
    ids = []
    for record in records:
        ids.append(record['node_id'])
    collection.delete(ids=ids)

def query_relevance_record(collection:chromadb.Collection, keys:list[str], topK:int=50):
    """
    Query the highest relevant 50 records from the database
    according to the search keys.
    """
    return collection.query(
        query_texts=keys,
        n_results=topK
    )

def query_most_piognancy_record(collection:chromadb.Collection):
    """
    Query the records which have the highest poignancy values.
    """
    return collection.get(
        where={
            "poignancy" : 10
        }
    )

    
def update_record(collection:chromadb.Collection, records:list[dict]):
    """
    Update records in the collection.
    """
    ids = []
    docs = []
    metadatas = []
    
    for record in records:
        ids.append(record['node_id'])

        metadata = {}
        metadata['node_type'] = record['node_type']
        metadata['last_accessed'] = record['last_accessed']
        metadata['created'] = record['created']
        metadata['poignancy'] = record['poignancy']
        metadata['code'] = record['code']
        metadata['description'] = record['description']
        metadatas.append(metadata)

        if record['node_type'] == 'skill':
            docs.append(record['description'])
        else:
            docs.append(record['summary'])
    collection.update(ids=ids, documents=docs, metadatas=metadatas)

if __name__ == "__main__":
    collection = get_collection('test', 'test')
    add_record(collection, [example, example1, example2])
    # add_record(collection, [example1])
    print(query_relevance_record(collection, "tst"))
    print(query_most_piognancy_record(collection))
    # update_record(collection, [example1])
    # delete_record(collection, [example])
    # delete_record(collection, [example1])

