Skip to content

Commit

Permalink
Implement loading model from gcs
Browse files Browse the repository at this point in the history
  • Loading branch information
ovejabu committed Jun 28, 2024
1 parent 3fdc481 commit 463ebf5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
SERVICE_ACCOUNT_KEY: ${{ secrets.GCP_SA_KEY }}
GCS_BUCKET: ${{ secrets.GCS_BUCKET }}
APP_NAME: 'flight-delay-api'
REGION: 'us-central1'

Expand All @@ -29,7 +30,7 @@ jobs:
run: |
echo "PROJECT_ID=${{ env.PROJECT_ID }}"
echo "SERVICE_ACCOUNT_KEY=${{ env.SERVICE_ACCOUNT_KEY }}" | head -c 20
echo "GCS_BUCKET=${{ env.GCS_BUCKET }}"
- name: Set up Google Cloud SDK
uses: google-github-actions/[email protected]
Expand Down Expand Up @@ -59,6 +60,7 @@ jobs:
--region ${{ env.REGION }} \
--allow-unauthenticated \
--format 'value(status.url)')
--set-env-vars GCS_BUCKET=${{ secrets.GCS_BUCKET }}
echo "url=$URL" >> $GITHUB_OUTPUT
echo "Deployed to $URL"
Expand Down
44 changes: 37 additions & 7 deletions challenge/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,22 @@
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

from google.cloud import storage

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MODEL_PATH = "challenge/model.pkl"
LOCAL_MODEL_PATH = "challenge/models/model.pkl"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
THRESHOLD_IN_MINUTES = 15
GCS_MODEL_PATH = "model.pkl"
GCS_BUCKET = os.getenv("GCS_BUCKET")


class DelayModel:
def __init__(self):
self._model = self._load_model(MODEL_PATH)
self._model = self._load_model(LOCAL_MODEL_PATH)
self.top_10_features = [
"OPERA_Latin American Wings",
"MES_7",
Expand All @@ -49,11 +51,39 @@ def _load_model(self, path: str) -> Union[None, object]:
if os.path.exists(path):
try:
with open(path, "rb") as f:
return pickle.load(f)
model = pickle.load(f)
logger.info("Model loaded from local folder.")
return model
except Exception as e:
logger.error(f"Error loading model: {e}")
return None
logger.warning(f"Model file not found at {path}.")
if GCS_BUCKET:
return self._load_model_from_gcs(GCS_BUCKET)
logger.warning("Model file not found")

def _load_model_from_gcs(self, bucket_name: str) -> Union[None, object]:
"""
Load the model from a Google Cloud Storage bucket.
Args:
bucket_name (str): Name of the bucket.
Returns:
object: Loaded model or None if the file does not exist.
"""
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(GCS_MODEL_PATH)
if blob.exists():
try:
with blob.open("rb") as f:
model = pickle.load(f)
logger.info("Model loaded from GCS bucket.")
return model
except Exception as e:
logger.error(f"Error loading model from GCS: {e}")
return None
logger.warning("Model file not found in GCS bucket")
return None

def _get_min_diff(self, data: pd.Series) -> float:
Expand Down Expand Up @@ -165,9 +195,9 @@ def fit(self, features: pd.DataFrame, target: pd.DataFrame) -> None:
logger.info("Model performance: \n%s", report)
logger.info("Confusion matrix: \n%s", cm)
# Save the model
with open(MODEL_PATH, "wb") as f:
with open(LOCAL_MODEL_PATH, "wb") as f:
pickle.dump(self._model, f)
logger.info(f"Model saved at {MODEL_PATH}")
logger.info(f"Model saved at {LOCAL_MODEL_PATH}")
return

def predict(self, features: pd.DataFrame) -> List[int]:
Expand Down

0 comments on commit 463ebf5

Please sign in to comment.