diff --git a/pyproject.toml b/pyproject.toml index 26a69ea..04246bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "hydra-core==1.3.2", "loguru == 0.7.3", "numpy==2.1.2", - "ocf-data-sampler==0.5.5", + "ocf-data-sampler==0.5.27", "pandas==2.2.3", "s3fs==2025.7.0", "safetensors==0.5.2", diff --git a/src/cloudcasting_app/app.py b/src/cloudcasting_app/app.py index de77c95..f8816f5 100644 --- a/src/cloudcasting_app/app.py +++ b/src/cloudcasting_app/app.py @@ -19,7 +19,7 @@ from safetensors.torch import load_model from loguru import logger -from cloudcasting_app.data import get_input_data, prepare_satellite_data, sat_path +from cloudcasting_app.data import SatelliteDownloader, sat_path, get_input_data # Get package version try: @@ -28,8 +28,6 @@ __version__ = "v?" -xr.set_options(keep_attrs=True) - # --------------------------------------------------------------------------- # Model will use GPU if available @@ -62,7 +60,8 @@ def app(t0=None): # --------------------------------------------------------------------------- # 1. Prepare the input data logger.info("Downloading satellite data") - prepare_satellite_data(t0) + satellite_downloader = SatelliteDownloader() + satellite_downloader.prepare_satellite_data(t0) # --------------------------------------------------------------------------- # 2. Load model @@ -90,12 +89,8 @@ def app(t0=None): # 3. Get inference inputs logger.info("Preparing inputs") - # TODO check the spatial dimensions of this zarr # Get inputs - ds = xr.open_zarr(sat_path) - - # Reshape to (channel, time, height, width) - ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary") + ds = xr.open_zarr(sat_path).compute() X = get_input_data(ds, t0) @@ -130,8 +125,12 @@ def app(t0=None): # Save predictions to the latest path and to path with timestring out_dir = os.environ["OUTPUT_PREDICTION_DIRECTORY"] - latest_zarr_path = f"{out_dir}/latest.zarr" - t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M.zarr") + if satellite_downloader.use_5_minute: + latest_zarr_path = f"{out_dir}/latest.zarr" + t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M.zarr") + else: + latest_zarr_path = f"{out_dir}/latest_0-deg.zarr" + t0_string_zarr_path = t0.strftime(f"{out_dir}/%Y-%m-%dT%H:%M_0-deg.zarr") fs = fsspec.open(out_dir).fs for path in [latest_zarr_path, t0_string_zarr_path]: diff --git a/src/cloudcasting_app/data.py b/src/cloudcasting_app/data.py index 3a6f092..570d7c1 100644 --- a/src/cloudcasting_app/data.py +++ b/src/cloudcasting_app/data.py @@ -1,28 +1,31 @@ import logging import shutil import os -import yaml import fsspec import numpy as np import pandas as pd -import zipfile +import zarr import torch import xarray as xr from ocf_data_sampler.select.geospatial import lon_lat_to_geostationary_area_coords + +xr.set_options(keep_attrs=True) + logger = logging.getLogger(__name__) -sat_5_path = "sat_5_min.zarr" -sat_15_path = "sat_15_min.zarr" +sat_5_path = "sat_5_min.zarr.zip" +sat_15_path = "sat_15_min.zarr.zip" sat_path = "sat.zarr" - lon_min = -16 lon_max = 10 lat_min = 45 lat_max = 70 +x_size = 614 +y_size = 372 channel_order = [ "IR_016", @@ -38,202 +41,52 @@ "WV_073", ] +def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex: + """Get the datetimes of the satellite data -def crop_input_area(ds): - (x_min, x_max), (y_min, y_max) = lon_lat_to_geostationary_area_coords( - [lon_min, lon_max], - [lat_min, lat_max], - ds.data, - ) - - ds = ds.isel(x_geostationary=slice(None, None, -1)) # x-axis is in decreasing order - - ds = ds.sel( - x_geostationary=slice(x_min, None), - y_geostationary=slice(y_min, None), - ).isel( - x_geostationary=slice(0,614), - y_geostationary=slice(0,372), - ) - - ds = ds.isel(x_geostationary=slice(None, None, -1)) # flip back - assert len(ds.x_geostationary)==614 - assert len(ds.y_geostationary)==372 - - return ds - - -def prepare_satellite_data(t0: pd.Timestamp): - - # Download the 5 and/or 15 minutely satellite data - download_all_sat_data() - - # Select between the 5/15 minute satellite data sources - combine_5_and_15_sat_data() + Args: + sat_zarr_path: The path to the satellite zarr - # Check the required expected timestamps are available - check_required_timestamps_available(t0) + Returns: + pd.DatetimeIndex: All available satellite timestamps + """ + with zarr.storage.ZipStore(sat_zarr_path) as store: + ds = xr.open_zarr(store) + return pd.to_datetime(ds.time.values) - # Load data the data for more preprocessing - ds = xr.open_zarr(sat_path) - # make sure area attrs are yaml string - if "area" in ds.data.attrs and isinstance(ds.data.attrs["area"], dict): - logger.warning("Converting area attribute to YAML string, " - "we should do this in the satellite consumer.") - ds.data.attrs["area"] = yaml.dump(ds.data.attrs["area"]) +def crop_input_area(ds: xr.Dataset) -> xr.Dataset: - # Crop the input area to expected - ds = crop_input_area(ds) + x_min, y_min = lon_lat_to_geostationary_area_coords(lon_min, lat_min, ds.data.attrs["area"]) - # Reorder channels - ds = ds.sel(variable=channel_order) + # x-axis is expected to be in decreasing order + # y-axis is expected to be in ascending order + assert ds.x_geostationary.values[0] > ds.x_geostationary.values[1] + assert ds.y_geostationary.values[0] < ds.y_geostationary.values[1] + + ds = ds.isel(x_geostationary=slice(None, None, -1)) - # Scale the satellite data from 0-1 - scale_factor = int(os.environ.get("SATELLITE_SCALE_FACTOR", 1023)) - logger.info( - f"Scaling satellite data by {scale_factor} to be between 0 and 1" + ds = ( + ds + .sel(x_geostationary=slice(x_min, None), y_geostationary=slice(y_min, None)) + .isel(x_geostationary=slice(0, x_size), y_geostationary=slice(0, y_size)) ) - ds = ds / scale_factor - - # Resave - ds = ds.compute() - if os.path.exists(sat_path): - shutil.rmtree(sat_path) - ds.to_zarr(sat_path) + assert len(ds.x_geostationary)==x_size + assert len(ds.y_geostationary)==y_size + return ds.isel(x_geostationary=slice(None, None, -1)) # flip back -def download_all_sat_data() -> bool: - """Download the sat data and return whether it was successful - Returns: - bool: Whether the download was successful - """ - # Clean out old files - logging.debug("Cleaning out old satellite data") - for loc in [sat_path, sat_5_path, sat_15_path]: - if os.path.exists(loc): - shutil.rmtree(loc) - - # Set variable to track whether the satellite download is successful - sat_available = False - - # get paths - sat_5_dl_path, sat_15_dl_path = get_satellite_source_paths() - - # download 5 minute satellite data - fs, _ = fsspec.core.url_to_fs(sat_5_dl_path) - if fs.exists(sat_5_dl_path): - sat_available = True - logger.info("Downloading 5-minute satellite data") - fs.get(sat_5_dl_path, "sat_5_min.zarr.zip") - with zipfile.ZipFile("sat_5_min.zarr.zip", "r") as zip_ref: - zip_ref.extractall(sat_5_path) - os.remove("sat_5_min.zarr.zip") - else: - logger.info("No 5-minute data available") - - # Also download 15-minute satellite if it exists - if fs.exists(sat_15_dl_path): - sat_available = True - logger.info("Downloading 15-minute satellite data") - fs.get(sat_15_dl_path, "sat_15_min.zarr.zip") - with zipfile.ZipFile("sat_15_min.zarr.zip", "r") as zip_ref: - zip_ref.extractall(sat_15_path) - os.remove("sat_15_min.zarr.zip") - else: - logger.info("No 15-minute data available") - - return sat_available - - -def check_required_timestamps_available(t0: pd.Timestamp): - available_timestamps = get_satellite_timestamps(sat_path) - - # Need 12 timestamps of 15 minutely data up to and including time t0 - expected_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min") - - timestamps_available = np.isin(expected_timestamps, available_timestamps) - - if not timestamps_available.all(): - missing_timestamps = expected_timestamps[~timestamps_available] - raise Exception( - "Some required timestamps missing\n" - f"Required timestamps: {expected_timestamps}\n" - f"Available timestamps: {timestamps_available}\n" - f"Missing timestamps: {missing_timestamps}", - ) - - -def get_satellite_timestamps(sat_zarr_path: str) -> pd.DatetimeIndex: - """Get the datetimes of the satellite data - - Args: - sat_zarr_path: The path to the satellite zarr - - Returns: - pd.DatetimeIndex: All available satellite timestamps - """ - ds_sat = xr.open_zarr(sat_zarr_path) - return pd.to_datetime(ds_sat.time.values) - - -def combine_5_and_15_sat_data() -> None: - """Select and/or combine the 5 and 15-minutely satellite data and move it to the expected path""" - # Check which satellite data exists - exists_5_minute = os.path.exists(sat_5_path) - exists_15_minute = os.path.exists(sat_15_path) - - if not exists_5_minute and not exists_15_minute: - raise FileNotFoundError("Neither 5- nor 15-minutely data was found.") - - # Find the delay in the 5- and 15-minutely data - if exists_5_minute: - datetimes_5min = get_satellite_timestamps(sat_5_path) - logger.info( - f"Latest 5-minute timestamp is {datetimes_5min.max()}. " - f"All the datetimes are: \n{datetimes_5min}", - ) - else: - logger.info("No 5-minute data was found.") - - if exists_15_minute: - datetimes_15min = get_satellite_timestamps(sat_15_path) - logger.info( - f"Latest 5-minute timestamp is {datetimes_15min.max()}. " - f"All the datetimes are: \n{datetimes_15min}", - ) - else: - logger.info("No 15-minute data was found.") - - # If both 5- and 15-minute data exists, use the most recent - if exists_5_minute and exists_15_minute: - use_5_minute = datetimes_5min.max() > datetimes_15min.max() - else: - # If only one exists, use that - use_5_minute = exists_5_minute - - # Move the selected data to the expected path - if use_5_minute: - logger.info("Using 5-minutely data.") - os.system(f"mv {sat_5_path} {sat_path}") - else: - logger.info("Using 15-minutely data.") - os.system(f"mv {sat_15_path} {sat_path}") - - -def get_input_data(ds: xr.Dataset, t0: pd.Timestamp): +def get_input_data(ds: xr.Dataset, t0: pd.Timestamp) -> torch.Tensor: + """Get the input data required to run the model for init-time t0""" # Slice the data required_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min") - ds_sel = ds.reindex(time=required_timestamps) - - # Load the data - ds_sel = ds_sel.compute(scheduler="single-threaded") + ds = ds.reindex(time=required_timestamps) # Convert to arrays - X = ds_sel.data.values.astype(np.float32) + X = ds.data.values.astype(np.float32) # Convert NaNs to -1 X = np.nan_to_num(X, nan=-1) @@ -241,12 +94,119 @@ def get_input_data(ds: xr.Dataset, t0: pd.Timestamp): return torch.Tensor(X) -def get_satellite_source_paths() -> (str | None, str | None): - """ Get the paths to the satellite data from environment variables""" - sat_source_path_5 = os.getenv("SATELLITE_ZARR_PATH", None) - sat_source_path_15 = os.getenv("SATELLITE_15_ZARR_PATH", None) - if sat_source_path_15 is None and sat_source_path_5 is not None: - sat_source_path_15 = sat_source_path_5.replace(".zarr", "_15.zarr") - logger.info( - f"Satellite source paths: 5-minute: {sat_source_path_5}, 15-minute: {sat_source_path_15}") - return sat_source_path_5, sat_source_path_15 +class SatelliteDownloader: + + def __init__(self): + self.use_5_minute = None + + def prepare_satellite_data(self, t0: pd.Timestamp) -> None: + + # Download the 5 and/or 15 minutely satellite data + self.download_all_sat_data() + + # Select between the 5/15 minute satellite data sources + ds = self.combine_5_and_15_sat_data() + + # Check the required expected timestamps are available + self.check_required_timestamps_available(ds, t0) + + # Crop the input area to expected + ds = crop_input_area(ds) + + # Reorder channels + ds = ds.sel(variable=channel_order) + + # Reshape to (channel, time, height, width) + ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary") + + # Resave + ds.to_zarr(sat_path) + + def download_all_sat_data(self) -> None: + """Download the sat data""" + # Clean out old files + logger.info("Cleaning out old satellite data") + for loc in [sat_path, sat_5_path, sat_15_path]: + if os.path.exists(loc): + shutil.rmtree(loc) + + sat_5_dl_path = os.getenv("SATELLITE_ZARR_PATH") + sat_15_dl_path = os.getenv("SATELLITE_15_ZARR_PATH") + + for remote_path, local_path, label in [ + (sat_5_dl_path, sat_5_path, "5-min"), + (sat_15_dl_path, sat_15_path, "15-min"), + ]: + if remote_path is not None: + fs, _ = fsspec.core.url_to_fs(remote_path) + if fs.exists(remote_path): + logger.info(f"Downloading {label} satellite data") + fs.get(remote_path, local_path) + else: + logger.info(f"No {label} data available to download") + + def combine_5_and_15_sat_data(self) -> xr.Dataset: + """Select and/or combine the 5 and 15-minutely satellite data""" + # Check which satellite data exists + exists_5_minute = os.path.exists(sat_5_path) + exists_15_minute = os.path.exists(sat_15_path) + + if not (exists_5_minute or exists_15_minute): + raise FileNotFoundError("Neither 5- nor 15-minutely data was found.") + + # Find the delay in the 5- and 15-minutely data + if exists_5_minute: + datetimes_5min = get_satellite_timestamps(sat_5_path) + logger.info( + f"Latest 5-minute timestamp is {datetimes_5min.max()}. " + f"All the datetimes are: \n{datetimes_5min}", + ) + + if exists_15_minute: + datetimes_15min = get_satellite_timestamps(sat_15_path) + logger.info( + f"Latest 5-minute timestamp is {datetimes_15min.max()}. " + f"All the datetimes are: \n{datetimes_15min}", + ) + + # If both 5- and 15-minute data exists, use the most recent + if exists_5_minute and exists_15_minute: + use_5_minute = datetimes_5min.max() > datetimes_15min.max() + else: + # If only one exists, use that + use_5_minute = exists_5_minute + + # Store the choice in satellite data + self.use_5_minute = use_5_minute + + # Move the selected data to the expected path + if use_5_minute: + logger.info("Using 5-minutely data.") + selected_path = sat_5_path + else: + logger.info("Using 15-minutely data.") + selected_path = sat_15_path + + # Open and return the satellite data + with zarr.storage.ZipStore(selected_path) as store: + ds = xr.open_zarr(store).compute() + + return ds + + @staticmethod + def check_required_timestamps_available(ds: xr.Dataset, t0: pd.Timestamp) -> None: + available_timestamps = pd.to_datetime(ds.time.values) + + # Need 12 timestamps of 15 minutely data up to and including time t0 + expected_timestamps = pd.date_range(t0-pd.Timedelta("165min"), t0, freq="15min") + + timestamps_available = np.isin(expected_timestamps, available_timestamps) + + if not timestamps_available.all(): + missing_timestamps = expected_timestamps[~timestamps_available] + raise Exception( + "Some required timestamps missing\n" + f"Required timestamps: {expected_timestamps}\n" + f"Available timestamps: {timestamps_available}\n" + f"Missing timestamps: {missing_timestamps}", + ) diff --git a/tests/test_app.py b/tests/test_app.py index f984ec6..950d323 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -32,8 +32,8 @@ def test_app(sat_5_data, tmp_path, test_t0): assert "sat_pred" in ds_y_hat assert ( - sorted(list(ds_y_hat.sat_pred.coords))== - ["init_time", "step", "variable", "x_geostationary", "y_geostationary"] + list(ds_y_hat.sat_pred.dims)== + ["init_time", "variable", "step", "y_geostationary", "x_geostationary"] ) # Make sure all the coords are correct