105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
import os.path
|
|
|
|
from langchain_chroma import Chroma
|
|
from utils.config_handler import chroma_conf
|
|
from model.factory import embedding_model
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from utils.config_handler import chroma_conf
|
|
import os
|
|
from utils.path_tool import get_abs_path
|
|
from utils.file_handler import pdf_loader, txt_loader, listdir_with_allowed_type, get_file_md5_hex
|
|
from utils.logger_handler import logger
|
|
from langchain_community.docstore.document import Document as Documents
|
|
|
|
class VectorStoreSerivce:
|
|
def __init__(self):
|
|
self.vector_store = Chroma(
|
|
collection_name=chroma_conf["collection_name"],
|
|
embedding_function=embedding_model,
|
|
persist_directory=get_abs_path(chroma_conf["persist_directory"]),
|
|
)
|
|
self.spliter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chroma_conf["chunk_size"],
|
|
chunk_overlap=chroma_conf["chunk_overlap"],
|
|
separators=chroma_conf["separators"],
|
|
length_function=len,
|
|
)
|
|
|
|
def get_retriever(self):
|
|
return self.vector_store.as_retriever(
|
|
search_kwargs={"k": chroma_conf["k"]}
|
|
)
|
|
|
|
def load_document(self):
|
|
"""md5, 存储到vector store"""
|
|
|
|
def check_md5_hex(md5_for_check: str):
|
|
md5_path = get_abs_path(chroma_conf["md5_hex_store"])
|
|
if not os.path.exists(md5_path):
|
|
open(chroma_conf["md5_hex_store"], "w", encoding="utf-8").close()
|
|
return False
|
|
|
|
with open(md5_path, "r", encoding="utf-8") as f:
|
|
for line in f.readlines():
|
|
line = line.strip()
|
|
if line == md5_for_check:
|
|
return True
|
|
return False
|
|
|
|
def save_md5(md5_for_save):
|
|
with open(get_abs_path(chroma_conf["md5_hex_store"]), 'a', encoding='utf-8') as f:
|
|
f.write(md5_for_save + '\n')
|
|
|
|
def get_file_document(read_path: str):
|
|
if read_path.endswith("txt"):
|
|
return txt_loader(read_path)
|
|
|
|
if read_path.endswith("pdf"):
|
|
return pdf_loader(read_path)
|
|
|
|
return []
|
|
|
|
allowed_file_path = listdir_with_allowed_type(
|
|
chroma_conf["data_path"],
|
|
tuple(chroma_conf["allow_knowledge_file_type"]),
|
|
)
|
|
|
|
for path in allowed_file_path:
|
|
md5_hex = get_file_md5_hex(path)
|
|
if check_md5_hex(md5_hex):
|
|
logger.info(f"[加载知识库]: {path} 的 md5已存在, 跳过")
|
|
continue
|
|
|
|
try:
|
|
document: list[Documents] = get_file_document(path)
|
|
if not document:
|
|
logger.warning(f"[加载知识库]{path}, 没有知识库,跳过")
|
|
|
|
split_doc: list[Documents] = self.spliter.split_documents(document)
|
|
|
|
if not split_doc:
|
|
logger.warning(f"[加载知识库]{path}分片后无内容,跳过")
|
|
continue
|
|
|
|
self.vector_store.add_documents(split_doc)
|
|
|
|
# 记录md5值避免重复加载
|
|
save_md5(md5_hex)
|
|
logger.info(f"[加载知识库]{path} 加载成功")
|
|
except Exception as e:
|
|
logger.error(f"[加载知识库]{path} 失败,错误 {str(e)}", exc_info=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
vs = VectorStoreSerivce()
|
|
|
|
vs.load_document()
|
|
|
|
retriever = vs.get_retriever()
|
|
|
|
res = retriever.invoke("迷路")
|
|
|
|
for r in res:
|
|
print(r.page_content)
|
|
print("-" * 20)
|