""" 知识库 """ import os import config_data as config import hashlib from datetime import datetime from langchain_chroma import Chroma from langchain_community.embeddings import DashScopeEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter def check_md5(md5_str : str): """检查字符串是否已处理 return False: 未处理, True: 已处理 """ if not os.path.exists(config.md5_path): open(config.md5_path, 'w', encoding='utf-8').close() return False else: with open(config.md5_path, 'r', encoding='utf-8') as f: lines = f.readlines() for line in lines: line = line.strip() if line == md5_str: return True return False def save_md5(md5_str : str): """传入md5记录到文件内保存""" with open(config.md5_path, 'a', encoding='utf-8') as f: f.write(md5_str + '\n') def get_md5(input_str: str, encoding='utf-8'): """传入字符串转为md5, hashlib""" str_bytes = input_str.encode(encoding = encoding) md5_obj = hashlib.md5() md5_obj.update(str_bytes) md5_hex = md5_obj.hexdigest() return md5_hex class KnowledgeBaseService(object): def __init__(self): os.makedirs(config.persist_directory, exist_ok=True) self.chroma = Chroma( collection_name = config.collection_name, embedding_function = DashScopeEmbeddings(model = "text-embedding-v4"), persist_directory = config.persist_directory ) # 向量存储实例 self.spliter = RecursiveCharacterTextSplitter( chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, # 连续段字符重叠 separators=config.separators, # 自然段划分 length_function=len, #长度统计 ) # 文本分割器 def upload_by_str(self, data : str, filename : str): """传入字符串,向量化,存库""" md5_hex = get_md5(data) if check_md5(md5_hex): return "[跳过]内容已存在知识库中" if len(data) > config.max_split_char_number: knowledge_chunks: list[str] = self.spliter.split_text(data) else: knowledge_chunks = [data] metadata = { "source": filename, "create_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "operator": "admin" } self.chroma.add_texts( texts = knowledge_chunks, metadatas = [metadata for _ in range(len(knowledge_chunks))], ) save_md5(md5_hex) return "成功存取" if __name__ == '__main__': service = KnowledgeBaseService() res = service.upload_by_str("周杰伦2", "testfile") print(res)