agent_proj/rag/vector_store.py

105 lines
3.5 KiB
Python

import os.path
from langchain_chroma import Chroma
from utils.config_handler import chroma_conf
from model.factory import embedding_model
from langchain_text_splitters import RecursiveCharacterTextSplitter
from utils.config_handler import chroma_conf
import os
from utils.path_tool import get_abs_path
from utils.file_handler import pdf_loader, txt_loader, listdir_with_allowed_type, get_file_md5_hex
from utils.logger_handler import logger
from langchain_community.docstore.document import Document as Documents
class VectorStoreSerivce:
def __init__(self):
self.vector_store = Chroma(
collection_name=chroma_conf["collection_name"],
embedding_function=embedding_model,
persist_directory=get_abs_path(chroma_conf["persist_directory"]),
)
self.spliter = RecursiveCharacterTextSplitter(
chunk_size=chroma_conf["chunk_size"],
chunk_overlap=chroma_conf["chunk_overlap"],
separators=chroma_conf["separators"],
length_function=len,
)
def get_retriever(self):
return self.vector_store.as_retriever(
search_kwargs={"k": chroma_conf["k"]}
)
def load_document(self):
"""md5, 存储到vector store"""
def check_md5_hex(md5_for_check: str):
md5_path = get_abs_path(chroma_conf["md5_hex_store"])
if not os.path.exists(md5_path):
open(chroma_conf["md5_hex_store"], "w", encoding="utf-8").close()
return False
with open(md5_path, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip()
if line == md5_for_check:
return True
return False
def save_md5(md5_for_save):
with open(get_abs_path(chroma_conf["md5_hex_store"]), 'a', encoding='utf-8') as f:
f.write(md5_for_save + '\n')
def get_file_document(read_path: str):
if read_path.endswith("txt"):
return txt_loader(read_path)
if read_path.endswith("pdf"):
return pdf_loader(read_path)
return []
allowed_file_path = listdir_with_allowed_type(
chroma_conf["data_path"],
tuple(chroma_conf["allow_knowledge_file_type"]),
)
for path in allowed_file_path:
md5_hex = get_file_md5_hex(path)
if check_md5_hex(md5_hex):
logger.info(f"[加载知识库]: {path} 的 md5已存在, 跳过")
continue
try:
document: list[Documents] = get_file_document(path)
if not document:
logger.warning(f"[加载知识库]{path}, 没有知识库,跳过")
split_doc: list[Documents] = self.spliter.split_documents(document)
if not split_doc:
logger.warning(f"[加载知识库]{path}分片后无内容,跳过")
continue
self.vector_store.add_documents(split_doc)
# 记录md5值避免重复加载
save_md5(md5_hex)
logger.info(f"[加载知识库]{path} 加载成功")
except Exception as e:
logger.error(f"[加载知识库]{path} 失败,错误 {str(e)}", exc_info=True)
if __name__ == '__main__':
vs = VectorStoreSerivce()
vs.load_document()
retriever = vs.get_retriever()
res = retriever.invoke("迷路")
for r in res:
print(r.page_content)
print("-" * 20)