started work on dqn
This commit is contained in:
parent
eacc3ce30a
commit
22bb3091c1
1 changed files with 86 additions and 0 deletions
86
dqn/agent.nim
Normal file
86
dqn/agent.nim
Normal file
|
@ -0,0 +1,86 @@
|
|||
import std / [random]
|
||||
import arraymancer
|
||||
|
||||
randomize()
|
||||
|
||||
# Define all the hyperparameters used in DDQN
|
||||
type
|
||||
HyperParams* = object
|
||||
# generic hyperparams
|
||||
batchSize: int = 64
|
||||
discountFactor: float = 0.9
|
||||
learningRate: float = 1e-4
|
||||
|
||||
# greedy-ε policy stuff
|
||||
explorationRate: float = 1
|
||||
explorationRateDecay: float = 0.9999975
|
||||
explorationRateMinimum: float = 1e-2
|
||||
|
||||
# ddqn specific (arbitrary defaults)
|
||||
burnin: int = 1000
|
||||
learnEvery: int = 3
|
||||
syncEvery: int = 100
|
||||
memoryLength: int = 10000
|
||||
|
||||
OptimizerName = enum
|
||||
adam
|
||||
sgd
|
||||
sgdMomentum
|
||||
|
||||
OptimizerParams* = object
|
||||
# names are "SGD", "SGDMomentum", or "Adam"
|
||||
name: OptimizerName = adam
|
||||
learning_rate: float = 1e-5
|
||||
momentum: float = 0.0
|
||||
decay: float = 0.0
|
||||
nesterov: bool = false
|
||||
beta1: float = 0.9
|
||||
beta2: float = 0.999
|
||||
epsilon: float = 1e-8
|
||||
|
||||
ActivationFunction = enum
|
||||
relu
|
||||
sigmoid
|
||||
softmax
|
||||
tanh
|
||||
|
||||
NetworkParams* = object
|
||||
hiddenLayersNum: int
|
||||
hiddenLayersSize: int
|
||||
activationFunction: ActivationFunction
|
||||
|
||||
type
|
||||
Agent* = ref object
|
||||
inputDims: int
|
||||
actionDims: int
|
||||
hParams: HyperParams
|
||||
optimParams: OptimizerParams
|
||||
networkParams: NetworkParams
|
||||
save_dir: string = "chkpts"
|
||||
load: bool = false
|
||||
|
||||
proc act*(model: Agent, state: Tensor): int =
|
||||
var actionIndex = rand(model.actionDims)
|
||||
if rand(1.0) > model.hParams.explorationRate:
|
||||
echo "The agent has acted."
|
||||
# TODO: Implement actual model
|
||||
#[var
|
||||
actionValues = model.forward(state)
|
||||
actionIndex = argmax(actionValues).item()]#
|
||||
actionIndex = 1
|
||||
|
||||
model.hParams.explorationRate *= model.hParams.explorationRateDecay
|
||||
model.hParams.explorationRate = max(model.hParams.explorationRate, model.hParams.explorationRateMinimum)
|
||||
|
||||
return actionIndex
|
||||
|
||||
proc cache*(model: Agent, state: Tensor, nextState: Tensor, action: int, reward: float, done: bool) =
|
||||
# TODO: Implement memory (either in Agent type or find another way)
|
||||
model.memory
|
||||
|
||||
|
||||
var
|
||||
model: Agent = Agent(inputDims: 3, actionDims: 3)
|
||||
testor = [1.0,2.0].toTensor()
|
||||
|
||||
var action = model.act(testor)
|
Loading…
Reference in a new issue