Skip to content

Commit

Permalink
👂 add whisper models for downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
anuejn committed Oct 11, 2023
1 parent 6d1b7f3 commit cae71b5
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 101 deletions.
10 changes: 10 additions & 0 deletions app/src/pages/LanguageSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ export function LanguageSettingsPage(): JSX.Element {
<MainMaxWidthContainer>
<BackButton id={'back' /* for tour */} />

<Heading marginTop={majorScale(3)} marginBottom={majorScale(2)} paddingLeft={majorScale(1)}>
Whisper Models for {language.lang}
</Heading>
<ModelTable
models={language.whisper_models}
lang={language.lang}
type={'whisper'}
id={'whisper_table'}
/>

<Heading marginTop={majorScale(3)} marginBottom={majorScale(2)} paddingLeft={majorScale(1)}>
Transcription Models for {language.lang}
</Heading>
Expand Down
6 changes: 6 additions & 0 deletions app/src/pages/ModelManager.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export function ModelManagerPage(): JSX.Element {
<Table.Head padding={0}>
<Table.TextHeaderCell {...firstColumnProps}>Language</Table.TextHeaderCell>
<Table.TextHeaderCell>Transcription Models</Table.TextHeaderCell>
<Table.TextHeaderCell>Whisper Models</Table.TextHeaderCell>
<Table.TextHeaderCell {...lastColumnProps} />
</Table.Head>

Expand All @@ -61,6 +62,11 @@ export function ModelManagerPage(): JSX.Element {
lang={lang.lang}
downloaded={downloaded}
/>
<ModelNumberTextCell
models={lang.whisper_models}
lang={lang.lang}
downloaded={downloaded}
/>
<Table.Cell {...lastColumnProps}>
<Tooltip content={'manage language'}>
<Icon color={theme.colors.default} icon={ChevronRightIcon} />
Expand Down
1 change: 1 addition & 0 deletions app/src/state/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export interface Model {
export interface Language {
lang: string;
transcription_models: Model[];
whisper_models: Model[];
}

export type DownloadingModel = Model & {
Expand Down
95 changes: 61 additions & 34 deletions server/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from urllib.parse import urlparse
from zipfile import ZipFile

import huggingface_hub
import requests
import yaml
from vosk import Model
Expand Down Expand Up @@ -40,7 +41,7 @@ class ModelDescription:
size: str
type: str
lang: str
compressed: bool = field(default=False)
download_type: str = field(default=False)
model_id: str = field(default=None)

def __post_init__(self):
Expand All @@ -58,9 +59,10 @@ def is_downloaded(self) -> bool:
class Language:
lang: str
transcription_models: List[ModelDescription] = field(default_factory=list)
whisper_models: List[ModelDescription] = field(default_factory=list)

def all_models(self):
return self.transcription_models
return self.transcription_models + self.whisper_models


class ModelDefaultDict(defaultdict):
Expand All @@ -81,6 +83,8 @@ def __init__(self):
models[model_description.model_id] = model_description
if model["type"] == "transcription":
languages[lang].transcription_models.append(model_description)
elif model["type"] == "whisper":
languages[lang].whisper_models.append(model_description)
self.available = dict(languages)
self.model_descriptions = models

Expand Down Expand Up @@ -122,38 +126,61 @@ def get(self, model_id: str) -> Union[Model]:
def download(self, model_id: str, task_uuid: str):
task: DownloadModelTask = tasks.get(task_uuid)
model = self.get_model_description(model_id)
with tempfile.TemporaryFile(dir=CACHE_DIR) as f:
response = requests.get(model.url, stream=True)
task.total = int(response.headers.get("content-length"))
task.state = DownloadModelState.DOWNLOADING

for data in response.iter_content(
chunk_size=max(int(task.total / 1000), 1024 * 1024)
):
task.add_progress(len(data))

f.write(data)
if task.canceled:
return

task.state = DownloadModelState.EXTRACTING
if model.compressed:
with ZipFile(f) as archive:
target_dir = model.path()
for info in archive.infolist():
if info.is_dir():
continue
path = target_dir / Path("/".join(info.filename.split("/")[1:]))
path.parent.mkdir(exist_ok=True, parents=True)

source = archive.open(info.filename)
target = open(path, "wb")
with source, target:
shutil.copyfileobj(source, target)
else:
f.seek(0)
with open(model.path(), "wb") as target:
shutil.copyfileobj(f, target)

if model.download_type.startswith("http"):
with tempfile.TemporaryFile(dir=CACHE_DIR) as f:
response = requests.get(model.url, stream=True)
task.total = int(response.headers.get("content-length"))
task.state = DownloadModelState.DOWNLOADING

for data in response.iter_content(
chunk_size=max(int(task.total / 1000), 1024 * 1024)
):
task.add_progress(len(data))

f.write(data)
if task.canceled:
return

task.state = DownloadModelState.EXTRACTING
if model.download_type.endswith("+zip"):
with ZipFile(f) as archive:
target_dir = model.path()
for info in archive.infolist():
if info.is_dir():
continue
path = target_dir / Path(
"/".join(info.filename.split("/")[1:])
)
path.parent.mkdir(exist_ok=True, parents=True)

source = archive.open(info.filename)
target = open(path, "wb")
with source, target:
shutil.copyfileobj(source, target)
else:
f.seek(0)
with open(model.path(), "wb") as target:
shutil.copyfileobj(f, target)
elif model.download_type == "huggingface":
api = huggingface_hub.HfApi()
repo_info = api.repo_info(model.url, files_metadata=True)
task.total = sum(f.size for f in repo_info.siblings)
with tempfile.TemporaryDirectory(dir=CACHE_DIR) as dir:
for f in repo_info.siblings:
url = huggingface_hub.hf_hub_url(model.url, f.rfilename)
with open(Path(dir) / f.rfilename, "wb") as file:
task.state = DownloadModelState.DOWNLOADING
response = requests.get(url, stream=True)
for data in response.iter_content(
chunk_size=max(int(task.total / 1000), 1024 * 1024)
):
task.add_progress(len(data))

file.write(data)
if task.canceled:
return
shutil.copytree(dir, model.path())

task.state = DownloadModelState.DONE

Expand Down
Loading

0 comments on commit cae71b5

Please sign in to comment.