Skip to content

Commit

Permalink
Merge pull request #39 from zmek/plot-supply-demand-examples
Browse files Browse the repository at this point in the history
Plot supply demand examples
  • Loading branch information
zmek authored Jan 2, 2025
2 parents 0797cb0 + 35b1c88 commit ad72c01
Show file tree
Hide file tree
Showing 11 changed files with 607 additions and 664 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
[conda-badge]: https://img.shields.io/conda/vn/conda-forge/patientflow
[conda-link]: https://github.com/conda-forge/patientflow-feedstock
[license-badge]: https://img.shields.io/badge/License-MIT-yellow.svg
[![ORCID](https://img.shields.io/badge/ORCID-0000-0001-7389-1527-brightgreen)](https://orcid.org/0000-0001-7389-1527)

<!-- [pypi-link]: https://pypi.org/project/patientflow/
[pypi-platforms]: https://img.shields.io/pypi/pyversions/patientflow
[pypi-version]: https://img.shields.io/pypi/v/patientflow -->
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/img/supply_demand_examples/Discharges.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/img/supply_demand_examples/Net_position.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/img/supply_demand_examples/Patients_in_ED.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified notebooks/img/supply_demand_examples/Total_demand.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,204 changes: 563 additions & 641 deletions notebooks/plots/Plot_suppy_and_demand_examples.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ sympy>=1.12
tabulate==0.9.0
xgboost>=2.0.3
joblib>=1.4.2
scikit-learn>=1.4.2
scikit-learn>=1.4.0,<1.5.0
10 changes: 8 additions & 2 deletions src/patientflow/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import xgboost as xgb
from xgboost import XGBClassifier
import pandas as pd
from joblib import dump
import json
Expand Down Expand Up @@ -177,7 +177,13 @@ def chronological_cross_validation(pipeline, X, y, n_splits=5):

# Initialise the model with given hyperparameters
def initialise_xgb(params):
model = xgb.XGBClassifier(n_jobs=-1, eval_metric="logloss")
model = XGBClassifier(
n_jobs=-1,
eval_metric="logloss",
use_label_encoder=False,
enable_categorical=True,
scikit_learn=True, # Add this parameter
)
model.set_params(**params)
return model

Expand Down
53 changes: 33 additions & 20 deletions src/patientflow/viz/prob_dist_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools

import numpy as np
from matplotlib import pyplot as plt

Expand All @@ -10,55 +9,69 @@ def prob_dist_plot(
directory_path=None,
figsize=(6, 3),
include_titles=False,
truncate_at_beds=20,
truncate_at_beds=(0, 20),
text_size=None,
bar_colour="#5B9BD5",
file_name=None,
min_beds_lines=None,
plot_min_beds_lines=True,
plot_bed_base=None,
):
plt.figure(figsize=figsize)

if not file_name:
file_name = (
title.replace(" ", "_").replace("/n", "_").replace("%", "percent") + ".png"
)

if isinstance(truncate_at_beds, (int, float)):
upper_bound = truncate_at_beds
lower_bound = 0
else:
lower_bound, upper_bound = truncate_at_beds
lower_bound = max(0, lower_bound) if lower_bound > 0 else lower_bound

mask = (prob_dist_data.index >= lower_bound) & (prob_dist_data.index <= upper_bound)
filtered_data = prob_dist_data[mask]

plt.bar(
prob_dist_data.index[0 : truncate_at_beds + 1],
prob_dist_data["agg_proba"].values[0 : truncate_at_beds + 1],
filtered_data.index,
filtered_data["agg_proba"].values,
color=bar_colour,
)

plt.xlim(-0.5, truncate_at_beds + 0.5)
plt.xticks(
np.arange(0, truncate_at_beds + 1, 5)
) # Set x-axis ticks at every 5 units
tick_start = (lower_bound // 5) * 5
tick_end = upper_bound + 1
plt.xticks(np.arange(tick_start, tick_end, 5))

if min_beds_lines:
if plot_min_beds_lines and min_beds_lines:
colors = itertools.cycle(
plt.cm.gray(np.linspace(0.3, 0.7, len(min_beds_lines)))
)

for point in min_beds_lines:
plt.axvline(
x=min_beds_lines[point],
x=prob_dist_data.index[min_beds_lines[point]],
linestyle="--",
linewidth=2,
color=next(colors),
label=f"{point*100:.0f}% probability",
)
plt.legend(loc="upper right", fontsize=14)

plt.legend(loc="upper right")
if plot_bed_base:
for point in plot_bed_base:
plt.axvline(
x=plot_bed_base[point], linewidth=2, color="red", label="bed balance"
)
plt.legend(loc="upper right", fontsize=14)

if text_size:
plt.tick_params(axis="both", which="major", labelsize=text_size)

if include_titles:
plt.title(title, fontsize=text_size)
plt.xlabel("Number of beds")
plt.ylabel("Probability")
plt.tick_params(axis="both", which="major", labelsize=14)
plt.xlabel("Number of beds", fontsize=text_size)
if include_titles:
plt.title(title, fontsize=text_size)
plt.ylabel("Probability", fontsize=text_size)

plt.tight_layout()

if directory_path:
plt.savefig(directory_path / file_name.replace(" ", "_"), dpi=300)
plt.show()

0 comments on commit ad72c01

Please sign in to comment.