comparison toto_wine_quality_train_eval.py @ 12:60778af2dd78 draft

planemo upload for repository https://forgemia.inra.fr/nathalie.rousse/use/-/tree/dnn/DNN/galaxy-tools/wine_quality_train_eval commit e7fd13c34ec074a7ebc246301b5a80069dcbcc3a-dirty
author siwaa
date Thu, 05 Dec 2024 16:03:49 +0000
parents b5f69f836e03
children dd7d99707a65
comparison
equal deleted inserted replaced
11:b5f69f836e03 12:60778af2dd78
98 import os 98 import os
99 import lightning.pytorch as pl 99 import lightning.pytorch as pl
100 import torchvision.transforms as T 100 import torchvision.transforms as T
101 ##toto## from IPython.display import display, HTML ##toto## HTML 101 ##toto## from IPython.display import display, HTML ##toto## HTML
102 from torch.utils.data import DataLoader, random_split 102 from torch.utils.data import DataLoader, random_split
103 from model_wine_lightning.modules.progressbar import CustomTrainProgressBar 103 ##toto## from model_wine_lightning.modules.progressbar import CustomTrainProgressBar
104 from model_wine_lightning.modules.data_load import WineQualityDataset 104 from model_wine_lightning.modules.data_load import WineQualityDataset
105 from model_wine_lightning.modules.data_load import Normalize, ToTensor 105 from model_wine_lightning.modules.data_load import Normalize, ToTensor
106 from model_wine_lightning.modules.model import LitRegression 106 from model_wine_lightning.modules.model import LitRegression
107 from lightning.pytorch.loggers.tensorboard import TensorBoardLogger 107 from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
108 import fidle 108 import fidle
275 logger= TensorBoardLogger(save_dir=f'{run_dir}/logs', name="reg_logs") 275 logger= TensorBoardLogger(save_dir=f'{run_dir}/logs', name="reg_logs")
276 276
277 # train model 277 # train model
278 trainer = pl.Trainer(accelerator='auto', max_epochs=100, 278 trainer = pl.Trainer(accelerator='auto', max_epochs=100,
279 logger=logger, num_sanity_val_steps=0, 279 logger=logger, num_sanity_val_steps=0,
280 callbacks=[savemodel_callback,CustomTrainProgressBar()]) 280 callbacks=[savemodel_callback])
281 ##toto##callbacks=[savemodel_callback,CustomTrainProgressBar()])
281 trainer.fit(model=reg, train_dataloaders=train_loader, 282 trainer.fit(model=reg, train_dataloaders=train_loader,
282 val_dataloaders=test_loader) 283 val_dataloaders=test_loader)
283 284
284 # ## Step 6 - Evaluate it 285 # ## Step 6 - Evaluate it
285 print("\n"+HEAD,"# ## Step 6 - Evaluate it\n") 286 print("\n"+HEAD,"# ## Step 6 - Evaluate it\n")