Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stream breaking up, fix blocking of JupyterLab load on repository fetch #14

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jupyterlab_gallery/git_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager
from pathlib import Path
from subprocess import run
from threading import Lock
from typing import Optional
import re
import os
Expand Down Expand Up @@ -42,9 +43,13 @@ def has_updates(repo_path: Path) -> bool:
return data["behind"] is not None


_git_credential_lock = Lock()


@contextmanager
def git_credentials(token: Optional[str], account: Optional[str]):
if token and account:
_git_credential_lock.acquire()
try:
path = Path(__file__).parent
os.environ["GIT_ASKPASS"] = str(path / "git_askpass.py")
Expand All @@ -59,5 +64,6 @@ def git_credentials(token: Optional[str], account: Optional[str]):
del os.environ["GIT_PULLER_TOKEN"]
del os.environ["GIT_TERMINAL_PROMPT"]
del os.environ["GIT_ASKPASS"]
_git_credential_lock.release()
else:
yield
127 changes: 84 additions & 43 deletions jupyterlab_gallery/gitpuller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# - reconnecting to the event stream when refreshing the browser
# - handling multiple waiting pulls
from tornado import gen, web, locks
import asyncio
import logging
import traceback

Expand All @@ -15,8 +16,7 @@
import os
from queue import Queue, Empty
from collections import defaultdict
from numbers import Number
from typing import Optional
from typing import Optional, TypedDict

import git
from jupyter_server.base.handlers import JupyterHandler
Expand Down Expand Up @@ -49,13 +49,15 @@ def update(self, op_code: int, cur_count, max_count=None, message=""):
self.prev_stage = self.max_stage
self.max_stage = new_stage

if isinstance(cur_count, Number) and isinstance(max_count, Number):
if isinstance(cur_count, (int, float)) and isinstance(max_count, (int, float)):
progress = self.prev_stage + cur_count / max_count * (
self.max_stage - self.prev_stage
)
self.queue.put(
{
"progress": self.prev_stage
+ cur_count / max_count * (self.max_stage - self.prev_stage),
"message": message,
}
Update(
progress=progress,
message=message,
)
)
# self.queue.join()

Expand All @@ -76,13 +78,17 @@ def initialize_repo(self):

def clone_task():
with git_credentials(token=self._token, account=self._account):
git.Repo.clone_from(
self.git_url,
self.repo_dir,
branch=self.branch_name,
progress=progress,
)
progress.queue.put(None)
try:
git.Repo.clone_from(
self.git_url,
self.repo_dir,
branch=self.branch_name,
progress=progress,
)
except Exception as e:
progress.queue.put(e)
finally:
progress.queue.put(None)

threading.Thread(target=clone_task).start()
# TODO: add configurable timeout
Expand All @@ -101,21 +107,31 @@ def update(self):
yield from super().update()


class Update(TypedDict):
progress: float
message: str


class SyncHandlerBase(JupyterHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if "pull_status_queues" not in self.settings:
self.settings["pull_status_queues"] = defaultdict(Queue)

# store the most recent message from each queue to re-emit when client re-connects
self.last_message = {}
if "last_message" not in self.settings:
self.settings["last_message"] = {}

# We use this lock to make sure that only one sync operation
# can be happening at a time. Git doesn't like concurrent use!
if "git_lock" not in self.settings:
self.settings["git_lock"] = locks.Lock()

if "enqueue_task" not in self.settings:
task = asyncio.create_task(self._enqueue_messages())
self.settings["enqueue_task"] = task
task.add_done_callback(lambda task: self.settings.pop("enqueue_task"))

def get_login_url(self):
# raise on failed auth, not redirect
# can't redirect EventStream to login
Expand All @@ -136,14 +152,14 @@ async def _pull(
):
q = self.settings["pull_status_queues"][exhibit_id]
try:
q.put_nowait({"phase": "waiting", "message": "Waiting for a git lock"})
await self.git_lock.acquire(1)
q.put_nowait(Update(progress=0.01, message="Waiting for a lock"))
await self.git_lock.acquire(5)
q.put_nowait(Update(progress=0.02, message="Lock acquired"))
except gen.TimeoutError:
q.put_nowait(
{
"phase": "error",
"message": "Another git operations is currently running, try again in a few minutes",
}
gen.TimeoutError(
"Another git operations is currently running, try again in a few minutes"
)
)
return

Expand Down Expand Up @@ -180,8 +196,8 @@ async def _pull(

def pull():
try:
for line in gp.pull():
q.put_nowait(line)
for update in gp.pull():
q.put_nowait(update)
# Sentinel when we're done
q.put_nowait(None)
except Exception as e:
Expand All @@ -199,23 +215,15 @@ async def emit(self, data: dict):
self.write("data: {}\n\n".format(serialized_data))
await self.flush()

async def _stream(self):
# We gonna send out event streams!
self.set_header("content-type", "text/event-stream")
self.set_header("cache-control", "no-cache")

# start by re-emitting last message so that client can catch up after reconnecting
for _exhibit_id, msg in self.last_message.items():
await self.emit(msg)

async def _enqueue_messages(self):
last_message = self.settings["last_message"]
queues = self.settings["pull_status_queues"]

# stream new messages as they are put on respective queues
while True:
empty_queues = 0
# copy to avoid error due to size change during iteration:
queues_view = queues.copy()
for exhibit_id, q in queues_view.items():
# try to consume next message
try:
progress = q.get_nowait()
except Empty:
Expand Down Expand Up @@ -252,12 +260,45 @@ async def _stream(self):
"exhibit_id": exhibit_id,
}

self.last_message[exhibit_id] = msg
try:
await self.emit(msg)
except StreamClosedError:
self.log.warn("git puller stream got closed")
pass
last_message[exhibit_id] = msg

if empty_queues == len(queues_view):
await gen.sleep(0.5)
await gen.sleep(0.1)

async def _stream(self):
# We gonna send out event streams!
self.set_header("content-type", "text/event-stream")
self.set_header("cache-control", "no-cache")

# https://bugzilla.mozilla.org/show_bug.cgi?id=833462
await self.emit({"phase": "connected"})

last_message = self.settings["last_message"]
last_message_sent = {}

# stream new messages as they are put on respective queues
while True:
messages_view = last_message.copy()
unchanged = 0
for exhibit_id, msg in messages_view.items():
# emit an update if anything changed
if last_message_sent.get(exhibit_id) == msg:
unchanged += 1
continue
last_message_sent[exhibit_id] = msg
try:
await self.emit(msg)
except StreamClosedError as e:
# this is expected to happen whenever client closes (e.g. user
# closes the browser or refreshes the tab with JupterLab)
if e.real_error:
self.warn.info(
f"git puller stream got closed with error {e.real_error}"
)
else:
self.log.info("git puller stream closed")
# return to stop reading messages, so that the next
# client who connects can consume them
return
if unchanged == len(messages_view):
await gen.sleep(0.1)
32 changes: 25 additions & 7 deletions jupyterlab_gallery/manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Optional
from threading import Thread

from traitlets.config.configurable import LoggingConfigurable
from traitlets import Dict, List, Unicode
Expand All @@ -13,6 +16,12 @@


class GalleryManager(LoggingConfigurable):
_has_updates: dict[str, Optional[bool]] = defaultdict(lambda: None)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._background_tasks = set()

root_dir = Unicode(
config=False,
allow_none=True,
Expand Down Expand Up @@ -71,6 +80,13 @@ def get_local_path(self, exhibit) -> Path:
repository_name = extract_repository_name(exhibit["git"])
return clone_destination / repository_name

def _check_updates(self, exhibit):
local_path = self.get_local_path(exhibit)
with git_credentials(
account=exhibit.get("account"), token=exhibit.get("token")
):
self._has_updates[local_path] = has_updates(local_path)

def get_exhibit_data(self, exhibit):
data = {}

Expand All @@ -90,14 +106,16 @@ def get_exhibit_data(self, exhibit):
data["isCloned"] = exists
if exists:
fetch_head = local_path / ".git" / "FETCH_HEAD"
if fetch_head.exists():
head = local_path / ".git" / "HEAD"
date_head = fetch_head if fetch_head.exists() else head
if date_head.exists():
data["lastUpdated"] = datetime.fromtimestamp(
fetch_head.stat().st_mtime
date_head.stat().st_mtime
).isoformat()
with git_credentials(
account=exhibit.get("account"), token=exhibit.get("token")
):
# TODO: this is blocking initial load; can we make it async?
data["updatesAvailable"] = has_updates(local_path)
data["updatesAvailable"] = self._has_updates[local_path]

def check_updates():
self._check_updates(exhibit)

Thread(target=check_updates).start()
return data
Loading
Loading