started work on dqn

This commit is contained in:
aethrvmn 2024-11-12 23:33:41 +01:00
parent eacc3ce30a
commit 22bb3091c1

86
dqn/agent.nim Normal file
View file

@ -0,0 +1,86 @@
import std / [random]
import arraymancer
randomize()
# Define all the hyperparameters used in DDQN
type
HyperParams* = object
# generic hyperparams
batchSize: int = 64
discountFactor: float = 0.9
learningRate: float = 1e-4
# greedy-ε policy stuff
explorationRate: float = 1
explorationRateDecay: float = 0.9999975
explorationRateMinimum: float = 1e-2
# ddqn specific (arbitrary defaults)
burnin: int = 1000
learnEvery: int = 3
syncEvery: int = 100
memoryLength: int = 10000
OptimizerName = enum
adam
sgd
sgdMomentum
OptimizerParams* = object
# names are "SGD", "SGDMomentum", or "Adam"
name: OptimizerName = adam
learning_rate: float = 1e-5
momentum: float = 0.0
decay: float = 0.0
nesterov: bool = false
beta1: float = 0.9
beta2: float = 0.999
epsilon: float = 1e-8
ActivationFunction = enum
relu
sigmoid
softmax
tanh
NetworkParams* = object
hiddenLayersNum: int
hiddenLayersSize: int
activationFunction: ActivationFunction
type
Agent* = ref object
inputDims: int
actionDims: int
hParams: HyperParams
optimParams: OptimizerParams
networkParams: NetworkParams
save_dir: string = "chkpts"
load: bool = false
proc act*(model: Agent, state: Tensor): int =
var actionIndex = rand(model.actionDims)
if rand(1.0) > model.hParams.explorationRate:
echo "The agent has acted."
# TODO: Implement actual model
#[var
actionValues = model.forward(state)
actionIndex = argmax(actionValues).item()]#
actionIndex = 1
model.hParams.explorationRate *= model.hParams.explorationRateDecay
model.hParams.explorationRate = max(model.hParams.explorationRate, model.hParams.explorationRateMinimum)
return actionIndex
proc cache*(model: Agent, state: Tensor, nextState: Tensor, action: int, reward: float, done: bool) =
# TODO: Implement memory (either in Agent type or find another way)
model.memory
var
model: Agent = Agent(inputDims: 3, actionDims: 3)
testor = [1.0,2.0].toTensor()
var action = model.act(testor)