在 parse_model 函数中添加了自定义模块支持


第一段代码(已修改版本)

在 parse_model 函数中添加了自定义模块支持:

n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain

# Custom modules support - Added for RFAConv, HSFPN, HATHead integration
custom_base_module_names = {'RFAConv', 'RFCBAMConv', 'RFABlock'}
is_base_module = m in base_modules or (hasattr(m, '__name__') and m.__name__ in custom_base_module_names)

if is_base_module: # 👈 使用 is_base_module
c1, c2 = ch[f], args[0]
# ... 后续处理

第二段代码(原始版本)

直接检查模块是否在 base_modules 中:

n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain

if m in base_modules: # 👈 直接使用 m in base_modules
c1, c2 = ch[f], args[0]
# ... 后续处理

修改的意义

第一段代码的修改允许:
1. 扩展自定义模块支持:除了内置的 base_modules,还支持 RFAConv、RFCBAMConv、RFABlock 这些自定义注意力模块
2. 动态检查模块名称:通过 hasattr(m, '__name__') 检查模块是否有名称属性
3. 灵活性更强:可以轻松添加更多自定义模块到 custom_base_module_names 集合中

posted @ 2025-11-17 23:03  量子我梦  阅读(4)  评论(0)    收藏  举报