Skip to content

Commit

Permalink
Added inline docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
FredDeCeuster committed Aug 12, 2024
1 parent 1367747 commit c7d0962
Show file tree
Hide file tree
Showing 9 changed files with 1,218 additions and 133 deletions.
93 changes: 61 additions & 32 deletions docs/src/examples/3D_stellar_wind.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,38 +53,38 @@
"name": "stdout",
"output_type": "stream",
"text": [
"--2024-08-09 14:55:49-- https://raw.githubusercontent.com/Ensor-code/phantom-models/main/Malfait%2B2024a/v10e00/wind.setup\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"--2024-08-12 08:55:22-- https://raw.githubusercontent.com/Ensor-code/phantom-models/main/Malfait%2B2024a/v10e00/wind.setup\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1209 (1.2K) [text/plain]\n",
"Saving to: ‘3D_stellar_wind_data/wind.setup’\n",
"\n",
"3D_stellar_wind_dat 100%[===================>] 1.18K --.-KB/s in 0s \n",
"\n",
"2024-08-09 14:55:49 (37.8 MB/s) - ‘3D_stellar_wind_data/wind.setup’ saved [1209/1209]\n",
"2024-08-12 08:55:22 (38.8 MB/s) - ‘3D_stellar_wind_data/wind.setup’ saved [1209/1209]\n",
"\n",
"--2024-08-09 14:55:50-- https://raw.githubusercontent.com/Ensor-code/phantom-models/main/Malfait%2B2024a/v10e00/wind.in\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"--2024-08-12 08:55:23-- https://raw.githubusercontent.com/Ensor-code/phantom-models/main/Malfait%2B2024a/v10e00/wind.in\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 5579 (5.4K) [text/plain]\n",
"Saving to: ‘3D_stellar_wind_data/wind.in’\n",
"\n",
"3D_stellar_wind_dat 100%[===================>] 5.45K --.-KB/s in 0s \n",
"\n",
"2024-08-09 14:55:50 (32.0 MB/s) - ‘3D_stellar_wind_data/wind.in’ saved [5579/5579]\n",
"2024-08-12 08:55:23 (29.3 MB/s) - ‘3D_stellar_wind_data/wind.in’ saved [5579/5579]\n",
"\n",
"--2024-08-09 14:55:51-- https://raw.githubusercontent.com/Ensor-code/phantom-models/main/Malfait%2B2024a/v10e00/wind_v10e00\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"--2024-08-12 08:55:24-- https://raw.githubusercontent.com/Ensor-code/phantom-models/main/Malfait%2B2024a/v10e00/wind_v10e00\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 70236960 (67M) [application/octet-stream]\n",
"Saving to: ‘3D_stellar_wind_data/wind.dump’\n",
"\n",
"3D_stellar_wind_dat 100%[===================>] 66.98M 46.4MB/s in 1.4s \n",
"3D_stellar_wind_dat 100%[===================>] 66.98M 44.0MB/s in 1.5s \n",
"\n",
"2024-08-09 14:55:59 (46.4 MB/s) - ‘3D_stellar_wind_data/wind.dump’ saved [70236960/70236960]\n",
"2024-08-12 08:55:29 (44.0 MB/s) - ‘3D_stellar_wind_data/wind.dump’ saved [70236960/70236960]\n",
"\n"
]
}
Expand Down Expand Up @@ -383,7 +383,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "489b9ce1c0fa4c33bbe84c1f47738d3f",
"model_id": "b9433978f3d9491cbbe91a5cb283bd1b",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -397,7 +397,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77fddf4ee9e449b19c0edf2e0eb779b5",
"model_id": "676803bf63d3470f94c15ceae7ba99e0",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -411,7 +411,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8ee3c398df7242088564bd196141b293",
"model_id": "012175d314f34c60ac39273bdc9c8d61",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -425,7 +425,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e4bc431f27e43a999bcb05270231e77",
"model_id": "1365ec3f89a0437fa5336139cd38dff9",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -470,7 +470,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -494,7 +494,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -521,7 +521,10 @@
"\n",
"# Fix all parameters, except for the ones we want to fit\n",
"model.fix_all()\n",
"model.free(['log_H2', 'log_velocity_r', 'log_temperature'])"
"model.free(['log_H2', 'log_velocity_r', 'log_temperature'])\n",
"\n",
"# Save the initial guess\n",
"model.save('3D_stellar_wind_recon_init.h5')"
]
},
{
Expand All @@ -533,7 +536,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -570,7 +573,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -598,7 +601,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -634,7 +637,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -665,7 +668,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -696,7 +699,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -734,18 +737,24 @@
"## Experiments"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can fit the model to the synthetic data.\n",
"First, to determine the relative weights for each loss function in the total loss, we run 3 iterations and observe the loss values. Then, we define the weight of each individual loss function by the inverse of its current value (renormalisation step), such that in the following iterations they all contribute equally."
]
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/3 [00:00<?, ?it/s]/home/frederikd/.local/lib/python3.9/site-packages/torch/autograd/__init__.py:200: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 9010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n",
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n",
"100%|██████████| 3/3 [00:36<00:00, 12.09s/it]\n"
"100%|██████████| 3/3 [00:42<00:00, 14.11s/it]\n"
]
}
],
Expand All @@ -758,8 +767,19 @@
" w_smt = 1.0e+0,\n",
" w_cnt = 1.0e+0,\n",
")\n",
"losses.renormalise_all()\n",
"losses.reset()"
"losses.renormalise_all() # Renormalise the losses\n",
"losses.reset() # Reset the losses to renormalised values\n",
"\n",
"# Save the (partially) reconstructed model\n",
"gmodel_recon.model.save('3D_stellar_wind_recon_000.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the renormalised weights in place, we now run 100 iterations of the reconstruction algorithm.\n",
"Note that one can still specify relative weights for the losses, these are an additional factor to the renormalisation. Here, we take them all equal to one, such that all losses contribute equally."
]
},
{
Expand Down Expand Up @@ -796,9 +816,17 @@
" w_smt = 1.0e+0,\n",
" w_cnt = 1.0e+0,\n",
")\n",
"gmodel_recon.model.save('3D_stellar_wind_recon_100.h5')\n",
"losses.plot()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we run an additional 500 itereations of the reconstruction algortihm."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -831,6 +859,7 @@
" w_smt = 1.0e+0,\n",
" w_cnt = 1.0e+0,\n",
")\n",
"gmodel_recon.model.save('3D_stellar_wind_recon_600.h5')\n",
"losses.plot()"
]
}
Expand Down
Loading

0 comments on commit c7d0962

Please sign in to comment.