Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase the size of simple maze example #229

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions simplemaze.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
0, -2680.8, 1558.7802154248689

30 changes: 15 additions & 15 deletions src/main/scala/symsim/examples/concrete/simplemaze/Maze.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package symsim
package examples.concrete.simplemaze

import cats.{Eq, Foldable, Monad}
import cats.{Eq, Foldable, Hash, Monad}
import cats.kernel.BoundedEnumerable

import org.scalacheck.{Arbitrary, Gen}

import symsim.concrete.Randomized

/**
Expand Down Expand Up @@ -51,19 +49,21 @@ object Maze
Agent[MazeState, MazeObservableState, MazeAction, MazeReward, Randomized],
Episodic:

val TimeHorizon: Int = 2000
val TimeHorizon: Int = 1200
val WIDTH : Int = 100
val HEIGHT: Int = 100

def isFinal (s: MazeState): Boolean =
(s._1, s._2) == (4, 3) || (s._1, s._2) == (4, 2) || s._3 == TimeHorizon
(s._1, s._2) == (WIDTH, HEIGHT) || (s._1, s._2) == (WIDTH, HEIGHT) || s._3 >= TimeHorizon

// Maze is discrete
def observe (s: MazeState): MazeObservableState = (s._1, s._2)

// We are not using the original reward function from AIAMA as it
// gives to unstable learning results
private def mazeReward (s: MazeState): MazeReward = (s._1, s._2) match
case (4, 3) => +0.0 // Good final state
case (4, 2) => -1000.0 // Bad final state (dead)
case (WIDTH, HEIGHT) => +0.0 // Good final state
case (WIDTH, y) if y == HEIGHT-1 => -1000.0 // Bad final state (dead)
case (_, _) => -1.0


Expand All @@ -82,7 +82,7 @@ object Maze
if valid (result) then result else s

def valid (s: MazeState): Boolean =
s._1 >= 1 && s._1 <= 4 && s._2 >= 1 && s._2 <= 3 && (s._1, s._2) != (2, 2)
s._1 >= 1 && s._1 <= WIDTH && s._2 >= 1 && s._2 <= HEIGHT && (s._1, s._2) != (2, 2)

val attention = 0.8

Expand All @@ -94,8 +94,8 @@ object Maze
yield (newState, mazeReward (newState))

def initialize: Randomized[MazeState] = for
x <- Randomized.repeat(Randomized.between(1, 4))
y <- Randomized.between(1, 3)
x <- Randomized.repeat(Randomized.between(1, WIDTH))
y <- Randomized.between(1, HEIGHT)
t = 0
s = (x, y, t) if !isFinal(s) && valid (s)
yield s
Expand All @@ -116,8 +116,8 @@ object MazeInstances

given enumState: BoundedEnumerable[MazeObservableState] =
val ss = for
y <- List (1, 2, 3)
x <- List (1, 2, 3, 4)
y <- (1 to Maze.HEIGHT).toList
x <- (1 to Maze.WIDTH).toList
result = (x, y, 0)
if Maze.valid (result)
yield (result._1, result._2)
Expand All @@ -130,9 +130,9 @@ object MazeInstances
given canTestInScheduler: CanTestIn[Randomized] = Randomized.canTestInRandomized

lazy val genMazeState: Gen[MazeState] = for
y <- Gen.choose (1, 3)
x <- Gen.choose (1, 4)
t = 0
y <- Gen.choose (1, Maze.HEIGHT)
x <- Gen.choose (1, Maze.WIDTH)
t <- Gen.choose (1, Maze.TimeHorizon)
if (x != 2 && y != 2)
yield (x.abs, y.abs, t)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,42 @@ class SarsaExperiments
alpha = 0.1,
gamma = 1,
epsilon0 = 0.1,
episodes = 60000,
episodes = 10000,
)

s"SimpleMaze experiment with ${sarsa}" in {

val policies = learnAndLog(sarsa)
.grouped (100)
.take (100)
.flatMap { _.headOption }
.toList

val policy = policies.head

withClue ("1,1") { policy (1, 1) should be (Up) }
withClue ("1,2") { policy (1, 2) should be (Up) }
withClue ("1,3") { policy (1, 3) should be (Right) }
withClue ("2,1") { policy (2, 1) should be (Left) }
withClue ("2,3") { policy (2, 3) should be (Right) }

// We leave 4,3 and 4,2 unconstrained (loosing and winning,
// final states)

// Which of policy is optimal is a bit hard to
// establish, and depends on the constants in the reward
// function. We include several options to decrease flakiness of
// tests (mostly positions in column 3 are sensitive)

// this appears to be still a good move
withClue ("3,1") { policy (3,1) should be (Left) }
// Left is safest but AIAMA reports the dangerous Up
withClue ("3,2") { policy (3,2) should (be (Left) or be (Up)) }
// Up is safest but AIAMA reports the somewhat risky Right
withClue ("3,3") { policy (3,3) should (be (Up) or be (Right)) }
// Left is faster, down is safer
withClue ("4,1") { policy (4, 1) should (be (Down) or be (Left)) }

val results = eval (sarsa, policies)
results.save ("simplemaze.csv")
// .grouped (100)
// .take (100)
// .flatMap { _.headOption }
// .toList
//
// val policy = policies.head
//
// withClue ("1,1") { policy (1, 1) should be (Up) }
// withClue ("1,2") { policy (1, 2) should be (Up) }
// withClue ("1,3") { policy (1, 3) should be (Right) }
// withClue ("2,1") { policy (2, 1) should be (Left) }
// withClue ("2,3") { policy (2, 3) should be (Right) }
//
// // We leave 4,3 and 4,2 unconstrained (loosing and winning,
// // final states)
//
// // Which of policy is optimal is a bit hard to
// // establish, and depends on the constants in the reward
// // function. We include several options to decrease flakiness of
// // tests (mostly positions in column 3 are sensitive)
//
// // this appears to be still a good move
// withClue ("3,1") { policy (3,1) should be (Left) }
// // Left is safest but AIAMA reports the dangerous Up
// withClue ("3,2") { policy (3,2) should (be (Left) or be (Up)) }
// // Up is safest but AIAMA reports the somewhat risky Right
// withClue ("3,3") { policy (3,3) should (be (Up) or be (Right)) }
// // Left is faster, down is safer
// withClue ("4,1") { policy (4, 1) should (be (Down) or be (Left)) }
//
// val results = eval (sarsa, policies)
// results.save ("simplemaze.csv")
}
Loading