Skip to content

Commit

Permalink
Fix database issue
Browse files Browse the repository at this point in the history
  • Loading branch information
BattlefieldDuck committed Nov 7, 2023
1 parent 3069402 commit e27a86f
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 82 deletions.
13 changes: 5 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from dotenv import load_dotenv
from flask import Flask, jsonify, render_template, request

from discordgsm.database import Database
from discordgsm.environment import env, environment
from discordgsm.main import tree
from discordgsm.service import gamedig, invite_link, public, whitelist_guilds
from discordgsm.service import database, gamedig, invite_link, public, whitelist_guilds
from discordgsm.translator import Locale, translations
from discordgsm.version import __version__

Expand Down Expand Up @@ -44,7 +43,7 @@ async def info():
return jsonify({
'version': __version__,
'invite_link': invite_link,
'statistics': await Database().statistics(),
'statistics': await database.statistics(),
})

@app.route('/api/v1/commands')
Expand Down Expand Up @@ -75,13 +74,13 @@ def guilds():
async def servers(game_id: str = None):
if game_id is None:
servers_count = {game_id: 0 for game_id in gamedig.games}
servers_count.update(await Database().count_servers_per_game())
servers_count.update(await database.count_servers_per_game())
return jsonify(servers_count)

if game_id not in gamedig.games:
return jsonify({'error': 'Invalid game id'})

servers = await Database().all_servers(game_id=game_id, filter_secret=True)
servers = await database.all_servers(game_id=game_id, filter_secret=True)
return jsonify(servers)

@app.route('/api/v1/channels')
Expand All @@ -90,10 +89,8 @@ async def channels(channel_id: str = None):
if channel_id is not None and not channel_id.isdigit():
return jsonify({'error': 'Invalid channel id'})

database = Database()

if channel_id is None:
servers_count = await Database().count_servers_per_channel()
servers_count = await database.count_servers_per_channel()
return jsonify(servers_count)

servers = await database.all_servers(channel_id=int(channel_id), filter_secret=True)
Expand Down
138 changes: 73 additions & 65 deletions discordgsm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from pymongo import DeleteOne, MongoClient, UpdateMany, UpdateOne
import psycopg2
import psycopg2.pool
from dotenv import load_dotenv


Expand Down Expand Up @@ -55,23 +56,22 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.close()
self.dispose()

def connect(self):
DB_CONNECTION: str = os.getenv('DB_CONNECTION', 'sqlite')
DATABASE_URL: str = os.getenv('DATABASE_URL', '')

if DATABASE_URL.startswith('postgres://') or DATABASE_URL.startswith('postgresql://') or DB_CONNECTION == Driver.PostgreSQL.value:
self.driver = Driver.PostgreSQL
self.conn = self.__connect_psycopg2(DATABASE_URL)
self.pool = self.__connect_psycopg2(DATABASE_URL)
elif DB_CONNECTION == Driver.MongoDB.value:
self.driver = Driver.MongoDB
self.conn = MongoClient(DATABASE_URL)
self.collection = self.conn.get_default_database()['servers']
else:
self.driver = Driver.SQLite
self.conn = sqlite3.connect(os.path.join(os.path.dirname(
os.path.realpath(__file__)), '..', 'data', 'servers.db'))
self.database = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'servers.db')

def __connect_psycopg2(self, database_url: str, max_retries=3):
retries = 0
Expand All @@ -81,8 +81,8 @@ def __connect_psycopg2(self, database_url: str, max_retries=3):
time.sleep(1)

try:
conn = psycopg2.connect(database_url, sslmode=sslmode)
return conn
pool = psycopg2.pool.ThreadedConnectionPool(1, 10, database_url, sslmode=sslmode)
return pool
except psycopg2.OperationalError as e:
if retries >= max_retries:
raise e
Expand All @@ -94,7 +94,7 @@ def create_table_if_not_exists(self):
if self.driver == Driver.MongoDB:
return

cursor = self.cursor()
conn, cursor = self.cursor()

if self.driver == Driver.PostgreSQL:
cursor.execute('''
Expand Down Expand Up @@ -131,21 +131,41 @@ def create_table_if_not_exists(self):
style_data TEXT NOT NULL
)''')

self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

def close(self):
self.conn.close()
def dispose(self):
if self.driver == Driver.PostgreSQL:
self.pool.closeall()
elif self.driver == Driver.MongoDB:
self.conn.close()

def cursor(self):
try:
cursor = self.conn.cursor()
except psycopg2.InterfaceError: # connection already closed
# Reconnect
self.connect()
cursor = self.conn.cursor()
if self.driver == Driver.PostgreSQL:
try:
conn = self.pool.getconn()
cursor = conn.cursor()
except psycopg2.InterfaceError: # connection already closed
# Reconnect
self.connect()
conn = self.pool.getconn()
cursor = conn.cursor()

return conn, cursor
else:
conn = sqlite3.connect(self.database)
cursor = conn.cursor()
return conn, cursor

return cursor
def close(self, conn: sqlite3.Connection, cursor: sqlite3.Cursor, *, commit=False):
if commit:
conn.commit()

cursor.close()

if self.driver == Driver.PostgreSQL:
self.pool.putconn(conn)
else:
conn.close()

def transform(self, sql: str):
if self.driver == Driver.PostgreSQL:
Expand Down Expand Up @@ -187,10 +207,10 @@ def statistics(self):
(SELECT COUNT(*) FROM (SELECT DISTINCT game_id, address, query_port, query_extra FROM servers) x) as unique_servers
FROM servers'''

cursor = self.cursor()
conn, cursor = self.cursor()
cursor.execute(self.transform(sql))
row: tuple[int] = cursor.fetchone()
cursor.close()
self.close(conn, cursor)
row = [0, 0, 0, 0] if row is None else row

return {
Expand All @@ -212,11 +232,11 @@ def count_servers_per_game(self):
results.close()
return servers_count

cursor = self.cursor()
conn, cursor = self.cursor()
cursor.execute(self.transform(
'SELECT game_id, COUNT(*) FROM servers GROUP BY game_id'))
servers_count = {str(row[0]): int(row[1]) for row in cursor.fetchall()}
cursor.close()
self.close(conn, cursor)

return servers_count

Expand All @@ -232,11 +252,11 @@ def count_servers_per_channel(self):
results.close()
return servers_count

cursor = self.cursor()
conn, cursor = self.cursor()
cursor.execute(self.transform(
'SELECT channel_id, COUNT(*) FROM servers GROUP BY channel_id'))
servers_count = {str(row[0]): int(row[1]) for row in cursor.fetchall()}
cursor.close()
self.close(conn, cursor)

return servers_count

Expand Down Expand Up @@ -267,7 +287,7 @@ def __all_servers(self, *, channel_id: int = None, guild_id: int = None, message

return servers

cursor = self.cursor()
conn, cursor = self.cursor()

if channel_id:
cursor.execute(self.transform(
Expand All @@ -286,7 +306,7 @@ def __all_servers(self, *, channel_id: int = None, guild_id: int = None, message

servers = [Server.from_list(row, filter_secret)
for row in cursor.fetchall()]
cursor.close()
self.close(conn, cursor)

return servers

Expand All @@ -311,11 +331,11 @@ def distinct_servers(self):
results.close()
return servers

cursor = self.cursor()
conn, cursor = self.cursor()
cursor.execute(
'SELECT DISTINCT game_id, address, query_port, query_extra, status, result FROM servers')
servers = [QueryServer.create(row) for row in cursor.fetchall()]
cursor.close()
self.close(conn, cursor)

return servers

Expand Down Expand Up @@ -348,16 +368,10 @@ def add_server(self, s: Server):
INSERT INTO servers (position, guild_id, channel_id, game_id, address, query_port, query_extra, status, result, style_id, style_data)
VALUES ((SELECT IFNULL(MAX(position + 1), 0) FROM servers WHERE channel_id = ?), ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)'''

try:
cursor = self.cursor()
cursor.execute(self.transform(sql), (s.channel_id, s.guild_id, s.channel_id, s.game_id, s.address, s.query_port, stringify(
s.query_extra), s.status, stringify(s.result), s.style_id, stringify(s.style_data)))
self.conn.commit()
except psycopg2.Error as e:
self.conn.rollback()
raise e
finally:
cursor.close()
conn, cursor = self.cursor()
cursor.execute(self.transform(sql), (s.channel_id, s.guild_id, s.channel_id, s.game_id, s.address, s.query_port, stringify(
s.query_extra), s.status, stringify(s.result), s.style_id, stringify(s.style_data)))
self.close(conn, cursor, commit=True)

return self.__find_server(s.channel_id, s.address, s.query_port)

Expand All @@ -378,10 +392,9 @@ def update_servers_message_id(self, servers: list[Server]):

sql = 'UPDATE servers SET message_id = ? WHERE id = ?'
parameters = [(server.message_id, server.id) for server in servers]
cursor = self.cursor()
conn, cursor = self.cursor()
cursor.executemany(self.transform(sql), parameters)
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

@run_in_executor
def update_servers(self, servers: list[Server], *, channel_id: int = None):
Expand All @@ -406,10 +419,9 @@ def update_servers(self, servers: list[Server], *, channel_id: int = None):
parameters = [(server.status, stringify(server.result), server.game_id, server.address,
server.query_port, stringify(server.query_extra)) for server in servers]
sql = 'UPDATE servers SET status = ?, result = ? WHERE game_id = ? AND address = ? AND query_port = ? AND query_extra = ?'
cursor = self.cursor()
conn, cursor = self.cursor()
cursor.executemany(self.transform(sql), parameters)
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

@run_in_executor
def delete_servers(self, *, guild_id: int = None, channel_id: int = None, servers: list[Server] = None):
Expand All @@ -428,7 +440,7 @@ def delete_servers(self, *, guild_id: int = None, channel_id: int = None, server
if operations:
self.collection.bulk_write(operations)
else:
cursor = self.cursor()
conn, cursor = self.cursor()

if guild_id is not None:
sql = 'DELETE FROM servers WHERE guild_id = ?'
Expand All @@ -441,8 +453,7 @@ def delete_servers(self, *, guild_id: int = None, channel_id: int = None, server
parameters = [(server.id,) for server in servers]
cursor.executemany(self.transform(sql), parameters)

self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

@run_in_executor
def find_server(self, channel_id: int, address: str = None, query_port: int = None):
Expand All @@ -458,14 +469,14 @@ def __find_server(self, channel_id: int, address: str = None, query_port: int =

return Server.from_docs(result)

cursor = self.cursor()
conn, cursor = self.cursor()

sql = 'SELECT * FROM servers WHERE channel_id = ? AND address = ? AND query_port = ?'
cursor.execute(self.transform(
sql), (channel_id, address, query_port))

row = cursor.fetchone()
cursor.close()
self.close(conn, cursor)

if not row:
raise self.ServerNotFoundError()
Expand Down Expand Up @@ -500,11 +511,10 @@ def __swap_servers_positon(self, server1: Server, server2: Server):
"$set": {"position": server1.position, "message_id": server1.message_id}})
else:
sql = 'UPDATE servers SET position = case when position = ? then ? else ? end, message_id = case when message_id = ? then ? else ? end WHERE id IN (?, ?)'
cursor = self.cursor()
conn, cursor = self.cursor()
cursor.execute(self.transform(sql), (server1.position, server2.position, server1.position,
server1.message_id, server2.message_id, server1.message_id, server1.id, server2.id))
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

# Swap the position and message_id values in the server objects
server1.position, server2.position = server2.position, server1.position
Expand All @@ -520,10 +530,9 @@ def update_server_style_id(self, server: Server):
return

sql = 'UPDATE servers SET style_id = ? WHERE id = ?'
cursor = self.cursor()
conn, cursor = self.cursor()
cursor.execute(self.transform(sql), (server.style_id, server.id))
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

@run_in_executor
def update_servers_style_data(self, servers: list[Server]):
Expand All @@ -541,10 +550,9 @@ def update_servers_style_data(self, servers: list[Server]):
sql = 'UPDATE servers SET style_data = ? WHERE id = ?'
parameters = [(stringify(server.style_data), server.id)
for server in servers]
cursor = self.cursor()
conn, cursor = self.cursor()
cursor.executemany(self.transform(sql), parameters)
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

def __update_servers_channel_id(self, servers: list[Server], channel_id: int):
if self.driver == Driver.MongoDB:
Expand Down Expand Up @@ -574,10 +582,9 @@ def __update_servers_channel_id(self, servers: list[Server], channel_id: int):
sql = 'UPDATE servers SET channel_id = ?, position = (SELECT IFNULL(MAX(position + 1), 0) FROM servers WHERE channel_id = ?) WHERE id = ?'
parameters = [(channel_id, channel_id, server.id)
for server in servers]
cursor = self.cursor()
conn, cursor = self.cursor()
cursor.executemany(self.transform(sql), parameters)
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

def export(self, *, to_driver: str):
if to_driver not in drivers:
Expand All @@ -603,7 +610,9 @@ def export(self, *, to_driver: str):
if self.driver == Driver.SQLite:
# Export data to SQL file
with open(file, 'w', encoding='utf-8') as f:
for line in self.conn.iterdump():
conn, _ = self.cursor()

for line in conn.iterdump():
f.write('%s\n' % line)
elif self.driver == Driver.PostgreSQL:
DATABASE_URL: str = os.getenv('DATABASE_URL', '')
Expand Down Expand Up @@ -665,16 +674,15 @@ def import_(self, *, filename: str):
self.create_table_if_not_exists()

# Execute the SQL commands
cursor = self.cursor()
conn, cursor = self.cursor()

if self.driver == Driver.PostgreSQL:
cursor.execute(sql_script)
if self.driver == Driver.SQLite:
cursor.executescript(sql_script)

# Commit the changes and close the cursor
self.conn.commit()
cursor.close()
self.close(conn, cursor, commit=True)

print(f"Imported {len(sql_script.splitlines())} servers.")

Expand Down
Loading

0 comments on commit e27a86f

Please sign in to comment.