Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Buffer messages to wait for reconnect #697

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions oidc-controller/api/routers/acapy_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..db.session import get_db

from ..core.config import settings
from ..routers.socketio import sio, connections_reload
from ..routers.socketio import buffered_emit, connections_reload

logger: structlog.typing.FilteringBoundLogger = structlog.getLogger(__name__)

Expand Down Expand Up @@ -39,9 +39,6 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db

# Get the saved websocket session
pid = str(auth_session.id)
connections = connections_reload()
sid = connections.get(pid)
logger.debug(f"sid: {sid} found for pid: {pid}")

if webhook_body["state"] == "presentation-received":
logger.info("presentation-received")
Expand All @@ -51,12 +48,10 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
if webhook_body["verified"] == "true":
auth_session.proof_status = AuthSessionState.VERIFIED
auth_session.presentation_exchange = webhook_body["by_format"]
if sid:
await sio.emit("status", {"status": "verified"}, to=sid)
await buffered_emit("status", {"status": "verified"}, to_pid=pid)
else:
auth_session.proof_status = AuthSessionState.FAILED
if sid:
await sio.emit("status", {"status": "failed"}, to=sid)
await buffered_emit("status", {"status": "failed"}, to_pid=pid)

await AuthSessionCRUD(db).patch(
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
Expand All @@ -67,8 +62,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
logger.info("ABANDONED")
logger.info(webhook_body["error_msg"])
auth_session.proof_status = AuthSessionState.ABANDONED
if sid:
await sio.emit("status", {"status": "abandoned"}, to=sid)
await buffered_emit("status", {"status": "abandoned"}, to_pid=pid)

await AuthSessionCRUD(db).patch(
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
Expand All @@ -93,8 +87,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
):
logger.info("EXPIRED")
auth_session.proof_status = AuthSessionState.EXPIRED
if sid:
await sio.emit("status", {"status": "expired"}, to=sid)
await buffered_emit("status", {"status": "expired"}, to_pid=pid)

await AuthSessionCRUD(db).patch(
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
Expand Down
7 changes: 2 additions & 5 deletions oidc-controller/api/routers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..db.session import get_db

# Access to the websocket
from ..routers.socketio import connections_reload, sio
from ..routers.socketio import buffered_emit, connections_reload

from ..verificationConfigs.crud import VerificationConfigCRUD
from ..verificationConfigs.helpers import VariableSubstitutionError
Expand All @@ -58,8 +58,6 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)):
auth_session = await AuthSessionCRUD(db).get(pid)

pid = str(auth_session.id)
connections = connections_reload()
sid = connections.get(pid)

"""
Check if proof is expired. But only if the proof has not been started.
Expand All @@ -75,8 +73,7 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)):
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
)
# Send message through the websocket.
if sid:
await sio.emit("status", {"status": "expired"}, to=sid)
await buffered_emit("status", {"status": "expired"}, to_pid=pid)

return {"proof_status": auth_session.proof_status}

Expand Down
9 changes: 2 additions & 7 deletions oidc-controller/api/routers/presentation_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..authSessions.models import AuthSession, AuthSessionState

from ..core.config import settings
from ..routers.socketio import sio, connections_reload
from ..routers.socketio import buffered_emit, connections_reload
from ..routers.oidc import gen_deep_link
from ..db.session import get_db

Expand Down Expand Up @@ -49,16 +49,11 @@ async def send_connectionless_proof_req(
pres_exch_id
)

# Get the websocket session
connections = connections_reload()
sid = connections.get(str(auth_session.id))

# If the qrcode has been scanned, toggle the verified flag
if auth_session.proof_status is AuthSessionState.NOT_STARTED:
auth_session.proof_status = AuthSessionState.PENDING
await AuthSessionCRUD(db).patch(auth_session.id, auth_session)
if sid:
await sio.emit("status", {"status": "pending"}, to=sid)
await buffered_emit("status", {"status": "pending"}, to_pid=auth_session.id)

msg = auth_session.presentation_request_msg

Expand Down
75 changes: 66 additions & 9 deletions oidc-controller/api/routers/socketio.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import socketio # For using websockets
import logging
import time

logger = logging.getLogger(__name__)


connections = {}
message_buffers = {}
buffer_timeout = 60 # Timeout in seconds

sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")

sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io")


Expand All @@ -18,18 +19,74 @@ async def connect(sid, socket):

@sio.event
async def initialize(sid, data):
global connections
# Store websocket session matched to the presentation exchange id
connections[data.get("pid")] = sid
global connections, message_buffers
pid = data.get("pid")
connections[pid] = sid
# Initialize buffer if it doesn't exist
if pid not in message_buffers:
message_buffers[pid] = []


@sio.event
async def disconnect(sid):
global connections
global connections, message_buffers
logger.info(f">>> disconnect : sid={sid}")
# Remove websocket session from the store
if len(connections) > 0:
connections = {k: v for k, v in connections.items() if v != sid}
# Find the pid associated with the sid
pid = next((k for k, v in connections.items() if v == sid), None)
if pid:
# Remove pid from connections
del connections[pid]


async def buffered_emit(event, data, to_pid=None):
global connections, message_buffers

connections = connections_reload()
sid = connections.get(to_pid)
logger.debug(f"sid: {sid} found for pid: {to_pid}")

if sid:
try:
await sio.emit(event, data, room=sid)
except:
# If send fails, buffer the message
buffer_message(to_pid, event, data)
else:
# Buffer the message if the target is not connected
buffer_message(to_pid, event, data)


def buffer_message(pid, event, data):
global message_buffers
current_time = time.time()
if pid not in message_buffers:
message_buffers[pid] = []
# Add message with timestamp and event name
message_buffers[pid].append((event, data, current_time))
# Clean up old messages
message_buffers[pid] = [
(msg_event, msg_data, timestamp)
for msg_event, msg_data, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]


@sio.event
async def fetch_buffered_messages(sid, pid):
global message_buffers
current_time = time.time()
if pid in message_buffers:
# Filter messages that are still valid (i.e., within the buffer_timeout)
valid_messages = [
(msg_event, msg_data, timestamp)
for msg_event, msg_data, timestamp in message_buffers[pid]
if current_time - timestamp <= buffer_timeout
]
# Emit each valid message
for event, data, _ in valid_messages:
await sio.emit(event, data, room=sid)
# Reassign the valid_messages back to message_buffers[pid] to clean up old messages
message_buffers[pid] = valid_messages


def connections_reload():
Expand Down
7 changes: 0 additions & 7 deletions oidc-controller/api/templates/assets/js/socket.io.475.min.js

This file was deleted.

7 changes: 7 additions & 0 deletions oidc-controller/api/templates/assets/js/socket.io.481.min.js

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion oidc-controller/api/templates/verified_credentials.html
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ <h1 class="mb-3 fw-bolder fs-1">Continue with:</h1>
>
DEBUG Disconnect Web Socket
</button>

<button
class="btn btn-primary mt-4"
v-on:click="socket.connect()"
title="Reconnect Websocket"
>
DEBUG Reconnect Web Socket
</button>
</div>

<hr v-if="mobileDevice" />
Expand Down Expand Up @@ -163,7 +171,7 @@ <h1 class="mb-3 fw-bolder fs-1">Continue with:</h1>
</div>
</div>

<script src="/static/js/socket.io.475.min.js"></script>
<script src="/static/js/socket.io.481.min.js"></script>
<script src="/static/js/vue.global.prod.3512.js"></script>
</body>

Expand Down Expand Up @@ -383,6 +391,8 @@ <h5 v-if="state.showScanned" class="fw-bolder mb-3">
`Socket connecting. SID: ${this.socket.id}. PID: {{pid}}. Recovered? ${this.socket.recovered} `
);
this.socket.emit("initialize", { pid: "{{pid}}" });
// Emit the `fetch_buffered_messages` event with `pid` as a string using Jinja templating
this.socket.emit('fetch_buffered_messages', '{{ pid }}');
});

this.socket.on("connect_error", (error) => {
Expand Down
Loading
Loading