Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi time step prediction #95

Open
ElNino9495 opened this issue Sep 9, 2024 · 1 comment
Open

Multi time step prediction #95

ElNino9495 opened this issue Sep 9, 2024 · 1 comment

Comments

@ElNino9495
Copy link

Im struggling to implement a multi time step model where i use my prediction outputs as an input for the next iteration. Can someone please help me on this.

@ElNino9495
Copy link
Author

Load model parameters and configuration

with open('/data4/home/rohitsuresh/graphcast_/model/params/params-GraphCast_operational-ERA5-HRES_1979-2021-resolution_0.25-pressure_levels_13-mesh_2to6-precipitation_output_only.npz', 'rb') as model:
ckpt = checkpoint.load(model, graphcast.CheckPoint)
params = ckpt.params
model_config = ckpt.model_config
task_config = ckpt.task_config

Load statistics

with open('/data4/home/rohitsuresh/graphcast_/model/stats/stats-diffs_stddev_by_level.nc', 'rb') as f:
diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open('/data4/home/rohitsuresh/graphcast_/model/stats/stats-mean_by_level.nc', 'rb') as f:
mean_by_level = xarray.load_dataset(f).compute()
with open('/data4/home/rohitsuresh/graphcast_/model/stats/stats-stddev_by_level.nc', 'rb') as f:
stddev_by_level = xarray.load_dataset(f).compute()

def construct_graphcast(model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
predictor = graphcast.GraphCast(model_config, task_config)
predictor = casting.Bfloat16Cast(predictor)
predictor = normalization.InputsAndResiduals(
predictor,
diffs_stddev_by_level=diffs_stddev_by_level,
mean_by_level=mean_by_level,
stddev_by_level=stddev_by_level
)
return predictor

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
predictor = construct_graphcast(model_config, task_config)
return predictor(inputs, targets_template=targets_template, forcings=forcings)

def with_configs(fn):
return functools.partial(fn, model_config=model_config, task_config=task_config)

def with_params(fn):
return functools.partial(fn, params=params, state={})

def drop_state(fn):
return lambda **kw: fn(**kw)[0]

run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))

class Predictor:
@classmethod
def predict(cls, inputs, targets, forcings) -> xarray.Dataset:
predictions = rollout.chunked_prediction(
predictor_fn=run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=inputs,
targets_template=targets,
forcings=forcings,
num_steps_per_chunk=2 # Adjust this value based on your needs
)
return predictions

Assuming inputs, targets, and forcings are already prepared

predictions = Predictor.predict(inputs, targets, forcings)
predictions.to_dataframe().to_csv('predictions2024_0.25_step3.csv', sep=',')

ValueError: 'grid2mesh_gnn/_networks_builder/encoder_nodes_grid_nodes_mlp//linear_0/w' with retrieved shape (184, 512) does not match shape=[189, 512] dtype=dtype(bfloat16)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant