Mercurial > repos > siwaa > wine_quality_train_eval
diff toto_wine_quality_train_eval.py @ 12:60778af2dd78 draft
planemo upload for repository https://forgemia.inra.fr/nathalie.rousse/use/-/tree/dnn/DNN/galaxy-tools/wine_quality_train_eval commit e7fd13c34ec074a7ebc246301b5a80069dcbcc3a-dirty
author | siwaa |
---|---|
date | Thu, 05 Dec 2024 16:03:49 +0000 |
parents | b5f69f836e03 |
children | dd7d99707a65 |
line wrap: on
line diff
--- a/toto_wine_quality_train_eval.py Thu Dec 05 15:55:00 2024 +0000 +++ b/toto_wine_quality_train_eval.py Thu Dec 05 16:03:49 2024 +0000 @@ -100,7 +100,7 @@ import torchvision.transforms as T ##toto## from IPython.display import display, HTML ##toto## HTML from torch.utils.data import DataLoader, random_split -from model_wine_lightning.modules.progressbar import CustomTrainProgressBar +##toto## from model_wine_lightning.modules.progressbar import CustomTrainProgressBar from model_wine_lightning.modules.data_load import WineQualityDataset from model_wine_lightning.modules.data_load import Normalize, ToTensor from model_wine_lightning.modules.model import LitRegression @@ -277,7 +277,8 @@ # train model trainer = pl.Trainer(accelerator='auto', max_epochs=100, logger=logger, num_sanity_val_steps=0, - callbacks=[savemodel_callback,CustomTrainProgressBar()]) + callbacks=[savemodel_callback]) + ##toto##callbacks=[savemodel_callback,CustomTrainProgressBar()]) trainer.fit(model=reg, train_dataloaders=train_loader, val_dataloaders=test_loader)