import os.path from langchain_chroma import Chroma from utils.config_handler import chroma_conf from model.factory import embedding_model from langchain_text_splitters import RecursiveCharacterTextSplitter from utils.config_handler import chroma_conf import os from utils.path_tool import get_abs_path from utils.file_handler import pdf_loader, txt_loader, listdir_with_allowed_type, get_file_md5_hex from utils.logger_handler import logger from langchain_community.docstore.document import Document as Documents class VectorStoreSerivce: def __init__(self): self.vector_store = Chroma( collection_name=chroma_conf["collection_name"], embedding_function=embedding_model, persist_directory=get_abs_path(chroma_conf["persist_directory"]), ) self.spliter = RecursiveCharacterTextSplitter( chunk_size=chroma_conf["chunk_size"], chunk_overlap=chroma_conf["chunk_overlap"], separators=chroma_conf["separators"], length_function=len, ) def get_retriever(self): return self.vector_store.as_retriever( search_kwargs={"k": chroma_conf["k"]} ) def load_document(self): """md5, 存储到vector store""" def check_md5_hex(md5_for_check: str): md5_path = get_abs_path(chroma_conf["md5_hex_store"]) if not os.path.exists(md5_path): open(chroma_conf["md5_hex_store"], "w", encoding="utf-8").close() return False with open(md5_path, "r", encoding="utf-8") as f: for line in f.readlines(): line = line.strip() if line == md5_for_check: return True return False def save_md5(md5_for_save): with open(get_abs_path(chroma_conf["md5_hex_store"]), 'a', encoding='utf-8') as f: f.write(md5_for_save + '\n') def get_file_document(read_path: str): if read_path.endswith("txt"): return txt_loader(read_path) if read_path.endswith("pdf"): return pdf_loader(read_path) return [] allowed_file_path = listdir_with_allowed_type( chroma_conf["data_path"], tuple(chroma_conf["allow_knowledge_file_type"]), ) for path in allowed_file_path: md5_hex = get_file_md5_hex(path) if check_md5_hex(md5_hex): logger.info(f"[加载知识库]: {path} 的 md5已存在, 跳过") continue try: document: list[Documents] = get_file_document(path) if not document: logger.warning(f"[加载知识库]{path}, 没有知识库,跳过") split_doc: list[Documents] = self.spliter.split_documents(document) if not split_doc: logger.warning(f"[加载知识库]{path}分片后无内容,跳过") continue self.vector_store.add_documents(split_doc) # 记录md5值避免重复加载 save_md5(md5_hex) logger.info(f"[加载知识库]{path} 加载成功") except Exception as e: logger.error(f"[加载知识库]{path} 失败,错误 {str(e)}", exc_info=True) if __name__ == '__main__': vs = VectorStoreSerivce() vs.load_document() retriever = vs.get_retriever() res = retriever.invoke("迷路") for r in res: print(r.page_content) print("-" * 20)