Training done
This commit is contained in:
parent
bb43e56f8c
commit
1f91ec9d5d
4 changed files with 4 additions and 4 deletions
|
@ -82,7 +82,7 @@ class ActorNetwork(nn.Module):
|
||||||
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
||||||
|
|
||||||
def load_checkpoint(self, filename = 'actor_torch_ppo'):
|
def load_checkpoint(self, filename = 'actor_torch_ppo'):
|
||||||
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename)))
|
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename), map_location=self.device))
|
||||||
|
|
||||||
|
|
||||||
class CriticNetwork(nn.Module):
|
class CriticNetwork(nn.Module):
|
||||||
|
@ -114,4 +114,4 @@ class CriticNetwork(nn.Module):
|
||||||
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
T.save(self.state_dict(), os.path.join(self.chkpt_dir, filename))
|
||||||
|
|
||||||
def load_checkpoint(self, filename = 'critic_torch_ppo'):
|
def load_checkpoint(self, filename = 'critic_torch_ppo'):
|
||||||
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename)))
|
self.load_state_dict(T.load(os.path.join(self.chkpt_dir, filename), map_location=self.device))
|
||||||
|
|
4
main.py
4
main.py
|
@ -14,8 +14,8 @@ random.seed(1)
|
||||||
np.random.seed(1)
|
np.random.seed(1)
|
||||||
T.manual_seed(1)
|
T.manual_seed(1)
|
||||||
|
|
||||||
n_episodes = 1000
|
n_episodes = 10000
|
||||||
game_len = 10000
|
game_len = 20000
|
||||||
|
|
||||||
figure_file = 'plots/score.png'
|
figure_file = 'plots/score.png'
|
||||||
|
|
||||||
|
|
BIN
tmp/ppo/player_0_actor
Executable file
BIN
tmp/ppo/player_0_actor
Executable file
Binary file not shown.
BIN
tmp/ppo/player_0_critic
Executable file
BIN
tmp/ppo/player_0_critic
Executable file
Binary file not shown.
Loading…
Reference in a new issue