Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CodeNode #992

Merged
merged 5 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def logger(self):

class Widgets(StrEnum):
expandable_text = "expandable_text"
code = "code"
toggle = "toggle"
select = "select"
float = "float"
Expand Down
121 changes: 120 additions & 1 deletion apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import datetime
import inspect
import json
import time
from typing import Literal

import tiktoken
Expand All @@ -14,13 +17,14 @@
from pydantic.config import ConfigDict
from pydantic_core import PydanticCustomError
from pydantic_core.core_schema import FieldValidationInfo
from RestrictedPython import compile_restricted, safe_builtins, safe_globals

from apps.assistants.models import OpenAiAssistant
from apps.channels.datamodels import Attachment
from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history
from apps.chat.models import ChatMessageType
from apps.experiments.models import ExperimentSession, ParticipantData
from apps.pipelines.exceptions import PipelineNodeBuildError
from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes
from apps.pipelines.nodes.base import NodeSchema, OptionsSource, PipelineNode, PipelineState, UiSchema, Widgets
from apps.pipelines.tasks import send_email_from_pipeline
Expand Down Expand Up @@ -622,3 +626,118 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen
return AgentAssistantChat(adapter=adapter, history_manager=history_manager)
else:
return AssistantChat(adapter=adapter, history_manager=history_manager)


DEFAULT_FUNCTION = """# You must define a main function, which takes the node input as a string.
# Return a string to pass to the next node.
def main(input: str, **kwargs) -> str:
return input
"""


class CodeNode(PipelineNode):
"""Runs python"""

model_config = ConfigDict(json_schema_extra=NodeSchema(label="Python Node"))
code: str = Field(
default=DEFAULT_FUNCTION,
description="The code to run",
json_schema_extra=UiSchema(widget=Widgets.code),
)

@field_validator("code")
def validate_code(cls, value, info: FieldValidationInfo):
if not value:
value = DEFAULT_FUNCTION
try:
byte_code = compile_restricted(
value,
filename="<inline code>",
mode="exec",
)
custom_locals = {}
exec(byte_code, {}, custom_locals)

try:
main = custom_locals["main"]
except KeyError:
raise SyntaxError("You must define a 'main' function")

for name, item in custom_locals.items():
if name != "main" and inspect.isfunction(item):
raise SyntaxError(
"You can only define a single function, 'main' at the top level. "
"You may use nested functions inside that function if required"
)

if list(inspect.signature(main).parameters) != ["input", "kwargs"]:
raise SyntaxError("The main function should have the signature main(input, **kwargs) only.")

except SyntaxError as exc:
raise PydanticCustomError("invalid_code", "{error}", {"error": exc.msg})
return value

def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState:
function_name = "main"
byte_code = compile_restricted(
self.code,
filename="<inline code>",
mode="exec",
)

custom_locals = {}
custom_globals = self._get_custom_globals()
try:
exec(byte_code, custom_globals, custom_locals)
result = str(custom_locals[function_name](input))
except Exception as exc:
raise PipelineNodeRunError(exc) from exc
return PipelineState.from_node_output(node_id=node_id, output=result)

def _get_custom_globals(self):
from RestrictedPython.Eval import (
default_guarded_getitem,
default_guarded_getiter,
)

custom_globals = safe_globals.copy()
custom_globals.update(
{
"__builtins__": self._get_custom_builtins(),
"json": json,
"datetime": datetime,
"time": time,
"_getitem_": default_guarded_getitem,
"_getiter_": default_guarded_getiter,
"_write_": lambda x: x,
}
)
return custom_globals

def _get_custom_builtins(self):
allowed_modules = {
"json",
"re",
"datetime",
"time",
}
custom_builtins = safe_builtins.copy()
custom_builtins.update(
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not for this PR) It would be nice to have some way of telling/showing the user which builtin modules and imports are available. Maybe we can utilize codemirror's autocomplete

"min": min,
"max": max,
"sum": sum,
"abs": abs,
"all": all,
"any": any,
"datetime": datetime,
}
)

def guarded_import(name, *args, **kwargs):
if name not in allowed_modules:
raise ImportError(f"Importing '{name}' is not allowed")
return __import__(name, *args, **kwargs)

custom_builtins["__import__"] = guarded_import
return custom_builtins
16 changes: 16 additions & 0 deletions apps/pipelines/tests/data/CodeNode.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"description": "Runs python",
"properties": {
"code": {
"default": "# You must define a main function, which takes the node input as a string.\n# Return a string to pass to the next node.\ndef main(input: str, **kwargs) -> str:\n return input\n",
"description": "The code to run",
"title": "Code",
"type": "string",
"ui:widget": "code"
}
},
"title": "CodeNode",
"type": "object",
"ui:flow_node_type": "pipelineNode",
"ui:label": "Python Node"
}
114 changes: 114 additions & 0 deletions apps/pipelines/tests/test_code_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import json
from unittest import mock

import pytest

from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.nodes.base import PipelineState
from apps.pipelines.tests.utils import (
code_node,
create_runnable,
end_node,
start_node,
)
from apps.utils.factories.pipelines import PipelineFactory
from apps.utils.pytest import django_db_with_data


@pytest.fixture()
def pipeline():
return PipelineFactory()


IMPORTS = """
import json
import datetime
import re
import time
def main(input, **kwargs):
return json.loads(input)
"""


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "output"),
[
("def main(input, **kwargs):\n\treturn f'Hello, {input}!'", "World", "Hello, World!"),
("", "foo", "foo"), # No code just returns the input
("def main(input, **kwargs):\n\t'foo'", "", "None"), # No return value will return "None"
(IMPORTS, json.dumps({"a": "b"}), str(json.loads('{"a": "b"}'))), # Importing json will work
],
)
def test_code_node(pipeline, code, input, output):
nodes = [
start_node(),
code_node(code),
end_node(),
]
assert create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1] == output


EXTRA_FUNCTION = """
def other(foo):
return f"other {foo}"

def main(input, **kwargs):
return other(input)
"""


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "error"),
[
("this{}", "", "SyntaxError: invalid syntax at statement: 'this{}"),
(
EXTRA_FUNCTION,
"",
(
"You can only define a single function, 'main' at the top level. "
"You may use nested functions inside that function if required"
),
),
("def other(input):\n\treturn input", "", "You must define a 'main' function"),
(
"def main(input, others, **kwargs):\n\treturn input",
"",
r"The main function should have the signature main\(input, \*\*kwargs\) only\.",
),
],
)
def test_code_node_build_errors(pipeline, code, input, error):
nodes = [
start_node(),
code_node(code),
end_node(),
]
with pytest.raises(PipelineNodeBuildError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "error"),
[
(
"import collections\ndef main(input, **kwargs):\n\treturn input",
"",
"Importing 'collections' is not allowed",
),
("def main(input, **kwargs):\n\treturn f'Hello, {blah}!'", "", "name 'blah' is not defined"),
],
)
def test_code_node_runtime_errors(pipeline, code, input, error):
nodes = [
start_node(),
code_node(code),
end_node(),
]
with pytest.raises(PipelineNodeRunError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]
12 changes: 12 additions & 0 deletions apps/pipelines/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,15 @@ def extract_structured_data_node(provider_id: str, provider_model_id: str, data_
"data_schema": data_schema,
},
}


def code_node(code: str | None = None):
if code is None:
code = "return f'Hello, {input}!'"
return {
"id": str(uuid4()),
"type": nodes.CodeNode.__name__,
"params": {
"code": code,
},
}
Loading
Loading