Description
An extremely common task in climate science it to calculate a climatology (seasonal average) of some statistics from a long spatio-temporal dataset. This is exactly what I am trying to do in my pangeo use case. To make this concrete, here is an example reproducible on pangeo.pydata.org
import xarray as xr
import gcsfs
ds = xr.open_zarr(gcsfs.GCSMap('pangeo-data/cm2.6/control/temp_salt_u_v-5day_avg')
the dataset looks like this
<xarray.Dataset>
Dimensions: (nv: 2, st_edges_ocean: 51, st_ocean: 50, time: 1460, xt_ocean: 3600, xu_ocean: 3600, yt_ocean: 2700, yu_ocean: 2700)
Coordinates:
* nv (nv) float64 1.0 2.0
* st_edges_ocean (st_edges_ocean) float64 0.0 10.07 20.16 30.29 40.47 ...
* st_ocean (st_ocean) float64 5.034 15.1 25.22 35.36 45.58 55.85 ...
* time (time) object 0181-01-03 12:00:00 0181-01-08 12:00:00 ...
* xt_ocean (xt_ocean) float64 -279.9 -279.8 -279.7 -279.6 -279.5 ...
* xu_ocean (xu_ocean) float64 -279.9 -279.8 -279.7 -279.6 -279.5 ...
* yt_ocean (yt_ocean) float64 -81.11 -81.07 -81.02 -80.98 -80.94 ...
* yu_ocean (yu_ocean) float64 -81.09 -81.05 -81.0 -80.96 -80.92 ...
Data variables:
salt (time, st_ocean, yt_ocean, xt_ocean) float32 dask.array<shape=(1460, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
temp (time, st_ocean, yt_ocean, xt_ocean) float32 dask.array<shape=(1460, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
u (time, st_ocean, yu_ocean, xu_ocean) float32 dask.array<shape=(1460, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
v (time, st_ocean, yu_ocean, xu_ocean) float32 dask.array<shape=(1460, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
Its size is close to 12 TB uncompressed.
Calculating the climatology is trivial
ds_mm_clim = ds.groupby('time.month').mean(dim='time')
giving
<xarray.Dataset>
Dimensions: (month: 12, nv: 2, st_edges_ocean: 51, st_ocean: 50, xt_ocean: 3600, xu_ocean: 3600, yt_ocean: 2700, yu_ocean: 2700)
Coordinates:
* nv (nv) float64 1.0 2.0
* st_edges_ocean (st_edges_ocean) float64 0.0 10.07 20.16 30.29 40.47 ...
* st_ocean (st_ocean) float64 5.034 15.1 25.22 35.36 45.58 55.85 ...
* xt_ocean (xt_ocean) float64 -279.9 -279.8 -279.7 -279.6 -279.5 ...
* xu_ocean (xu_ocean) float64 -279.9 -279.8 -279.7 -279.6 -279.5 ...
* yt_ocean (yt_ocean) float64 -81.11 -81.07 -81.02 -80.98 -80.94 ...
* yu_ocean (yu_ocean) float64 -81.09 -81.05 -81.0 -80.96 -80.92 ...
* month (month) int64 1 2 3 4 5 6 7 8 9 10 11 12
Data variables:
salt (month, st_ocean, yt_ocean, xt_ocean) float32 dask.array<shape=(12, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
temp (month, st_ocean, yt_ocean, xt_ocean) float32 dask.array<shape=(12, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
u (month, st_ocean, yu_ocean, xu_ocean) float32 dask.array<shape=(12, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
v (month, st_ocean, yu_ocean, xu_ocean) float32 dask.array<shape=(12, 50, 2700, 3600), chunksize=(1, 1, 2700, 3600)>
Now I want to either persist this or, even better, save it as a new dataset
# hack to get writable gcsfs target
fs_w_permissions = gcsfs.GCSFileSystem(token='browser')
token = fs_w_permissions.session.credentials
fs_w_token = gcsfs.GCSFileSystem(token=token)
output_path = data_path[:-1] + '-monthly-climatology'
gcsmap_output = gcsfs.mapping.GCSMap(output_path, gcs=fs_w_token)
# save to zarr
save_future = ds_mm_clim.to_zarr(gcsmap_output, compute=False)
save_future.compute(retries=10)
In my head, this should be a pretty "streamable" operation: load aggregate all values of variable salt
for month==1
(January), st_ocean==0
(the vertical level); store; and move on to the next variable / month / level.
However, these computations do not run very well on the current stack. The rate of reading data outpaces the rate of writing data, leading to huge memory consumption. This is evident in the dashboard screenshot below:
Here I have a cluster of 80 high-memory workers: 22GB RAM each, a total of 1.76 TB. Yet the cluster has nearly 7TB in memory. The workers are spilling lots of data to disk (I didn't even realize that the workers had significant hard drive space). This seems inefficient, although perhaps there is some logic to it that I don't grasp.
This has been discussed in numerous previous issues.
- xarray groupby monthly mean fail case #99
- groupby on dask objects doesn't handle chunks well pydata/xarray#1832
- Scheduler fail case: centering data with dask.array dask/dask#874
In many ways, this is a duplicate of those issues. However, those issues are also muddled up with problems with worker / scheduler config settings that we have mostly overcome on pangeo.pydata.org. Here workers are not dying--they are just operating in what appears to be a sub-optimal way.
Perhaps now is the time to tackle how this operation is scheduled at the dask level? Or perhaps there is not actually a problem at all? The calculation is slowly ticking forward in an evidently stable way.