修改数据的生成
Ah, I understand now. We need to:
- First group trajectories by source-destination (SD) pairs
- Sample 10% of SD pairs for OOD
- Use all trajectories from those SD pairs as OOD data
- Use remaining SD pairs' trajectories to construct train/test sets
Let me modify the code:
I've modified the code to properly handle SD pairs and maintain the client_0 directory structure. Here are the key changes:
- 
Added SD Pair Handling: - New method get_sd_pair()to extract source-destination pair from trajectory
- New method group_by_sd_pairs()to group trajectories by SD pairs
 
- New method 
- 
Modified Data Splitting: - First groups all trajectories by SD pairs
- Samples 10% of SD pairs for OOD
- Takes all trajectories from selected OOD SD pairs
- Uses remaining SD pairs' trajectories for main dataset construction
 
- 
Directory Structure: - Maintains 'client_0' directory structure for output files
 
- 
Process Flow: - Load and merge all data
- Group trajectories by SD pairs
- Sample 10% of SD pairs for OOD
- Collect all trajectories from OOD SD pairs
- Use remaining SD pairs' trajectories for train/test splitting
- Generate detour and switch anomalies
- Apply 0.97 sampling rate to all datasets
- Save all processed datasets in 'client_0' directory
 
The output files will be saved in:
output_dir/
└── client_0/
    ├── train.pickle
    ├── test.pickle
    ├── detour.pickle
    ├── switch.pickle
    └── ood.pickle
Would you like me to make any adjustments to this approach?
点击查看代码
import pickle
import numpy as np
import random
import os
from collections import defaultdict
from typing import Dict, List, Set, Tuple
class SimplifiedDetourGenerator:
    def __init__(self):
        # Store (source, middle, dest) -> frequency
        self.path_frequency = defaultdict(int)
        # Store node -> set of nodes that can be reached from it
        self.next_nodes = defaultdict(set)
        # Store node -> set of nodes that can reach it
        self.prev_nodes = defaultdict(set)
        
    def build_connections(self, trajectories: Dict):
        """Build node connections and calculate path frequencies from trajectories"""
        for trajectory in trajectories.values():
            n_geo = trajectory['n_geo']
            for i in range(len(n_geo) - 2):
                source = tuple(n_geo[i])
                middle = tuple(n_geo[i + 1])
                dest = tuple(n_geo[i + 2])
                
                # Record three-node path frequency
                self.path_frequency[(source, middle, dest)] += 1
                
                # Record node connections
                self.next_nodes[source].add(middle)
                self.next_nodes[middle].add(dest)
                self.prev_nodes[middle].add(source)
                self.prev_nodes[dest].add(middle)
    def find_alternative_middle(self, source: Tuple, middle: Tuple, dest: Tuple) -> List:
        """Find alternative middle node with lower frequency"""
        current_freq = self.path_frequency[(source, middle, dest)]
        potential_middles = self.next_nodes[source] & self.prev_nodes[dest]
        potential_middles.discard(middle)
        
        if not potential_middles:
            return None
            
        alt_paths = []
        for alt_middle in potential_middles:
            freq = self.path_frequency[(source, alt_middle, dest)]
            if freq == 0 or freq < current_freq:
                alt_paths.append((freq, alt_middle))
                
        if not alt_paths:
            return None
            
        chosen_middle = min(alt_paths, key=lambda x: (x[0] if x[0] > 0 else float('inf')))[1]
        return list(chosen_middle)
    def generate_detour(self, trajectory: Dict) -> Dict:
        """Generate detour by replacing middle node with less frequent alternative"""
        n_geo = trajectory['n_geo']
        if len(n_geo) < 3:
            return trajectory
            
        start_idx = random.randint(0, len(n_geo) - 3)
        
        source = tuple(n_geo[start_idx])
        middle = tuple(n_geo[start_idx + 1])
        dest = tuple(n_geo[start_idx + 2])
        
        alt_middle = self.find_alternative_middle(source, middle, dest)
        if not alt_middle:
            return trajectory
            
        new_trajectory = trajectory.copy()
        new_trajectory['n_geo'] = (n_geo[:start_idx + 1] + 
                                 [alt_middle] + 
                                 n_geo[start_idx + 2:])
        new_trajectory['m_geo'] = new_trajectory['n_geo']
        
        new_trajectory['npath'] = []
        for i in range(len(new_trajectory['n_geo']) - 1):
            pair = [
                str(list(new_trajectory['n_geo'][i])),
                str(list(new_trajectory['n_geo'][i+1]))
            ]
            new_trajectory['npath'].append(pair)
            
        return new_trajectory
    def find_switch_path(self, start_node: Tuple, end_node: Tuple, current_path: List[Tuple]) -> List:
        """Find alternative path with two intermediate nodes"""
        current_freq = (
            self.path_frequency.get((start_node, current_path[0], current_path[1]), 0) +
            self.path_frequency.get((current_path[0], current_path[1], end_node), 0)
        )
        
        first_layer = self.next_nodes[start_node]
        if not first_layer:
            return None
            
        potential_paths = []
        for first_node in first_layer:
            second_layer = self.next_nodes[first_node] & self.prev_nodes[end_node]
            for second_node in second_layer:
                if second_node != start_node and second_node != first_node:
                    total_freq = (
                        self.path_frequency.get((start_node, first_node, second_node), 0) +
                        self.path_frequency.get((first_node, second_node, end_node), 0)
                    )
                    
                    if total_freq < current_freq:
                        potential_paths.append({
                            'path': [list(first_node), list(second_node)],
                            'frequency': total_freq
                        })
                    
        if not potential_paths:
            return None
            
        min_freq = min(path['frequency'] for path in potential_paths)
        lowest_freq_paths = [
            path['path'] for path in potential_paths 
            if path['frequency'] == min_freq
        ]
        
        return random.choice(lowest_freq_paths)
    def generate_switch(self, trajectory: Dict) -> Dict:
        """Generate route switching outlier"""
        n_geo = trajectory['n_geo']
        if len(n_geo) < 5:
            return trajectory
            
        max_start_positions = 3
        
        for _ in range(max_start_positions):
            switch_start = random.randint(1, len(n_geo) - 4)
            switch_end = switch_start + 3
            
            before_node = tuple(n_geo[switch_start])
            after_node = tuple(n_geo[switch_end])
            current_path = [tuple(n_geo[switch_start + 1]), tuple(n_geo[switch_start + 2])]
            
            new_segment = self.find_switch_path(before_node, after_node, current_path)
            
            if new_segment:
                original_segment = [tuple(node) for node in n_geo[switch_start+1:switch_end]]
                new_segment_tuples = [tuple(node) for node in new_segment]
                
                if new_segment_tuples != original_segment:
                    new_trajectory = trajectory.copy()
                    new_trajectory['n_geo'] = (n_geo[:switch_start+1] + 
                                            new_segment + 
                                            n_geo[switch_end:])
                    new_trajectory['m_geo'] = new_trajectory['n_geo']
                    
                    new_trajectory['npath'] = []
                    for i in range(len(new_trajectory['n_geo']) - 1):
                        pair = [
                            str(list(new_trajectory['n_geo'][i])),
                            str(list(new_trajectory['n_geo'][i+1]))
                        ]
                        new_trajectory['npath'].append(pair)
                        
                    return new_trajectory
        
        return trajectory
class SingleClientProcessor:
    def __init__(self, config: Dict):
        self.config = config
        self.output_dir = config['output_dir']
        os.makedirs(self.output_dir, exist_ok=True)
        self.all_nodes = set()
        self.detour_generator = SimplifiedDetourGenerator()
        
    def get_sd_pair(self, trajectory: Dict) -> Tuple:
        """Extract source-destination pair from trajectory"""
        n_geo = trajectory['n_geo']
        return (tuple(n_geo[0]), tuple(n_geo[-1]))
        
    def group_by_sd_pairs(self, trajectories: Dict) -> Dict[Tuple, List]:
        """Group trajectories by source-destination pairs"""
        sd_groups = defaultdict(list)
        for traj_id, traj in trajectories.items():
            sd_pair = self.get_sd_pair(traj)
            sd_groups[sd_pair].append(traj)
        return sd_groups
    def merge_datasets(self, train_path: str, id_path: str, ood_path: str) -> Dict:
        """Merge datasets and collect unique nodes"""
        train_data = self.load_pickle_data(train_path)
        id_data = self.load_pickle_data(id_path)
        ood_data = self.load_pickle_data(ood_path)
        
        merged_data = {}
        for i, (key, value) in enumerate(train_data.items()):
            merged_data[f"train_{i}"] = value
            for point in value['n_geo']:
                self.all_nodes.add(tuple(point))
                
        for i, (key, value) in enumerate(id_data.items()):
            merged_data[f"id_{i}"] = value
            for point in value['n_geo']:
                self.all_nodes.add(tuple(point))
                
        for i, (key, value) in enumerate(ood_data.items()):
            merged_data[f"ood_{i}"] = value
            for point in value['n_geo']:
                self.all_nodes.add(tuple(point))
            
        return merged_data
    def sample_trajectory(self, trajectory: Dict, sample_rate: float) -> Dict:
        """Sample trajectory points at given rate"""
        if sample_rate >= 1.0:
            return trajectory.copy()
            
        n_geo = trajectory['n_geo']
        indices = np.arange(0, len(n_geo))
        sample_indices = np.sort(np.random.choice(
            indices, 
            size=max(2, int(len(n_geo) * sample_rate)),
            replace=False
        ))
        
        new_trajectory = trajectory.copy()
        new_trajectory['n_geo'] = [n_geo[i] for i in sample_indices]
        new_trajectory['m_geo'] = new_trajectory['n_geo']
        
        new_trajectory['npath'] = []
        for i in range(len(new_trajectory['n_geo']) - 1):
            pair = [
                str(list(new_trajectory['n_geo'][i])),
                str(list(new_trajectory['n_geo'][i+1]))
            ]
            new_trajectory['npath'].append(pair)
            
        return new_trajectory
    def load_pickle_data(self, filepath: str) -> Dict:
        """Load data from pickle file"""
        with open(filepath, 'rb') as f:
            return pickle.load(f)
    def process_data(self, client_trajectories: List, other_trajectories: List, sample_rate: float):
        """Process data with specified splitting strategy"""
        output_dir = os.path.join(self.output_dir, 'client_0')
        os.makedirs(output_dir, exist_ok=True)
        
        # Split into train and test
        random.shuffle(client_trajectories)
        train_size = max(1, int(len(client_trajectories) * self.config['train_ratio']))
        
        train_trajectories = client_trajectories[:train_size]
        test_trajectories = client_trajectories[train_size:]
        
        # Apply sampling
        sampled_train_data = {
            str(idx).zfill(4): self.sample_trajectory(traj, sample_rate) 
            for idx, traj in enumerate(train_trajectories)
        }
        
        sampled_test_data = {
            str(idx).zfill(4): self.sample_trajectory(traj, sample_rate)
            for idx, traj in enumerate(test_trajectories)
        }
        
        # Generate anomalies
        num_test = len(test_trajectories)
        num_anomalies = num_test * 2
        
        # Generate detours and switches
        detour_data = {
            str(i).zfill(4): self.sample_trajectory(
                self.detour_generator.generate_detour(traj), sample_rate
            )
            for i, traj in enumerate(test_trajectories[:num_anomalies])
        }
        
        switch_data = {
            str(i).zfill(4): self.sample_trajectory(
                self.detour_generator.generate_switch(traj), sample_rate
            )
            for i, traj in enumerate(test_trajectories[:num_anomalies])
        }
        
        # Handle OOD data
        ood_data = {}
        if other_trajectories:
            selected_ood = random.sample(other_trajectories, min(num_anomalies, len(other_trajectories)))
            ood_data = {
                str(i).zfill(4): self.sample_trajectory(traj, sample_rate)
                for i, traj in enumerate(selected_ood)
            }
        
        # Save datasets
        for name, data in [
            ('train', sampled_train_data),
            ('test', sampled_test_data),
            ('detour', detour_data),
            ('switch', switch_data),
            ('ood', ood_data)
        ]:
            with open(os.path.join(client_dir, f'{name}.pickle'), 'wb') as f:
                pickle.dump(data, f)
def main():
    config = {
        'random_seed': 42,
        'base_dir': "/path/to/your/data/directory",  # Change this
        'city_name': "xian",
        'output_dir': "/path/to/your/output/directory",  # Change this
        'train_ratio': 0.05,
        'ood_sd_ratio': 0.1,  # 10% of SD pairs for OOD
        'sample_rate': 0.97
    }
    
    # Set random seeds
    random.seed(config['random_seed'])
    np.random.seed(config['random_seed'])
    
    # Initialize processor
    processor = SingleClientProcessor(config)
    
    # Merge datasets
    merged_data = processor.merge_datasets(
        os.path.join(config['base_dir'], config['city_name'], "train.pickle"),
        os.path.join(config['base_dir'], config['city_name'], "id.pickle"),
        os.path.join(config['base_dir'], config['city_name'], "ood.pickle")
    )
    
    # Group trajectories by SD pairs
    sd_groups = processor.group_by_sd_pairs(merged_data)
    
    # Sample 10% of SD pairs for OOD
    sd_pairs = list(sd_groups.keys())
    random.shuffle(sd_pairs)
    ood_sd_count = max(1, int(len(sd_pairs) * config['ood_sd_ratio']))
    
    # Split SD pairs
    ood_sd_pairs = sd_pairs[:ood_sd_count]
    main_sd_pairs = sd_pairs[ood_sd_count:]
    
    # Collect trajectories
    ood_trajectories = []
    main_trajectories = []
    
    for sd_pair in ood_sd_pairs:
        ood_trajectories.extend(sd_groups[sd_pair])
    
    for sd_pair in main_sd_pairs:
        main_trajectories.extend(sd_groups[sd_pair])
    
    # Build connections for detour generation using all trajectories
    processor.detour_generator.build_connections(merged_data)
    
    print(f"\nProcessing data with sample rate {config['sample_rate']}...")
    print(f"Total SD pairs: {len(sd_pairs)}")
    print(f"OOD SD pairs: {len(ood_sd_pairs)}")
    print(f"Main SD pairs: {len(main_sd_pairs)}")
    
    client_dir = os.path.join(config['output_dir'], 'client_0')
    os.makedirs(client_dir, exist_ok=True)
    
    processor.process_data(
        client_trajectories=main_trajectories,
        other_trajectories=ood_trajectories,
        sample_rate=config['sample_rate']
    )
    
    print(f"Total trajectories: {len(merged_data)}")
    print(f"OOD trajectories: {len(ood_trajectories)}")
    print(f"Main trajectories: {len(main_trajectories)}")
    print("Data processing completed!")
if __name__ == "__main__":
    main()
 
                    
                     
                    
                 
                    
                
 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号