Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(awel): New MessageConverter and more AWEL operators #1039

Merged
merged 6 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,8 @@ thirdparty

# typescript
*.tsbuildinfo
/web/next-env.d.ts
/web/next-env.d.ts

# Ignore awel DAG visualization files
/examples/**/*.gv
/examples/**/*.gv.pdf
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pre-commit: fmt test ## Run formatting and unit tests before committing
test: $(VENV)/.testenv ## Run unit tests
$(VENV_BIN)/pytest dbgpt

.PHONY: test-doc
test-doc: $(VENV)/.testenv ## Run doctests
# -k "not test_" skips tests that are not doctests.
$(VENV_BIN)/pytest --doctest-modules -k "not test_" dbgpt/core

.PHONY: coverage
coverage: setup ## Run tests and report coverage
$(VENV_BIN)/pytest dbgpt --cov=dbgpt
Expand Down
6 changes: 6 additions & 0 deletions dbgpt/app/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def __init__(self, chat_param: Dict):
is_stream=True, dag_name="llm_stream_model_dag"
)

# Get the message version, default is v1 in app
# In v1, we will transform the message to compatible format of specific model
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
self._message_version = chat_param.get("message_version", "v1")

class Config:
"""Configuration for this pydantic object."""

Expand Down Expand Up @@ -185,6 +190,7 @@ async def __call_base(self):
"temperature": float(self.prompt_template.temperature),
"max_new_tokens": int(self.prompt_template.max_new_tokens),
"echo": self.llm_echo,
"version": self._message_version,
}
return payload

Expand Down
24 changes: 24 additions & 0 deletions dbgpt/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,39 @@
CacheValue,
)
from dbgpt.core.interface.llm import (
DefaultMessageConverter,
LLMClient,
MessageConverter,
ModelExtraMedata,
ModelInferenceMetrics,
ModelMetadata,
ModelOutput,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core.interface.message import (
AIMessage,
BaseMessage,
ConversationIdentifier,
HumanMessage,
MessageIdentifier,
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
OnceConversation,
StorageConversation,
SystemMessage,
)
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
from dbgpt.core.interface.prompt import (
BasePromptTemplate,
ChatPromptTemplate,
HumanPromptTemplate,
MessagesPlaceholder,
PromptManager,
PromptTemplate,
StoragePromptTemplate,
SystemPromptTemplate,
)
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.core.interface.storage import (
Expand All @@ -49,14 +61,26 @@
"ModelMessage",
"LLMClient",
"ModelMessageRoleType",
"ModelExtraMedata",
"MessageConverter",
"DefaultMessageConverter",
"OnceConversation",
"StorageConversation",
"BaseMessage",
"SystemMessage",
"AIMessage",
"HumanMessage",
"MessageStorageItem",
"ConversationIdentifier",
"MessageIdentifier",
"PromptTemplate",
"PromptManager",
"StoragePromptTemplate",
"BasePromptTemplate",
"ChatPromptTemplate",
"MessagesPlaceholder",
"SystemPromptTemplate",
"HumanPromptTemplate",
"BaseOutputParser",
"SQLOutputParser",
"Serializable",
Expand Down
24 changes: 23 additions & 1 deletion dbgpt/core/awel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

"""

import logging
from typing import List, Optional

from dbgpt.component import SystemApp
Expand Down Expand Up @@ -39,6 +40,8 @@
)
from .trigger.http_trigger import HttpTrigger

logger = logging.getLogger(__name__)

__all__ = [
"initialize_awel",
"DAGContext",
Expand Down Expand Up @@ -89,14 +92,24 @@ def initialize_awel(system_app: SystemApp, dag_dirs: List[str]):

def setup_dev_environment(
dags: List[DAG],
host: Optional[str] = "0.0.0.0",
host: Optional[str] = "127.0.0.1",
port: Optional[int] = 5555,
logging_level: Optional[str] = None,
logger_filename: Optional[str] = None,
show_dag_graph: Optional[bool] = True,
) -> None:
"""Setup a development environment for AWEL.

Just using in development environment, not production environment.

Args:
dags (List[DAG]): The DAGs.
host (Optional[str], optional): The host. Defaults to "127.0.0.1"
port (Optional[int], optional): The port. Defaults to 5555.
logging_level (Optional[str], optional): The logging level. Defaults to None.
logger_filename (Optional[str], optional): The logger filename. Defaults to None.
show_dag_graph (Optional[bool], optional): Whether show the DAG graph. Defaults to True.
If True, the DAG graph will be saved to a file and open it automatically.
"""
import uvicorn
from fastapi import FastAPI
Expand All @@ -118,6 +131,15 @@ def setup_dev_environment(
system_app.register_instance(trigger_manager)

for dag in dags:
if show_dag_graph:
try:
dag_graph_file = dag.visualize_dag()
if dag_graph_file:
logger.info(f"Visualize DAG {str(dag)} to {dag_graph_file}")
except Exception as e:
logger.warning(
f"Visualize DAG {str(dag)} failed: {e}, if your system has no graphviz, you can install it by `pip install graphviz` or `sudo apt install graphviz`"
)
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger)
trigger_manager.after_register()
Expand Down
137 changes: 134 additions & 3 deletions dbgpt/core/awel/dag/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from abc import ABC, abstractmethod
from collections import deque
from concurrent.futures import Executor
from functools import cache
from typing import Any, Dict, List, Optional, Sequence, Set, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union

from dbgpt.component import SystemApp

Expand Down Expand Up @@ -177,7 +176,10 @@ async def before_dag_run(self):
pass

async def after_dag_end(self):
"""The callback after DAG end"""
"""The callback after DAG end,

This method may be called multiple times, please make sure it is idempotent.
"""
pass


Expand Down Expand Up @@ -299,6 +301,20 @@ def set_dependency(self, nodes: DependencyType, is_upstream: bool = True) -> Non
self._downstream.append(node)
node._upstream.append(self)

def __repr__(self):
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
if self.node_name:
return f"{cls_name}(node_name={self.node_name})"
else:
return f"{cls_name}"

def __str__(self):
return self.__repr__()


def _build_task_key(task_name: str, key: str) -> str:
return f"{task_name}___$$$$$$___{key}"
Expand Down Expand Up @@ -496,6 +512,15 @@ async def _after_dag_end(self) -> None:
tasks.append(node.after_dag_end())
await asyncio.gather(*tasks)

def print_tree(self) -> None:
"""Print the DAG tree"""
_print_format_dag_tree(self)

def visualize_dag(self, view: bool = True, **kwargs) -> Optional[str]:
"""Create the DAG graph"""
self.print_tree()
return _visualize_dag(self, view=view, **kwargs)

def __enter__(self):
DAGVar.enter_dag(self)
return self
Expand All @@ -516,3 +541,109 @@ def _get_nodes(node: DAGNode, is_upstream: Optional[bool] = True) -> set[DAGNode
for node in stream_nodes:
nodes = nodes.union(_get_nodes(node, is_upstream))
return nodes


def _print_format_dag_tree(dag: DAG) -> None:
for node in dag.root_nodes:
_print_dag(node)


def _print_dag(
node: DAGNode,
level: int = 0,
prefix: str = "",
last: bool = True,
level_dict: Dict[str, Any] = None,
):
if level_dict is None:
level_dict = {}

connector = " -> " if level != 0 else ""
new_prefix = prefix
if last:
if level != 0:
new_prefix += " "
print(prefix + connector + str(node))
else:
if level != 0:
new_prefix += "| "
print(prefix + connector + str(node))

level_dict[level] = level_dict.get(level, 0) + 1
num_children = len(node.downstream)
for i, child in enumerate(node.downstream):
_print_dag(child, level + 1, new_prefix, i == num_children - 1, level_dict)


def _print_dag_tree(root_nodes: List[DAGNode], level_sep: str = " ") -> None:
def _print_node(node: DAGNode, level: int) -> None:
print(f"{level_sep * level}{node}")

_apply_root_node(root_nodes, _print_node)


def _apply_root_node(
root_nodes: List[DAGNode],
func: Callable[[DAGNode, int], None],
) -> None:
for dag_node in root_nodes:
_handle_dag_nodes(False, 0, dag_node, func)


def _handle_dag_nodes(
is_down_to_up: bool,
level: int,
dag_node: DAGNode,
func: Callable[[DAGNode, int], None],
):
if not dag_node:
return
func(dag_node, level)
stream_nodes = dag_node.upstream if is_down_to_up else dag_node.downstream
level += 1
for node in stream_nodes:
_handle_dag_nodes(is_down_to_up, level, node, func)


def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
"""Visualize the DAG

Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.

Returns:
Optional[str]: The filename of the DAG graph
"""
try:
from graphviz import Digraph
except ImportError:
logger.warn("Can't import graphviz, skip visualize DAG")
return None

dot = Digraph(name=dag.dag_id)
# Record the added edges to avoid adding duplicate edges
added_edges = set()

def add_edges(node: DAGNode):
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)

for root in dag.root_nodes:
add_edges(root)
filename = f"dag-vis-{dag.dag_id}.gv"
if "filename" in kwargs:
filename = kwargs["filename"]
del kwargs["filename"]

if not "directory" in kwargs:
from dbgpt.configs.model_config import LOGDIR

kwargs["directory"] = LOGDIR

return dot.render(filename, view=view, **kwargs)
Loading
Loading