Skip to content

Commit

Permalink
Merge pull request #86 from mediacloud/cleanup
Browse files Browse the repository at this point in the history
Cleanup again- pydantic_config in ui.py
  • Loading branch information
pgulley authored Jul 12, 2024
2 parents 8133f35 + 394b603 commit 3d9d4c8
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,58 @@

import os
import random
from typing import List
from urllib.parse import quote_plus

import altair as alt
import matplotlib.pyplot as plt
import pandas as pd
import requests
import streamlit as st
from pydantic import BaseModel, computed_field
from pydantic_settings import BaseSettings
from wordcloud import WordCloud

from utils import env_to_list, load_config

config = load_config()
class Config(BaseSettings):
title: str = "Collection Search API"
apiurl: str = "http://localhost:8000/v1"
termfields: str = "article_title,text_content"
termaggrs: str = "top"
maxwc: int = 30

@computed_field()
def full_title(self) -> str:
return self.title + " Explorer"

@computed_field()
def termfields_list(self) -> List[str]:
return self.termfields.split(",")

@computed_field()
def termaggrs_list(self) -> List[str]:
return self.termaggrs.split(",")

config["title"] = (
os.getenv("TITLE", config.get("title", "Collection Search API")) + " Explorer"
)
config["apiurl"] = os.getenv(
"APIURL", config.get("apiurl", "http://localhost:8000/v1")
).rstrip("/")
config["termfields"] = env_to_list("TERMFIELDS") or config.get("termfields", [])
config["termaggrs"] = env_to_list("TERMAGGRS") or config.get("termaggrs", [])
config["maxwc"] = int(os.getenv("MAXWC", config.get("maxwc", 30)))

st.set_page_config(page_title=config["title"], layout="wide")
st.title(config["title"])
config = Config()


st.set_page_config(
page_title=config.full_title, layout="wide" # type: ignore[arg-type]
)
st.title(config.full_title)


# @st.cache(ttl=300)
def load_data(cname, qstr, ep="search/overview"):
r = requests.get(
f"{config['apiurl']}/{cname}/{ep}?q={quote_plus(qstr)}", timeout=60
)
r = requests.get(f"{config.apiurl}/{cname}/{ep}?q={quote_plus(qstr)}", timeout=60)
if r.ok:
return r.json()
return None


def load_collections():
r = requests.get(f"{config['apiurl']}/collections", timeout=60)
r = requests.get(f"{config.apiurl}/collections", timeout=60)
if r.ok:
return r.json()
return None
Expand Down Expand Up @@ -135,19 +148,20 @@ def load_collections():
tbs[0].altair_chart(c, use_container_width=True)
tbs[1].write(ov[v])

for fld in config["termfields"]:
for fld in config.termfields_list: # type: ignore[attr-defined]
cols = st.columns(3)
for i, aggr in enumerate(config["termaggrs"]):
aggr: str
for i, aggr in enumerate(config.termaggrs_list): # type: ignore[arg-type]
with cols[i]:
tbs = st.tabs([f"{aggr} {fld} terms".title(), "Data"])
tt = load_data(col, q, f"terms/{fld}/{aggr}")
if tt:
sample = tt
if len(tt) > config["maxwc"]:
if len(tt) > config.maxwc:
if aggr == "rare":
sample = dict(random.sample(list(tt.items()), config["maxwc"]))
sample = dict(random.sample(list(tt.items()), config.maxwc))
else:
sample = dict(list(tt.items())[: config["maxwc"]])
sample = dict(list(tt.items())[: config.maxwc])
wc = WordCloud(background_color="white")
wc.generate_from_frequencies(sample)
fig, ax = plt.subplots()
Expand Down

0 comments on commit 3d9d4c8

Please sign in to comment.