Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added train by epoch for Trainer and added support for texts #12

Open
wants to merge 122 commits into
base: main
Choose a base branch
from

Conversation

MarcusLoppe
Copy link
Contributor

Trainer

  • Added train function that will train with epochs instead of steps
    - Added option to display graph (maybe remove since it requires matlib?)
  • Added checkpoint_every_epoch to init so checkpoints can be saved every X.

Data.py

  • Modified custom_collate so it wont pad if the data isn't a tensor

MeshAutoencoder

  • Was crashing if texts was in the args, added dummy parameterto prevent this.

Setup.py

  • The setup file was missing a comma

@lucidrains
Copy link
Owner

this is awesome Marcus! will take a look at it tomorrow morning and get it merged!

@adeerAI
Copy link

adeerAI commented Dec 14, 2023

Awesome work, thanks for including your suggestions to the main, this allows better understanding on the user's side.

@@ -741,6 +741,7 @@ def forward(
vertices: TensorType['b', 'nv', 3, float],
faces: TensorType['b', 'nf', 3, int],
face_edges: Optional[TensorType['b', 'e', 2, int]] = None,
texts: Optional[List[str]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so the text is actually only conditioned through the transformer stage through cross attention

basically the autoencoder is given the job of only compressing meshes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know :) But if you pass it a dict with texts it will give a error since the arg doesnt exist.
So then you would need two dataset classes.

Either replace the model(**forward_args) so it uses the prarameters directly:
model(vertices = data["vertices"], faces = data["faces"])

Or just implement a dummy texts :) There is probably a better solution

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh got it! yea, i can take care of that within the trainer class (just scrub out the text and text_embed keys)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that will work, I'm not 100% since the dataloader passes the data and maybe copies it(?).

But it won't work if you access it without copying it since the dataset is returning the data and not copying/cloning, when you do del on a key, it will remove it completely from the dataset.
So if you train the encoder and then want to train a transformer, you'll need to recreate the dataset since the texts key is removed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer if the dataset returns the text with each vertices and faces.

@@ -367,7 +370,63 @@ def forward(self):
self.wait()

self.print('training complete')
def train(self, num_epochs, diplay_graph = False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small typo here diplay



self.print('Training complete')
if diplay_graph:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i haven't documented this, but you can already use wandb.ai experiment tracker

you just have to do

trainer = Trainer(..., use_wandb_tracking = True)

with trainer.trackers('meshgpt', 'one-experiment-name'):
  trainer.train()

@MarcusLoppe
Copy link
Contributor Author

Btw since I don't really think grad_accum_every is very useful I removed it from the train function, what is your option?

I forgot and left grad_accum_every in the loss function, so if it wont be used in the train function it should be removed from:

self.accelerator.backward(loss / self.grad_accum_every)

@lucidrains
Copy link
Owner

Btw since I don't really think grad_accum_every is very useful I removed it from the train function, what is your option?

I forgot and left grad_accum_every in the loss function, so if it wont be used in the train function it should be removed from:

self.accelerator.backward(loss / self.grad_accum_every)

i'm sure researchers will want to stretch to the next level if this approach pans out (multiple meshes, scenes etc)

probably good to keep it for the gpu poor

@MarcusLoppe
Copy link
Contributor Author

Another thing :) I'm not very experienced in using github forks but it seems like the pull request added later commits then when I made the request.

I made bit of a error and replaced entire meshgpt_pytorch.py since there was some weird stash thing and I wanted to ensure it was up to date. I reverted but it seems like that stash thing messed it up bit, please double check if this is the case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants