Mercurial > repos > siwaa > wine_quality_train_eval
comparison toto_wine_quality_train_eval.py @ 12:60778af2dd78 draft
planemo upload for repository https://forgemia.inra.fr/nathalie.rousse/use/-/tree/dnn/DNN/galaxy-tools/wine_quality_train_eval commit e7fd13c34ec074a7ebc246301b5a80069dcbcc3a-dirty
author | siwaa |
---|---|
date | Thu, 05 Dec 2024 16:03:49 +0000 |
parents | b5f69f836e03 |
children | dd7d99707a65 |
comparison
equal
deleted
inserted
replaced
11:b5f69f836e03 | 12:60778af2dd78 |
---|---|
98 import os | 98 import os |
99 import lightning.pytorch as pl | 99 import lightning.pytorch as pl |
100 import torchvision.transforms as T | 100 import torchvision.transforms as T |
101 ##toto## from IPython.display import display, HTML ##toto## HTML | 101 ##toto## from IPython.display import display, HTML ##toto## HTML |
102 from torch.utils.data import DataLoader, random_split | 102 from torch.utils.data import DataLoader, random_split |
103 from model_wine_lightning.modules.progressbar import CustomTrainProgressBar | 103 ##toto## from model_wine_lightning.modules.progressbar import CustomTrainProgressBar |
104 from model_wine_lightning.modules.data_load import WineQualityDataset | 104 from model_wine_lightning.modules.data_load import WineQualityDataset |
105 from model_wine_lightning.modules.data_load import Normalize, ToTensor | 105 from model_wine_lightning.modules.data_load import Normalize, ToTensor |
106 from model_wine_lightning.modules.model import LitRegression | 106 from model_wine_lightning.modules.model import LitRegression |
107 from lightning.pytorch.loggers.tensorboard import TensorBoardLogger | 107 from lightning.pytorch.loggers.tensorboard import TensorBoardLogger |
108 import fidle | 108 import fidle |
275 logger= TensorBoardLogger(save_dir=f'{run_dir}/logs', name="reg_logs") | 275 logger= TensorBoardLogger(save_dir=f'{run_dir}/logs', name="reg_logs") |
276 | 276 |
277 # train model | 277 # train model |
278 trainer = pl.Trainer(accelerator='auto', max_epochs=100, | 278 trainer = pl.Trainer(accelerator='auto', max_epochs=100, |
279 logger=logger, num_sanity_val_steps=0, | 279 logger=logger, num_sanity_val_steps=0, |
280 callbacks=[savemodel_callback,CustomTrainProgressBar()]) | 280 callbacks=[savemodel_callback]) |
281 ##toto##callbacks=[savemodel_callback,CustomTrainProgressBar()]) | |
281 trainer.fit(model=reg, train_dataloaders=train_loader, | 282 trainer.fit(model=reg, train_dataloaders=train_loader, |
282 val_dataloaders=test_loader) | 283 val_dataloaders=test_loader) |
283 | 284 |
284 # ## Step 6 - Evaluate it | 285 # ## Step 6 - Evaluate it |
285 print("\n"+HEAD,"# ## Step 6 - Evaluate it\n") | 286 print("\n"+HEAD,"# ## Step 6 - Evaluate it\n") |