diff --git a/ui.py b/ui.py index 6dc073a..a2f6ce7 100755 --- a/ui.py +++ b/ui.py @@ -2,6 +2,7 @@ import os import random +from typing import List from urllib.parse import quote_plus import altair as alt @@ -9,38 +10,50 @@ import pandas as pd import requests import streamlit as st +from pydantic import BaseModel, computed_field +from pydantic_settings import BaseSettings from wordcloud import WordCloud -from utils import env_to_list, load_config -config = load_config() +class Config(BaseSettings): + title: str = "Collection Search API" + apiurl: str = "http://localhost:8000/v1" + termfields: str = "article_title,text_content" + termaggrs: str = "top" + maxwc: int = 30 + + @computed_field() + def full_title(self) -> str: + return self.title + " Explorer" + + @computed_field() + def termfields_list(self) -> List[str]: + return self.termfields.split(",") + + @computed_field() + def termaggrs_list(self) -> List[str]: + return self.termaggrs.split(",") -config["title"] = ( - os.getenv("TITLE", config.get("title", "Collection Search API")) + " Explorer" -) -config["apiurl"] = os.getenv( - "APIURL", config.get("apiurl", "http://localhost:8000/v1") -).rstrip("/") -config["termfields"] = env_to_list("TERMFIELDS") or config.get("termfields", []) -config["termaggrs"] = env_to_list("TERMAGGRS") or config.get("termaggrs", []) -config["maxwc"] = int(os.getenv("MAXWC", config.get("maxwc", 30))) -st.set_page_config(page_title=config["title"], layout="wide") -st.title(config["title"]) +config = Config() + + +st.set_page_config( + page_title=config.full_title, layout="wide" # type: ignore[arg-type] +) +st.title(config.full_title) # @st.cache(ttl=300) def load_data(cname, qstr, ep="search/overview"): - r = requests.get( - f"{config['apiurl']}/{cname}/{ep}?q={quote_plus(qstr)}", timeout=60 - ) + r = requests.get(f"{config.apiurl}/{cname}/{ep}?q={quote_plus(qstr)}", timeout=60) if r.ok: return r.json() return None def load_collections(): - r = requests.get(f"{config['apiurl']}/collections", timeout=60) + r = requests.get(f"{config.apiurl}/collections", timeout=60) if r.ok: return r.json() return None @@ -135,19 +148,20 @@ def load_collections(): tbs[0].altair_chart(c, use_container_width=True) tbs[1].write(ov[v]) -for fld in config["termfields"]: +for fld in config.termfields_list: # type: ignore[attr-defined] cols = st.columns(3) - for i, aggr in enumerate(config["termaggrs"]): + aggr: str + for i, aggr in enumerate(config.termaggrs_list): # type: ignore[arg-type] with cols[i]: tbs = st.tabs([f"{aggr} {fld} terms".title(), "Data"]) tt = load_data(col, q, f"terms/{fld}/{aggr}") if tt: sample = tt - if len(tt) > config["maxwc"]: + if len(tt) > config.maxwc: if aggr == "rare": - sample = dict(random.sample(list(tt.items()), config["maxwc"])) + sample = dict(random.sample(list(tt.items()), config.maxwc)) else: - sample = dict(list(tt.items())[: config["maxwc"]]) + sample = dict(list(tt.items())[: config.maxwc]) wc = WordCloud(background_color="white") wc.generate_from_frequencies(sample) fig, ax = plt.subplots()