from qdrant_client import QdrantClient
from qdrant_client.http import models as qdrant_models
import requests
import os
import logging
import json
import numpy as np
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer

# Cấu hình logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load biến môi trường từ .env
load_dotenv()

# Khởi tạo model embedding toàn cục
_embedding_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
VECTOR_SIZE = 768

def call_openrouter_llama3(prompt, api_key):
    """Gọi API OpenRouter để xử lý prompt với mô hình Llama-3."""
    url = "https://openrouter.ai/api/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    data = {
        "model": "meta-llama/llama-3-8b-instruct",
        "messages": [
            {"role": "system", "content": "Bạn là trợ lý AI chuyên tư vấn về các thông tin của Trường Đại học Y Dược Cần Thơ, trả lời bằng tiếng Việt tối đa 40 từ."},
            {"role": "user", "content": prompt}
        ],
        "temperature": 0.2
    }
    try:
        response = requests.post(url, headers=headers, json=data)
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]
    except requests.RequestException as e:
        logging.error(f"Lỗi khi gọi OpenRouter API: {e}")
        raise
    """Điều chỉnh câu hỏi bằng LLM, thêm ngữ cảnh về Trường Đại học Y Dược Cần Thơ."""
    word_count = len(question.split())
    max_words = min(word_count * 8, 30)
    prompt = f"""
    Diễn đạt lại câu hỏi sau rõ ràng hơn, giữ đúng ý nghĩa gốc, đúng chính tả và thêm ngữ cảnh về Trường Đại học Y Dược Cần Thơ nếu phù hợp, tối đa {max_words} từ.
    Ví dụ:
    - Câu hỏi gốc: "hiệu trưởng là ai"
    - Câu hỏi điều chỉnh: "Ai là hiệu trưởng của Trường Đại học Y Dược Cần Thơ?"
    Câu hỏi: {question}
    """
    try:
        refined = call_openrouter_llama3(prompt, api_key)
        refined_word_count = len(refined.split())
        if refined_word_count > max_words:
            logging.warning(f"Câu hỏi điều chỉnh ({refined_word_count} từ) vượt quá giới hạn {max_words} từ, cắt bớt.")
            refined = " ".join(refined.split()[:max_words])
        logging.info(f"[RAG] Câu hỏi gốc: {question}, Câu hỏi điều chỉnh: {refined}")
        return refined
    except Exception as e:
        logging.error(f"Lỗi khi điều chỉnh câu hỏi: {e}")
        return question  # Trả về câu hỏi gốc nếu lỗi
# Load config
def load_keywords_from_config(collection_name: str, config_path="config.json"):
    """Tải danh sách keyword tương ứng với collection từ config.json"""
    with open(config_path, "r", encoding="utf-8") as f:
        config = json.load(f)
    for collection in config.get("collections", []):
        if collection.get("collection_name") == collection_name:
            return collection.get("keywords", [])
    return []

#tương đồng 2 vector
def cosine_similarity(vec1, vec2):
    vec1, vec2 = np.array(vec1), np.array(vec2)
    return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))

#từ khóa gần nhất với vector q
def find_closest_keyword(query_vector, keywords, embedding_func, top_k=1):
    """Tìm từ khóa gần nhất với vector câu hỏi"""
    keyword_vectors = [embedding_func(kw) for kw in keywords]
    similarities = [cosine_similarity(query_vector, vec) for vec in keyword_vectors]
    top_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:top_k]
    return [(keywords[i], similarities[i]) for i in top_indices]

#ktr Qdrant
def initialize_qdrant_collection(client, collection_name):
    """Khởi tạo collection trong Qdrant nếu chưa tồn tại."""
    try:
        collections = client.get_collections()
        if collection_name not in [c.name for c in collections.collections]:
            client.create_collection(
                collection_name=collection_name,
                vectors_config=qdrant_models.VectorParams(
                    size=VECTOR_SIZE,
                    distance=qdrant_models.Distance.COSINE
                )
            )
            logging.info(f"[RAG] Đã tạo collection '{collection_name}' với vector size {VECTOR_SIZE}.")
        else:
            collection_info = client.get_collection(collection_name)
            if collection_info.config.params.vectors.size != VECTOR_SIZE:
                raise ValueError(f"Collection tồn tại nhưng có vector size {collection_info.config.params.vectors.size} khác với yêu cầu {VECTOR_SIZE}")
            logging.info(f"[RAG] Collection '{collection_name}' đã tồn tại với vector size phù hợp.")
    except Exception as e:
        logging.error(f"Lỗi khi khởi tạo collection: {e}")
        raise

#Xử lý văn bản
def clean_text(text):
    """Làm sạch văn bản: loại bỏ lặp lại và chuẩn hóa."""
    if not text:
        return ""
    
    # Loại bỏ lặp lại
    sentences = text.split(". ")
    unique_sentences = list(dict.fromkeys(sentences))
    cleaned_text = ". ".join(unique_sentences).strip()
    
    # Loại bỏ các từ thừa hoặc ký tự không cần thiết
    cleaning_patterns = [
        "Web - Thẻ Span: ",
        "Web - Đoạn văn: ",
        "Web - Danh sách: ",
        "Web - Tiêu đề: "
    ]
    
    for pattern in cleaning_patterns:
        cleaned_text = cleaned_text.replace(pattern, "")
    
    return cleaned_text if cleaned_text else text

#tạo vector nhúng
def get_embedding(text):
    """Tạo embedding cho văn bản sử dụng model đã khởi tạo."""
    try:
        emb = _embedding_model.encode(text).tolist()
        logging.info(f"[RAG] Embedding dimension: {len(emb)}")
        return emb
    except Exception as e:
        logging.error(f"Lỗi khi tạo embedding: {e}")
        raise

def RAG_agent(input_text=None):
    """Xử lý câu hỏi sử dụng RAG với Qdrant và LLM."""
    if not input_text:
        return {"error": "Input text is required."}

    qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333")
    api_key = os.environ.get("OPENROUTER_API_KEY")
    qdrant_api_key = os.environ.get("QDRANT_API_KEY")

    if api_key is None:
        raise ValueError("OPENROUTER_API_KEY không được cấu hình trong .env!")

    try:
        with open("config.json", "r", encoding="utf-8") as f:
            config = json.load(f)
    except FileNotFoundError:
        logging.error("Không tìm thấy file config.json!")
        raise
    except json.JSONDecodeError:
        logging.error("File config.json có định dạng không hợp lệ!")
        raise

    try:
        collections = config.get("collections", [])
        if not collections or not isinstance(collections, list):
            raise ValueError("Cấu trúc collections trong config.json không hợp lệ")

        collection_config = collections[0]
        collection_name = collection_config["collection_name"]
        prompt_template = collection_config["prompt"]
    except (KeyError, IndexError) as e:
        logging.error(f"Cấu hình collection không hợp lệ: {e}")
        raise

    logging.info(f"[RAG] Sử dụng collection: {collection_name}")

    try:
        if qdrant_api_key:
            client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
        else:
            client = QdrantClient(url=qdrant_url)

        initialize_qdrant_collection(client, collection_name)

        from qdrant_client.http.models import PayloadSchemaType
        try:
            client.create_payload_index(
                collection_name=collection_name,
                field_name="content",
                field_schema=PayloadSchemaType.TEXT
            )
            logging.info("[RAG] Đã bật text index cho 'content'")
        except Exception as e:
            logging.warning(f"[RAG] Không thể tạo text index (có thể đã tồn tại): {e}")

        points_count = client.count(collection_name=collection_name)
        logging.info(f"[RAG] Số lượng điểm trong collection '{collection_name}': {points_count.count}")
    except Exception as e:
        logging.error(f"Lỗi khi kết nối Qdrant: {e}")
        raise

    try:
        refined_input = input_text.strip()
        # Xử lý trường hợp người dùng chỉ chào
        if refined_input.lower() in ["hello", "hi", "xin chào", "chào", "chào bạn"]:
            return "Xin chào! Tôi là trợ lý AI của Trường Đại học Y Dược Cần Thơ. Tôi có thể giúp bạn tìm kiếm thông tin về trường đại học này."

        logging.info(f"[RAG] Câu hỏi dùng để truy vấn: {refined_input}")

        query_vector = get_embedding(refined_input)
        keywords = load_keywords_from_config(collection_name)

        from qdrant_client.http.models import Filter, FieldCondition, MatchText

        use_filter = False
        query_filter = None

        if keywords:
            closest_keywords = find_closest_keyword(query_vector, keywords, get_embedding, top_k=1)
            best_keyword, best_score = closest_keywords[0]
            logging.info(f"[RAG] Từ khóa gần nhất: {best_keyword} (score={best_score:.3f})")

            if best_score >= 0.2:
                use_filter = True
                query_filter = Filter(
                    must=[
                        FieldCondition(
                            key="content",
                            match=MatchText(text=best_keyword)
                        )
                    ]
                )
                logging.info(f"[RAG] Sử dụng filter MatchText cho keyword: {best_keyword}")
            else:
                use_filter = True
                query_filter = Filter(
                    must=[
                        FieldCondition(
                            key="content",
                            match=MatchText(text=refined_input)
                        )
                    ]
                )
                logging.info(f"[RAG] Không tìm thấy keyword phù hợp, dùng MatchText trực tiếp từ câu hỏi")

        try:
            search_result = client.search(
                collection_name=collection_name,
                query_vector=query_vector,
                query_filter=query_filter if use_filter else None,
                limit=190,
                with_payload=True,
                with_vectors=False
            )

            if not search_result and use_filter:
                logging.warning("[RAG] Không có kết quả với filter, fallback sang truy vấn vector gốc")
                search_result = client.search(
                    collection_name=collection_name,
                    query_vector=query_vector,
                    limit=200,
                    with_payload=True,
                    with_vectors=False
                )
        except Exception as e:
            logging.error(f"[RAG] Lỗi khi truy vấn Qdrant: {e}")
            search_result = []

        for point in search_result:
            print(f"ID: {point.id}, Score: {point.score:.3f}")
            print("Title:", point.payload.get('metadata', {}).get('title', ''))
            print("Section:", point.payload.get('metadata', {}).get('section', ''))
            print("Content:", point.payload.get('content', ''))

        if not search_result:
            logging.warning(f"[RAG] Qdrant trả về kết quả rỗng cho collection '{collection_name}'")
            try:
                answer = call_openrouter_llama3(refined_input, api_key)
                logging.info("[RAG] Đã gọi LLM trực tiếp do không có kết quả từ Qdrant")
                return {"warning": "Không tìm thấy tài liệu liên quan.", "answer": answer}
            except Exception as e:
                logging.error(f"Lỗi khi gọi LLM trực tiếp: {e}")
                return {"error": f"LLM error: {str(e)}"}

        if use_filter:
            filtered_points = search_result
        else:
            filtered_points = [point for point in search_result if point.score >= 0.2]

        if not filtered_points:
            logging.warning("[RAG] Không có tài liệu nào đạt ngưỡng score >= 0.2")
            try:
                answer = call_openrouter_llama3(refined_input, api_key)
                logging.info("[RAG] Đã gọi LLM trực tiếp do không có tài liệu đạt ngưỡng score")
                return {"warning": "Không tìm thấy tài liệu liên quan.", "answer": answer}
            except Exception as e:
                logging.error(f"Lỗi khi gọi LLM trực tiếp: {e}")
                return {"error": f"LLM error: {str(e)}"}

        # Ghép các chunks
        documents = []
        temp_doc = ""
        max_length = 1500

        for point in filtered_points:
            content = clean_text(point.payload.get("content", ""))
            if len(temp_doc) + len(content) < max_length:
                temp_doc += " " + content
            else:
                documents.append(temp_doc.strip())
                temp_doc = content
        if temp_doc:
            documents.append(temp_doc.strip())

        context_lines = [
            f"{i+1}. {doc}" for i, doc in enumerate(documents)
        ]
        context = "\n".join(context_lines)

        prompt = f"""{prompt_template}\n\nCâu hỏi của quý khách: {input_text}\n\n
        Thông tin có sẵn:\n{context}\n\n
        Trả lời:"""

        try:
            answer = call_openrouter_llama3(prompt, api_key)
            return answer
        except Exception as e:
            logging.error(f"Lỗi khi gọi LLM để trả lời: {e}")
            return "Xin lỗi, tôi gặp lỗi khi xử lý câu hỏi của bạn."

    except Exception as e:
        logging.error(f"Lỗi trong quá trình RAG: {e}", exc_info=True)
        return {"error": f"RAG error: {str(e)}"}

def process_with_ai_agent(input_text=None):
    """
    Xử lý văn bản đầu vào bằng AI Agent với Qdrant.
    Args:
        input_text (str): Văn bản đầu vào cần xử lý
    Returns:
        Kết quả xử lý từ AI Agent hoặc thông báo lỗi
    """
    try:
        response = RAG_agent(input_text)
        return response
    except Exception as e:
        logging.error(f"Lỗi trong quá trình xử lý AI Agent: {str(e)}")
        return None

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    print("=== AI Agent Qdrant CLI ===")
    print("Nhập câu hỏi (gõ 'exit' để thoát):")
    while True:
        user_input = input("Bạn: ").strip()
        if user_input.lower() in ["exit", "quit", "q"]:
            print("Kết thúc.")
            break
        if not user_input:
            continue
        result = process_with_ai_agent(user_input)
        print(f"AI: {result}\n")