RAG_proj/knowledge_base.py

91 lines
2.8 KiB
Python

"""
知识库
"""
import os
import config_data as config
import hashlib
from datetime import datetime
from langchain_chroma import Chroma
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
def check_md5(md5_str : str):
"""检查字符串是否已处理
return False: 未处理, True: 已处理
"""
if not os.path.exists(config.md5_path):
open(config.md5_path, 'w', encoding='utf-8').close()
return False
else:
with open(config.md5_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line == md5_str:
return True
return False
def save_md5(md5_str : str):
"""传入md5记录到文件内保存"""
with open(config.md5_path, 'a', encoding='utf-8') as f:
f.write(md5_str + '\n')
def get_md5(input_str: str, encoding='utf-8'):
"""传入字符串转为md5, hashlib"""
str_bytes = input_str.encode(encoding = encoding)
md5_obj = hashlib.md5()
md5_obj.update(str_bytes)
md5_hex = md5_obj.hexdigest()
return md5_hex
class KnowledgeBaseService(object):
def __init__(self):
os.makedirs(config.persist_directory, exist_ok=True)
self.chroma = Chroma(
collection_name = config.collection_name,
embedding_function = DashScopeEmbeddings(model = "text-embedding-v4"),
persist_directory = config.persist_directory
) # 向量存储实例
self.spliter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap, # 连续段字符重叠
separators=config.separators, # 自然段划分
length_function=len, #长度统计
) # 文本分割器
def upload_by_str(self, data : str, filename : str):
"""传入字符串,向量化,存库"""
md5_hex = get_md5(data)
if check_md5(md5_hex):
return "[跳过]内容已存在知识库中"
if len(data) > config.max_split_char_number:
knowledge_chunks: list[str] = self.spliter.split_text(data)
else:
knowledge_chunks = [data]
metadata = {
"source": filename,
"create_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"operator": "admin"
}
self.chroma.add_texts(
texts = knowledge_chunks,
metadatas = [metadata for _ in range(len(knowledge_chunks))],
)
save_md5(md5_hex)
return "成功存取"
if __name__ == '__main__':
service = KnowledgeBaseService()
res = service.upload_by_str("周杰伦2", "testfile")
print(res)