Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"hydra-core==1.3.2",
"loguru == 0.7.3",
"numpy==2.1.2",
"ocf-data-sampler==0.5.5",
"ocf-data-sampler==0.5.27",
"pandas==2.2.3",
"s3fs==2025.7.0",
"safetensors==0.5.2",
Expand Down
21 changes: 10 additions & 11 deletions src/cloudcasting_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from safetensors.torch import load_model
from loguru import logger

from cloudcasting_app.data import get_input_data, prepare_satellite_data, sat_path
from cloudcasting_app.data import SatelliteDownloader, sat_path, get_input_data

# Get package version
try:
Expand All @@ -28,8 +28,6 @@
__version__ = "v?"


xr.set_options(keep_attrs=True)

# ---------------------------------------------------------------------------

# Model will use GPU if available
Expand Down Expand Up @@ -62,7 +60,8 @@ def app(t0=None):
# ---------------------------------------------------------------------------
# 1. Prepare the input data
logger.info("Downloading satellite data")
prepare_satellite_data(t0)
satellite_downloader = SatelliteDownloader()
satellite_downloader.prepare_satellite_data(t0)

# ---------------------------------------------------------------------------
# 2. Load model
Expand Down Expand Up @@ -90,12 +89,8 @@ def app(t0=None):
# 3. Get inference inputs
logger.info("Preparing inputs")

# TODO check the spatial dimensions of this zarr
# Get inputs
ds = xr.open_zarr(sat_path)

# Reshape to (channel, time, height, width)
ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
ds = xr.open_zarr(sat_path).compute()

X = get_input_data(ds, t0)

Expand Down Expand Up @@ -130,8 +125,12 @@ def app(t0=None):
# Save predictions to the latest path and to path with timestring
out_dir = os.environ["OUTPUT_PREDICTION_DIRECTORY"]

latest_zarr_path = f"{out_dir}/latest.zarr"
t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M.zarr")
if satellite_downloader.use_5_minute:
latest_zarr_path = f"{out_dir}/latest.zarr"
t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M.zarr")
else:
latest_zarr_path = f"{out_dir}/latest_0-deg.zarr"
t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M_0-deg.zarr")

fs = fsspec.open(out_dir).fs
for path in [latest_zarr_path, t0_string_zarr_path]:
Expand Down
Loading
Loading