started work on dqn
This commit is contained in:
parent
eacc3ce30a
commit
22bb3091c1
1 changed files with 86 additions and 0 deletions
86
dqn/agent.nim
Normal file
86
dqn/agent.nim
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import std / [random]
|
||||||
|
import arraymancer
|
||||||
|
|
||||||
|
randomize()
|
||||||
|
|
||||||
|
# Define all the hyperparameters used in DDQN
|
||||||
|
type
|
||||||
|
HyperParams* = object
|
||||||
|
# generic hyperparams
|
||||||
|
batchSize: int = 64
|
||||||
|
discountFactor: float = 0.9
|
||||||
|
learningRate: float = 1e-4
|
||||||
|
|
||||||
|
# greedy-ε policy stuff
|
||||||
|
explorationRate: float = 1
|
||||||
|
explorationRateDecay: float = 0.9999975
|
||||||
|
explorationRateMinimum: float = 1e-2
|
||||||
|
|
||||||
|
# ddqn specific (arbitrary defaults)
|
||||||
|
burnin: int = 1000
|
||||||
|
learnEvery: int = 3
|
||||||
|
syncEvery: int = 100
|
||||||
|
memoryLength: int = 10000
|
||||||
|
|
||||||
|
OptimizerName = enum
|
||||||
|
adam
|
||||||
|
sgd
|
||||||
|
sgdMomentum
|
||||||
|
|
||||||
|
OptimizerParams* = object
|
||||||
|
# names are "SGD", "SGDMomentum", or "Adam"
|
||||||
|
name: OptimizerName = adam
|
||||||
|
learning_rate: float = 1e-5
|
||||||
|
momentum: float = 0.0
|
||||||
|
decay: float = 0.0
|
||||||
|
nesterov: bool = false
|
||||||
|
beta1: float = 0.9
|
||||||
|
beta2: float = 0.999
|
||||||
|
epsilon: float = 1e-8
|
||||||
|
|
||||||
|
ActivationFunction = enum
|
||||||
|
relu
|
||||||
|
sigmoid
|
||||||
|
softmax
|
||||||
|
tanh
|
||||||
|
|
||||||
|
NetworkParams* = object
|
||||||
|
hiddenLayersNum: int
|
||||||
|
hiddenLayersSize: int
|
||||||
|
activationFunction: ActivationFunction
|
||||||
|
|
||||||
|
type
|
||||||
|
Agent* = ref object
|
||||||
|
inputDims: int
|
||||||
|
actionDims: int
|
||||||
|
hParams: HyperParams
|
||||||
|
optimParams: OptimizerParams
|
||||||
|
networkParams: NetworkParams
|
||||||
|
save_dir: string = "chkpts"
|
||||||
|
load: bool = false
|
||||||
|
|
||||||
|
proc act*(model: Agent, state: Tensor): int =
|
||||||
|
var actionIndex = rand(model.actionDims)
|
||||||
|
if rand(1.0) > model.hParams.explorationRate:
|
||||||
|
echo "The agent has acted."
|
||||||
|
# TODO: Implement actual model
|
||||||
|
#[var
|
||||||
|
actionValues = model.forward(state)
|
||||||
|
actionIndex = argmax(actionValues).item()]#
|
||||||
|
actionIndex = 1
|
||||||
|
|
||||||
|
model.hParams.explorationRate *= model.hParams.explorationRateDecay
|
||||||
|
model.hParams.explorationRate = max(model.hParams.explorationRate, model.hParams.explorationRateMinimum)
|
||||||
|
|
||||||
|
return actionIndex
|
||||||
|
|
||||||
|
proc cache*(model: Agent, state: Tensor, nextState: Tensor, action: int, reward: float, done: bool) =
|
||||||
|
# TODO: Implement memory (either in Agent type or find another way)
|
||||||
|
model.memory
|
||||||
|
|
||||||
|
|
||||||
|
var
|
||||||
|
model: Agent = Agent(inputDims: 3, actionDims: 3)
|
||||||
|
testor = [1.0,2.0].toTensor()
|
||||||
|
|
||||||
|
var action = model.act(testor)
|
Loading…
Reference in a new issue