Langchain Splitter源码阅读笔记(一)CharacterTextSplitter
一、TextSplitter
TextSplitter继承自BaseDocumentTransformer,是一个抽象类,不能直接创建实例。
核心(内部)属性有:
_chunk_size: 每块大小
_chunk_overlap: 每块之间的重叠区大小
_length_function: 计算大小的方法,可以传递token计算的函数,也可以传别的比如普通的len()
_keep_separator: Boolean 分块后是否保留分割符
_add_start_index: Boolean 是否在分割后返回的文档元数据中保存每块第一个字符在原始文档中的index
_strip_whitespace: Boolean 分割后是否去掉前后的空格
核心方法:
split_text(self, text: str) -> List(str)
分割方法,抽象方法,要在具体的子类中根据分割算法实现。
create_documents(self, texts: list[str], metadatas: list[dict]) -> list[Document]
传入文本和可选的元数据信息,返回将文本调用split_text分割后,创建的Document格式数据,doc.page_content是文本,metadata是创建的元数据,根据是否_add_start_index自动保存index
split_documents(self, documents: Iterable[Document]) -> list[Document]
将传入的document列表分割,返回分割后的document列表,内部就是对每个document调用create_documents创建文档,组合返回。
--------以下为内部方法---------
_join_docs(self, docs: list[str], separator: str) -> str
注意这个参数里的docs是字符串列表,就是根据传入的分割符合并字符串列表为一个长字符串,给下面的_merge_splits使用
_merge_splits(self, splits: Iterable[str], separator: str) -> list[str]
把分割得过于细的小块合并成更接近self._chunk_size的块,并确保相邻块之间有self._chunk_overlap大小的重叠内容。
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: 2 # We now want to combine these smaller pieces into medium size 3 # chunks to send to the LLM. 4 separator_len = self._length_function(separator) 5 6 docs = [] 7 current_doc: list[str] = [] 8 total = 0 9 for d in splits: 10 len_ = self._length_function(d) # 默认先在current_doc里面append(d),直到满足下面的if,往docs里面加入值 11 if ( 12 total + len_ + (separator_len if len(current_doc) > 0 else 0) 13 > self._chunk_size 14 ): 15 if total > self._chunk_size: 16 logger.warning( 17 "Created a chunk of size %d, which is longer than the " 18 "specified %d", 19 total, 20 self._chunk_size, 21 ) 22 if len(current_doc) > 0: 23 doc = self._join_docs(current_doc, separator) 24 if doc is not None: 25 docs.append(doc) 26 # Keep on popping if: 27 # - we have a larger chunk than in the chunk overlap 28 # - or if we still have any chunks and the length is long 29 while total > self._chunk_overlap or ( 30 total + len_ + (separator_len if len(current_doc) > 0 else 0) 31 > self._chunk_size 32 and total > 0 33 ): 34 total -= self._length_function(current_doc[0]) + ( 35 separator_len if len(current_doc) > 1 else 0 36 ) 37 current_doc = current_doc[1:] 38 current_doc.append(d) 39 total += len_ + (separator_len if len(current_doc) > 1 else 0) 40 doc = self._join_docs(current_doc, separator) 41 if doc is not None: 42 docs.append(doc) 43 return docs
这个方法的核心是,每当current_doc满足chunk_size时,先把current_chunk里面的字符join后塞进docs,然后,不是直接清空curent_chunk,而是依次从current_chunk头部移除文本单元,直到current_chunk的文本长度小于_chunk_overlap。此时current_chunk里面的文本就是新块的开头,也是两块之间的重叠值。
二、CharacterTextSplitter
这个类继承自上面的TextSplitter,增加了separator属性和is_separator_regex(分割符是否为正则表达式)属性。实现了父类的抽象方法split_text。
这个类里的split_text方法调用了自定义的_split_text_with_regex()方法,对传入的文本text进行分割。先看代码:
1 # CharactorTextSplitter类内部 2 def split_text(self, text: str) -> list[str]: 3 """Split into chunks without re-inserting lookaround separators.""" 4 # 1. Determine split pattern: raw regex or escaped literal 5 sep_pattern = ( 6 self._separator if self._is_separator_regex else re.escape(self._separator) 7 ) 8 9 # 2. Initial split (keep separator if requested) 10 splits = _split_text_with_regex( 11 text, sep_pattern, keep_separator=self._keep_separator 12 ) 13 14 # 3. Detect zero-width lookaround so we never re-insert it 15 lookaround_prefixes = ("(?=", "(?<!", "(?<=", "(?!") 16 is_lookaround = self._is_separator_regex and any( 17 self._separator.startswith(p) for p in lookaround_prefixes 18 ) 19 20 # 4. Decide merge separator: 21 # - if keep_separator or lookaround -> don't re-insert 22 # - else -> re-insert literal separator 23 merge_sep = "" 24 if not (self._keep_separator or is_lookaround): 25 merge_sep = self._separator 26 27 # 5. Merge adjacent splits and return 28 return self._merge_splits(splits, merge_sep) 29 30 # 外部方法 31 def _split_text_with_regex( 32 text: str, separator: str, *, keep_separator: bool | Literal["start", "end"] 33 ) -> list[str]: 34 # Now that we have the separator, split the text 35 if separator: 36 if keep_separator: 37 # The parentheses in the pattern keep the delimiters in the result. 38 splits_ = re.split(f"({separator})", text) 39 splits = ( 40 ([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)]) 41 if keep_separator == "end" 42 else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)]) 43 ) 44 if len(splits_) % 2 == 0: 45 splits += splits_[-1:] 46 splits = ( 47 ([*splits, splits_[-1]]) 48 if keep_separator == "end" 49 else ([splits_[0], *splits]) 50 ) 51 else: 52 splits = re.split(separator, text) 53 else: 54 splits = list(text) 55 return [s for s in splits if s]
如果不考虑保留分割符,其实这个方法很简单,就是使用re.split将传入text用分割符分开后,再调用父类实现的_merge_splits()拼接成合适大小的块,返回list[str]。
1. 分割前处理
如果传入的分割符是一个字符串,调用re.split前,需要将字符串转义一下,防止有不合法的字符。
# 1. Determine split pattern: raw regex or escaped literal sep_pattern = ( self._separator if self._is_separator_regex else re.escape(self._separator) )
1 splits = ( 2 ([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)]) 3 if keep_separator == "end" 4 else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)]) 5 ) 6 if len(splits_) % 2 == 0: 7 splits += splits_[-1:] 8 splits = ( 9 ([*splits, splits_[-1]]) 10 if keep_separator == "end" 11 else ([splits_[0], *splits]) 12 )
1 merge_sep = "" 2 if not (self._keep_separator or is_lookaround): 3 merge_sep = self._separator
浙公网安备 33010602011771号