Skip to content

Adding data loss term led to very poor performance #28

@annien094

Description

@annien094

Hello!

I have been playing around with the Allen-Cahn example, which trained really well without any data but just the res_loss and ics_loss terms. I wanted to experiment adding a data loss term to the model, which I expect would help with the training as we give the model even more information. However, it actually led to very poor performance. Here is the visualisation of the result:

Image

and the training loss history:

Image Image

The data loss wouldn't go down, and the res_loss and ics_loss are also kept at significantly larger values. cas_weight also did not converge to 1 in this case. I tried turning on and off causal training and NTK weighting, which did not improve the performance. Using only the data loss and turning off res_loss and ics_loss also gave similarly poor performance.

I was wondering if you have had any experience with this? Any advice or suggestions would be much appreciated. Thank you in advance and I look forward to hearing from you.

Best wishes,
Annie
.
.
.
For your reference, the following are snippets of code relevant to adding the additional data term to the loss function.
In losses() I added:

def losses(self, params, batch):

        # Unpack batch
        data_coords_batch, data_batch = batch["data"]
        res_batch = batch["res"]

        ...
        
        # Data loss
        u_pred_data = vmap(self.u_net, (None, 0, 0))(params, data_coords_batch[:, 0], data_coords_batch[:, 1])
        data_loss = jnp.mean((data_batch - u_pred_data) ** 2)

        loss_dict = {"ics": ics_loss, "res": res_loss, "data": data_loss}
        return loss_dict

In compute_diag_ntk() I added:

    def compute_diag_ntk(self, params, batch):
        # Unpack batch
        data_coords_batch, data_batch = batch["data"]
        res_batch = batch["res"]

        ...
        # Compute data_ntk
        data_ntk = vmap(ntk_fn, (None, None, 0, 0))(
            self.u_net, params, data_coords_batch[:, 0], data_coords_batch[:, 1]
        )

        ntk_dict = {"ics": ics_ntk, "res": res_ntk, "data": data_ntk}

        return ntk_dict

where data_coords_batch is generated with:

class DataSampler(BaseSampler):
    def __init__(self, coords, u, test_size, batch_size, train=True, rng_key=random.PRNGKey(1234)):
        super().__init__(batch_size, rng_key)
        self.u = u.flatten()
        self.coords = coords

        # Perform train-test split
        coords_train, coords_test, u_train, _ = train_test_split(
            self.coords, self.u, test_size=test_size, random_state=42)
        self.u_train = jnp.array(u_train)  # Convert to jnp array so that it can be indexed
        self.selected_coords = coords_train if train else coords_test # Use either training or testing data

    @partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation(self, key):
        "Generates data containing batch_size samples"
        idx = random.choice(
            key, self.selected_coords.shape[0], shape=(self.batch_size,))
        coords_batch = self.selected_coords[idx, :]
        u_batch = self.u_train[idx, :]
        batch = (data_coords_batch, u_batch)
        return batch

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions