修改数据的生成
Ah, I understand now. We need to:
- First group trajectories by source-destination (SD) pairs
- Sample 10% of SD pairs for OOD
- Use all trajectories from those SD pairs as OOD data
- Use remaining SD pairs' trajectories to construct train/test sets
Let me modify the code:
I've modified the code to properly handle SD pairs and maintain the client_0 directory structure. Here are the key changes:
-
Added SD Pair Handling:
- New method
get_sd_pair()to extract source-destination pair from trajectory - New method
group_by_sd_pairs()to group trajectories by SD pairs
- New method
-
Modified Data Splitting:
- First groups all trajectories by SD pairs
- Samples 10% of SD pairs for OOD
- Takes all trajectories from selected OOD SD pairs
- Uses remaining SD pairs' trajectories for main dataset construction
-
Directory Structure:
- Maintains 'client_0' directory structure for output files
-
Process Flow:
- Load and merge all data
- Group trajectories by SD pairs
- Sample 10% of SD pairs for OOD
- Collect all trajectories from OOD SD pairs
- Use remaining SD pairs' trajectories for train/test splitting
- Generate detour and switch anomalies
- Apply 0.97 sampling rate to all datasets
- Save all processed datasets in 'client_0' directory
The output files will be saved in:
output_dir/
└── client_0/
├── train.pickle
├── test.pickle
├── detour.pickle
├── switch.pickle
└── ood.pickle
Would you like me to make any adjustments to this approach?
点击查看代码
import pickle
import numpy as np
import random
import os
from collections import defaultdict
from typing import Dict, List, Set, Tuple
class SimplifiedDetourGenerator:
def __init__(self):
# Store (source, middle, dest) -> frequency
self.path_frequency = defaultdict(int)
# Store node -> set of nodes that can be reached from it
self.next_nodes = defaultdict(set)
# Store node -> set of nodes that can reach it
self.prev_nodes = defaultdict(set)
def build_connections(self, trajectories: Dict):
"""Build node connections and calculate path frequencies from trajectories"""
for trajectory in trajectories.values():
n_geo = trajectory['n_geo']
for i in range(len(n_geo) - 2):
source = tuple(n_geo[i])
middle = tuple(n_geo[i + 1])
dest = tuple(n_geo[i + 2])
# Record three-node path frequency
self.path_frequency[(source, middle, dest)] += 1
# Record node connections
self.next_nodes[source].add(middle)
self.next_nodes[middle].add(dest)
self.prev_nodes[middle].add(source)
self.prev_nodes[dest].add(middle)
def find_alternative_middle(self, source: Tuple, middle: Tuple, dest: Tuple) -> List:
"""Find alternative middle node with lower frequency"""
current_freq = self.path_frequency[(source, middle, dest)]
potential_middles = self.next_nodes[source] & self.prev_nodes[dest]
potential_middles.discard(middle)
if not potential_middles:
return None
alt_paths = []
for alt_middle in potential_middles:
freq = self.path_frequency[(source, alt_middle, dest)]
if freq == 0 or freq < current_freq:
alt_paths.append((freq, alt_middle))
if not alt_paths:
return None
chosen_middle = min(alt_paths, key=lambda x: (x[0] if x[0] > 0 else float('inf')))[1]
return list(chosen_middle)
def generate_detour(self, trajectory: Dict) -> Dict:
"""Generate detour by replacing middle node with less frequent alternative"""
n_geo = trajectory['n_geo']
if len(n_geo) < 3:
return trajectory
start_idx = random.randint(0, len(n_geo) - 3)
source = tuple(n_geo[start_idx])
middle = tuple(n_geo[start_idx + 1])
dest = tuple(n_geo[start_idx + 2])
alt_middle = self.find_alternative_middle(source, middle, dest)
if not alt_middle:
return trajectory
new_trajectory = trajectory.copy()
new_trajectory['n_geo'] = (n_geo[:start_idx + 1] +
[alt_middle] +
n_geo[start_idx + 2:])
new_trajectory['m_geo'] = new_trajectory['n_geo']
new_trajectory['npath'] = []
for i in range(len(new_trajectory['n_geo']) - 1):
pair = [
str(list(new_trajectory['n_geo'][i])),
str(list(new_trajectory['n_geo'][i+1]))
]
new_trajectory['npath'].append(pair)
return new_trajectory
def find_switch_path(self, start_node: Tuple, end_node: Tuple, current_path: List[Tuple]) -> List:
"""Find alternative path with two intermediate nodes"""
current_freq = (
self.path_frequency.get((start_node, current_path[0], current_path[1]), 0) +
self.path_frequency.get((current_path[0], current_path[1], end_node), 0)
)
first_layer = self.next_nodes[start_node]
if not first_layer:
return None
potential_paths = []
for first_node in first_layer:
second_layer = self.next_nodes[first_node] & self.prev_nodes[end_node]
for second_node in second_layer:
if second_node != start_node and second_node != first_node:
total_freq = (
self.path_frequency.get((start_node, first_node, second_node), 0) +
self.path_frequency.get((first_node, second_node, end_node), 0)
)
if total_freq < current_freq:
potential_paths.append({
'path': [list(first_node), list(second_node)],
'frequency': total_freq
})
if not potential_paths:
return None
min_freq = min(path['frequency'] for path in potential_paths)
lowest_freq_paths = [
path['path'] for path in potential_paths
if path['frequency'] == min_freq
]
return random.choice(lowest_freq_paths)
def generate_switch(self, trajectory: Dict) -> Dict:
"""Generate route switching outlier"""
n_geo = trajectory['n_geo']
if len(n_geo) < 5:
return trajectory
max_start_positions = 3
for _ in range(max_start_positions):
switch_start = random.randint(1, len(n_geo) - 4)
switch_end = switch_start + 3
before_node = tuple(n_geo[switch_start])
after_node = tuple(n_geo[switch_end])
current_path = [tuple(n_geo[switch_start + 1]), tuple(n_geo[switch_start + 2])]
new_segment = self.find_switch_path(before_node, after_node, current_path)
if new_segment:
original_segment = [tuple(node) for node in n_geo[switch_start+1:switch_end]]
new_segment_tuples = [tuple(node) for node in new_segment]
if new_segment_tuples != original_segment:
new_trajectory = trajectory.copy()
new_trajectory['n_geo'] = (n_geo[:switch_start+1] +
new_segment +
n_geo[switch_end:])
new_trajectory['m_geo'] = new_trajectory['n_geo']
new_trajectory['npath'] = []
for i in range(len(new_trajectory['n_geo']) - 1):
pair = [
str(list(new_trajectory['n_geo'][i])),
str(list(new_trajectory['n_geo'][i+1]))
]
new_trajectory['npath'].append(pair)
return new_trajectory
return trajectory
class SingleClientProcessor:
def __init__(self, config: Dict):
self.config = config
self.output_dir = config['output_dir']
os.makedirs(self.output_dir, exist_ok=True)
self.all_nodes = set()
self.detour_generator = SimplifiedDetourGenerator()
def get_sd_pair(self, trajectory: Dict) -> Tuple:
"""Extract source-destination pair from trajectory"""
n_geo = trajectory['n_geo']
return (tuple(n_geo[0]), tuple(n_geo[-1]))
def group_by_sd_pairs(self, trajectories: Dict) -> Dict[Tuple, List]:
"""Group trajectories by source-destination pairs"""
sd_groups = defaultdict(list)
for traj_id, traj in trajectories.items():
sd_pair = self.get_sd_pair(traj)
sd_groups[sd_pair].append(traj)
return sd_groups
def merge_datasets(self, train_path: str, id_path: str, ood_path: str) -> Dict:
"""Merge datasets and collect unique nodes"""
train_data = self.load_pickle_data(train_path)
id_data = self.load_pickle_data(id_path)
ood_data = self.load_pickle_data(ood_path)
merged_data = {}
for i, (key, value) in enumerate(train_data.items()):
merged_data[f"train_{i}"] = value
for point in value['n_geo']:
self.all_nodes.add(tuple(point))
for i, (key, value) in enumerate(id_data.items()):
merged_data[f"id_{i}"] = value
for point in value['n_geo']:
self.all_nodes.add(tuple(point))
for i, (key, value) in enumerate(ood_data.items()):
merged_data[f"ood_{i}"] = value
for point in value['n_geo']:
self.all_nodes.add(tuple(point))
return merged_data
def sample_trajectory(self, trajectory: Dict, sample_rate: float) -> Dict:
"""Sample trajectory points at given rate"""
if sample_rate >= 1.0:
return trajectory.copy()
n_geo = trajectory['n_geo']
indices = np.arange(0, len(n_geo))
sample_indices = np.sort(np.random.choice(
indices,
size=max(2, int(len(n_geo) * sample_rate)),
replace=False
))
new_trajectory = trajectory.copy()
new_trajectory['n_geo'] = [n_geo[i] for i in sample_indices]
new_trajectory['m_geo'] = new_trajectory['n_geo']
new_trajectory['npath'] = []
for i in range(len(new_trajectory['n_geo']) - 1):
pair = [
str(list(new_trajectory['n_geo'][i])),
str(list(new_trajectory['n_geo'][i+1]))
]
new_trajectory['npath'].append(pair)
return new_trajectory
def load_pickle_data(self, filepath: str) -> Dict:
"""Load data from pickle file"""
with open(filepath, 'rb') as f:
return pickle.load(f)
def process_data(self, client_trajectories: List, other_trajectories: List, sample_rate: float):
"""Process data with specified splitting strategy"""
output_dir = os.path.join(self.output_dir, 'client_0')
os.makedirs(output_dir, exist_ok=True)
# Split into train and test
random.shuffle(client_trajectories)
train_size = max(1, int(len(client_trajectories) * self.config['train_ratio']))
train_trajectories = client_trajectories[:train_size]
test_trajectories = client_trajectories[train_size:]
# Apply sampling
sampled_train_data = {
str(idx).zfill(4): self.sample_trajectory(traj, sample_rate)
for idx, traj in enumerate(train_trajectories)
}
sampled_test_data = {
str(idx).zfill(4): self.sample_trajectory(traj, sample_rate)
for idx, traj in enumerate(test_trajectories)
}
# Generate anomalies
num_test = len(test_trajectories)
num_anomalies = num_test * 2
# Generate detours and switches
detour_data = {
str(i).zfill(4): self.sample_trajectory(
self.detour_generator.generate_detour(traj), sample_rate
)
for i, traj in enumerate(test_trajectories[:num_anomalies])
}
switch_data = {
str(i).zfill(4): self.sample_trajectory(
self.detour_generator.generate_switch(traj), sample_rate
)
for i, traj in enumerate(test_trajectories[:num_anomalies])
}
# Handle OOD data
ood_data = {}
if other_trajectories:
selected_ood = random.sample(other_trajectories, min(num_anomalies, len(other_trajectories)))
ood_data = {
str(i).zfill(4): self.sample_trajectory(traj, sample_rate)
for i, traj in enumerate(selected_ood)
}
# Save datasets
for name, data in [
('train', sampled_train_data),
('test', sampled_test_data),
('detour', detour_data),
('switch', switch_data),
('ood', ood_data)
]:
with open(os.path.join(client_dir, f'{name}.pickle'), 'wb') as f:
pickle.dump(data, f)
def main():
config = {
'random_seed': 42,
'base_dir': "/path/to/your/data/directory", # Change this
'city_name': "xian",
'output_dir': "/path/to/your/output/directory", # Change this
'train_ratio': 0.05,
'ood_sd_ratio': 0.1, # 10% of SD pairs for OOD
'sample_rate': 0.97
}
# Set random seeds
random.seed(config['random_seed'])
np.random.seed(config['random_seed'])
# Initialize processor
processor = SingleClientProcessor(config)
# Merge datasets
merged_data = processor.merge_datasets(
os.path.join(config['base_dir'], config['city_name'], "train.pickle"),
os.path.join(config['base_dir'], config['city_name'], "id.pickle"),
os.path.join(config['base_dir'], config['city_name'], "ood.pickle")
)
# Group trajectories by SD pairs
sd_groups = processor.group_by_sd_pairs(merged_data)
# Sample 10% of SD pairs for OOD
sd_pairs = list(sd_groups.keys())
random.shuffle(sd_pairs)
ood_sd_count = max(1, int(len(sd_pairs) * config['ood_sd_ratio']))
# Split SD pairs
ood_sd_pairs = sd_pairs[:ood_sd_count]
main_sd_pairs = sd_pairs[ood_sd_count:]
# Collect trajectories
ood_trajectories = []
main_trajectories = []
for sd_pair in ood_sd_pairs:
ood_trajectories.extend(sd_groups[sd_pair])
for sd_pair in main_sd_pairs:
main_trajectories.extend(sd_groups[sd_pair])
# Build connections for detour generation using all trajectories
processor.detour_generator.build_connections(merged_data)
print(f"\nProcessing data with sample rate {config['sample_rate']}...")
print(f"Total SD pairs: {len(sd_pairs)}")
print(f"OOD SD pairs: {len(ood_sd_pairs)}")
print(f"Main SD pairs: {len(main_sd_pairs)}")
client_dir = os.path.join(config['output_dir'], 'client_0')
os.makedirs(client_dir, exist_ok=True)
processor.process_data(
client_trajectories=main_trajectories,
other_trajectories=ood_trajectories,
sample_rate=config['sample_rate']
)
print(f"Total trajectories: {len(merged_data)}")
print(f"OOD trajectories: {len(ood_trajectories)}")
print(f"Main trajectories: {len(main_trajectories)}")
print("Data processing completed!")
if __name__ == "__main__":
main()

浙公网安备 33010602011771号