Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add postgresql support #626

Merged
merged 7 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pilot/connections/manages/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.connections.db_conn_info import DBConfig
Expand Down
197 changes: 197 additions & 0 deletions pilot/connections/rdbms/conn_postgresql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from typing import Iterable, Optional, Any
from sqlalchemy import text
from urllib.parse import quote
from pilot.connections.rdbms.base import RDBMSDatabase


class PostgreSQLDatabase(RDBMSDatabase):
driver = 'postgresql+psycopg2'
db_type = "postgresql"
db_dialect = 'postgresql'

@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.driver
+ "://"
+ quote(user)
+ ":"
+ quote(pwd)
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
)
return cls.from_uri(db_url, engine_args, **kwargs)

def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
)
view_results = self.session.execute(
text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables


def get_grants(self):
session = self._db_sessions()
cursor = session.execute(text(f"""
SELECT DISTINCT grantee, privilege_type
FROM information_schema.role_table_grants
WHERE grantee = CURRENT_USER;"""))
grants = cursor.fetchall()
return grants

def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"))
collation = cursor.fetchone()[0]
fangyinc marked this conversation as resolved.
Show resolved Hide resolved
return collation

def get_users(self):
"""Get user info."""
try:
cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';"))
users = cursor.fetchall()
return [user[0] for user in users]
except Exception as e:
fangyinc marked this conversation as resolved.
Show resolved Hide resolved
return []

def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
FROM information_schema.columns WHERE table_name = :table_name",
),
{"table_name": table_name},
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]

def get_charset(self):
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"))
character_set = cursor.fetchone()[0]
return character_set


def get_show_create_table(self,table_name):
cur = self.session.execute(
text(
f"""
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
FROM pg_catalog.pg_attribute a
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
SELECT max(a.attnum)
FROM pg_catalog.pg_attribute a
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
"""
)
)
rows = cur.fetchall()

create_table_query = f"CREATE TABLE {table_name} (\n"
for row in rows:
create_table_query += f" {row[0]} {row[1]},\n"
create_table_query = create_table_query.rstrip(',\n') + "\n)"

return create_table_query

def get_table_comments(self, db_name=None):
tablses = self.table_simple_info()
comments = []
for table in tablses:
table_name = table[0]
table_comment = self.get_show_create_table(table_name)
comments.append((table_name, table_comment))
return comments

def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["template0", "template1", "postgres"]
]

def get_database_names(self):
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["template0", "template1", "postgres"]
]

def get_current_db_name(self) -> str:
return self.session.execute(text("SELECT current_database()")).scalar()

def table_simple_info(self):
_sql = f"""
SELECT table_name, string_agg(column_name, ', ') AS schema_info
FROM (
SELECT c.relname AS table_name, a.attname AS column_name
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
WHERE c.relkind = 'r'
AND a.attnum > 0
AND NOT a.attisdropped
AND n.nspname NOT LIKE 'pg_%'
AND n.nspname != 'information_schema'
ORDER BY c.relname, a.attnum
) sub
GROUP BY table_name;
"""
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
return results

def get_fields(self, table_name, schema_name='public'):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
FROM information_schema.columns c
LEFT JOIN pg_catalog.pg_description d
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
"""
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]


def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"))
indexes = cursor.fetchall()
return [(index[0], index[1]) for index in indexes]
6 changes: 3 additions & 3 deletions pilot/scene/chat_dashboard/out_parser.py
fangyinc marked this conversation as resolved.
Show resolved Hide resolved
fangyinc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(self, sep: str, is_stream_out: bool):
super().__init__(sep=sep, is_stream_out=is_stream_out)

def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
# clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", model_out_text)
response = json.loads(model_out_text)
chart_items: List[ChartItem] = []
if not isinstance(response, list):
response = [response]
Expand Down
8 changes: 6 additions & 2 deletions pilot/scene/chat_dashboard/prompt.py
fangyinc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@

Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens

Give the correct {dialect} analysis SQL (don't use unprovided values such as 'paid'), analysis title(don't exist the same), display method and summary of brief analysis thinking, and respond in the following json format:
Give the correct {dialect} analysis SQL
1.Do not use unprovided values such as 'paid'
2.All queried values must have aliases, such as select count(*) as count from table
3.If the table structure definition uses the keywords of {dialect} as field names, you need to use escape characters, such as select `count` from table
4.Carefully check the correctness of the SQL, the SQL must be correct, display method and summary of brief analysis thinking, and respond in the following json format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
The important thing is: Please make sure to only return the json string, do not add any other content (for direct processing by the program), and the json can be parsed by Python json.loads
"""

RESPONSE_FORMAT = [
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def all_datasource_requires():
"""
pip install "db-gpt[datasource]"
"""
setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
setup_spec.extras["datasource"] = ["pymssql", "pymysql","psycopg2"]


def openai_requires():
Expand Down