NumPy memmap 内存映射 使用详解

NumPy memmap 内存映射 使用详解

一、memmap 核心概念

1.1 什么是内存映射文件?

内存映射文件将磁盘文件直接映射到进程的虚拟内存空间,使得文件可以像内存数组一样被访问。

1.2 工作原理

磁盘文件 ↔ 虚拟内存 ↔ 程序访问
----↑-----------↑
--物理磁盘---物理内存(按需加载)

二、创建和初始化

2.1 基本创建方法

import numpy as np
import os

# 方法1:创建新文件
shape = (1000, 1000)
dtype = np.float32

# 'w+' 模式:创建新文件或覆盖已有文件
mmap = np.memmap('data.dat', dtype=dtype, mode='w+', shape=shape)

# 初始化数据
mmap[:] = np.random.randn(*shape)
mmap.flush()  # 确保数据写入磁盘

2.2 从已有文件创建

# 方法2:从已有文件创建(必须知道shape和dtype)
# 假设已有文件 'existing.dat',形状为 (500, 500),类型为 float32
mmap = np.memmap('existing.dat', dtype=np.float32, mode='r+', shape=(500, 500))

# 方法3:自动推断文件大小(需要知道dtype)
file_size = os.path.getsize('existing.dat')
itemsize = np.dtype(np.float32).itemsize
shape = (file_size // itemsize,)  # 一维数组
mmap = np.memmap('existing.dat', dtype=np.float32, mode='r', shape=shape)

2.3 使用 offset 参数

# 在文件中的特定位置开始映射
# 跳过前 1024 字节(可能是文件头)
mmap = np.memmap('data.bin', 
                 dtype=np.float64,
                 mode='r+',
                 offset=1024,  # 字节偏移
                 shape=(100, 100))

# 计算偏移量示例
header_size = 128  # 字节
data_shape = (1000, 100)
element_size = np.dtype(np.float32).itemsize  # 4字节

mmap = np.memmap('file_with_header.bin',
                 dtype=np.float32,
                 mode='r+',
                 offset=header_size,
                 shape=data_shape)

三、访问模式详解

3.1 四种模式对比

# 测试不同模式
filename = 'test.dat'
shape = (10, 10)
dtype = np.float32

# 1. 只读模式 'r'
if os.path.exists(filename):
    mmap_r = np.memmap(filename, dtype=dtype, mode='r', shape=shape)
    # 可以读取
    value = mmap_r[0, 0]
    # mmap_r[0, 0] = 1.0  # 错误!不能修改

# 2. 读写模式 'r+'
mmap_rplus = np.memmap(filename, dtype=dtype, mode='r+', shape=shape)
mmap_rplus[0, 0] = 42.0  # 可以修改
value = mmap_rplus[0, 0]  # 可以读取

# 3. 写读模式 'w+'
# 注意:这会清空文件!
mmap_wplus = np.memmap(filename, dtype=dtype, mode='w+', shape=shape)
mmap_wplus[:] = 0  # 文件被清空后重新初始化

# 4. 拷贝写模式 'c'
mmap_c = np.memmap(filename, dtype=dtype, mode='c', shape=shape)
mmap_c[0, 0] = 100  # 修改只在内存中,不立即写入磁盘
mmap_c.flush()  # 显式写入磁盘

3.2 模式选择指南

def create_or_load_memmap(filename, shape, dtype, force_create=False):
    """
    智能创建或加载 memmap
    """
    if force_create or not os.path.exists(filename):
        print(f"创建新文件: {filename}")
        mmap = np.memmap(filename, dtype=dtype, mode='w+', shape=shape)
        # 初始化
        mmap[:] = 0
        mmap.flush()
        return mmap
    else:
        # 检查文件大小是否匹配
        expected_size = np.prod(shape) * np.dtype(dtype).itemsize
        actual_size = os.path.getsize(filename)
        
        if actual_size == expected_size:
            print(f"加载现有文件: {filename}")
            return np.memmap(filename, dtype=dtype, mode='r+', shape=shape)
        else:
            print(f"文件大小不匹配,重新创建")
            return np.memmap(filename, dtype=dtype, mode='w+', shape=shape)

四、高效访问模式

4.1 顺序访问 vs 随机访问

# 创建测试数据
shape = (10000, 1000)
mmap = np.memmap('large_data.dat', dtype=np.float32, mode='w+', shape=shape)
mmap[:] = np.random.randn(*shape)
mmap.flush()

# 1. 顺序访问(高效)
def sequential_access(mmap):
    """按行顺序访问"""
    results = []
    for i in range(mmap.shape[0]):
        row_sum = mmap[i].sum()  # 一次读取整行
        results.append(row_sum)
    return results

# 2. 分块访问(更高效)
def chunked_access(mmap, chunk_size=1000):
    """分块访问"""
    results = []
    n_rows = mmap.shape[0]
    
    for start in range(0, n_rows, chunk_size):
        end = min(start + chunk_size, n_rows)
        chunk = mmap[start:end]  # 一次读取一个块
        chunk_results = chunk.sum(axis=1)
        results.extend(chunk_results)
    
    return results

# 3. 随机访问(低效,避免使用)
def random_access(mmap, indices):
    """随机访问特定行"""
    results = []
    for idx in indices:
        row = mmap[idx]  # 每次访问都可能触发磁盘I/O
        results.append(row.sum())
    return results

4.2 使用内存视图

# 创建内存视图,避免重复创建 memmap 对象
class MemmapManager:
    def __init__(self, filename, shape, dtype):
        self.filename = filename
        self.shape = shape
        self.dtype = dtype
        self._mmap = None
    
    @property
    def mmap(self):
        if self._mmap is None:
            self._mmap = np.memmap(self.filename, 
                                  dtype=self.dtype,
                                  mode='r+',
                                  shape=self.shape)
        return self._mmap
    
    def close(self):
        if self._mmap is not None:
            # 确保数据写入磁盘
            self._mmap.flush()
            if hasattr(self._mmap, '_mmap'):
                self._mmap._mmap.close()
            self._mmap = None
    
    def __enter__(self):
        return self.mmap
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

# 使用示例
with MemmapManager('data.dat', (1000, 1000), np.float32) as mmap:
    # 在 with 块内使用 mmap
    result = mmap[100:200].mean()

五、多进程共享

5.1 基础多进程共享

import numpy as np
from multiprocessing import Pool
import os

def worker_process(args):
    """工作进程函数"""
    pid, filename, shape, dtype, start_row, end_row = args
    
    # 每个进程创建自己的 memmap 视图
    mmap = np.memmap(filename, dtype=dtype, mode='r+', shape=shape)
    
    results = []
    for i in range(start_row, end_row):
        # 处理数据
        row_mean = mmap[i].mean()
        results.append((i, row_mean))
    
    # 不需要显式关闭,Python 会自动处理
    return results

def parallel_process(filename, shape, dtype, n_processes=4):
    """并行处理 memmap 文件"""
    n_rows = shape[0]
    chunk_size = n_rows // n_processes
    
    # 准备参数
    args_list = []
    for i in range(n_processes):
        start = i * chunk_size
        end = start + chunk_size if i < n_processes - 1 else n_rows
        args_list.append((i, filename, shape, dtype, start, end))
    
    # 并行处理
    with Pool(processes=n_processes) as pool:
        all_results = pool.map(worker_process, args_list)
    
    # 合并结果
    final_results = []
    for results in all_results:
        final_results.extend(results)
    
    return final_results

5.2 使用共享内存(更高效)

from multiprocessing import shared_memory
import numpy as np

class SharedMemmap:
    """使用共享内存的 memmap 包装器"""
    def __init__(self, filename, shape, dtype):
        self.filename = filename
        self.shape = shape
        self.dtype = dtype
        
        # 创建或附加共享内存
        self.shm = shared_memory.SharedMemory(
            name='memmap_shared',
            create=True,
            size=np.prod(shape) * np.dtype(dtype).itemsize
        )
        
        # 创建 numpy 数组视图
        self.array = np.ndarray(
            shape, 
            dtype=dtype, 
            buffer=self.shm.buf
        )
        
        # 从文件加载数据
        if os.path.exists(filename):
            mmap = np.memmap(filename, dtype=dtype, mode='r', shape=shape)
            self.array[:] = mmap[:]
    
    def save(self):
        """保存到文件"""
        mmap = np.memmap(self.filename, 
                         dtype=self.dtype, 
                         mode='w+', 
                         shape=self.shape)
        mmap[:] = self.array[:]
        mmap.flush()
    
    def close(self):
        """清理资源"""
        self.save()
        self.shm.close()
        self.shm.unlink()

六、实际应用案例

6.1 大型矩阵运算

class LargeMatrix:
    """处理超大矩阵的类"""
    def __init__(self, filename, shape, dtype=np.float64):
        self.filename = filename
        self.shape = shape
        self.dtype = dtype
        
        # 检查或创建文件
        self._init_file()
        
    def _init_file(self):
        """初始化文件"""
        expected_size = np.prod(self.shape) * np.dtype(self.dtype).itemsize
        
        if not os.path.exists(self.filename):
            # 创建新文件
            mmap = np.memmap(self.filename, 
                           dtype=self.dtype, 
                           mode='w+', 
                           shape=self.shape)
            mmap[:] = 0
            mmap.flush()
        else:
            # 检查文件大小
            actual_size = os.path.getsize(self.filename)
            if actual_size != expected_size:
                raise ValueError(f"文件大小不匹配: {actual_size} != {expected_size}")
    
    def matrix_multiply(self, other, chunk_size=1000):
        """矩阵乘法(分块实现)"""
        if self.shape[1] != other.shape[0]:
            raise ValueError("矩阵维度不匹配")
        
        result_shape = (self.shape[0], other.shape[1])
        result_file = 'result.dat'
        
        # 创建结果文件
        result_mmap = np.memmap(result_file, 
                               dtype=self.dtype,
                               mode='w+',
                               shape=result_shape)
        
        # 分块乘法
        for i in range(0, self.shape[0], chunk_size):
            i_end = min(i + chunk_size, self.shape[0])
            
            # 读取 A 的一个块
            A_chunk = np.memmap(self.filename,
                               dtype=self.dtype,
                               mode='r',
                               shape=self.shape)[i:i_end]
            
            for j in range(0, other.shape[1], chunk_size):
                j_end = min(j + chunk_size, other.shape[1])
                
                # 读取 B 的一个块
                B_chunk = np.memmap(other.filename,
                                   dtype=self.dtype,
                                   mode='r',
                                   shape=other.shape)[:, j:j_end]
                
                # 计算并写入结果
                result_mmap[i:i_end, j:j_end] = A_chunk @ B_chunk
        
        result_mmap.flush()
        return LargeMatrix(result_file, result_shape, self.dtype)

6.2 时间序列数据存储

class TimeSeriesStorage:
    """时间序列数据存储"""
    def __init__(self, filename, max_points, n_features, dtype=np.float32):
        self.filename = filename
        self.max_points = max_points
        self.n_features = n_features
        self.dtype = dtype
        self.shape = (max_points, n_features)
        
        self.current_idx = 0
        self.is_full = False
        
        # 初始化存储
        self._init_storage()
    
    def _init_storage(self):
        """初始化存储文件"""
        if not os.path.exists(self.filename):
            self.mmap = np.memmap(self.filename,
                                 dtype=self.dtype,
                                 mode='w+',
                                 shape=self.shape)
            self.mmap[:] = np.nan
        else:
            self.mmap = np.memmap(self.filename,
                                 dtype=self.dtype,
                                 mode='r+',
                                 shape=self.shape)
            
            # 查找最后一个有效数据点
            for i in range(self.max_points):
                if np.isnan(self.mmap[i, 0]):
                    self.current_idx = i
                    break
            else:
                self.current_idx = self.max_points
                self.is_full = True
    
    def append(self, data):
        """添加新数据"""
        n_points = data.shape[0]
        
        if self.current_idx + n_points > self.max_points:
            # 循环缓冲区:覆盖旧数据
            remaining = self.max_points - self.current_idx
            self.mmap[self.current_idx:] = data[:remaining]
            self.mmap[:n_points - remaining] = data[remaining:]
            self.current_idx = n_points - remaining
            self.is_full = True
        else:
            self.mmap[self.current_idx:self.current_idx + n_points] = data
            self.current_idx += n_points
        
        self.mmap.flush()
    
    def get_recent(self, n_points):
        """获取最近的数据"""
        if not self.is_full and self.current_idx < n_points:
            return self.mmap[:self.current_idx].copy()
        
        start_idx = self.current_idx - n_points
        if start_idx < 0:
            start_idx += self.max_points
            return np.vstack([
                self.mmap[start_idx:],
                self.mmap[:self.current_idx]
            ])
        else:
            return self.mmap[start_idx:self.current_idx].copy()

七、性能优化技巧

7.1 预分配空间

def create_preallocated_memmap(filename, shape, dtype, fill_value=0):
    """预分配空间的 memmap"""
    # 计算总大小
    total_size = np.prod(shape) * np.dtype(dtype).itemsize
    
    # 预分配文件空间(快速创建大文件)
    with open(filename, 'wb') as f:
        f.seek(total_size - 1)
        f.write(b'\x00')
    
    # 创建 memmap
    mmap = np.memmap(filename, dtype=dtype, mode='r+', shape=shape)
    
    # 如果需要,填充初始值
    if fill_value != 0:
        chunk_size = 10000
        total_elements = np.prod(shape)
        
        for start in range(0, total_elements, chunk_size):
            end = min(start + chunk_size, total_elements)
            mmap.flat[start:end] = fill_value
    
    mmap.flush()
    return mmap

7.2 使用内存缓存

from functools import lru_cache

class CachedMemmap:
    """带缓存的 memmap"""
    def __init__(self, filename, shape, dtype):
        self.filename = filename
        self.shape = shape
        self.dtype = dtype
        self.cache = {}
        self.cache_size = 1000  # 缓存1000行
    
    @lru_cache(maxsize=1000)
    def get_row(self, row_idx):
        """获取行(带缓存)"""
        if row_idx in self.cache:
            return self.cache[row_idx]
        
        # 从磁盘读取
        mmap = np.memmap(self.filename, 
                        dtype=self.dtype,
                        mode='r',
                        shape=self.shape)
        row_data = mmap[row_idx].copy()
        
        # 更新缓存
        if len(self.cache) >= self.cache_size:
            # 移除最旧的条目
            oldest_key = next(iter(self.cache))
            del self.cache[oldest_key]
        
        self.cache[row_idx] = row_data
        return row_data
    
    def invalidate_cache(self):
        """清空缓存"""
        self.cache.clear()
        self.get_row.cache_clear()

八、常见问题解决

8.1 文件大小不匹配

def safe_memmap_load(filename, expected_shape, dtype):
    """安全加载 memmap,处理大小不匹配"""
    if not os.path.exists(filename):
        raise FileNotFoundError(f"文件不存在: {filename}")
    
    expected_size = np.prod(expected_shape) * np.dtype(dtype).itemsize
    actual_size = os.path.getsize(filename)
    
    if actual_size < expected_size:
        # 文件太小,扩展文件
        print(f"扩展文件大小: {actual_size}")
        with open(filename, 'ab') as f:
            f.write(b'\x00' * (expected_size - actual_size))
    elif actual_size > expected_size:
        # 文件太大,可以截断或警告
        print(f"警告:文件大小({actual_size})大于预期({expected_size})")
        # 可以选择截断文件
        # with open(filename, 'r+b') as f:
        #     f.truncate(expected_size)
    
    return np.memmap(filename, dtype=dtype, mode='r+', shape=expected_shape)
		

8.2 内存不足处理

def process_large_file_in_chunks(filename, chunk_size_mb=100):
    """分块处理超大文件"""
    dtype = np.float64
    itemsize = np.dtype(dtype).itemsize  # 8字节
    
    # 获取文件总大小
    total_size = os.path.getsize(filename)
    total_elements = total_size // itemsize
    
    # 计算合适的块大小(元素个数)
    chunk_elements = (chunk_size_mb * 1024 * 1024) // itemsize
    
    results = []
    for start in range(0, total_elements, chunk_elements):
        end = min(start + chunk_elements, total_elements)
        shape = (end - start,)
        
        # 只映射当前块
        mmap = np.memmap(filename, 
                        dtype=dtype,
                        mode='r',
                        offset=start * itemsize,  # 关键:偏移到正确位置
                        shape=shape)
        
        # 处理当前块
        chunk_result = process_chunk(mmap)
        results.append(chunk_result)
        
        # 显式清理
        del mmap
    
    return combine_results(results)

8.3 数据类型转换

def convert_memmap_dtype(input_file, output_file, 
                         input_dtype, output_dtype, 
                         shape):
    """转换 memmap 的数据类型"""
    # 输入文件(只读)
    input_mmap = np.memmap(input_file, 
                          dtype=input_dtype,
                          mode='r',
                          shape=shape)
    
    # 输出文件(新建)
    output_mmap = np.memmap(output_file,
                           dtype=output_dtype,
                           mode='w+',
                           shape=shape)
    
    # 分块转换
    chunk_size = 10000
    for i in range(0, shape[0], chunk_size):
        end = min(i + chunk_size, shape[0])
        chunk = input_mmap[i:end].astype(output_dtype)
        output_mmap[i:end] = chunk
    
    output_mmap.flush()
    return output_file

九、高级应用场景

9.1 数据库式索引访问

class IndexedMemmap:
    """带索引的 memmap,支持快速查找"""
    def __init__(self, data_file, index_file, dtype=np.float32):
        self.data_file = data_file
        self.index_file = index_file
        self.dtype = dtype
        
        # 加载索引(假设索引是 (n_samples, 2) 的数组,存储起始位置和长度)
        self.index = np.load(index_file) if os.path.exists(index_file) else None
        self.data_mmap = None
    
    def build_index(self, data_shape, record_sizes):
        """构建索引"""
        n_records = len(record_sizes)
        self.index = np.zeros((n_records, 2), dtype=np.int64)
        
        offset = 0
        for i, size in enumerate(record_sizes):
            self.index[i, 0] = offset  # 起始位置
            self.index[i, 1] = size    # 记录长度
            offset += size
        
        # 保存索引
        np.save(self.index_file, self.index)
        
        # 创建数据文件
        total_elements = offset
        self.data_mmap = np.memmap(self.data_file,
                                  dtype=self.dtype,
                                  mode='w+',
                                  shape=(total_elements,))
    
    def get_record(self, record_id):
        """获取指定记录"""
        if self.data_mmap is None:
            self.data_mmap = np.memmap(self.data_file,
                                      dtype=self.dtype,
                                      mode='r',
                                      shape=(self.index[-1, 0] + self.index[-1, 1],))
        
        start, length = self.index[record_id]
        return self.data_mmap[start:start + length].copy()
    
    def update_record(self, record_id, data):
        """更新记录"""
        if self.data_mmap is None:
            self.data_mmap = np.memmap(self.data_file,
                                      dtype=self.dtype,
                                      mode='r+',
                                      shape=(self.index[-1, 0] + self.index[-1, 1],))
        
        start, length = self.index[record_id]
        if len(data) != length:
            raise ValueError(f"数据长度不匹配: {len(data)} != {length}")
        
        self.data_mmap[start:start + length] = data
        self.data_mmap.flush()

9.2 实时数据流处理

class StreamingMemmapWriter:
    """流式数据写入 memmap"""
    def __init__(self, filename, max_samples, sample_shape, dtype=np.float32):
        self.filename = filename
        self.max_samples = max_samples
        self.sample_shape = sample_shape
        self.dtype = dtype
        
        # 总形状
        self.total_shape = (max_samples,) + sample_shape
        
        # 预分配文件
        self._preallocate()
        
        # 当前写入位置
        self.current_pos = 0
        self.is_full = False
    
    def _preallocate(self):
        """预分配文件空间"""
        total_size = np.prod(self.total_shape) * np.dtype(self.dtype).itemsize
        
        if not os.path.exists(self.filename):
            with open(self.filename, 'wb') as f:
                f.seek(total_size - 1)
                f.write(b'\x00')
        
        # 创建 memmap
        self.mmap = np.memmap(self.filename,
                             dtype=self.dtype,
                             mode='r+',
                             shape=self.total_shape)
    
    def write(self, data):
        """写入数据"""
        n_samples = data.shape[0]
        
        if self.current_pos + n_samples > self.max_samples:
            # 循环写入:覆盖旧数据
            remaining = self.max_samples - self.current_pos
            self.mmap[self.current_pos:] = data[:remaining]
            self.mmap[:n_samples - remaining] = data[remaining:]
            self.current_pos = n_samples - remaining
            self.is_full = True
        else:
            self.mmap[self.current_pos:self.current_pos + n_samples] = data
            self.current_pos += n_samples
        
        # 异步刷新(在实际应用中可以使用线程)
        self.mmap.flush()
    
    def get_latest(self, n_samples):
        """获取最新的 n 个样本"""
        if not self.is_full and self.current_pos < n_samples:
            return self.mmap[:self.current_pos].copy()
        
        start = self.current_pos - n_samples
        if start < 0:
            start += self.max_samples
            return np.concatenate([
                self.mmap[start:],
                self.mmap[:self.current_pos]
            ], axis=0)
        else:
            return self.mmap[start:self.current_pos].copy()

9.3 内存映射与 GPU 结合

import torch

class GPUMemmapLoader:
    """将 memmap 数据加载到 GPU"""
    def __init__(self, filename, shape, dtype=np.float32, 
                 device='cuda', chunk_size=1024):
        self.filename = filename
        self.shape = shape
        self.dtype = dtype
        self.device = device
        self.chunk_size = chunk_size
        
        # 创建 CPU 端的 memmap
        self.cpu_mmap = np.memmap(filename,
                                 dtype=dtype,
                                 mode='r',
                                 shape=shape)
        
        # 对应的 torch 数据类型
        self.torch_dtype = self._get_torch_dtype(dtype)
    
    def _get_torch_dtype(self, np_dtype):
        """将 numpy dtype 转换为 torch dtype"""
        dtype_map = {
            np.float32: torch.float32,
            np.float64: torch.float64,
            np.int32: torch.int32,
            np.int64: torch.int64,
        }
        return dtype_map.get(np_dtype, torch.float32)
    
    def load_to_gpu(self, indices=None):
        """加载数据到 GPU"""
        if indices is None:
            # 加载全部数据(分块进行)
            return self._load_chunked()
        else:
            # 加载指定索引的数据
            return self._load_indices(indices)
    
    def _load_chunked(self):
        """分块加载到 GPU"""
        n_samples = self.shape[0]
        gpu_tensors = []
        
        for start in range(0, n_samples, self.chunk_size):
            end = min(start + self.chunk_size, n_samples)
            
            # 读取 CPU 数据
            cpu_chunk = self.cpu_mmap[start:end].copy()
            
            # 转换为 GPU tensor
            gpu_chunk = torch.from_numpy(cpu_chunk).to(
                device=self.device,
                dtype=self.torch_dtype
            )
            gpu_tensors.append(gpu_chunk)
        
        return torch.cat(gpu_tensors, dim=0)
    
    def _load_indices(self, indices):
        """加载指定索引到 GPU"""
        # 收集所有需要的数据
        data_chunks = []
        current_chunk = []
        
        for idx in sorted(indices):
            current_chunk.append(idx)
            
            if len(current_chunk) >= self.chunk_size:
                # 读取一个块
                chunk_indices = np.array(current_chunk)
                cpu_data = self.cpu_mmap[chunk_indices].copy()
                
                # 转换到 GPU
                gpu_data = torch.from_numpy(cpu_data).to(
                    device=self.device,
                    dtype=self.torch_dtype
                )
                data_chunks.append(gpu_data)
                current_chunk = []
        
        # 处理剩余的
        if current_chunk:
            chunk_indices = np.array(current_chunk)
            cpu_data = self.cpu_mmap[chunk_indices].copy()
            gpu_data = torch.from_numpy(cpu_data).to(
                device=self.device,
                dtype=self.torch_dtype
            )
            data_chunks.append(gpu_data)
        
        return torch.cat(data_chunks, dim=0) if data_chunks else None

十、最佳实践总结

10.1 性能优化清单

class OptimizedMemmap:
    """优化版的 memmap 使用"""
    
    @staticmethod
    def create_optimized(filename, shape, dtype, 
                        order='C',  # 行优先,适合顺序访问
                        preallocate=True):
        """创建优化的 memmap"""
        
        # 1. 选择合适的数据类型
        if dtype is None:
            dtype = np.float32  # 默认使用 float32,节省空间
        
        # 2. 预分配空间
        if preallocate:
            total_size = np.prod(shape) * np.dtype(dtype).itemsize
            if not os.path.exists(filename):
                with open(filename, 'wb') as f:
                    f.seek(total_size - 1)
                    f.write(b'\x00')
        
        # 3. 创建 memmap
        mmap = np.memmap(filename,
                        dtype=dtype,
                        mode='w+' if preallocate else 'r+',
                        shape=shape,
                        order=order)
        
        return mmap
    
    @staticmethod
    def efficient_access(mmap, access_pattern='sequential', 
                        chunk_size=None):
        """高效访问策略"""
        
        if chunk_size is None:
            # 根据可用内存自动计算块大小
            import psutil
            available_memory = psutil.virtual_memory().available
            element_size = mmap.dtype.itemsize
            chunk_elements = (available_memory // 4) // element_size  # 使用1/4可用内存
            chunk_size = max(1, chunk_elements // mmap.shape[0])
        
        if access_pattern == 'sequential':
            # 顺序访问
            for i in range(0, mmap.shape[0], chunk_size):
                chunk = mmap[i:i+chunk_size]
                yield chunk
        
        elif access_pattern == 'strided':
            # 跨步访问(适合卷积等操作)
            stride = 2  # 示例跨步
            for i in range(0, mmap.shape[0] - chunk_size + 1, stride):
                chunk = mmap[i:i+chunk_size]
                yield chunk
    
    @staticmethod
    def memory_usage_info(mmap):
        """获取内存使用信息"""
        import psutil
        import os
        
        process = psutil.Process(os.getpid())
        memory_info = process.memory_info()
        
        print(f"进程内存使用: {memory_info.rss / 1024**2:.2f} MB")
        print(f"文件大小: {os.path.getsize(mmap.filename) / 1024**2:.2f} MB")
        print(f"数组形状: {mmap.shape}")
        print(f"数据类型: {mmap.dtype}")
        print(f"总元素数: {np.prod(mmap.shape):,}")
        print(f"理论内存占用: {np.prod(mmap.shape) * mmap.dtype.itemsize / 1024**2:.2f} MB")

10.2 错误处理模板

def safe_memmap_operation(func):
    """memmap 操作的安全装饰器"""
    def wrapper(*args, **kwargs):
        try:
            result = func(*args, **kwargs)
            return result
        except (ValueError, OSError, MemoryError) as e:
            print(f"memmap 操作失败: {e}")
            
            # 尝试清理资源
            for arg in args:
                if isinstance(arg, np.memmap):
                    try:
                        arg.flush()
                        if hasattr(arg, '_mmap'):
                            arg._mmap.close()
                    except:
                        pass
            
            # 根据错误类型采取不同措施
            if isinstance(e, MemoryError):
                print("内存不足,尝试使用更小的块大小")
                # 可以在这里实现降级策略
                return None
            elif isinstance(e, OSError):
                print("文件系统错误,检查磁盘空间和权限")
                return None
            else:
                raise
    
    return wrapper

# 使用示例
@safe_memmap_operation
def process_large_dataset(filename, shape, dtype):
    mmap = np.memmap(filename, dtype=dtype, mode='r', shape=shape)
    # ... 处理逻辑 ...
    return result

10.3 监控和调试


class MemmapMonitor:
    """监控 memmap 使用情况"""
    
    def __init__(self):
        self.operations = []
        self.memory_snapshots = []
    
    def track_operation(self, operation, filename, shape, dtype):
        """跟踪操作"""
        import time
        start_time = time.time()
        
        # 记录内存快照
        self._take_memory_snapshot('before_' + operation)
        
        try:
            # 执行操作
            yield  # 这里使用生成器模式
            
            # 记录成功
            duration = time.time() - start_time
            self.operations.append({
                'operation': operation,
                'filename': filename,
                'shape': shape,
                'dtype': str(dtype),
                'duration': duration,
                'status': 'success'
            })
            
        except Exception as e:
            # 记录失败
            duration = time.time() - start_time
            self.operations.append({
                'operation': operation,
                'filename': filename,
                'shape': shape,
                'dtype': str(dtype),
                'duration': duration,
                'status': 'failed',
                'error': str(e)
            })
            raise
        
        finally:
            # 最终内存快照
            self._take_memory_snapshot('after_' + operation)
    
    def _take_memory_snapshot(self, label):
        """记录内存快照"""
        import psutil
        import os
        
        process = psutil.Process(os.getpid())
        memory_info = process.memory_info()
        
        self.memory_snapshots.append({
            'label': label,
            'timestamp': time.time(),
            'rss_mb': memory_info.rss / 1024**2,
            'vms_mb': memory_info.vms / 1024**2
        })
    
    def generate_report(self):
        """生成使用报告"""
        report = {
            'total_operations': len(self.operations),
            'successful_operations': sum(1 for op in self.operations if op['status'] == 'success'),
            'failed_operations': sum(1 for op in self.operations if op['status'] == 'failed'),
            'average_duration': np.mean([op['duration'] for op in self.operations]),
            'memory_usage_trend': self.memory_snapshots
        }
        return report

总结

memmap 是处理大数据的强大工具,关键要点:

  • 模式选择:根据需求选择合适的模式(r, r+, w+, c)
  • 分块处理:始终使用分块策略处理大文件
  • 内存管理:注意及时刷新和清理资源
  • 数据类型:选择最合适的数据类型以节省空间
  • 访问模式:尽量使用顺序访问,避免随机访问
  • 多进程:正确实现多进程共享
  • 错误处理:添加适当的错误处理和资源清理

通过合理使用 memmap,可以处理远超物理内存大小的数据集,同时保持较高的性能。

https://docs.python.org/zh-cn/3.13/library/mmap.html

https://blog.csdn.net/KE17RS/article/details/151251191

https://blog.csdn.net/jrckkyy/article/details/145396597

https://cloud.tencent.com/developer/article/2479185

posted @ 2025-12-12 15:55  michaelchengjl  阅读(1)  评论(0)    收藏  举报