Skip to content

Integrate prediction head into model's forward pass for loss_fn #6

@clementpoiret

Description

@clementpoiret

Ensure the loss_fn in hackathon/objectdetection.py correctly utilizes the newly implemented prediction head by calling eqx.combine(params, static) and then jax.vmap(model, ...) to get the detection logits.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions