import chromadb
import openai
import os
from tqdm import tqdm
import tiktoken
from chromadb.errors import NotEnoughElementsException
import re
from colorama import Fore, Style

# Instructions (assumes Windows OS)
# In the console/terminal use this command to install the necessary python libraries on your machine: pip install chromadb openai tqdm tiktoken colorama
# Place this script (knowledge_extractor.py) next to a directory named 'documents'. Put text files you want to use as sources of information inside this folder.
# Edit the line below to reflect your openAI api key.
# In the console, at the location of this script enter this command to run it: python ./knowledge_extractor.py

# openai.api_key = YOUR_OPENAI_API_KEY_HERE
openai.api_key = os.getenv("OPENAI_API_KEY") # delete this line if you're using the line above 

chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="my_collection")

def detect_hard_line_breaks(data):
    lines = data.strip().split('\n')
    num_lines = len(lines)

    lines_with_soft_breaks = sum(1 for line in lines if re.search(r'[.!?;:]\s+[a-z]', line))
    proportion_soft_breaks = lines_with_soft_breaks / num_lines

    return proportion_soft_breaks < 0.1  # Adjust the threshold based on your observations


def chunk_text(text, chunk_size=300):
    words = text.split()
    chunks = []

    for i in range(0, len(words), chunk_size):
        chunk = " ".join(words[i:i + chunk_size])
        chunks.append(chunk)

    return chunks


def read_and_embed_file(file_name, title):
    with open(file_name, 'r', encoding="utf-8") as file:
        data = file.read()

    if detect_hard_line_breaks(data):
        # Split the text into chunks of 300 words each
        paragraphs = chunk_text(data, 300)
    else:
        # Split into paragraphs using a regular expression
        paragraphs = re.split(r'\n{1,2}', data.strip())

    metadata_list = []
    ids_list = []

    for idx, paragraph in enumerate(paragraphs):
        metadata_list.append({"source": f"{title}"})
        ids_list.append(f"id{idx + 1}")

    collection.add(
        documents=paragraphs,
        metadatas=metadata_list,
        ids=ids_list
    )


def generate_prompt(sources, question):
    return f"""
    
    Answer the question below usine the sources to get relevant information. Cite sources using in-text citations with square brackets.
    For example: [1] refers to source 1 and [2] refers to source 2. Cite once per sentence.
    If the context doesn't answer the question, output "I don't know".
    Sources: {sources}
    Question: {question}
    Result:"""


def make_openai_call(context, question):
    model_name = "text-davinci-003"
    max_tokens=4097
    max_tokens = max_tokens - 200  # Reserve 200 tokens for the completion

    sources = ''
    total_tokens = 0

    # To get the tokeniser corresponding to a specific model in the OpenAI API:
    encoding = tiktoken.encoding_for_model("text-davinci-003")
    question_tokens=len(encoding.encode(question))

    for idx, paragraph in enumerate(context):
        paragraph_tokens=len(encoding.encode(paragraph))
        if total_tokens + paragraph_tokens + question_tokens <= max_tokens:
            sources += f"Source {idx + 1}: {paragraph}\n"
            total_tokens += paragraph_tokens
        else:
            break

    prompt = generate_prompt(sources, question)

    response = openai.Completion.create(
        model=model_name,
        prompt=prompt,
        temperature=0.6,
        max_tokens=200,
        top_p=1,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop=["\n"]
    )

    return response['choices'][0]['text']


def pretty_print_results(query, summary, sources):
    print("--------------------")
    print(Fore.YELLOW + summary)

    print(Fore.BLUE + "*********")
    print("Sources:")
    for idx, source in enumerate(sources):
        print(f"{idx + 1}: {source}\n")
    print("*********"+ Style.RESET_ALL)


if __name__ == "__main__":
    documents_folder = "documents"
    all_files = [file for file in os.listdir(documents_folder) if file.endswith(".txt")]

    # Use tqdm to create a progress bar for the loop
    for file in tqdm(all_files, desc="Processing files"):
        file_path = os.path.join(documents_folder, file)
        print(f"Processing: {file_path}")
        read_and_embed_file(file_path, file.split('.')[0])

    while True:
        query = input(Fore.GREEN +"Enter your question or type 'exit' to quit: ")
        if query.lower() == "exit":
            break

        results = collection.query(
            query_texts=[query],
            n_results=5
        )

        top_results = results["documents"][0]

        pretty_print_results(query, make_openai_call(
            top_results, query).strip(), top_results)