Mercurial > repos > siwaa > wine_quality
comparison wine_quality.py @ 0:143b15001522 draft
planemo upload for repository https://forgemia.inra.fr/nathalie.rousse/use/-/tree/dnn/DNN/galaxy-tools/wine_quality commit e7c4e447552083db7eaecbdf139a7c359fe9becc
| author | siwaa |
|---|---|
| date | Wed, 04 Dec 2024 15:25:26 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:143b15001522 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 # coding: utf-8 | |
| 3 | |
| 4 ######### ######### ######### ######### ######### ######### ######### | |
| 5 # | |
| 6 # File created from fidlemore/model_wine_lightning/wine_quality_predict.py | |
| 7 # | |
| 8 # Modifications : | |
| 9 # | |
| 10 # - restore sanitized text for -data value | |
| 11 # | |
| 12 # - quality.txt file containing only quality value (extracted from report) | |
| 13 # | |
| 14 ######### ######### ######### ######### ######### ######### ######### | |
| 15 | |
| 16 ############################################################################### | |
| 17 # Module : model_wine_lightning | |
| 18 # | |
| 19 # This code has been extracted from 01-DNN-Wine-Regression-lightning.ipynb | |
| 20 # (fidle-tp/fidle-master-3.0.11/Wine.Lightning) then modified. | |
| 21 # Only last part kept : | |
| 22 # - Restore saved model from checkpoint | |
| 23 # - Evaluate the model not kept | |
| 24 # - Make some predictions 1 prediction | |
| 25 # | |
| 26 # Inputs : | |
| 27 # | |
| 28 # Data of wine for which quality is going to be predicted is given | |
| 29 # by -data_json_filepath or by -data (only one of both). | |
| 30 # | |
| 31 # -data_json_filepath : data file path (.json) containing data. | |
| 32 # -data : data (string format) | |
| 33 # | |
| 34 # Example of wine data : { "fixed acidity": 11.2, | |
| 35 # "volatile acidity": 0.28, | |
| 36 # "citric acid": 0.56, | |
| 37 # "residual sugar": 1.9, | |
| 38 # "chlorides": 0.075, | |
| 39 # "free sulfur dioxide": 17, | |
| 40 # "total sulfur dioxide": 60, | |
| 41 # "density": 0.998, | |
| 42 # "pH": 3.16, | |
| 43 # "sulphates": 0.58, | |
| 44 # "alcohol": 9.8 } | |
| 45 # | |
| 46 # -model_ckpt_filepath : checkpoint model file path (.ckpt) to be loaded. | |
| 47 # | |
| 48 # -norm_config_json_filepath : normalization configuration file (.json) | |
| 49 # containing information (norm_config) that has been returned by the model | |
| 50 # wine_quality_train_eval running. | |
| 51 # | |
| 52 # Outputs : | |
| 53 # | |
| 54 # Output files under "OUTPUTS" folder (must exist !!!) | |
| 55 # | |
| 56 # - Quality prediction value (float) | |
| 57 # | |
| 58 # - Report file (report_json_filepath) (.json) containing: | |
| 59 # - Quality prediction value | |
| 60 # - Wine data | |
| 61 # - error message, more message, warning message | |
| 62 # | |
| 63 # - Log files into Wine.Lightning/run/LWINE1/logs/reg_logs | |
| 64 # | |
| 65 # - Screen display containing running information | |
| 66 # | |
| 67 ############################################################################### | |
| 68 | |
| 69 # <img width="800px" src="../fidle/img/header.svg"></img> | |
| 70 # | |
| 71 # # <!-- TITLE --> [LWINE1] - Wine quality prediction with a Dense Network (DNN) | |
| 72 # <!-- DESC --> Another example of regression, with a wine quality prediction, using PyTorch Lightning | |
| 73 # <!-- AUTHOR : Achille Mbogol Touye (EFFILIA-MIAI/SIMaP) --> | |
| 74 # | |
| 75 # ## Objectives : | |
| 76 # - Predict the **quality of wines**, based on their analysis | |
| 77 # - Understanding the principle and the architecture of a regression with a dense neural network with backup and restore of the trained model. | |
| 78 # | |
| 79 # The **[Wine Quality datasets](https://archive.ics.uci.edu/ml/datasets/wine+Quality)** are made up of analyses of a large number of wines, with an associated quality (between 0 and 10) | |
| 80 # This dataset is provide by : | |
| 81 # Paulo Cortez, University of Minho, GuimarĂ£es, Portugal, http://www3.dsi.uminho.pt/pcortez | |
| 82 # A. Cerdeira, F. Almeida, T. Matos and J. Reis, Viticulture Commission of the Vinho Verde Region(CVRVV), Porto, Portugal, @2009 | |
| 83 # This dataset can be retreive at [University of California Irvine (UCI)](https://archive-beta.ics.uci.edu/ml/datasets/wine+quality) | |
| 84 # | |
| 85 # | |
| 86 # Due to privacy and logistic issues, only physicochemical and sensory variables are available | |
| 87 # There is no data about grape types, wine brand, wine selling price, etc. | |
| 88 # | |
| 89 # - fixed acidity | |
| 90 # - volatile acidity | |
| 91 # - citric acid | |
| 92 # - residual sugar | |
| 93 # - chlorides | |
| 94 # - free sulfur dioxide | |
| 95 # - total sulfur dioxide | |
| 96 # - density | |
| 97 # - pH | |
| 98 # - sulphates | |
| 99 # - alcohol | |
| 100 # - quality (score between 0 and 10) | |
| 101 # | |
| 102 # ## What we're going to do : | |
| 103 # | |
| 104 # - (Retrieve data) | |
| 105 # - (Preparing the data) | |
| 106 # - (Build a model) | |
| 107 # - Train and save the model | |
| 108 # - Restore saved model | |
| 109 # - Evaluate the model | |
| 110 # - Make some predictions | |
| 111 # | |
| 112 HEAD = "[wine_quality/wine_quality]" | |
| 113 | |
| 114 # ## Step 1 - Import and init | |
| 115 print("\n"+HEAD,"# ## Step 1 - Import and init\n") | |
| 116 | |
| 117 # Import some packages | |
| 118 import os | |
| 119 import pandas as pd | |
| 120 import torch | |
| 121 import torchvision.transforms as T | |
| 122 from model_wine_lightning.modules.data_load import NormalizeFeatures | |
| 123 from model_wine_lightning.modules.data_load import ToTensorFeatures | |
| 124 from model_wine_lightning.modules.model import LitRegression | |
| 125 import fidle | |
| 126 import json | |
| 127 import argparse | |
| 128 from pprint import pprint | |
| 129 | |
| 130 OUTPUTS_PATH = "OUTPUTS" # must exit !!! | |
| 131 | |
| 132 error_msg, warn_msg, more_msg = "", "", "" # default | |
| 133 | |
| 134 report_json_filepath = os.path.join(OUTPUTS_PATH, "report.json") | |
| 135 quality_txt_filepath = os.path.join(OUTPUTS_PATH, "quality.txt") | |
| 136 report = dict() # init | |
| 137 pred = 99.99 # default | |
| 138 | |
| 139 data_json_filepath, data = None, None | |
| 140 model_ckpt_filepath = None | |
| 141 norm_config_json_filepath, norm_config = None, None | |
| 142 | |
| 143 try: | |
| 144 def is_not_given(argument): | |
| 145 return ((argument is None) or (argument == 'None')) | |
| 146 def is_given(argument): | |
| 147 return ((argument is not None) and (argument != 'None')) | |
| 148 | |
| 149 if not os.path.exists(OUTPUTS_PATH): # cas isfile non traite | |
| 150 os.mkdir(OUTPUTS_PATH) | |
| 151 message = "Outputs folder '"+OUTPUTS_PATH+" does not exist => created." | |
| 152 warn_msg += message + " " | |
| 153 print(HEAD, "Warning :", message) | |
| 154 | |
| 155 if not os.path.isdir(OUTPUTS_PATH): | |
| 156 message = "Outputs folder '" + OUTPUTS_PATH + "' must exist." | |
| 157 error_msg += message + " " | |
| 158 raise Exception(message) | |
| 159 | |
| 160 # ## INPUTS | |
| 161 print("\n"+HEAD, "# ## INPUTS\n") | |
| 162 | |
| 163 desc_text = "Predict Quality of a Wine" | |
| 164 parser = argparse.ArgumentParser(prog='wine_quality_predict', | |
| 165 description=desc_text) | |
| 166 | |
| 167 help_required_keys = "\"fixed acidity\";\"volatile acidity\";\"citric acid\";\"residual sugar\";\"chlorides\";\"free sulfur dioxide\";\"total sulfur dioxide\";\"density\";\"pH\";\"sulphates\";\"alcohol\"" | |
| 168 help_text = "data file path (.json), required keys:"+help_required_keys | |
| 169 parser.add_argument("-data_json_filepath", type=str, help=help_text) | |
| 170 help_text = "data in string format, required keys:"+help_required_keys | |
| 171 parser.add_argument("-data", type=str, help=help_text) | |
| 172 | |
| 173 help_text = "checkpoint model file path (.ckpt)" | |
| 174 parser.add_argument("-model_ckpt_filepath", type=str, help=help_text) | |
| 175 | |
| 176 help_text = "normalization configuration file path (.json), " | |
| 177 help_text += "information returned by wine_quality_train_eval runnig." | |
| 178 parser.add_argument("-norm_config_json_filepath", type=str, help=help_text) | |
| 179 | |
| 180 args = parser.parse_args() | |
| 181 # 1 and only 1 among -data_json_filepath and -data | |
| 182 if is_given(args.data_json_filepath) and is_given(args.data): | |
| 183 message = "Both -data_json_filepath and -data given " | |
| 184 message += "(1 and only 1 of -data_json_filepath and -data expected) " | |
| 185 message += "=> STOP." | |
| 186 error_msg += message + " " | |
| 187 raise Exception(message) | |
| 188 if is_not_given(args.data_json_filepath) and is_not_given(args.data): | |
| 189 message = "NO data_json_filepath and NO -data given. " | |
| 190 message += "(1 and only 1 of -data_json_filepath and -data expected) " | |
| 191 message += "=> STOP." | |
| 192 error_msg += message + " " | |
| 193 raise Exception(message) | |
| 194 | |
| 195 path = args.data_json_filepath | |
| 196 if is_given(path) : | |
| 197 if os.path.isfile(path) : | |
| 198 data_json_filepath = path | |
| 199 print(HEAD, "data_json_filepath used :", data_json_filepath) | |
| 200 try : | |
| 201 inputfile = open(data_json_filepath, 'r') | |
| 202 data = json.load(inputfile) | |
| 203 except: | |
| 204 message = "Failed to get json data from " | |
| 205 message += "'" + data_json_filepath+ "'" + " file." | |
| 206 error_msg += message + " " | |
| 207 raise Exception(message) | |
| 208 else : | |
| 209 message = path+ "data_json_filepath file not found => STOP." | |
| 210 error_msg += message + "" | |
| 211 raise Exception(message) | |
| 212 | |
| 213 if is_given(args.data) : | |
| 214 data_text = args.data | |
| 215 try : | |
| 216 | |
| 217 # restore sanitized text | |
| 218 MAPPING = {'>': '__gt__', '<': '__lt__', "'": '__sq__', '"': '__dq__', '[': '__ob__', ']': '__cb__', '{': '__oc__', '}': '__cc__', '@': '__at__', '\n': '__cn__', '\r': '__cr__', '\t': '__tc__', '#': '__pd__'} | |
| 219 for key, value in MAPPING.items(): | |
| 220 data_text = data_text.replace(value, key) | |
| 221 data = json.loads(data_text) # get data | |
| 222 | |
| 223 except: | |
| 224 message = "Failed to get json data from string '"+data_text+"'" | |
| 225 error_msg += message + " " | |
| 226 raise Exception(message) | |
| 227 | |
| 228 path = args.model_ckpt_filepath | |
| 229 if is_given(path) : | |
| 230 if os.path.isfile(path) : | |
| 231 model_ckpt_filepath = path | |
| 232 print(HEAD, "model_ckpt_filepath used :", model_ckpt_filepath) | |
| 233 else : | |
| 234 message = path+ "model_ckpt_filepath file not found => STOP." | |
| 235 error_msg += message + " " | |
| 236 raise Exception(message) | |
| 237 else: | |
| 238 message = "NO model_ckpt_filepath given => STOP." | |
| 239 error_msg += message + "" | |
| 240 raise Exception(message) | |
| 241 | |
| 242 path = args.norm_config_json_filepath | |
| 243 if is_given(path) : | |
| 244 if os.path.isfile(path) : | |
| 245 norm_config_json_filepath = path | |
| 246 print(HEAD, "norm_config_json_filepath used :", | |
| 247 norm_config_json_filepath) | |
| 248 try : | |
| 249 inputfile = open(norm_config_json_filepath, 'r') | |
| 250 norm_config = json.load(inputfile) | |
| 251 except: | |
| 252 message = "Failed to get json norm_config from " | |
| 253 message += "'" + norm_config_json_filepath+ "'" + "file." | |
| 254 error_msg += message + " " | |
| 255 raise Exception(message) | |
| 256 else : | |
| 257 message = path+ "norm_config_json_filepath file not found => STOP." | |
| 258 error_msg += message + " " | |
| 259 raise Exception(message) | |
| 260 else: | |
| 261 message = "NO norm_config_json_filepath given => STOP." | |
| 262 error_msg += message + "" | |
| 263 raise Exception(message) | |
| 264 | |
| 265 print(HEAD, "INPUTS:") | |
| 266 print("- data:", data) | |
| 267 print("- model checkpoint:", model_ckpt_filepath) | |
| 268 print("- norm_config:", norm_config) | |
| 269 | |
| 270 # Init Fidle environment | |
| 271 print("\n"+HEAD, "# Init Fidle environment\n") | |
| 272 run_id, run_dir, datasets_dir = fidle.init('LWINE1_predict') | |
| 273 | |
| 274 # Verbosity during training : | |
| 275 # - 0 = silent | |
| 276 # - 1 = progress bar | |
| 277 # - 2 = one line per epoch | |
| 278 fit_verbosity = 1 | |
| 279 | |
| 280 # Override parameters (batch mode) - Just forget this cell | |
| 281 fidle.override('fit_verbosity') | |
| 282 | |
| 283 # ## Step 7 - Restore model : | |
| 284 print("\n"+HEAD, "# ## Step 7 - Restore model :\n") | |
| 285 | |
| 286 # ### 7.1 - Reload model | |
| 287 print("\n"+HEAD, "# ### 7.1 - Reload model\n") | |
| 288 loaded_model = LitRegression.load_from_checkpoint(model_ckpt_filepath) | |
| 289 print(HEAD, "Model loaded from checkpoint: ", model_ckpt_filepath) | |
| 290 print("Loaded:", loaded_model) | |
| 291 | |
| 292 ## ### 7.2 - Evaluate model : Not kept | |
| 293 | |
| 294 # ### 7.3 - Make a prediction | |
| 295 print("\n"+HEAD, "# ### 7.3 - Make a prediction\n") | |
| 296 | |
| 297 mean_json = norm_config['mean_json'] | |
| 298 std_json = norm_config['std_json'] | |
| 299 min_json = norm_config['min_json'] | |
| 300 max_json = norm_config['max_json'] | |
| 301 print(HEAD, "Use Normalization mean: ", mean_json) | |
| 302 print(HEAD, "Use Normalization std: ", std_json) | |
| 303 print(HEAD, "Use Normalization min: ", min_json) | |
| 304 print(HEAD, "Use Normalization max: ", max_json) | |
| 305 NF = NormalizeFeatures(mean_json, std_json, min_json, max_json) | |
| 306 if not NF.is_in_domain(data): | |
| 307 message = "data values out of domain => no prediction." | |
| 308 error_msg += message + " " | |
| 309 raise Exception(message) | |
| 310 features = NF.get_features(data) | |
| 311 transform = T.Compose([NF, ToTensorFeatures()]) | |
| 312 sample = transform(features) | |
| 313 | |
| 314 # Sets the model in evaluation mode | |
| 315 loaded_model.eval() | |
| 316 | |
| 317 # Perform inference using the loaded model | |
| 318 y_pred = loaded_model(sample) | |
| 319 pred = y_pred[0][0].item() | |
| 320 print(HEAD, ":") | |
| 321 print("Quality prediction :", f'{pred:.2f}', " , for wine data:") | |
| 322 pprint(data) | |
| 323 | |
| 324 # ## OUTPUTS | |
| 325 print("\n"+HEAD, "# ## OUTPUTS\n") | |
| 326 | |
| 327 # Report (json) : | |
| 328 # - quality prediction value | |
| 329 # - wine data | |
| 330 # - error message, more message, warning message | |
| 331 report["quality"] = pred | |
| 332 report["data"] = data | |
| 333 report["model_ckpt_filepath"] = model_ckpt_filepath | |
| 334 report["norm_config"] = norm_config | |
| 335 | |
| 336 fidle.end() | |
| 337 | |
| 338 except Exception as e : | |
| 339 error_msg += type(e).__name__ + str(e.args) + ". " | |
| 340 | |
| 341 if error_msg != "": report["error"] = error_msg | |
| 342 if more_msg != "": report["more"] = more_msg | |
| 343 if warn_msg != "": report["warning"] = warn_msg | |
| 344 | |
| 345 print("OUTPUT:", "Quality prediction :", pred) | |
| 346 | |
| 347 print("OUTPUT:", "Report: ") | |
| 348 pprint(report) | |
| 349 | |
| 350 ## Save Report as .json file | |
| 351 #try: | |
| 352 # with open(report_json_filepath, "w") as outfile: | |
| 353 # json.dump(report, outfile) | |
| 354 # print("OUTPUT:", "Report file (containing report) :", report_json_filepath) | |
| 355 #except : | |
| 356 # pass | |
| 357 | |
| 358 # Save quality alone into .txt file | |
| 359 try: | |
| 360 with open(quality_txt_filepath, "w") as outfile: | |
| 361 outfile.write(str(pred)) | |
| 362 print("OUTPUT:", "Quality file (containing quality value) :", | |
| 363 quality_txt_filepath) | |
| 364 except : | |
| 365 pass | |
| 366 | |
| 367 # --- | |
| 368 # <img width="80px" src="../fidle/img/logo-paysage.svg"></img> | |
| 369 |
