TH1/AITrainPython/data_loader.py
2025-12-06 11:44:43 +08:00

286 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载器
从C#导出的JSONL文件中加载训练数据
"""
import json
import os
import numpy as np
class TrainingDataLoader:
"""
训练数据加载器
"""
def __init__(self, data_dir='Data'):
"""
Args:
data_dir: 数据目录路径
"""
self.data_dir = data_dir
def load_episode_file(self, file_path):
"""
加载单个episode文件
Args:
file_path: JSONL文件路径
Returns:
episode_data: list of dict每个dict包含一步的数据
"""
episode_data = []
try:
with open(file_path, 'r', encoding='utf-8-sig') as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
episode_data.append(data)
return episode_data
except Exception as e:
print(f"加载文件失败 {file_path}: {e}")
return []
def parse_episode_data(self, episode_data):
"""
解析episode数据
Args:
episode_data: 单个episode的原始数据
Returns:
states: (n_steps, state_dim)
actions: (n_steps,) 选择的动作索引
rewards: (n_steps,)
dones: (n_steps,)
valid_actions_list: list of uint64 arrays每步的有效动作
"""
states = []
actions = []
rewards = []
dones = []
valid_actions_list = []
for step_data in episode_data:
# 状态
state = np.array(step_data['State'], dtype=np.float32)
states.append(state)
# 有效动作列表
valid_actions = np.array(step_data['Actions'], dtype=np.uint64)
valid_actions_list.append(valid_actions)
# 选择的动作找到在valid_actions中的索引
selected_action = int(step_data['SelectedAction'])
action_index = np.where(valid_actions == selected_action)[0]
if len(action_index) > 0:
action_index = action_index[0]
else:
# 如果找不到默认为0这种情况不应该发生
action_index = 0
# print(f"警告: 未找到选择的动作 {selected_action} 在有效动作列表中")
actions.append(action_index)
# 奖励和done标志
rewards.append(float(step_data['Reward']))
dones.append(bool(step_data['Done']))
states = np.array(states, dtype=np.float32)
actions = np.array(actions, dtype=np.int64)
rewards = np.array(rewards, dtype=np.float32)
dones = np.array(dones, dtype=np.bool_)
return states, actions, rewards, dones, valid_actions_list
def load_all_episodes(self, max_episodes=None):
"""
加载所有episode文件
Args:
max_episodes: 最大加载数量None表示加载全部
Returns:
all_episodes: list of tuples (states, actions, rewards, dones, valid_actions_list)
"""
if not os.path.exists(self.data_dir):
print(f"数据目录不存在: {self.data_dir}")
return []
all_episodes = []
file_count = 0
# 获取所有文件并排序(保证顺序一致)
all_files = sorted([f for f in os.listdir(self.data_dir) if f.startswith('episode_')])
# 如果指定了max_episodes只取前N个
if max_episodes is not None:
all_files = all_files[:max_episodes]
print(f"⚡ 快速测试模式: 只加载前 {max_episodes} 个episodes")
# 遍历文件
for file_name in all_files:
file_path = os.path.join(self.data_dir, file_name)
# 跳过目录
if not os.path.isfile(file_path):
continue
# 加载并解析文件
episode_data = self.load_episode_file(file_path)
if len(episode_data) == 0:
continue
parsed_data = self.parse_episode_data(episode_data)
all_episodes.append(parsed_data)
file_count += 1
if file_count % 10 == 0:
print(f"已加载 {file_count} 个episode文件...")
print(f"总共加载 {file_count} 个episode文件")
return all_episodes
def get_state_dim(self):
"""
获取状态维度(从第一个文件推断)
Returns:
state_dim: int
"""
all_episodes = self.load_all_episodes()
if len(all_episodes) == 0:
# 默认维度(如果没有数据,使用默认值)
return 572 # 从实际数据检测通常为570-572
states, _, _, _, _ = all_episodes[0]
return states.shape[1]
def group_episodes_by_action_pattern(self, all_episodes):
"""
将episodes按动作数量模式分组以便批量处理
Args:
all_episodes: list of episodes
Returns:
grouped_episodes: dict mapping action_pattern_key -> list of episodes
"""
from collections import defaultdict
grouped = defaultdict(list)
for episode in all_episodes:
states, actions, rewards, dones, valid_actions_list = episode
# 计算动作数量模式(使用固定长度的哈希)
action_counts = tuple(len(va) for va in valid_actions_list)
# 如果episode很短或动作数量变化不大可以分组
if len(set(action_counts)) == 1:
# 所有步都有相同数量的动作
key = f"uniform_{action_counts[0]}"
else:
# 动作数量变化,放入通用组
key = "mixed"
grouped[key].append(episode)
return grouped
class ReplayBuffer:
"""
经验回放缓冲区 - 高性能版本使用numpy arrays而非Python lists
"""
def __init__(self, capacity=10000):
"""
Args:
capacity: 缓冲区容量
"""
self.capacity = capacity
# 使用None表示未初始化首次添加时创建numpy数组
self.states = None
self.actions = None
self.rewards = None
self.dones = None
self.valid_actions_ids = [] # 这个必须保持list因为每个元素长度不同
self.current_size = 0
def add_episode(self, states, actions, rewards, dones, valid_actions_list):
"""
添加一个完整episode - 优化版本使用numpy操作
Args:
states: (n_steps, state_dim) numpy array
actions: (n_steps,) numpy array
rewards: (n_steps,) numpy array
dones: (n_steps,) numpy array
valid_actions_list: list of arrays (每个是 uint64 动作 IDs)
"""
n_steps = len(states)
# 首次添加初始化numpy数组
if self.states is None:
state_dim = states.shape[1]
self.states = np.empty((self.capacity, state_dim), dtype=np.float32)
self.actions = np.empty(self.capacity, dtype=np.int64)
self.rewards = np.empty(self.capacity, dtype=np.float32)
self.dones = np.empty(self.capacity, dtype=np.bool_)
# 检查是否会超出容量
if self.current_size + n_steps <= self.capacity:
# 直接添加
self.states[self.current_size:self.current_size+n_steps] = states
self.actions[self.current_size:self.current_size+n_steps] = actions
self.rewards[self.current_size:self.current_size+n_steps] = rewards
self.dones[self.current_size:self.current_size+n_steps] = dones
self.valid_actions_ids.extend(valid_actions_list)
self.current_size += n_steps
else:
# 需要滚动:保留后面的数据
overflow = (self.current_size + n_steps) - self.capacity
keep_size = self.capacity - n_steps
# 滚动已有数据使用numpy切片非常快
self.states[:keep_size] = self.states[overflow:self.current_size]
self.actions[:keep_size] = self.actions[overflow:self.current_size]
self.rewards[:keep_size] = self.rewards[overflow:self.current_size]
self.dones[:keep_size] = self.dones[overflow:self.current_size]
self.valid_actions_ids = self.valid_actions_ids[overflow:]
# 添加新数据
self.states[keep_size:keep_size+n_steps] = states
self.actions[keep_size:keep_size+n_steps] = actions
self.rewards[keep_size:keep_size+n_steps] = rewards
self.dones[keep_size:keep_size+n_steps] = dones
self.valid_actions_ids.extend(valid_actions_list)
self.current_size = self.capacity
def get_all(self):
"""
获取所有数据 - 零拷贝版本
Returns:
states, actions, rewards, dones, valid_actions_ids_list
"""
if self.states is None:
return (
np.array([], dtype=np.float32),
np.array([], dtype=np.int64),
np.array([], dtype=np.float32),
np.array([], dtype=bool),
[]
)
# 返回视图而非拷贝(更快)
return (
self.states[:self.current_size],
self.actions[:self.current_size],
self.rewards[:self.current_size],
self.dones[:self.current_size],
self.valid_actions_ids
)
def clear(self):
"""清空缓冲区"""
# 不释放内存,只重置指针(避免重新分配)
self.current_size = 0
self.valid_actions_ids.clear()
def __len__(self):
return self.current_size