import numpy as np
from chroma_db import *
import chromadb

recency_decay = 0.9
gw = [1, 1, 1]
max_poig = 10

def extract_relevance(collection:chromadb.Collection, msgs:list[str]):
    """
    Extract the top 50 most relevant records with msg
    from collection
    """
    relevance_records = query_relevance_record(collection, msgs, topK=50)
    relevance = []
    for distance in relevance_records['distances'][0]:
        relevance.append(1-distance)
    return relevance, relevance_records

def extract_poignance(records:list[dict]):
    """
    Extract the poigancy value from records extracted
    from **extract_relevance**
    """
    poignance = []
    for metadata in records['metadatas'][0]:
        poignance.append(metadata['poignancy'])
    return poignance

def extract_most_poignant(collection:chromadb.Collection):
    """
    Retriece records with highest poignance.
    """
    records = query_most_piognancy_record(collection)
    

def extract_recency(records:list[dict]):
    """
    Extract the recency value from records extracted
    from **extract_relevance**
    """
    recency = []
    for i, metadata in enumerate(records['metadatas'][0]):
        recency.append([i, metadata['last_accessed']])
    recency.sort(key=lambda x: x[1], reverse=True)

    quat_recency = []
    for i, recency_e in enumerate(recency):
        quat_recency.append((recency_e[0], recency_decay**i))    
    return list(map(lambda x: x[1], quat_recency))


def top_highest_k_values(scores, topK=30):
    idx = list(range(len(scores)))
    scores_with_idx = list(zip(idx, scores))
    scores_with_idx.sort(key=lambda x: x[1], reverse=True)
    return scores_with_idx[:topK]

def get_res(records, idx):
    res = {
        'node_id': records['ids'][0][idx],
        'node_type': records['metadatas'][0][idx]['node_type'],
        'last_accessed': records['metadatas'][0][idx]['last_accessed'],
        'created': records['metadatas'][0][idx]['created'],
        'poignancy': records['metadatas'][0][idx]['poignancy'],
        'code': records['metadatas'][0][idx]['code'],
        'description': records['metadatas'][0][idx]['description'],
    }
    
    if res['node_type'] != 'skill':
        res['summary'] = records['documents'][0][idx]
    return res
    
def retrieve(collection:chromadb.Collection, msgs:list[str], topK=10):
    relevance, relevance_records = extract_relevance(collection, msgs)
    poignance = extract_poignance(relevance_records)
    recency = extract_recency(relevance_records)
    
    scores = []
    for i in range(len(relevance)):
        score = relevance[i]*gw[0] + poignance[i]* gw[1] + recency[i] * gw[2]
        scores.append(score)
    scores_with_idx = top_highest_k_values(scores, topK=topK)
    res = []
    for idx, _ in scores_with_idx:
        res.append(get_res(relevance_records, idx))
        
    return res

if __name__ == '__main__':
    collection = get_collection('test', 'test')
    print(retrieve(collection, ['ppp']))