Skip to content

Commit

Permalink
Add get and set participant data functions to code node
Browse files Browse the repository at this point in the history
  • Loading branch information
proteusvacuum committed Dec 30, 2024
1 parent 1f1508b commit 6dc10b7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 3 deletions.
39 changes: 36 additions & 3 deletions apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal

import tiktoken
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.core.validators import validate_email
from jinja2 import meta
Expand All @@ -23,7 +24,7 @@
from apps.channels.datamodels import Attachment
from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history
from apps.chat.models import ChatMessageType
from apps.experiments.models import ExperimentSession, ParticipantData
from apps.experiments.models import Experiment, ExperimentSession, ParticipantData
from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes
from apps.pipelines.nodes.base import NodeSchema, OptionsSource, PipelineNode, PipelineState, UiSchema, Widgets
Expand Down Expand Up @@ -630,6 +631,11 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen

DEFAULT_FUNCTION = """# You must define a main function, which takes the node input as a string.
# Return a string to pass to the next node.
# Available functions:
# - get_participant_data(key_name: str) -> str | None
# - set_participant_data(key_name: str, data: Any) -> None
def main(input: str, **kwargs) -> str:
return input
"""
Expand Down Expand Up @@ -686,15 +692,15 @@ def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineSt
)

custom_locals = {}
custom_globals = self._get_custom_globals()
custom_globals = self._get_custom_globals(state)
try:
exec(byte_code, custom_globals, custom_locals)
result = str(custom_locals[function_name](input))
except Exception as exc:
raise PipelineNodeRunError(exc) from exc
return PipelineState.from_node_output(node_id=node_id, output=result)

def _get_custom_globals(self):
def _get_custom_globals(self, state: PipelineState):
from RestrictedPython.Eval import (
default_guarded_getitem,
default_guarded_getiter,
Expand All @@ -710,10 +716,37 @@ def _get_custom_globals(self):
"_getitem_": default_guarded_getitem,
"_getiter_": default_guarded_getiter,
"_write_": lambda x: x,
"get_participant_data": self._get_participant_data(state),
"set_participant_data": self._set_participant_data(state),
}
)
return custom_globals

def _set_participant_data(self, state: PipelineState):
def set_particpant_data(key_name: str, value: str) -> None:
content_type = ContentType.objects.get_for_model(Experiment)
session = state["experiment_session"]
participant_data, _ = ParticipantData.objects.get_or_create(
participant=session.participant,
content_type=content_type,
object_id=session.experiment.id,
team=session.experiment.team,
)
participant_data.data[key_name] = value
participant_data.save()

return set_particpant_data

def _get_participant_data(self, state: PipelineState):
def get_particpant_data(key_name: str):
session = state["experiment_session"]
participant_data: ParticipantData = ParticipantData.objects.for_experiment(session.experiment).get(
participant=session.participant
)
return participant_data.data.get(key_name)

return get_particpant_data

def _get_custom_builtins(self):
allowed_modules = {
"json",
Expand Down
67 changes: 67 additions & 0 deletions apps/pipelines/tests/test_code_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from apps.experiments.models import ParticipantData
from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.nodes.base import PipelineState
from apps.pipelines.tests.utils import (
Expand All @@ -11,6 +12,7 @@
end_node,
start_node,
)
from apps.utils.factories.experiment import ExperimentSessionFactory
from apps.utils.factories.pipelines import PipelineFactory
from apps.utils.pytest import django_db_with_data

Expand All @@ -20,6 +22,11 @@ def pipeline():
return PipelineFactory()


@pytest.fixture()
def experiment_session():
return ExperimentSessionFactory()


IMPORTS = """
import json
import datetime
Expand Down Expand Up @@ -112,3 +119,63 @@ def test_code_node_runtime_errors(pipeline, code, input, error):
]
with pytest.raises(PipelineNodeRunError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
def test_get_participant_data(pipeline, experiment_session):
ParticipantData.objects.create(
team=experiment_session.team,
content_object=experiment_session.experiment,
participant=experiment_session.participant,
data={"fun_facts": {"personality": "fun loving", "body_type": "robot"}},
)

code = """
def main(input, **kwargs):
return get_participant_data("fun_facts")["body_type"]
"""
nodes = [
start_node(),
code_node(code),
end_node(),
]
assert (
create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=[input]))[
"messages"
][-1]
== "robot"
)


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
def test_update_participant_data(pipeline, experiment_session):
output = "moody"
participant_data = ParticipantData.objects.create(
team=experiment_session.team,
content_object=experiment_session.experiment,
participant=experiment_session.participant,
data={"fun_facts": {"personality": "fun loving", "body_type": "robot"}},
)

code = f"""
def main(input, **kwargs):
facts = get_participant_data("fun_facts")
facts["personality"] = "{output}"
set_participant_data("fun_facts", facts)
return get_participant_data("fun_facts")["personality"]
"""
nodes = [
start_node(),
code_node(code),
end_node(),
]
assert (
create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=[input]))[
"messages"
][-1]
== output
)
participant_data.refresh_from_db()
assert participant_data.data["fun_facts"]["personality"] == output

0 comments on commit 6dc10b7

Please sign in to comment.