initial commit

This commit is contained in:
aethrvmn 2024-10-03 10:15:39 +02:00
commit 3f101abbd9
17 changed files with 1913534 additions and 0 deletions

177
.gitignore vendored Normal file
View file

@ -0,0 +1,177 @@
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python

61
LICENSE Normal file
View file

@ -0,0 +1,61 @@
Don't Be Evil License (DBEL) 1.0
1. Acceptance
By using, copying, modifying, or distributing the source code, training data, training environment, or its associated machine learning model weights (collectively the "Software"), you agree to comply with all terms outlines in this license.
2. Copyright License
The Licensor (defined below) grants you a non-exclusive, worldwide, royalty-free, non-sublicensable, non-transferable clicense to use, copy, modify, and distribute the Software, including associated model weights, training data, and training environments, subject to the conditions set forth in this license. This includes the right to create and distribute derivative works of the Software, provided that the limitations below are observed.
3. Non-Commercial Use Only
You may use, copy, modify, and distribute the Software and derivative works solely for non-commercial purposes. Non-commercial purposes include, but are not limited to:
- Personal research and study.
- Educational and academic projects.
- Public knowledge and hobby projects
- Religious observance.
- Non-commercial research, or AI and machine learning (ML) experimentation.
4. Distribution and Monetization Provisions
Any use of the Software or derivative works for profit, or in a business context, including in monetized services and products, requries explicit, seperate permission from the Licensor. The restrictions on commercial use apply to both the source code and any model weights produced by the Software.
Any distribution must include this license, and the non-commercial restriction must be maintained. Weights resulting from use of the Software, including but not limited to training or fine-tuning models, must be shared under this same license, ensuring all restrictions and conditions are preserved.
5. Integrity of the Licensor's Software
You may not alter, remove, or obscure any functionalities related to payment, donation, or attribution in any distributed version of the Licensed Materials. You must retain all notices of copyright, licensing, and attribution provided by the Licensor in any derivative works.
You may not alter or remove copyright, license, or trademark notices in the Software, and any public mention of the Software must include attribution to the Licensor.
6. Patents
This license grants you a patent license under any patents held by the Licensor that are directly related to the Software. If you or your company make any claim that the Software infringes on a patent, your rights under this license terminate immediately.
7. Distribution of Modifications
If you modify the Software, you must
- Provide prominent and clear notice of any modifications
- Retain all original notices of copyright, licensing, and attribution to the Licensor.
- Distribute modified versions under this license.
8. Fair Use
Nothing under this license restricts your rights under applicable laws regarding fair use of copyrighted material.
9. No Other Rights
These terms do not allow you to sublicense, assign, or transfer any of your rights to third parties, except as expressly allowed by the terms.
These terms do not prevent the Licensor from granting licenses to anyone else.
These terms do not imply any other licenses.
No other rights beyond those explicitly stated are granted.
10. Termination
Your rights under this license will automatically terminate if you breach any of its terms. The Licensor may provide you with a 30-day period to rectify any breach. If you fail to do so, or if you breach the terms again after rectification, your license will terminate permanently.
11. Disclaimer of Warranty
The Licensed Materials are provides “as-is”, without any warranties, express or implied, including but not limited to warranties of fitness for a particular purpose. The Licensor is not liable for any claims or damages arising from your use of the Licensed Materials.
12. Definitions
- "Licensor": The entity or individual offering the Licensed Materials under this license.
- "Licensed Materials": The software, source code, training data, training environment, model weights, and any associated AI/ML components provided under this license.
- "You": The individual or entity accepting the terms of this license, including any organization or entity that this individual or entity might work for or represent, including any entities under common control.
- "Your license": The license granted to you for the software under this terms.
- "Model weights": The machine learning model parameters generated by training or fine-tuning models using the Licensed Materials.
- "Use": Anything you do with the software requiring your license
- "Trademark": Trademarks, service marks, and similar rights.

1
README.md Normal file
View file

@ -0,0 +1 @@
# Lyceum

65
lexicon/newnew.py Normal file
View file

@ -0,0 +1,65 @@
import os
from tqdm import tqdm
class Tokenizer:
def __init__(self):
self.special_tokens = ["<PAD>", "<UNK>", "<CLS>", "<SEP>", "<MASK>", "<SOS>", "<EOS>", "<BOS>"]
self.token_dir = "tokens"
self.token_filename = "tokens.txt"
def _load_list_from_file(self, file_path):
if not os.path.exists(file_path):
raise NameError(f"File {file_path} not found. Are you sure it exists and/or that the name is correct?")
with open(file_path, 'r', encoding='utf-8') as f:
return set(line.strip() for line in f.readlines())
def generate_tokens(
self,
text_file,
prefix_file=os.path.join("tokens", "prefixes.txt"),
root_file=os.path.join("tokens", "roots.txt"),
suffix_file=os.path.join("tokens", "suffix.txt"),
):
self.token_set = set()
self.prefixes = self._load_list_from_file(prefix_file)
self.root_words = self._load_list_from_file(root_file)
self.suffixes = self._load_list_from_file(suffix_file)
self.vocab = self._load_list_from_file(text_file)
for compound_word in tqdm(self.vocab):
compound_word = compound_word.strip()
print(compound_word)
for root_word in sorted(self.root_words, key=len):
if root_word in compound_word:
print(compound_word)
self.token_set.add(root_word)
compound_word = compound_word.replace(root_word, '')
print('--------------------------------------------')
break
print(compound_word)
for prefix in sorted(self.prefixes, key=len, reverse=True):
if compound_word.startswith(prefix):
word_prefix = prefix
print(f"Prefix {prefix}")
compound_word = compound_word[len(prefix):]
break
print(compound_word)
for suffix in sorted(self.suffixes, key=len, reverse=True):
if compound_word.endswith(suffix):
word_suffix = suffix
print(f"Suffix {suffix}")
compound_word = compound_word[:-len(suffix)]
break
print(compound_word)
print('\n')
if __name__ == '__main__':
tokenizer = Tokenizer()
tokenizer.generate_tokens("tokens/prefixes.txt", "tokens/roots.txt", "tokens/suffixes.txt", "tokens/words.txt")

133
lexicon/newtokenizer.py Normal file
View file

@ -0,0 +1,133 @@
import os
import re
from tqdm import tqdm
class Tokenizer:
def __init__(self):
self.prefixes = set()
self.suffixes = set()
self.vocab = set()
self.special_tokens = ["<PAD>", "<UNK>", "<CLS>", "<SEP>", "<MASK>"]
self.token_file = "tokens.txt"
def _load__file(self, file_path):
if not os.path.exists(file_path):
return set()
with open(file_path, 'r', encoding='utf-8') as f:
return set(line.strip() for line in f.readlines())
def generate_tokens(self, prefix_file, text_file, suffix_file):
self.prefixes = self._load_list_from_file(prefix_file)
self.suffixes = self._load_list_from_file(suffix_file)
with open(text_file, 'r', encoding='utf-8') as f:
words = set(line.strip() for line in f.readlines())
# Add single character tokens with trailing space (e.g., "a ", "I ")
self.vocab = {w + ' ' if len(w) == 1 else w for w in words}
# Process each word in text file
for word in tqdm(words):
tokens = self._split_word(word)
self.vocab.update(tokens)
print(f"{word} -> {tokens}")
# Save all tokens to file
self._save_tokens()
def _save_tokens(self):
with open(self.token_file, 'w', encoding='utf-8') as f:
for token in self.special_tokens:
f.write(token + '\n')
for token in sorted(self.vocab, key=len, reverse=True):
f.write(token + '\n')
def _split_word(self, word):
tokens = []
# Check for prefixes
prefix_found = False
for prefix in sorted(self.prefixes, key=len, reverse=True):
if word.startswith(prefix):
tokens.append(prefix)
word = word[len(prefix):]
prefix_found = True
break
# Check for suffixes
suffix_found = False
for suffix in sorted(self.suffixes, key=len, reverse=True):
if word.endswith(suffix):
tokens.append(suffix)
word = word[:-len(suffix)]
suffix_found = True
break
# Split remaining middle part
middle_tokens = self._split_compound_word(word)
tokens.extend(middle_tokens)
return tokens
def _split_compound_word(self, word):
tokens = []
if not word:
return tokens
# Special handling of compound words with special characters (except '.')
split_pattern = re.compile(r"([^\w.])")
parts = re.split(split_pattern, word)
print(parts)
for part in parts:
part = part.strip() # Clean up any leading/trailing whitespace
if part == '':
continue
if re.match(r"[^\w.]", part): # If it's a special character (except '.')
# Attach previous token with the special character (e.g., p-)
if tokens:
tokens[-1] += part
else:
# If no previous token exists, treat as a special token
tokens.append(part)
else:
# Process the part to find tokens
sub_tokens = self._split_by_vocab(part)
if not sub_tokens:
# If no tokens found in vocab, add as a fallback special token
sub_tokens = [part]
tokens.extend(sub_tokens)
print(tokens)
# Ensure that any trailing special characters are included
if tokens and re.match(r"[^\w.]", word[-1]):
tokens[-1] += word[-1]
# Replace empty tokens with a fallback special token
tokens = [token if token else "<UNK>" for token in tokens]
print(tokens)
return tokens
def _split_by_vocab(self, word):
"""Helper method to split a word by longest matching tokens in the vocab."""
tokens = []
if not word:
return tokens
# Start from the longest match down to shortest
for w in sorted(self.vocab, key=len, reverse=True):
if word.startswith(w):
tokens.append(w)
remainder = word[len(w):]
tokens.extend(self._split_by_vocab(remainder))
break
return tokens
if __name__ == '__main__':
tokenizer = Tokenizer()
tokenizer.generate_tokens("tokens/prefixes.txt", "tokens/words.txt", "tokens/suffixes.txt")

26
lexicon/tokenizer.py Normal file
View file

@ -0,0 +1,26 @@
import re
from collections import Counter
import torch
import hparams
def tokenize(filename):
with open(filename, 'r') as f:
text = f.read()
tokens = re.findall(r'\S+', text.lower())
return tokens
def build_vocab(tokens, max_vocab_size):
freq = Counter(tokens)
vocab = sorted(freq, key=freq.get, reverse=True)[:max_vocab_size]
vocab.insert(0, "<PAD>")
vocab.insert(1, "<UNK>")
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
return word_to_idx
def numericalize(tokens, word_to_idx):
return [word_to_idx.get(token, word_to_idx["<UNK>"]) for token in tokens]
def stringify(indices, idx_to_word):
return ' '.join([idx_to_word[idx] for idx in indices])

File diff suppressed because it is too large Load diff

1997
lexicon/tokens/prefixes.txt Normal file

File diff suppressed because it is too large Load diff

6
lexicon/tokens/remdup.py Normal file
View file

@ -0,0 +1,6 @@
with open('tokens.txt', 'r', encoding='utf-8') as f:
words = set(line.strip() for line in f.readlines())
with open('roots.txt', 'w', encoding='utf-8') as f:
for token in words:
f.write(token+'\n')

316215
lexicon/tokens/roots.txt Normal file

File diff suppressed because it is too large Load diff

1100
lexicon/tokens/suffixes.txt Normal file

File diff suppressed because it is too large Load diff

991873
lexicon/tokens/words.txt Normal file

File diff suppressed because it is too large Load diff

1051
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

19
pyproject.toml Normal file
View file

@ -0,0 +1,19 @@
[tool.poetry]
name = "lyceum"
version = "0.0.1"
description = "A school for RL pupils studying NLP"
authors = ["aethrvmn <aethrvmn@apotheke.earth>"]
license = "DBEL 1.0"
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.12"
torch = "^2.4.1"
numpy = "^2.1.1"
matplotlib = "^3.9.2"
tqdm = "^4.66.5"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

128
students/sac/agent.py Normal file
View file

@ -0,0 +1,128 @@
import os
import torch
import torch.nn.functional as F
import numpy as np
from buffer import ReplayBuffer
from networks import ActorNetwork, CriticNetwork, ValueNetwork
class SoftActorCritic():
def __init__(self, alpha=3e-4, beta=3e-4, input_dims=[8],
env=None, gamma=0.99, tau=5e-3, n_actions=2, max_size=1000000,
batch_size=256, reward_scale=2):
self.gamma = gamma
self.tau = tau
self.memory = ReplayBuffer(max_size, input_dims, n_actions)
self.batch_size = batch_size
self.n_actions = n_actions
self.actor = ActorNetwork(alpha, input_dims, n_actions=n_actions, max_action=env.action_space.high)
self.critic1 = CriticNetwork(beta, input_dims, n_actions=n_actions, name='critic1')
self.critic2 = CriticNetwork(beta, input_dims, n_actions=n_actions, name='critic2')
self.value = ValueNetwork(beta, input_dims, name='value')
self.target_value = ValueNetwork(beta, input_dims, name='target_value')
self.scale = reward_scale
self.update_network_parameters(tau=1)
def choose_action(self, observation):
state = T.Tensor([observation]).to(self.actor.device)
actions, _ = self.actor.sample_normal(state, reparametrize=False)
return actions.cpu().detach().numpy()[0]
def remember(self, state, action, reward, new_state, done):
self.memory.store_transition(state,action,reward,new_state, done)
def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
target_value_params = self.target_value.names_parameters()
value_params = self.value.named_parameters()
target_value_state_dict = dict(target_value_params)
value_state_dict = dict(value_params)
for name in value_state_dict:
value_state_dict[name] = tau*value_state_dict[name].clone() + \
(1-tau)*target_value_state_dict[name].clone()
self.target_value.load_state_dict(value_state_dict)
def save_models(self):
print('... saving models ...')
self.actor.save()
self.critic1.save()
self.critic2.save()
self.value.save()
self.target_value.save()
def load_models(self):
print('... loading models ...')
self.actor.load()
self.critic1.load()
self.critic2.load()
self.value.load()
self.target_value.load()
def learn(self):
if self.memory.mem_cntr < self.batch_size:
return
state, action, reward, new_state, done =\
self.memory.sample_buffer(self.batch_size)
reward = T.tensor(reward, dtype=T.float).to(self.actor.device)
done = T.tensor(done).to(self.actor.device)
new_state = T.tensor(new_state, dtype=T.float).to(self.actor.device)
state = T.tensor(state, dtype=T.float).to(self.actor.device)
action = T.tensor(action, dtpye=T.float).to(self.actor.device)
value = self.value(state).view(-1)
target_value = self.target_value(new_state).view(-1)
target_value[done] = 0.0
actions, log_probs = self.actor.sample_normal(state, reparameterize=False)
log_probs = log_probs.view(-1)
q1_new_policy = self.critic1.forward(state, actions)
q2_new_policy = self.critic2.forward(state, actions)
critic_value = T.min(q1_new_policy, q2_new_policy)
critic_value = critic_Value.view(-1)
self.value_optimizer.zero_grad()
value_target = critic_value - log_probs
value_loss = 0.5*F.mse_loss(value, value_target)
value_loss.backward(retain_graph=True)
self.value.optimizer.step()
actions, log_probs = self.actor.sample_normal(state, reparametrize=True)
log_probs = log_probs.view(-1)
q1_new_policy = self.critic1.forward(state, actions)
q2_new_policy = self.critic2.forward(state, actions)
critic_value = T.min(q1_new_policy, q2_new_policy)
critic_value = critic_Value.view(-1)
actor_loss = log_probs - critic_value
actor_loss = T.mean(actor_loss)
self.actor.optimizer.zero_grad()
actor_loss.backward(retain_graph=True)
self.actor.optimizer.step()
self.critic1.optimizer.zero_grad()
self.critic2.optimizer.zero_grad()
q_hat = self.scale * reward + self.gamma*new_value
q1_old_policy = self.critic1.forward(state, action).view(-1)
q2_old_policy = self.critic2.forward(state, action).view(-1)
critic1_loss = 0.5*F.mse_loss(q1_old_policy, q_hat)
critic2_loss = 0.5*F.mse_loss(q2_old_policy, q_hat)
critic_loss = critic1_loss + critic2_loss
critic_loss.backward()
self.critic1.optimizer.step()
self.critic2.optimizer.step()
self.update_network_parameters()

139
students/sac/brain.py Normal file
View file

@ -0,0 +1,139 @@
import os
import torch as T
import otrch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np
class CriticNetwork(nn.Module):
def __init__(self, beta, input_dims, n_actions, name='critic', chkpt_dir='tmp/sac'):
super(CriticNetwork, self).__init__()
self.input_dims = input_dims
self.fc_size = input_dims // 2
self.n_actions = n_actions
self.name = name
self.chkpt_dir = chkpt_dir
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
self.input_layer = nn.Linear(self.input_dims[0]+n_actions, self.fc_size)
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
self.output_layer = nn.Linear(self.fc_dims, 1)
self.optimizer = optim.Adam(self.parameters(), lr=beta, betas=(0.9, 0.9))
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state, action):
action_value = self.input_layer(T.cat([state, action], dim=1))
action_calue = F.tanh(action_value)
action_value = self.middle_layer(action_value)
action_calue = F.tanh(action_value)
action_value = self.middle_layer(action_value)
action_calue = F.tanh(action_value)
action_value = self.middle_layer(action_value)
action_value = F.tanh(action_value)
q = self.output_layer(action_value)
return q
def save(self):
T.save(self.state_dict(), self.chkpt_file)
def load(self):
self.load_state_dict(T.load(self.chkpt_file))
class ValueNetwork(nn.Module):
def __init__(self, beta, input_dims, name='value', chkpt_dir='tmp/sac'):
super(ValueNetwork, self).__init__()
self.input_dims = input_dims
self.fc_size = input_dims // 2
self.name = name
self.chkpt_dir = chkpt_dir
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
self.input_layer = nn.Linear(*self.input_dims, self.fc_size)
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
self.output_layer = nn.Linear(self.fc_size, 1)
self.optimizer = optim.Adam(self.parameters(), lr=beta, betas=(0.9, 0.9))
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state):
state_value = self.input_layer(state)
state_value = F.tanh(state_value)
state_value = self.middle_layer(state)
state_value = F.tanh(state_value)
v = self.output_layer(state_value)
return v
def save(self):
T.save(self.state_dict(), self.chkpt_file)
def load(self):
self.load_state_dict(T.load(self.chkpt_file))
class ActorNetwork(nn.Module):
def __init__(self, alpha, input_dims, n_actions, max_action, name='actor', chkpt_dir='tmp/sac'):
super(ActorNetwork, self).__init__()
self.input_dims = input_dims
self.fc_size = input_dims // 4
self.n_actions = n_actions
self.max_action = max_action
self.name = name
self.chkpt_dir = chkpt_dir
self.chkpt_file = os.path.join(self.chkpt_dir, name+'_sac')
self.reparam_noise = 1e-6
self.input_layer = nn.Linear(*self.input_dims, self.fc_size)
self.middle_layer = nn.Linear(self.fc_size, self.fc_size)
self.mu = nn.Linear(self.fc_size, self.n_actions)
self.sigma = nn.Linear(self.fc_size, self.n_actions)
self.optimizer = optim.Adam(self.parameters(), lr=alpha, betas=(0.9, 0.9))
self.device = T.device('cuda' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state):
prob = self.input_layer(state)
prob = F.tanh(prob)
prob = self.middle_layer(prob)
prob = F.tanh(prob)
prob = self.middle_layer(prob)
mu = self.mu(prob)
sigma = self.sigma(prob)
sigma = T.clamp(sigma, min=self.reparam_noise, max=1)
return mu, sigma
def sample_normal(self, state, reparametrize=True):
mu, sigma = self.forward(state)
probabilities = Normal(mu, sigma)
if reparametrize:
actions = probabilities.rsample()
else:
actions = probabilities.sample()
action = T.tanh(actions)*T.tensor(self.max_action).to(self.device)
log_probs = probabilities.log_prob(actions)
log_probs = T.log(1-action.pow(2)+self.reparam_noise)
log_probs = log_probs.sum(1, keepdim=True)
return action, log_probs
def save(self):
T.save(self.state_dict(), self.chkpt_file)
def load(self):
self.load_state_dict(T.load(self.chkpt_file))

36
students/sac/buffer.py Normal file
View file

@ -0,0 +1,36 @@
import numpy as np
class ReplayBuffer():
def __init__(self, max_size, input_shape, n_actions):
self.mem_size = max_size
self.mem_cntr = 0
self.state_memory = np.zeros((self.mem_size, *input_shape))
self.new_state_memory = np.zeros((self.mem_size, *input_shape))
self.action_memory = np.zeros((self.mem_size, n_actions))
self.reward_memory = np.zeros(self.mem_size)
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)
def store_transition(self, state, action, reward, new_state, done):
index = self.mem_cntr % self.mem_size
self.state_memory[index] = state
self.new_state_memory[index] = new_state
self.action_memory[index] = action
self.reward_memory[index] = reward
self.terminal_memory[index] = done
self.mem_cntr += 1
def sample_buffer(self, batch_size):
max_mem = min(self.mem_cntr, self.mem_size)
batch = np.random.choice(max_mem, batch_size)
states = self.state_memory[batch]
new_states = self.new_state_memory[batch]
actions = self.action_memory[batch]
rewards = self.reward_memory[batch]
dones = self.terminal_memory[batch]
return states, actions, rewards, new_states, dones