286 lines
10 KiB
Python
286 lines
10 KiB
Python
"""
|
||
数据加载器
|
||
从C#导出的JSONL文件中加载训练数据
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import numpy as np
|
||
|
||
|
||
class TrainingDataLoader:
|
||
"""
|
||
训练数据加载器
|
||
"""
|
||
def __init__(self, data_dir='Data'):
|
||
"""
|
||
Args:
|
||
data_dir: 数据目录路径
|
||
"""
|
||
self.data_dir = data_dir
|
||
|
||
def load_episode_file(self, file_path):
|
||
"""
|
||
加载单个episode文件
|
||
Args:
|
||
file_path: JSONL文件路径
|
||
Returns:
|
||
episode_data: list of dict,每个dict包含一步的数据
|
||
"""
|
||
episode_data = []
|
||
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8-sig') as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
data = json.loads(line)
|
||
episode_data.append(data)
|
||
|
||
return episode_data
|
||
except Exception as e:
|
||
print(f"加载文件失败 {file_path}: {e}")
|
||
return []
|
||
|
||
def parse_episode_data(self, episode_data):
|
||
"""
|
||
解析episode数据
|
||
Args:
|
||
episode_data: 单个episode的原始数据
|
||
Returns:
|
||
states: (n_steps, state_dim)
|
||
actions: (n_steps,) 选择的动作索引
|
||
rewards: (n_steps,)
|
||
dones: (n_steps,)
|
||
valid_actions_list: list of uint64 arrays,每步的有效动作
|
||
"""
|
||
states = []
|
||
actions = []
|
||
rewards = []
|
||
dones = []
|
||
valid_actions_list = []
|
||
|
||
for step_data in episode_data:
|
||
# 状态
|
||
state = np.array(step_data['State'], dtype=np.float32)
|
||
states.append(state)
|
||
|
||
# 有效动作列表
|
||
valid_actions = np.array(step_data['Actions'], dtype=np.uint64)
|
||
valid_actions_list.append(valid_actions)
|
||
|
||
# 选择的动作(找到在valid_actions中的索引)
|
||
selected_action = int(step_data['SelectedAction'])
|
||
action_index = np.where(valid_actions == selected_action)[0]
|
||
if len(action_index) > 0:
|
||
action_index = action_index[0]
|
||
else:
|
||
# 如果找不到,默认为0(这种情况不应该发生)
|
||
action_index = 0
|
||
# print(f"警告: 未找到选择的动作 {selected_action} 在有效动作列表中")
|
||
actions.append(action_index)
|
||
|
||
# 奖励和done标志
|
||
rewards.append(float(step_data['Reward']))
|
||
dones.append(bool(step_data['Done']))
|
||
|
||
states = np.array(states, dtype=np.float32)
|
||
actions = np.array(actions, dtype=np.int64)
|
||
rewards = np.array(rewards, dtype=np.float32)
|
||
dones = np.array(dones, dtype=np.bool_)
|
||
|
||
return states, actions, rewards, dones, valid_actions_list
|
||
|
||
def load_all_episodes(self, max_episodes=None):
|
||
"""
|
||
加载所有episode文件
|
||
Args:
|
||
max_episodes: 最大加载数量,None表示加载全部
|
||
Returns:
|
||
all_episodes: list of tuples (states, actions, rewards, dones, valid_actions_list)
|
||
"""
|
||
if not os.path.exists(self.data_dir):
|
||
print(f"数据目录不存在: {self.data_dir}")
|
||
return []
|
||
|
||
all_episodes = []
|
||
file_count = 0
|
||
|
||
# 获取所有文件并排序(保证顺序一致)
|
||
all_files = sorted([f for f in os.listdir(self.data_dir) if f.startswith('episode_')])
|
||
|
||
# 如果指定了max_episodes,只取前N个
|
||
if max_episodes is not None:
|
||
all_files = all_files[:max_episodes]
|
||
print(f"⚡ 快速测试模式: 只加载前 {max_episodes} 个episodes")
|
||
|
||
# 遍历文件
|
||
for file_name in all_files:
|
||
file_path = os.path.join(self.data_dir, file_name)
|
||
|
||
# 跳过目录
|
||
if not os.path.isfile(file_path):
|
||
continue
|
||
|
||
# 加载并解析文件
|
||
episode_data = self.load_episode_file(file_path)
|
||
if len(episode_data) == 0:
|
||
continue
|
||
|
||
parsed_data = self.parse_episode_data(episode_data)
|
||
all_episodes.append(parsed_data)
|
||
file_count += 1
|
||
|
||
if file_count % 10 == 0:
|
||
print(f"已加载 {file_count} 个episode文件...")
|
||
|
||
print(f"总共加载 {file_count} 个episode文件")
|
||
return all_episodes
|
||
|
||
def get_state_dim(self):
|
||
"""
|
||
获取状态维度(从第一个文件推断)
|
||
Returns:
|
||
state_dim: int
|
||
"""
|
||
all_episodes = self.load_all_episodes()
|
||
if len(all_episodes) == 0:
|
||
# 默认维度(如果没有数据,使用默认值)
|
||
return 572 # 从实际数据检测,通常为570-572
|
||
|
||
states, _, _, _, _ = all_episodes[0]
|
||
return states.shape[1]
|
||
|
||
def group_episodes_by_action_pattern(self, all_episodes):
|
||
"""
|
||
将episodes按动作数量模式分组,以便批量处理
|
||
Args:
|
||
all_episodes: list of episodes
|
||
Returns:
|
||
grouped_episodes: dict mapping action_pattern_key -> list of episodes
|
||
"""
|
||
from collections import defaultdict
|
||
grouped = defaultdict(list)
|
||
|
||
for episode in all_episodes:
|
||
states, actions, rewards, dones, valid_actions_list = episode
|
||
|
||
# 计算动作数量模式(使用固定长度的哈希)
|
||
action_counts = tuple(len(va) for va in valid_actions_list)
|
||
|
||
# 如果episode很短或动作数量变化不大,可以分组
|
||
if len(set(action_counts)) == 1:
|
||
# 所有步都有相同数量的动作
|
||
key = f"uniform_{action_counts[0]}"
|
||
else:
|
||
# 动作数量变化,放入通用组
|
||
key = "mixed"
|
||
|
||
grouped[key].append(episode)
|
||
|
||
return grouped
|
||
|
||
|
||
class ReplayBuffer:
|
||
"""
|
||
经验回放缓冲区 - 高性能版本:使用numpy arrays而非Python lists
|
||
"""
|
||
def __init__(self, capacity=10000):
|
||
"""
|
||
Args:
|
||
capacity: 缓冲区容量
|
||
"""
|
||
self.capacity = capacity
|
||
# 使用None表示未初始化,首次添加时创建numpy数组
|
||
self.states = None
|
||
self.actions = None
|
||
self.rewards = None
|
||
self.dones = None
|
||
self.valid_actions_ids = [] # 这个必须保持list因为每个元素长度不同
|
||
self.current_size = 0
|
||
|
||
def add_episode(self, states, actions, rewards, dones, valid_actions_list):
|
||
"""
|
||
添加一个完整episode - 优化版本:使用numpy操作
|
||
Args:
|
||
states: (n_steps, state_dim) numpy array
|
||
actions: (n_steps,) numpy array
|
||
rewards: (n_steps,) numpy array
|
||
dones: (n_steps,) numpy array
|
||
valid_actions_list: list of arrays (每个是 uint64 动作 IDs)
|
||
"""
|
||
n_steps = len(states)
|
||
|
||
# 首次添加:初始化numpy数组
|
||
if self.states is None:
|
||
state_dim = states.shape[1]
|
||
self.states = np.empty((self.capacity, state_dim), dtype=np.float32)
|
||
self.actions = np.empty(self.capacity, dtype=np.int64)
|
||
self.rewards = np.empty(self.capacity, dtype=np.float32)
|
||
self.dones = np.empty(self.capacity, dtype=np.bool_)
|
||
|
||
# 检查是否会超出容量
|
||
if self.current_size + n_steps <= self.capacity:
|
||
# 直接添加
|
||
self.states[self.current_size:self.current_size+n_steps] = states
|
||
self.actions[self.current_size:self.current_size+n_steps] = actions
|
||
self.rewards[self.current_size:self.current_size+n_steps] = rewards
|
||
self.dones[self.current_size:self.current_size+n_steps] = dones
|
||
self.valid_actions_ids.extend(valid_actions_list)
|
||
self.current_size += n_steps
|
||
else:
|
||
# 需要滚动:保留后面的数据
|
||
overflow = (self.current_size + n_steps) - self.capacity
|
||
keep_size = self.capacity - n_steps
|
||
|
||
# 滚动已有数据(使用numpy切片,非常快)
|
||
self.states[:keep_size] = self.states[overflow:self.current_size]
|
||
self.actions[:keep_size] = self.actions[overflow:self.current_size]
|
||
self.rewards[:keep_size] = self.rewards[overflow:self.current_size]
|
||
self.dones[:keep_size] = self.dones[overflow:self.current_size]
|
||
self.valid_actions_ids = self.valid_actions_ids[overflow:]
|
||
|
||
# 添加新数据
|
||
self.states[keep_size:keep_size+n_steps] = states
|
||
self.actions[keep_size:keep_size+n_steps] = actions
|
||
self.rewards[keep_size:keep_size+n_steps] = rewards
|
||
self.dones[keep_size:keep_size+n_steps] = dones
|
||
self.valid_actions_ids.extend(valid_actions_list)
|
||
self.current_size = self.capacity
|
||
|
||
def get_all(self):
|
||
"""
|
||
获取所有数据 - 零拷贝版本
|
||
Returns:
|
||
states, actions, rewards, dones, valid_actions_ids_list
|
||
"""
|
||
if self.states is None:
|
||
return (
|
||
np.array([], dtype=np.float32),
|
||
np.array([], dtype=np.int64),
|
||
np.array([], dtype=np.float32),
|
||
np.array([], dtype=bool),
|
||
[]
|
||
)
|
||
|
||
# 返回视图而非拷贝(更快)
|
||
return (
|
||
self.states[:self.current_size],
|
||
self.actions[:self.current_size],
|
||
self.rewards[:self.current_size],
|
||
self.dones[:self.current_size],
|
||
self.valid_actions_ids
|
||
)
|
||
|
||
def clear(self):
|
||
"""清空缓冲区"""
|
||
# 不释放内存,只重置指针(避免重新分配)
|
||
self.current_size = 0
|
||
self.valid_actions_ids.clear()
|
||
|
||
def __len__(self):
|
||
return self.current_size
|
||
|