-
Notifications
You must be signed in to change notification settings - Fork 96
Description
Hello!
I have been playing around with the Allen-Cahn example, which trained really well without any data but just the res_loss and ics_loss terms. I wanted to experiment adding a data loss term to the model, which I expect would help with the training as we give the model even more information. However, it actually led to very poor performance. Here is the visualisation of the result:
and the training loss history:
The data loss wouldn't go down, and the res_loss and ics_loss are also kept at significantly larger values. cas_weight also did not converge to 1 in this case. I tried turning on and off causal training and NTK weighting, which did not improve the performance. Using only the data loss and turning off res_loss and ics_loss also gave similarly poor performance.
I was wondering if you have had any experience with this? Any advice or suggestions would be much appreciated. Thank you in advance and I look forward to hearing from you.
Best wishes,
Annie
.
.
.
For your reference, the following are snippets of code relevant to adding the additional data term to the loss function.
In losses() I added:
def losses(self, params, batch):
# Unpack batch
data_coords_batch, data_batch = batch["data"]
res_batch = batch["res"]
...
# Data loss
u_pred_data = vmap(self.u_net, (None, 0, 0))(params, data_coords_batch[:, 0], data_coords_batch[:, 1])
data_loss = jnp.mean((data_batch - u_pred_data) ** 2)
loss_dict = {"ics": ics_loss, "res": res_loss, "data": data_loss}
return loss_dict
In compute_diag_ntk() I added:
def compute_diag_ntk(self, params, batch):
# Unpack batch
data_coords_batch, data_batch = batch["data"]
res_batch = batch["res"]
...
# Compute data_ntk
data_ntk = vmap(ntk_fn, (None, None, 0, 0))(
self.u_net, params, data_coords_batch[:, 0], data_coords_batch[:, 1]
)
ntk_dict = {"ics": ics_ntk, "res": res_ntk, "data": data_ntk}
return ntk_dict
where data_coords_batch is generated with:
class DataSampler(BaseSampler):
def __init__(self, coords, u, test_size, batch_size, train=True, rng_key=random.PRNGKey(1234)):
super().__init__(batch_size, rng_key)
self.u = u.flatten()
self.coords = coords
# Perform train-test split
coords_train, coords_test, u_train, _ = train_test_split(
self.coords, self.u, test_size=test_size, random_state=42)
self.u_train = jnp.array(u_train) # Convert to jnp array so that it can be indexed
self.selected_coords = coords_train if train else coords_test # Use either training or testing data
@partial(pmap, static_broadcasted_argnums=(0,))
def data_generation(self, key):
"Generates data containing batch_size samples"
idx = random.choice(
key, self.selected_coords.shape[0], shape=(self.batch_size,))
coords_batch = self.selected_coords[idx, :]
u_batch = self.u_train[idx, :]
batch = (data_coords_batch, u_batch)
return batch
