Skip to content

Commit

Permalink
Vector Data visualize for Chat Data (#2172)
Browse files Browse the repository at this point in the history
Signed-off-by: shanhaikang.shk <[email protected]>
  • Loading branch information
GITHUBear authored Dec 17, 2024
1 parent 433550b commit ed96b95
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 2 deletions.
5 changes: 4 additions & 1 deletion dbgpt/app/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _generate_numbered_list(self) -> str:
},
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
"response_scatter_chart": "Suitable for exploring relationships between variables, detecting outliers, etc."
},
{
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
Expand All @@ -527,6 +527,9 @@ def _generate_numbered_list(self) -> str:
{
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
},
{
"response_vector_chart": "Suitable for projecting high-dimensional vector data onto a two-dimensional plot through the PCA algorithm."
},
]

return "\n".join(
Expand Down
53 changes: 53 additions & 0 deletions dbgpt/app/scene/chat_db/auto_execute/out_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import xml.etree.ElementTree as ET
from typing import Dict, NamedTuple

import numpy as np
import pandas as pd
import sqlparse

from dbgpt._private.config import Config
Expand Down Expand Up @@ -68,6 +70,52 @@ def parse_prompt_response(self, model_out_text):
logger.error(f"json load failed:{clean_str}")
return SqlAction("", clean_str, "", "")

def parse_vector_data_with_pca(self, df):
try:
from sklearn.decomposition import PCA
except ImportError:
raise ImportError(
"Could not import scikit-learn package. "
"Please install it with `pip install scikit-learn`."
)

nrow, ncol = df.shape
if nrow == 0 or ncol == 0:
return df, False

vec_col = -1
for i_col in range(ncol):
if isinstance(df.iloc[:, i_col][0], list):
vec_col = i_col
break
elif isinstance(df.iloc[:, i_col][0], bytes):
sample = df.iloc[:, i_col][0]
if isinstance(json.loads(sample.decode()), list):
vec_col = i_col
break
if vec_col == -1:
return df, False
vec_dim = len(json.loads(df.iloc[:, vec_col][0].decode()))
if min(nrow, vec_dim) < 2:
return df, False
df.iloc[:, vec_col] = df.iloc[:, vec_col].apply(
lambda x: json.loads(x.decode())
)
X = np.array(df.iloc[:, vec_col].tolist())

pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

new_df = pd.DataFrame()
for i_col in range(ncol):
if i_col == vec_col:
continue
col_name = df.columns[i_col]
new_df[col_name] = df[col_name]
new_df["__x"] = [pos[0] for pos in X_pca]
new_df["__y"] = [pos[1] for pos in X_pca]
return new_df, True

def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
api_call_element = ET.Element("chart-view")
Expand All @@ -83,6 +131,11 @@ def parse_view_response(self, speak, data, prompt_response) -> str:
if prompt_response.sql:
df = data(prompt_response.sql)
param["type"] = prompt_response.display

if param["type"] == "response_vector_chart":
df, visualizable = self.parse_vector_data_with_pca(df)
param["type"] = "response_scatter_chart" if visualizable else "response_table"

param["sql"] = prompt_response.sql
param["data"] = json.loads(
df.to_json(orient="records", date_format="iso", date_unit="s")
Expand Down
112 changes: 112 additions & 0 deletions dbgpt/datasource/rdbms/dialect/oceanbase/ob_dialect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,111 @@
"""OB Dialect support."""

import re

from sqlalchemy import util
from sqlalchemy.dialects import registry
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser, _re_compile


class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
"""OceanBase table definition parser."""

def __init__(self, dialect, preparer, *, default_schema=None):
"""Initialize OceanBaseTableDefinitionParser."""
MySQLTableDefinitionParser.__init__(self, dialect, preparer)
self.default_schema = default_schema

def _prep_regexes(self):
super()._prep_regexes()

_final = self.preparer.final_quote
quotes = dict(
zip(
("iq", "fq", "esc_fq"),
[
re.escape(s)
for s in (
self.preparer.initial_quote,
_final,
self.preparer._escape_identifier(_final),
)
],
)
)

self._re_key = _re_compile(
r" "
r"(?:(SPATIAL|VECTOR|(?P<type>\S+)) )?KEY"
# r"(?:(?P<type>\S+) )?KEY"
r"(?: +{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq})?"
r"(?: +USING +(?P<using_pre>\S+))?"
r" +\((?P<columns>.+?)\)"
r"(?: +USING +(?P<using_post>\S+))?"
r"(?: +(KEY_)?BLOCK_SIZE *[ =]? *(?P<keyblock>\S+) *(LOCAL)?)?"
r"(?: +WITH PARSER +(?P<parser>\S+))?"
r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
r",?$".format(iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"])
)

kw = quotes.copy()
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
self._re_fk_constraint = _re_compile(
r" "
r"CONSTRAINT +"
r"{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq} +"
r"FOREIGN KEY +"
r"\((?P<local>[^\)]+?)\) REFERENCES +"
r"(?P<table>{iq}[^{fq}]+{fq}"
r"(?:\.{iq}[^{fq}]+{fq})?) *"
r"\((?P<foreign>(?:{iq}[^{fq}]+{fq}(?: *, *)?)+)\)"
r"(?: +(?P<match>MATCH \w+))?"
r"(?: +ON UPDATE (?P<onupdate>{on}))?"
r"(?: +ON DELETE (?P<ondelete>{on}))?".format(
iq=quotes["iq"], esc_fq=quotes["esc_fq"], fq=quotes["fq"], on=kw["on"]
)
)

def _parse_constraints(self, line):
"""Parse a CONSTRAINT line."""
ret = super()._parse_constraints(line)
if ret:
tp, spec = ret
if tp == "partition":
# do not handle partition
return ret
# logger.info(f"{tp} {spec}")
if (
tp == "fk_constraint"
and len(spec["table"]) == 2
and spec["table"][0] == self.default_schema
):
spec["table"] = spec["table"][1:]
if spec.get("onupdate", "").lower() == "restrict":
spec["onupdate"] = None
if spec.get("ondelete", "").lower() == "restrict":
spec["ondelete"] = None
return ret


class OBDialect(pymysql.MySQLDialect_pymysql):
"""OBDialect expend."""

supports_statement_cache = True

def __init__(self, **kwargs):
"""Initialize OBDialect."""
try:
from pyobvector import VECTOR # type: ignore
except ImportError:
raise ImportError(
"Could not import pyobvector package. "
"Please install it with `pip install pyobvector`."
)
super().__init__(**kwargs)
self.ischema_names["VECTOR"] = VECTOR

def initialize(self, connection):
"""Ob dialect initialize."""
super(OBDialect, self).initialize(connection)
Expand All @@ -22,5 +121,18 @@ def get_isolation_level(self, dbapi_connection):
self.server_version_info = (5, 7, 19)
return super(OBDialect, self).get_isolation_level(dbapi_connection)

@util.memoized_property
def _tabledef_parser(self):
"""Return the MySQLTableDefinitionParser, generate if needed.
The deferred creation ensures that the dialect has
retrieved server version information first.
"""
preparer = self.identifier_preparer
default_schema = self.default_schema_name
return OceanBaseTableDefinitionParser(
self, preparer, default_schema=default_schema
)


registry.register("mysql.ob", __name__, "OBDialect")
2 changes: 1 addition & 1 deletion dbgpt/vis/tags/vis_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def default_chart_type_prompt() -> str:
"non-numeric columns"
},
{
"response_scatter_plot": "Suitable for exploring relationships between "
"response_scatter_chart": "Suitable for exploring relationships between "
"variables, detecting outliers, etc."
},
{
Expand Down

0 comments on commit ed96b95

Please sign in to comment.