diff --git a/stix2/datastore/relational_db/relational_db.py b/stix2/datastore/relational_db/relational_db.py index aa1a391c..0fa674f5 100644 --- a/stix2/datastore/relational_db/relational_db.py +++ b/stix2/datastore/relational_db/relational_db.py @@ -56,7 +56,7 @@ def _add(store, stix_data, allow_custom=True, version="2.1"): class RelationalDBStore(DataStoreMixin): def __init__( self, database_connection_url, allow_custom=True, version=None, - create_db=True, instantiate_database=True, *stix_object_classes, + instantiate_database=True, force_recreate=False, *stix_object_classes, ): """ Initialize this store. @@ -68,6 +68,8 @@ def __init__( version: TODO: unused so far instantiate_database: Whether tables, etc should be created in the database (only necessary the first time) + force_recreate: Drops old database and creates new one (useful if + the schema has changed and the tables need to be updated) *stix_object_classes: STIX object classes to map into table schemas (and ultimately database tables, if instantiation is desired). This can be used to limit which table schemas are created, if @@ -76,38 +78,32 @@ def __init__( them. """ database_connection = create_engine(database_connection_url) - self.database_exists = database_exists(database_connection.url) - if create_db: - if self.database_exists: - drop_database(database_connection_url) - create_database(database_connection_url) - self.database_exists = database_exists(database_connection.url) - if self.database_exists: - self.metadata = MetaData() - create_table_objects( - self.metadata, stix_object_classes, - ) + self.metadata = MetaData() + create_table_objects( + self.metadata, stix_object_classes, + ) - super().__init__( - source=RelationalDBSource( - database_connection, - metadata=self.metadata, - ), - sink=RelationalDBSink( - database_connection, - allow_custom=allow_custom, - version=version, - instantiate_database=instantiate_database, - metadata=self.metadata, - ), - ) + super().__init__( + source=RelationalDBSource( + database_connection, + metadata=self.metadata, + ), + sink=RelationalDBSink( + database_connection, + allow_custom=allow_custom, + version=version, + instantiate_database=instantiate_database, + force_recreate=force_recreate, + metadata=self.metadata, + ), + ) class RelationalDBSink(DataSink): def __init__( self, database_connection_or_url, allow_custom=True, version=None, - instantiate_database=True, *stix_object_classes, metadata=None, + instantiate_database=True, force_recreate=False, *stix_object_classes, metadata=None, ): """ Initialize this sink. Only one of stix_object_classes and metadata @@ -119,8 +115,10 @@ def __init__( allow_custom: Whether custom content is allowed when processing dict content to be added to the sink version: TODO: unused so far - instantiate_database: Whether tables, etc should be created in the - database (only necessary the first time) + instantiate_database: Whether the database, tables, etc should be + created (only necessary the first time) + force_recreate: Drops old database and creates new one (useful if + the schema has changed and the tables need to be updated) *stix_object_classes: STIX object classes to map into table schemas (and ultimately database tables, if instantiation is desired). This can be used to limit which table schemas are created, if @@ -140,6 +138,10 @@ def __init__( else: self.database_connection = database_connection_or_url + self.database_exists = database_exists(self.database_connection.url) + if force_recreate: + self._create_database() + if metadata: self.metadata = metadata else: @@ -156,6 +158,8 @@ def __init__( self.tables_dictionary[canonicalize_table_name(t.name, t.schema)] = t if instantiate_database: + if not self.database_exists: + self._create_database() self._create_schemas() self._instantiate_database() @@ -169,6 +173,12 @@ def _create_schemas(self): def _instantiate_database(self): self.metadata.create_all(self.database_connection) + def _create_database(self): + if self.database_exists: + drop_database(self.database_connection.url) + create_database(self.database_connection.url) + self.database_exists = database_exists(self.database_connection.url) + def generate_stix_schema(self): for t in self.metadata.tables.values(): print(CreateTable(t).compile(self.database_connection)) diff --git a/stix2/datastore/relational_db/relational_db_testing.py b/stix2/datastore/relational_db/relational_db_testing.py index e948b929..853d5d9f 100644 --- a/stix2/datastore/relational_db/relational_db_testing.py +++ b/stix2/datastore/relational_db/relational_db_testing.py @@ -101,10 +101,11 @@ def main(): False, None, True, - True, + False, stix2.Directory, ) - if store.database_exists: + + if store.sink.database_exists: store.sink.generate_stix_schema() store.sink.clear_tables()