Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #160 from facebookresearch/heiner/addhow
Browse files Browse the repository at this point in the history
Heiner/addhow
  • Loading branch information
Heinrich Kuttler authored Jun 4, 2021
2 parents ed6388c + c5f077c commit cf877c2
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 53 deletions.
3 changes: 2 additions & 1 deletion include/nleobs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#define NLE_MESSAGE_SIZE 256
#define NLE_BLSTATS_SIZE 25
#define NLE_PROGRAM_STATE_SIZE 6
#define NLE_INTERNAL_SIZE 8
#define NLE_INTERNAL_SIZE 9
#define NLE_INVENTORY_SIZE 55
#define NLE_INVENTORY_STR_LENGTH 80
#define NLE_SCREEN_DESCRIPTION_LENGTH 80
Expand All @@ -16,6 +16,7 @@ typedef struct nle_observation {
int action;
int done;
char in_normal_game; /* Bool indicating if other obs are set. */
int how_done; /* If game is really_done, how it ended. */
short *glyphs; /* Size ROWNO * (COLNO - 1) */
unsigned char *chars; /* Size ROWNO * (COLNO - 1) */
unsigned char *colors; /* Size ROWNO * (COLNO - 1) */
Expand Down
20 changes: 11 additions & 9 deletions nle/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,16 +382,18 @@ def step(self, action: int):
done = True

info = {}
if end_status:
# TODO: fix stats
# stats = self._collect_stats(last_observation, end_status)
# stats = stats._asdict()
stats = {}
info["stats"] = stats

if self._stats_logger is not None:
self._stats_logger.writerow(stats)
# TODO: fix stats
# if end_status:
# # stats = self._collect_stats(last_observation, end_status)
# # stats = stats._asdict()
# # stats = {}
# # info["stats"] = stats
#
# # if self._stats_logger is not None:
# # self._stats_logger.writerow(stats)

info["end_status"] = end_status
info["is_ascended"] = self.env.how_done() == nethack.ASCENDED

return self._get_observation(observation), reward, done, info

Expand Down
3 changes: 3 additions & 0 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,6 @@ def get_current_seeds(self):

def in_normal_game(self):
return self._pynethack.in_normal_game()

def how_done(self):
return self._pynethack.how_done()
5 changes: 1 addition & 4 deletions nle/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def rollout_env(env, max_rollout_len):
assert isinstance(done, bool)
assert isinstance(info, dict)
if done:
assert not info["is_ascended"]
break
env.close()
return reward
Expand Down Expand Up @@ -74,10 +75,6 @@ def compare_rollouts(env0, env1, max_rollout_len):
assert reward0 == reward1
assert done0 == done1

if done0:
assert "stats" in info0 # just to be sure
assert "stats" in info1

assert info0 == info1

if done0 or step >= max_rollout_len:
Expand Down
2 changes: 2 additions & 0 deletions nle/tests/test_nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def test_run_n_episodes(self, tmpdir, game, episodes=3):
ch = random.choice(ACTIONS)
_, done = game.step(ch)
if done:
# This will typically be DIED, but could be POISONED, etc.
assert int(game.how_done()) < int(nethack.GENOCIDED)
break

steps += 1
Expand Down
4 changes: 4 additions & 0 deletions nle/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# Copyright (c) Facebook, Inc. and its affiliates.

# Requires
# pip install pytest-benchmark
# to run

import pytest

import numpy as np
Expand Down
4 changes: 4 additions & 0 deletions src/end.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#endif
#include "dlb.h"

extern void FDECL(nle_done, (int));

/* add b to long a, convert wraparound to max value */
#define nowrap_add(a, b) (a = ((a + b) < 0 ? LONG_MAX : (a + b)))

Expand Down Expand Up @@ -1473,6 +1475,8 @@ int how;
/* don't bother counting to see whether it should be plural */
}

nle_done(how);

Sprintf(pbuf, "%s %s the %s...", Goodbye(), plname,
(how != ASCENDED)
? (const char *) ((flags.female && urole.name.f)
Expand Down
14 changes: 12 additions & 2 deletions src/nle.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ vt_char_color_extract(TMTCHAR *c)
case (TMT_COLOR_WHITE):
color = (c->a.bold) ? CLR_WHITE : CLR_GRAY; // c = 15:7
break;
case (TMT_COLOR_MAX):
break;
}

if (c->a.reverse) {
Expand Down Expand Up @@ -220,8 +222,8 @@ nle_fflush(FILE *stream)
/* Only act on fflush(stdout). */
if (stream != stdout) {
fprintf(stderr,
"Warning: nle_flush called with unexpected FILE pointer %d ",
(int) stream);
"Warning: nle_flush called with unexpected FILE pointer %p ",
stream);
return fflush(stream);
}
nle_ctx_t *nle = current_nle_ctx;
Expand Down Expand Up @@ -328,6 +330,14 @@ nethack_exit(int status)
nle_yield(NULL);
}

/* Called in really_done() in end.c to get "how". */
void
nle_done(int how)
{
nle_ctx_t *nle = current_nle_ctx;
nle->observation->how_done = how;
}

nle_seeds_init_t *nle_seeds_init;

/* See rng.c. */
Expand Down
102 changes: 66 additions & 36 deletions win/rl/pynethack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class Nethack
return obs_.in_normal_game;
}

game_end_types
how_done()
{
return static_cast<game_end_types>(obs_.how_done);
}

private:
void
reset(FILE *ttyrec)
Expand Down Expand Up @@ -268,7 +274,8 @@ PYBIND11_MODULE(_pynethack, m)
.def("set_initial_seeds", &Nethack::set_initial_seeds)
.def("set_seeds", &Nethack::set_seeds)
.def("get_seeds", &Nethack::get_seeds)
.def("in_normal_game", &Nethack::in_normal_game);
.def("in_normal_game", &Nethack::in_normal_game)
.def("how_done", &Nethack::how_done);

py::module mn = m.def_submodule(
"nethack", "Collection of NetHack constants and functions");
Expand Down Expand Up @@ -354,6 +361,27 @@ PYBIND11_MODULE(_pynethack, m)
// From monsym.h.
mn.attr("MAXMCLASSES") = py::int_(static_cast<int>(MAXMCLASSES));

// game_end_types from hack.h (used in end.c)
py::enum_<game_end_types>(mn, "game_end_types",
"This is the way the game ends.")
.value("DIED", DIED)
.value("CHOKING", CHOKING)
.value("POISONING", POISONING)
.value("STARVING", STARVING)
.value("DROWNING", DROWNING)
.value("BURNING", BURNING)
.value("DISSOLVED", DISSOLVED)
.value("CRUSHING", CRUSHING)
.value("STONING", STONING)
.value("TURNED_SLIME", TURNED_SLIME)
.value("GENOCIDED", GENOCIDED)
.value("PANICKED", PANICKED)
.value("TRICKED", TRICKED)
.value("QUIT", QUIT)
.value("ESCAPED", ESCAPED)
.value("ASCENDED", ASCENDED)
.export_values();

// "Special" mapglyph
mn.attr("MG_CORPSE") = py::int_(MG_CORPSE);
mn.attr("MG_INVIS") = py::int_(MG_INVIS);
Expand Down Expand Up @@ -392,19 +420,20 @@ PYBIND11_MODULE(_pynethack, m)
[](int glyph) { return glyph_is_warning(glyph); });

py::class_<permonst>(mn, "permonst", "The permonst struct.")
.def("__init__",
// See https://github.com/pybind/pybind11/issues/2394
[](py::detail::value_and_holder &v_h, int index) {
if (index < 0 || index >= NUMMONS)
throw std::out_of_range(
"Index should be between 0 and NUMMONS ("
+ std::to_string(NUMMONS) + ") but got "
+ std::to_string(index));
v_h.value_ptr() = &mons[index];
v_h.inst->owned = false;
v_h.set_holder_constructed(true);
},
py::detail::is_new_style_constructor())
.def(
"__init__",
// See https://github.com/pybind/pybind11/issues/2394
[](py::detail::value_and_holder &v_h, int index) {
if (index < 0 || index >= NUMMONS)
throw std::out_of_range(
"Index should be between 0 and NUMMONS ("
+ std::to_string(NUMMONS) + ") but got "
+ std::to_string(index));
v_h.value_ptr() = &mons[index];
v_h.inst->owned = false;
v_h.set_holder_constructed(true);
},
py::detail::is_new_style_constructor())
.def_readonly("mname", &permonst::mname) /* full name */
.def_readonly("mlet", &permonst::mlet) /* symbol */
.def_readonly("mlevel", &permonst::mlevel) /* base monster level */
Expand Down Expand Up @@ -468,28 +497,29 @@ PYBIND11_MODULE(_pynethack, m)
mn, "objclass",
"The objclass struct.\n\n"
"All fields are constant and don't reflect user changes.")
.def("__init__",
// See https://github.com/pybind/pybind11/issues/2394
[](py::detail::value_and_holder &v_h, int i) {
if (i < 0 || i >= NUM_OBJECTS)
throw std::out_of_range(
"Index should be between 0 and NUM_OBJECTS ("
+ std::to_string(NUM_OBJECTS) + ") but got "
+ std::to_string(i));

/* Initialize. Cannot depend on o_init.c as it pulls
* in all kinds of other code. Instead, do what
* makedefs.c does at set it here.
* Alternative: Get the pointer from the game itself?
* Dangerous!
*/
objects[i].oc_name_idx = objects[i].oc_descr_idx = i;

v_h.value_ptr() = &objects[i];
v_h.inst->owned = false;
v_h.set_holder_constructed(true);
},
py::detail::is_new_style_constructor())
.def(
"__init__",
// See https://github.com/pybind/pybind11/issues/2394
[](py::detail::value_and_holder &v_h, int i) {
if (i < 0 || i >= NUM_OBJECTS)
throw std::out_of_range(
"Index should be between 0 and NUM_OBJECTS ("
+ std::to_string(NUM_OBJECTS) + ") but got "
+ std::to_string(i));

/* Initialize. Cannot depend on o_init.c as it pulls
* in all kinds of other code. Instead, do what
* makedefs.c does at set it here.
* Alternative: Get the pointer from the game itself?
* Dangerous!
*/
objects[i].oc_name_idx = objects[i].oc_descr_idx = i;

v_h.value_ptr() = &objects[i];
v_h.inst->owned = false;
v_h.set_holder_constructed(true);
},
py::detail::is_new_style_constructor())
.def_readonly("oc_name_idx",
&objclass::oc_name_idx) /* index of actual name */
.def_readonly(
Expand Down
5 changes: 4 additions & 1 deletion win/rl/winrl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ NetHackRL::fill_obs(nle_obs *obs)
obs->internal[5] = nle_seeds[0]; /* core */
obs->internal[6] = nle_seeds[1]; /* disp */
obs->internal[7] = u.uhunger;
obs->internal[8] =
u.urexp; /* score (careful! check botl_score() and end.c) */
}

if ((!program_state.something_worth_saving && !program_state.in_moveloop)
Expand Down Expand Up @@ -412,7 +414,8 @@ NetHackRL::fill_obs(nle_obs *obs)
}
}
if (obs->screen_descriptions) {
memcpy(obs->screen_descriptions, &screen_descriptions_, screen_descriptions_.size());
memcpy(obs->screen_descriptions, &screen_descriptions_,
screen_descriptions_.size());
}
}

Expand Down

0 comments on commit cf877c2

Please sign in to comment.