chatGPT帮我优化代码-2024.06.20

改成面向对象

  该程序的主要任务是从指定的文本文件中提取 ROI(感兴趣区域)信息,统计不同 ROI 标签(如 56,63,69 ……)的出现次数,并绘制统计结果的条形图。通过将功能模块化到 RoiAnalyzer 类中,代码变得更加结构化和可维护。
【文本文件】:

  • 源代码
    
    import re
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    def ret_roi_value_dict(txt_path):
        output = []
        line_number = 0
        with open(txt_path, 'r') as file:
            for line in file:
                line_number += 1
                # 使用正则表达式提取case_name和pixel_value
                match = re.match(r'Case\s+(.*?)\s+has\s+\[(.*?)\]', line)
                if match:
                    case_name = match.group(1)
                    # 去除字符串中的多余空格,并进行分割
                    pixel_value = list(map(int, match.group(2).strip().split()))
                    output.append({'case_name': case_name, 'pixel_value': pixel_value})
                else:
                    print(f"Line {line_number} not matched: {line.strip()}")
        
        return output
    
    def show_roi_value(roi_value_dict, fig_save_path, save_dpi):
        # 提取所有 "pixel_value" 对应的值
        all_pixel_values = []
        for item in roi_value_dict:
            all_pixel_values.extend(item["pixel_value"])
        # 计算每个元素的出现次数
        unique_elements, counts = np.unique(np.array(all_pixel_values), return_counts=True)
        # 将 unique_elements 和 counts 转换为 list,并根据 counts 进行排序
        sorted_indices = np.argsort(counts)[::-1]
        sorted_unique_elements = unique_elements[sorted_indices]
        sorted_counts = counts[sorted_indices]
        
        # 定义前五个颜色
        top_colors = ['red', 'blue', 'green', 'yellow', 'purple']
        bar_colors = ['grey'] * len(unique_elements)
    
        # 将排名第二到第六的元素的颜色改为指定颜色
        for i in range(1, min(6, len(sorted_unique_elements))):
            # 找到 sorted_unique_elements[i] 在 original unique_elements 中的位置
            original_index = np.where(unique_elements == sorted_unique_elements[i])[0][0]
            bar_colors[original_index] = top_colors[i-1]
        # 绘制条形图
        plt.figure(figsize=(10, 6))
        bars = sns.barplot(x=unique_elements, y=counts, palette=bar_colors)
        plt.title('TB ROI Information')
        plt.xlabel('ROI label')
        plt.ylabel('Case count')
        plt.xticks(rotation=45)  # 旋转 x 轴标签以提高可读性
        plt.tight_layout()  # 确保标签不会被裁剪
        plt.grid(axis="y")
        # 在每个 bar 上方标注数量
        for i, bar in enumerate(bars.patches):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width() / 2, int(height), f'{int(height)}', ha='center', va='bottom', color='black')
        # 保存图像
        plt.savefig(fig_save_path, dpi=save_dpi)
    
    def main():
        keyword_list = ["MaXin_50_Newlabel_final", "MaXin_100", "ZhangRui_50_Newlabel_final", "ZhangRui_54"]
        roi_value_dict = []
        for keyword in keyword_list:
            temp = ret_roi_value_dict(txt_path=f"/homes/xchang/Projects/Multi-center_Tuberculosis/segmentation/nnUNet_dl/{keyword}.txt")
            roi_value_dict.extend(temp) 
        show_roi_value(roi_value_dict=roi_value_dict, fig_save_path=f"/homes/xchang/Projects/Multi-center_Tuberculosis/roi_new_values_final.png", save_dpi=600)
    
    if __name__ == "__main__":
        main()
    
  • 优化后
    
    import re
    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    class RoiAnalyzer:
        def __init__(self, keywords, base_path):
            self.keywords = keywords
            self.base_path = base_path
            self.roi_value_dict = []
    
        def ret_roi_value_dict(self, txt_path):
            output = []
            line_number = 0
            with open(txt_path, 'r') as file:
                for line in file:
                    line_number += 1
                    # 使用正则表达式提取case_name和pixel_value
                    match = re.match(r'Case\s+(.*?)\s+has\s+\[(.*?)\]', line)
                    if match:
                        case_name = match.group(1)
                        # 去除字符串中的多余空格,并进行分割
                        pixel_value = list(map(int, match.group(2).strip().split()))
                        output.append({'case_name': case_name, 'pixel_value': pixel_value})
                    else:
                        print(f"Line {line_number} not matched: {line.strip()}")
            
            return output
    
        def collect_data(self):
            for keyword in self.keywords:
                txt_path = f"{self.base_path}/{keyword}.txt"
                temp = self.ret_roi_value_dict(txt_path)
                self.roi_value_dict.extend(temp)
    
        def show_roi_value(self, fig_save_path, save_dpi):
            # 提取所有 "pixel_value" 对应的值
            all_pixel_values = []
            for item in self.roi_value_dict:
                all_pixel_values.extend(item["pixel_value"])
            # 计算每个元素的出现次数
            unique_elements, counts = np.unique(np.array(all_pixel_values), return_counts=True)
            # 将 unique_elements 和 counts 转换为 list,并根据 counts 进行排序
            sorted_indices = np.argsort(counts)[::-1]
            sorted_unique_elements = unique_elements[sorted_indices]
            sorted_counts = counts[sorted_indices]
            
            # 定义前五个颜色
            top_colors = ['red', 'blue', 'green', 'yellow', 'purple']
            bar_colors = ['grey'] * len(unique_elements)
    
            # 将排名第二到第六的元素的颜色改为指定颜色
            for i in range(1, min(6, len(sorted_unique_elements))):
                # 找到 sorted_unique_elements[i] 在 original unique_elements 中的位置
                original_index = np.where(unique_elements == sorted_unique_elements[i])[0][0]
                bar_colors[original_index] = top_colors[i-1]
            
            # 绘制条形图
            plt.figure(figsize=(10, 6))
            bars = sns.barplot(x=unique_elements, y=counts, palette=bar_colors)
            plt.title('TB ROI Information')
            plt.xlabel('ROI label')
            plt.ylabel('Case count')
            plt.xticks(rotation=45)  # 旋转 x 轴标签以提高可读性
            plt.tight_layout()  # 确保标签不会被裁剪
            plt.grid(axis="y")
            
            # 在每个 bar 上方标注数量
            for i, bar in enumerate(bars.patches):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width() / 2, int(height), f'{int(height)}', ha='center', va='bottom', color='black')
            
            # 保存图像
            plt.savefig(fig_save_path, dpi=save_dpi)
    
        def run(self, fig_save_path, save_dpi=600):
            self.collect_data()
            self.show_roi_value(fig_save_path, save_dpi)
    
    def main():
        keywords = ["MaXin_50_Newlabel_final", "MaXin_100", "ZhangRui_50_Newlabel_final", "ZhangRui_54"]
        base_path = "/homes/xchang/Projects/Multi-center_Tuberculosis/segmentation/nnUNet_dl"
        fig_save_path = "/homes/xchang/Projects/Multi-center_Tuberculosis/roi_new_values_final.png"
        
        analyzer = RoiAnalyzer(keywords, base_path)
        analyzer.run(fig_save_path)
    
    if __name__ == "__main__":
        main()
    

posted @ 2024-06-20 19:23  Elina-Chang  阅读(50)  评论(0)    收藏  举报