Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(be+fe): custom DB Sourcecatalog #276

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions source/constructs/api/common/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ class MessageEnum(Enum):
SOURCE_JDBC_NO_CREDENTIAL = {1231: "No credential"}
SOURCE_JDBC_NO_AUTH = {1232: "No authorization"}
SOURCE_JDBC_DUPLICATE_AUTH = {1233: "Duplicate authorization"}
SOURCE_JDBC_ALREADY_EXISTS = {1234: "JDBC connection with the same instance already exists"}
SOURCE_GLUE_DATABASE_EXISTS = {1235: "Glue database with the same name already exists"}
SOURCE_GLUE_DATABASE_NO_INSTANCE = {1236: "Glue database does not exist"}
# label
LABEL_EXIST_FAILED = {1611: "Cannot create duplicated label"}

Expand Down
49 changes: 44 additions & 5 deletions source/constructs/api/data_source/crud.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
import datetime

from sqlalchemy import desc
from common.enum import ConnectionState, MessageEnum, Provider, SourceRegionStatus, SourceProviderStatus, SourceResourcesStatus, SourceAccountStatus
from common.enum import (ConnectionState,
MessageEnum,
Provider,
SourceRegionStatus,
SourceProviderStatus,
SourceResourcesStatus,
SourceAccountStatus)
from common.query_condition import QueryCondition, query_with_condition
from db.database import get_session
from db.models_data_source import S3BucketSource, Account, RdsInstanceSource, JDBCInstanceSource, SourceRegion, SourceProvider, SourceResource
from db.models_data_source import (S3BucketSource,
Account,
RdsInstanceSource,
JDBCInstanceSource,
SourceRegion,
SourceProvider,
SourceResource,
SourceGlueDatabase)
from common.exception_handler import BizException
from . import schemas

Expand Down Expand Up @@ -95,6 +108,11 @@ def list_rds_instance_source(condition: QueryCondition):
instances = query_with_condition(instances, condition)
return instances

def list_glue_database_by_name(name: str):
return get_session().query(SourceGlueDatabase).filter(SourceGlueDatabase.glue_database_name == name).all()

def list_jdbc_instance_source_by_instance_id(instance_id: str):
return get_session().query(JDBCInstanceSource).filter(JDBCInstanceSource.instance_id == instance_id).all()

def list_jdbc_instance_source(provider_id: int):
# instances = Nonex
Expand Down Expand Up @@ -142,7 +160,7 @@ def get_jdbc_instance_source_glue_state(provider_id: int, account: str, region:
rds = get_session().query(JDBCInstanceSource).filter(JDBCInstanceSource.data_source_id == account_tmp[0],
JDBCInstanceSource.instance_id == instance_id,
JDBCInstanceSource.region == region,
JDBCInstanceSource.aws_account == account).order_by(
JDBCInstanceSource.account_id == account).order_by(
desc(JDBCInstanceSource.detection_history_id)).first()
if rds is not None:
return rds.glue_state
Expand Down Expand Up @@ -172,7 +190,10 @@ def get_rds_instance_source(account: str, region: str, instance_id: str):
RdsInstanceSource.region == region,
RdsInstanceSource.instance_id == instance_id).scalar()

def get_jdbc_instance_source(provider: str, account: str, region: str, instance_id: str):
def get_glue_database_source(provider: int, account: str, region: str, instance_id: str):
pass

def get_jdbc_instance_source(provider: int, account: str, region: str, instance_id: str):
return get_session().query(JDBCInstanceSource).filter(JDBCInstanceSource.account_provider_id == provider,
JDBCInstanceSource.account_id == account,
JDBCInstanceSource.region == region,
Expand Down Expand Up @@ -508,6 +529,24 @@ def get_source_rds_account_region():
.all()
)

def add_glue_database(glueDatabase: schemas.SourceGlueDatabase):
session = get_session()

glue_database = SourceGlueDatabase()
glue_database.glue_database_name = glueDatabase.glue_database_name
glue_database.glue_database_location_uri = glueDatabase.glue_database_location_uri
glue_database.glue_database_description = glueDatabase.glue_database_description
glue_database.glue_database_create_time = glueDatabase.glue_database_create_time
glue_database.glue_database_catalog_id = glueDatabase.glue_database_catalog_id
glue_database.region = glueDatabase.region
glue_database.account_id = glueDatabase.account_id

session.add(glue_database)
session.commit()
session.refresh(glue_database)

return glue_database

def add_jdbc_conn(jdbcConn: schemas.JDBCInstanceSource):
session = get_session()

Expand All @@ -528,7 +567,7 @@ def add_jdbc_conn(jdbcConn: schemas.JDBCInstanceSource):
jdbc_instance_source.jdbc_driver_jar_uri = jdbcConn.jdbc_driver_jar_uri
jdbc_instance_source.instance_class = jdbcConn.instance_class
jdbc_instance_source.instance_status = jdbcConn.instance_status
jdbc_instance_source.account_provider = jdbcConn.account_provider
jdbc_instance_source.account_provider_id = jdbcConn.account_provider
jdbc_instance_source.account_id = jdbcConn.account_id
jdbc_instance_source.region = jdbcConn.region
jdbc_instance_source.data_source_id = jdbcConn.data_source_id
Expand Down
125 changes: 125 additions & 0 deletions source/constructs/api/data_source/glue_database_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os

import boto3
import logging

from common.constant import const
from common.enum import ConnectionState, DatabaseType, Provider
from db.database import get_session
from db.models_data_source import DetectionHistory, RdsInstanceSource, Account
from . import crud
from . import service
from catalog.service import delete_catalog_by_database_region
from sqlalchemy.orm import Session
import asyncio

sts_client = boto3.client('sts')
admin_account_region = boto3.session.Session().region_name
logger = logging.getLogger()
logger.setLevel(logging.INFO)

async def detect_glue_database_connection(session: Session, aws_account_id: str):
530051970 marked this conversation as resolved.
Show resolved Hide resolved
iam_role_name = crud.get_iam_role(aws_account_id)
history = DetectionHistory(aws_account=aws_account_id, source_type='jdbc', state=0)
session.add(history)
530051970 marked this conversation as resolved.
Show resolved Hide resolved
session.commit()
assumed_role_object = sts_client.assume_role(
RoleArn=f"{iam_role_name}",
RoleSessionName="rds-instance-source-detection"
)
credentials = assumed_role_object['Credentials']
regions = crud.get_account_agent_regions(aws_account_id)
for region in regions:
total_rds = 0
connected_rds = 0
client = boto3.client(
'rds',
aws_access_key_id=credentials['AccessKeyId'],
aws_secret_access_key=credentials['SecretAccessKey'],
aws_session_token=credentials['SessionToken'],
region_name=region
)
rds_agent_list = []
""" :type : pyboto3.rds """
logger.info("detect_rds_data_source")
for instance in client.describe_db_instances()['DBInstances']:
logger.info(instance)
if instance['Engine'] not in const.RDS_SUPPORTED_ENGINES:
continue
rds_agent_list.append(instance['DBInstanceIdentifier'])
rds_instance_source = session.query(RdsInstanceSource).filter(
RdsInstanceSource.instance_id == instance['DBInstanceIdentifier'],
RdsInstanceSource.region == region,
# RdsInstanceSource.instance_class == instance['DBInstanceClass'],
# RdsInstanceSource.engine == instance['Engine'],
# RdsInstanceSource.instance_status == instance['DBInstanceStatus'],
# RdsInstanceSource.address == instance['Endpoint']['Address'] if 'Endpoint' in instance else '',
# RdsInstanceSource.port == instance['Endpoint']['Port'] if 'Endpoint' in instance else '',
# RdsInstanceSource.master_username == instance['MasterUsername'],
RdsInstanceSource.aws_account == aws_account_id,
).scalar()
if rds_instance_source is None:
rds_instance_source = RdsInstanceSource(instance_id=str(instance['DBInstanceIdentifier']),
region=region,
instance_class=instance['DBInstanceClass'],
engine=instance['Engine'],
instance_status=instance['DBInstanceStatus'],
address=instance['Endpoint']['Address'],
port=instance['Endpoint']['Port'],
master_username=instance['MasterUsername'],
# created_time=instance['Engine'],
aws_account=aws_account_id

)
rds_instance_source.detection_history_id = history.id
session.merge(rds_instance_source)
total_rds += 1
if crud.get_rds_instance_source_glue_state(aws_account_id, region, str(
instance['DBInstanceIdentifier'])) == ConnectionState.ACTIVE.value:
connected_rds += 1

session.commit()
# Get RDS instance in data source table
query_result = session.query(RdsInstanceSource).filter(RdsInstanceSource.aws_account == aws_account_id, RdsInstanceSource.account_id == 0).all()
rds_db_list = []
rds_db_region_list = []
rds_db_glue_state = []
for row in query_result:
rds_db_list.append(row.bucket_name)
rds_db_region_list.append(row.region)
rds_db_glue_state.append(row.glue_state) # None is unconnected
deleted_rds_list = list(set(rds_db_list) - set(rds_agent_list))
for deleted_rds_instance in deleted_rds_list:
index = deleted_rds_list.index(deleted_rds_instance)
rds_connection_state = rds_db_glue_state[index]
deleted_rds_region = rds_db_region_list[index]

if rds_connection_state is not None:
# The RDS instance is a connected data source
service.delete_rds_connection(aws_account_id, deleted_rds_region, deleted_rds_instance)

# Delete data catalog in case the user first connect it and generate data catalog then disconnect it
try:
delete_catalog_by_database_region(deleted_rds_instance, deleted_rds_region, DatabaseType.RDS)
except Exception as e:
logger.error(str(e))
crud.delete_rds_instance_source_by_instance_id(aws_account_id, deleted_rds_region, deleted_rds_instance)
account = session.query(Account).filter(Account.account_provider_id == Provider.AWS_CLOUD.value,
Account.account_id == aws_account_id,
Account.region == region).first()
# TODO support multiple regions
crud.update_rds_instance_count(account=aws_account_id, region=admin_account_region)


async def detect_multiple_account_in_async(accounts):
session = get_session()
tasks = []

for aws_account_id in accounts:
task = asyncio.create_task(detect_glue_database_connection(session, aws_account_id))
tasks.append(task)
await asyncio.gather(*tasks)


def detect(accounts):
asyncio.run(detect_multiple_account_in_async(accounts))
47 changes: 44 additions & 3 deletions source/constructs/api/data_source/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def list_rds_instances(condition: QueryCondition):
page=condition.page,
))

@router.post("/list-glue-database", response_model=BaseResponse[Page[schemas.JDBCInstanceSource]])
@inject_session
def list_glue_databases(condition: QueryCondition):
instances = service.list_glue_databases(condition)
if instances is None:
return None
return paginate(instances, Params(
size=condition.size,
page=condition.page,
))

@router.post("/list-jdbc", response_model=BaseResponse[Page[schemas.JDBCInstanceSource]])
@inject_session
def list_jdbc_instances(condition: QueryCondition):
Expand Down Expand Up @@ -93,16 +104,37 @@ def sync_rds_connection(rds: schemas.SourceRdsConnection):
rds.rds_secret
)

@router.post("/delete-glue-database", response_model=BaseResponse)
@inject_session
def delete_glue_database(glueDatabase: schemas.SourceDeteteGlueDatabase):
return service.delete_glue_database(
int(glueDatabase.account_provider),
glueDatabase.account_id,
glueDatabase.region,
glueDatabase.name
)

# TODO
@router.post("/sync-glue-database", response_model=BaseResponse)
@inject_session
def sync_glue_database(jdbc: schemas.SourceGlueDatabase):
return service.sync_glue_database(
jdbc.account_id,
jdbc.region,
jdbc.instance
)

@router.post("/delete-jdbc", response_model=BaseResponse)
@inject_session
def delete_jdbc_connection(jdbc: schemas.SourceDeteteJDBCConnection):
return service.delete_jdbc_connection(
int(jdbc.account_provider),
jdbc.account_id,
jdbc.region,
jdbc.instance
)


# TODO
@router.post("/sync-jdbc", response_model=BaseResponse)
@inject_session
def sync_jdbc_connection(jdbc: schemas.SourceJDBCConnection):
Expand Down Expand Up @@ -133,8 +165,8 @@ def refresh_data_source(type: schemas.NewDataSource):

@router.get("/coverage", response_model=BaseResponse[schemas.SourceCoverage])
@inject_session
def get_data_source_coverage(provider_id:int):
return service.get_data_source_coverage()
def get_data_source_coverage(provider_id: int):
return service.get_data_source_coverage(provider_id)


@router.post("/list-account", response_model=BaseResponse[Page[schemas.Account]])
Expand Down Expand Up @@ -178,6 +210,11 @@ def get_secrets(account: str, region: str):
def get_admin_account_info():
return service.get_admin_account_info()

@router.post("/add-glue-database", response_model=BaseResponse)
@inject_session
def add_glue_database(glueDataBase: schemas.SourceGlueDatabase):
return service.add_glue_database(glueDataBase)

@router.post("/add-jdbc-conn", response_model=BaseResponse)
@inject_session
def add_jdbc_conn(jdbcConn: schemas.JDBCInstanceSource):
Expand All @@ -188,6 +225,10 @@ def add_jdbc_conn(jdbcConn: schemas.JDBCInstanceSource):
def query_glue_connections(account: schemas.AdminAccountInfo):
return service.query_glue_connections(account)

@router.post("/query-glue-databases", response_model=BaseResponse)
@inject_session
def query_glue_databases(account: schemas.AdminAccountInfo):
return service.query_glue_databases(account)

@router.post("/query-account-network", response_model=BaseResponse)
@inject_session
Expand Down
20 changes: 20 additions & 0 deletions source/constructs/api/data_source/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ class DynamodbTableSource(BaseModel):
class Config:
orm_mode = True

class SourceGlueDatabase(BaseModel):

id: int
glue_database_name: Optional[str]
glue_database_description: Optional[str]
glue_database_location_uri: Optional[str]
glue_database_create_time: Optional[str]
glue_database_catalog_id: Optional[str]
account_id: Optional[str]
region: Optional[str]
version: Optional[int]
create_by: Optional[str]
create_time: Optional[datetime.datetime]
modify_by: Optional[str]
modify_time: Optional[datetime.datetime]

class RdsInstanceSource(BaseModel):
id: int
Expand Down Expand Up @@ -198,6 +213,11 @@ class SourceJDBCConnection(BaseModel):
password: Optional[str]
secret: Optional[str]

class SourceDeteteGlueDatabase(BaseModel):
account_provider: int
account_id: str
region: str
name: str

class SourceDeteteJDBCConnection(BaseModel):
account_provider: int
Expand Down
Loading
Loading