From 1d3ea2acd5c80b9bc02e592abe096607e45f6c46 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Wed, 27 Sep 2023 12:03:05 +0800 Subject: [PATCH 1/2] fix:chromadb version --- pilot/connections/conn_spark.py | 4 ++-- setup.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pilot/connections/conn_spark.py b/pilot/connections/conn_spark.py index cf7a1e963..d46be65f1 100644 --- a/pilot/connections/conn_spark.py +++ b/pilot/connections/conn_spark.py @@ -1,6 +1,4 @@ from typing import Optional, Any -from pyspark.sql import SparkSession, DataFrame -from sqlalchemy import text from pilot.connections.base import BaseConnect @@ -19,6 +17,7 @@ class SparkConnect(BaseConnect): driver: str = "spark" """db dialect""" dialect: str = "sparksql" + from pyspark.sql import SparkSession, DataFrame def __init__( self, @@ -30,6 +29,7 @@ def __init__( """Initialize the Spark DataFrame from Datasource path return: Spark DataFrame """ + from pyspark.sql import SparkSession self.spark_session = ( spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate() ) diff --git a/setup.py b/setup.py index ab9f33448..6aa97b61d 100644 --- a/setup.py +++ b/setup.py @@ -281,7 +281,7 @@ def core_requires(): "importlib-resources==5.12.0", "psutil==5.9.4", "python-dotenv==1.0.0", - "colorama==0.4.10", + "colorama==0.4.6", "prettytable", "cachetools", ] @@ -312,7 +312,7 @@ def knowledge_requires(): setup_spec.extras["knowledge"] = [ "spacy==3.5.3", # "chromadb==0.3.22", - "chromadb", + "chromadb==0.4.10", "markdown", "bs4", "python-pptx", From 92c25fe39524813fb0602e4d081599e780218810 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Wed, 27 Sep 2023 12:40:30 +0800 Subject: [PATCH 2/2] fix:pyspark lazy load --- pilot/connections/conn_spark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pilot/connections/conn_spark.py b/pilot/connections/conn_spark.py index d46be65f1..d89ca9399 100644 --- a/pilot/connections/conn_spark.py +++ b/pilot/connections/conn_spark.py @@ -17,12 +17,11 @@ class SparkConnect(BaseConnect): driver: str = "spark" """db dialect""" dialect: str = "sparksql" - from pyspark.sql import SparkSession, DataFrame def __init__( self, file_path: str, - spark_session: Optional[SparkSession] = None, + spark_session: Optional = None, engine_args: Optional[dict] = None, **kwargs: Any, ) -> None: @@ -30,6 +29,7 @@ def __init__( return: Spark DataFrame """ from pyspark.sql import SparkSession + self.spark_session = ( spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate() ) @@ -47,7 +47,7 @@ def from_file_path( except Exception as e: print("load spark datasource error" + str(e)) - def create_df(self, path) -> DataFrame: + def create_df(self, path): """Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.). return: Spark DataFrame reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html