Skip to content

Commit

Permalink
refactor story progress
Browse files Browse the repository at this point in the history
  • Loading branch information
neph1 committed Oct 31, 2024
1 parent 2d415ad commit 1c54c45
Show file tree
Hide file tree
Showing 21 changed files with 165 additions and 134 deletions.
15 changes: 1 addition & 14 deletions tale/cmds/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,20 +1368,7 @@ def do_flee(player: Player, parsed: base.ParseResult, ctx: util.Context) -> None
pass
raise ActionRefused("You can't flee anywhere!")


@cmd("save")
@disable_notify_action
@disabled_in_gamemode(GameMode.MUD)
def do_save(player: Player, parsed: base.ParseResult, ctx: util.Context) -> Generator:
"""Save your game."""
if not ctx.driver.do_check_savefile_free(player):
if not (yield "input", ("Are you sure you want to overwrite the previous save game?", lang.yesno)):
player.tell("Ok, not saved.")
return
ctx.driver.do_save(player)


@cmd("load", "reload", "restore", "restart")
@cmd("load", "reload", "restore")
@disable_notify_action
@disabled_in_gamemode(GameMode.MUD)
def do_load(player: Player, parsed: base.ParseResult, ctx: util.Context) -> None:
Expand Down
6 changes: 6 additions & 0 deletions tale/cmds/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,3 +904,9 @@ def do_create_item(player: Player, parsed: base.ParseResult, ctx: util.Context)
player.tell(item.name + ' added.', evoke=False)
else:
raise ParseError("Item could not be added")

@wizcmd("restart_story")
def do_restart(player: Player, parsed: base.ParseResult, ctx: util.Context) -> None:
"""Restart the game."""
player.tell("Restarting the game... Please reconnect")
os.execv(sys.executable, ['python3'] + sys.argv)
7 changes: 2 additions & 5 deletions tale/day_cycle/day_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import enum
from typing import List, Protocol
from tale import mud_context
from tale.driver import Driver
from tale.util import GameDateTime, call_periodically

class DayCycleEventObserver(Protocol):
Expand All @@ -20,12 +18,11 @@ class TimeOfDay(str, enum.Enum):
NIGHT = 'Night'
class DayCycle:

def __init__(self, driver: Driver):
def __init__(self, game_clock: GameDateTime):

self.game_date_time = driver.game_clock
self.game_date_time = game_clock
self.current_hour = self.game_date_time.clock.hour
self.observers: List[DayCycleEventObserver] = []
driver.register_periodicals(self)
self._time_of_day = TimeOfDay.DAY

def register_observer(self, observer: DayCycleEventObserver):
Expand Down
13 changes: 6 additions & 7 deletions tale/day_cycle/llm_day_cycle_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@



from typing import List
from tale.day_cycle.day_cycle import DayCycleEventObserver
from tale.driver import Driver
from tale.player import PlayerConnection


class LlmDayCycleListener(DayCycleEventObserver):

def __init__(self, driver):
self.driver = driver # type: Driver
def __init__(self, llm_util, players):
self.llm_util = llm_util
self.players = players # type: List[PlayerConnection]

def on_dawn(self):
self._describe_transition("night", "dawn")
Expand All @@ -26,4 +25,4 @@ def on_midnight(self):
pass

def _describe_transition(self, from_time: str, to_time: str):
self.driver.llm_util.describe_day_cycle_transition(list(self.driver.all_players.values())[0], from_time, to_time)
self.llm_util.describe_day_cycle_transition(self.players, from_time, to_time)
5 changes: 1 addition & 4 deletions tale/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from . import __version__ as tale_version_str, _check_required_libraries
from . import mud_context, errors, util, cmds, player, pubsub, charbuilder, lang, verbdefs, vfs, base
from .story import StoryContext, TickMethod, GameMode, MoneyType, StoryBase
from .story import TickMethod, GameMode, MoneyType, StoryBase
from .tio import DEFAULT_SCREEN_WIDTH
from .races import playable_races
from .errors import StoryCompleted
Expand Down Expand Up @@ -535,9 +535,6 @@ def _server_tick(self) -> None:
events, idle_time, subbers = topicinfo[topicname]
if events == 0 and not subbers and idle_time > 30:
pubsub.topic(topicname).destroy()
progress = self.story.increase_progress(0.0001)
if progress:
self.llm_util.advance_story_section(self.story)

def disconnect_idling(self, conn: player.PlayerConnection) -> None:
raise NotImplementedError
Expand Down
14 changes: 2 additions & 12 deletions tale/json_story.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from tale import load_items
from tale.day_cycle.day_cycle import DayCycle
from tale.day_cycle.llm_day_cycle_listener import LlmDayCycleListener
from tale.items import generic
from tale.llm.dynamic_story import DynamicStory
from tale.player import Player
from tale.random_event import RandomEvent
from tale.story import GameMode, StoryConfig
from tale.story import StoryConfig
import tale.parse_utils as parse_utils
import tale.llm.llm_cache as llm_cache

Expand All @@ -18,7 +15,7 @@ def __init__(self, path: str, config: StoryConfig):


def init(self, driver) -> None:
self.driver = driver
super(JsonStory, self).init(driver)
locs = {}
zones = []
world = parse_utils.load_json(self.path +'world.json')
Expand Down Expand Up @@ -56,13 +53,6 @@ def init(self, driver) -> None:
for item in extra_items:
self._catalogue.add_item(item)

if self.config.day_night:
self.day_cycle = DayCycle(driver)
if self.config.server_mode == GameMode.IF:
self.day_cycle.register_observer(LlmDayCycleListener(self.driver))

if self.config.random_events:
self.random_events = RandomEvent(driver)

def welcome(self, player: Player) -> str:
player.tell("<bright>Welcome to `%s'.</>" % self.config.name, end=True)
Expand Down
2 changes: 1 addition & 1 deletion tale/llm/contexts/BaseContext.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Union

from tale.story import StoryContext
from tale.story_context import StoryContext


class BaseContext(ABC):
Expand Down
23 changes: 20 additions & 3 deletions tale/llm/dynamic_story.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,43 @@
from tale import parse_utils
from tale.base import Item, Living, Location
from tale.coord import Coord
from tale.day_cycle.day_cycle import DayCycle
from tale.day_cycle.llm_day_cycle_listener import LlmDayCycleListener
from tale.item_spawner import ItemSpawner
from tale.llm.LivingNpc import LivingNpc
from tale.quest import Quest, QuestType
from tale.mob_spawner import MobSpawner
from tale.story import StoryBase, StoryContext
from tale.random_event import RandomEvent
from tale.story import GameMode, StoryBase

from tale.story_context import StoryContext
from tale.zone import Zone
import tale.llm.llm_cache as llm_cache

class DynamicStory(StoryBase):



def __init__(self) -> None:
self._zones = dict() # type: dict[str, Zone]
self._world = WorldInfo()
self._catalogue = Catalogue()
if isinstance(self.config.context, str):
self.config.context = StoryContext(self.config.context)

def init(self, driver) -> None:
if self.config.day_night:
self.day_cycle = DayCycle(driver.game_clock)
driver.register_periodicals(self.day_cycle)
if self.config.server_mode == GameMode.IF:
self.day_cycle.register_observer(LlmDayCycleListener(driver.llm_util, driver.all_players.values()))

if self.config.random_events:
self.random_events = RandomEvent(driver.llm_util, driver.all_players.values())

if isinstance(self.config.context, StoryContext):
driver.register_periodicals(self.config.context)

def get_zone(self, name: str) -> Zone:
""" Find a zone by name."""
return self._zones[name]
Expand Down Expand Up @@ -115,8 +134,6 @@ def save(self, save_name: str = '') -> None:
with open(os.path.join(save_path, 'world.json'), "w") as fp:
json.dump(story , fp, indent=4)

if self.driver:
self.config.epoch = self.driver.game_clock.clock.timestamp()
with open(os.path.join(save_path, 'story_config.json'), "w") as fp:
json.dump(parse_utils.save_story_config(self.config), fp, indent=4)

Expand Down
5 changes: 3 additions & 2 deletions tale/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import tale.parse_utils as parse_utils
import tale.llm.llm_cache as llm_cache
from tale.quest import Quest
from tale.story_context import StoryContext
from tale.web.web_utils import copy_single_image
from tale.zone import Zone

Expand Down Expand Up @@ -342,9 +343,9 @@ def set_story(self, story: DynamicStory):
if story.config.image_gen:
self._init_image_gen(story.config.image_gen)

def advance_story_section(self, story: DynamicStory) -> str:
def advance_story_section(self, story_context: StoryContext) -> str:
""" Increase the story progress"""
return self._story_building.advance_story_section(story or self.__story)
return self._story_building.advance_story_section(story_context or self.__story.config.context)

def _init_image_gen(self, image_gen: str):
""" Initialize the image generator"""
Expand Down
9 changes: 5 additions & 4 deletions tale/llm/story_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tale.llm.contexts.AdvanceStoryContext import AdvanceStoryContext
from tale.llm.dynamic_story import DynamicStory
from tale.llm.llm_io import IoUtil
from tale.story_context import StoryContext


class StoryBuilding():
Expand All @@ -23,11 +24,11 @@ def generate_story_background(self, world_mood: int, world_info: str, story_type
request_body = self.default_body
return self.io_util.synchronous_request(request_body, prompt=prompt)

def advance_story_section(self, story: DynamicStory) -> str:
story_context = AdvanceStoryContext(story.config.context)
prompt = self.advance_story_prompt.format(context=story_context.to_prompt_string())
def advance_story_section(self, story_context: StoryContext) -> str:
context = AdvanceStoryContext(story_context)
prompt = self.advance_story_prompt.format(context=context.to_prompt_string())
request_body = self.default_body
result = self.io_util.synchronous_request(request_body, prompt=prompt)
story.config.context.set_current_section(result)
story_context.set_current_section(result)
return result

3 changes: 2 additions & 1 deletion tale/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from tale.npc_defs import StationaryMob, StationaryNpc, Trader
from tale.races import BodyType, UnarmedAttack
from tale.mob_spawner import MobSpawner
from tale.story import GameMode, MoneyType, StoryContext, TickMethod, StoryConfig
from tale.story import GameMode, MoneyType, TickMethod, StoryConfig
from tale.story_context import StoryContext
from tale.skills.weapon_type import WeaponType
import json
import re
Expand Down
15 changes: 4 additions & 11 deletions tale/random_event.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@


import random
from tale import _MudContext
from tale.driver import Driver
from tale.llm.dynamic_story import DynamicStory
from tale.player import PlayerConnection
from tale.util import call_periodically


class RandomEvent:

def __init__(self, driver: Driver):
self.driver = driver
self.driver.register_periodicals(self)
def __init__(self, llm_util, player: PlayerConnection):
self.llm_util = llm_util
self.player = player

@call_periodically(300, 600)
def _random_event(self):
self.narrative_event()

def narrative_event(self):
self.player = list(self.driver.all_players.values())[0] # type: PlayerConnection
self.driver.llm_util.generate_narrative_event(self.player.player.location)
self.llm_util.generate_narrative_event(self.player.player.location)
56 changes: 1 addition & 55 deletions tale/story.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import datetime
import enum
import math
import random
from typing import Optional, Any, List, Set, Generator, Union
from packaging.version import Version

Expand Down Expand Up @@ -74,7 +72,7 @@ def __init__(self) -> None:
self.server_mode = GameMode.IF # the actual game mode the server is operating in (will be set at startup time)
self.items = "" # items to populate the world with. only used by json loading
self.npcs = "" # npcs to populate the world with. only used by json loading
self.context = "" # type: Union[str, StoryContext] # context to giving background for the story.
self.context = "" # context to giving background for the story.
self.type = "" # brief description of the setting and type of story, for LLM context
self.world_info = "" # brief description of the world, for LLM context
self.world_mood = 0 # how safe is the world? 5 is a happy place, -5 is nightmare mode.
Expand Down Expand Up @@ -150,11 +148,6 @@ def story_failure(self, player) -> None:
player.tell("You have failed to complete the story.")
player.tell("\n")

def increase_progress(self, amount: float = 1.0) -> bool:
if isinstance(self.config.context, StoryContext):
return self.config.context.increase_progress(amount)
return False

def _verify(self, driver) -> None:
"""verify correctness and compatibility of the story configuration"""
if not isinstance(self.config, StoryConfig):
Expand All @@ -171,50 +164,3 @@ def _verify(self, driver) -> None:
tale_version_required = Version(self.config.requires_tale)
if tale_version < tale_version_required:
raise StoryConfigError("This game requires tale " + self.config.requires_tale + ", but " + tale_version_str + " is installed.")


class StoryContext:

def __init__(self, base_story: str = "") -> None:
self.base_story = base_story
self.current_section = ""
self.past_sections = []
self.progress = 0.0
self.length = 10.0
self.speed = 1.0

def increase_progress(self, amount: float = 1.0) -> bool:
""" increase the progress by the given amount, return True if the progress has changed past the integer value """
start_progess = math.floor(self.progress)
self.progress += random.random() * amount * self.speed
if self.progress >= self.length:
self.progress = self.length
return start_progess != math.floor(self.progress)

def set_current_section(self, section: str) -> None:
if self.current_section:
self.past_sections.append(self.current_section)
self.current_section = section

def to_context(self) -> str:
return f"<story> Base plot: {self.base_story}; Active section: {self.current_section}</story>"

def to_context_with_past(self) -> str:
return f"<story> Base plot: {self.base_story}; Past: {' '.join(self.past_sections) if self.past_sections else 'This is the beginning of the story'}; Active section:{self.current_section}; Progress: {self.progress}/{self.length};</story>"

def from_json(self, data: dict) -> 'StoryContext':
self.base_story = data.get("base_story", "")
self.current_section = data.get("current_section", "")
self.past_sections = data.get("past_sections", [])
self.progress = data.get("progress", 0.0)
self.length = data.get("length", 10.0)
self.speed = data.get("speed", 1.0)
return self

def to_json(self) -> dict:
return {"base_story": self.base_story,
"current_section": self.current_section,
"past_sections": self.past_sections,
"progress": self.progress,
"length": self.length,
"speed": self.speed}
2 changes: 1 addition & 1 deletion tale/story_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tale.llm.llm_utils import LlmUtil
from tale.player import PlayerConnection
from tale.quest import Quest, QuestType
from tale.story import StoryContext
from tale.story_context import StoryContext
from tale.zone import Zone


Expand Down
Loading

0 comments on commit 1c54c45

Please sign in to comment.