defadd(self, item: StateNode): ifnotisinstance(item.action, torch.Tensor): raise ValueError(f"action must be Tensor not {type(item.action)}") self.buffer.append(item)
defsample(self, batch_size): transition = random.sample(self.buffer, batch_size) states = transition[0].state actions = transition[0].action if actions.ndim == 0: actions = actions.unsqueeze(0) rewards = [transition[0].reward] next_states = transition[0].next_state dones = [transition[0].done] iflen(self.buffer) > 1: for node in transition[1:]: states = torch.cat((states, node.state), dim=0) # print(f"sample action is {node.action}, type is {type(node.action)}") if node.action.ndim == 0: node.action = node.action.unsqueeze(0) actions = torch.cat((actions, node.action), dim=0) next_states = torch.cat((next_states, node.next_state), dim=0) rewards.append(node.reward) dones.append(node.done) return states.reshape((batch_size, self.state_dim)), actions.reshape((-1, 1)), torch.tensor(rewards).reshape( (-1, 1)), next_states.reshape((batch_size, self.state_dim)), torch.tensor(dones)
deftake_action(self, state: torch.Tensor): ifnotisinstance(state, torch.Tensor): print(f"error state is {state}") raise ValueError(f"state must be tensor not {type(state)}") p = np.random.random() if p < self.epsilon: action = torch.tensor(np.random.randint(self.action_dim)) else: Qsa = self.policy_net(state) if Qsa.dim() == 1: Qsa = Qsa.unsqueeze(0) _, action = torch.max(Qsa, dim=1) return action
defforward(self, x): A = self.fc_A(self.relu.relu(self.fc1(x))) V = self.fc_V(self.relu.relu(self.fc1(x))) Q = V + A - A.mean(1).view(-1, 1) # Q值由V值和A值计算得到 return Q