Skip to content

Commit

Permalink
Merge pull request #9 from benyaming/feature/ws-tracking
Browse files Browse the repository at this point in the history
Feature - vehicle tracking
  • Loading branch information
benyaming authored Aug 12, 2024
2 parents 6c9e864 + dab3ccf commit 8a69871
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 8 deletions.
2 changes: 1 addition & 1 deletion israel_transport_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Env(BaseSettings):
SCHED_HOURS: int
SCHED_MINS: int

DB_BATCH_SIZE: int = 300_000
WS_UPDATE_INTERVAL: int = 5


env = Env()
25 changes: 22 additions & 3 deletions israel_transport_api/siri/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from israel_transport_api.misc import http_client
from israel_transport_api.siri.exceptions import SiriException
from israel_transport_api.siri.models import IncomingRoute, IncomingRoutesResponse
from israel_transport_api.siri.siri_models import MonitoredStopVisit

from israel_transport_api.siri.siri_models import MonitoredStopVisit, VehicleLocation

RETRY_COUNT = 5
logger = logging.getLogger('siri_client')
Expand Down Expand Up @@ -59,7 +58,27 @@ async def get_incoming_routes(
arrival_time = stop_visit.monitored_vehicle_journey.monitored_call.expected_arrival_time.replace(tzinfo=None)
eta = (arrival_time - dt.now()).seconds // 60
route = await routes_repository.find_route_by_id(int(stop_visit.monitored_vehicle_journey.line_ref), conn)
incoming_routes.append(IncomingRoute(eta=eta, route=route))
incoming_routes.append(
IncomingRoute(
eta=eta,
route=route,
plate_number=stop_visit.monitored_vehicle_journey.vehicle_ref
)
)

resp = IncomingRoutesResponse(stop_info=stop_info, incoming_routes=sorted(incoming_routes, key=lambda r: r.eta))
return resp


async def get_vehicle_location(vehicle_plate_number: str, stop_code: int) -> VehicleLocation:
siri_data = await _make_request(stop_code, 30)
vehicle = list(
filter(
lambda m: m.monitored_vehicle_journey.vehicle_ref == vehicle_plate_number,
siri_data
)
)
if len(vehicle) == 0:
raise ValueError # todo
current_location = vehicle[0].monitored_vehicle_journey.vehicle_location
return current_location
14 changes: 14 additions & 0 deletions israel_transport_api/siri/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel, Field

from israel_transport_api.gtfs.models import Stop, Route
from israel_transport_api.siri.siri_models import VehicleLocation


class IncomingRoutesResponse(BaseModel):
Expand All @@ -13,7 +14,20 @@ class IncomingRoutesResponse(BaseModel):

class IncomingRoute(BaseModel):
eta: int
plate_number: str
route: Route


class VehicleLocationResponse(BaseModel):
latitude: float
longitude: float

@classmethod
def from_siri_model(cls, siri_model: VehicleLocation):
return cls(
latitude=siri_model.latitude,
longitude=siri_model.longitude
)


IncomingRoutesResponse.update_forward_refs()
26 changes: 24 additions & 2 deletions israel_transport_api/siri/router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from fastapi import APIRouter, Path, Query, Request
from asyncio import sleep

from israel_transport_api.siri.client import get_incoming_routes
from fastapi import APIRouter, Path, Query, Request, WebSocket, WebSocketDisconnect

from israel_transport_api.siri.client import get_incoming_routes, get_vehicle_location
from israel_transport_api.siri.models import IncomingRoutesResponse
from israel_transport_api.config import env


siri_router = APIRouter(prefix='/siri', tags=['Siri'])

Expand All @@ -13,3 +17,21 @@ async def get_routes_for_stop(
monitoring_interval: int = Query(30, description='Monitoring interval in minutes')
) -> IncomingRoutesResponse:
return await get_incoming_routes(request.app.state.conn, stop_code, monitoring_interval)


@siri_router.websocket('/track_vehicle/{stop_code}/{vehicle_plate_number}')
async def track_vehicle(
ws: WebSocket,
stop_code: int = Path(..., description='Stop code for tracking'),
vehicle_plate_number: str = Path(..., description='Vehicle plate number')
):
await ws.accept()
try:
previous_resp = None
while True:
resp = await get_vehicle_location(vehicle_plate_number, stop_code)
resp and previous_resp != resp and await ws.send_json(resp.model_dump())
previous_resp = resp
await sleep(env.WS_UPDATE_INTERVAL)
except WebSocketDisconnect:
pass
17 changes: 16 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = [
{name = "Benyamin Ginzburg", email = "[email protected]"},
]
dependencies = [
"fastapi==0.111.1",
"fastapi[websockets]==0.111.1",
"pydantic==2.8.2",
"pydantic-settings==2.4.0",
"uvicorn==0.30.4",
Expand Down

0 comments on commit 8a69871

Please sign in to comment.