diff --git a/agents/ppo/agent.py b/agents/ppo/agent.py index bc1bcfb..69fb1cc 100644 --- a/agents/ppo/agent.py +++ b/agents/ppo/agent.py @@ -1,4 +1,3 @@ -import os import numpy as np import torch as T @@ -40,7 +39,7 @@ class Agent: print('.. done ...') def choose_action(self, observation): - state = T.tensor([observation], dtype=T.float).to(self.actor.device) + state = T.tensor(observation, dtype=T.float).to(self.actor.device) dist = self.actor(state) value = self.critic(state) diff --git a/tmp/ppo/actor_torch_ppo b/tmp/ppo/actor_torch_ppo index 4e9b721..60dc325 100644 Binary files a/tmp/ppo/actor_torch_ppo and b/tmp/ppo/actor_torch_ppo differ diff --git a/tmp/ppo/critic_torch_ppo b/tmp/ppo/critic_torch_ppo index b3373b1..03a0b45 100644 Binary files a/tmp/ppo/critic_torch_ppo and b/tmp/ppo/critic_torch_ppo differ