comparison wine_quality.py @ 0:143b15001522 draft

planemo upload for repository https://forgemia.inra.fr/nathalie.rousse/use/-/tree/dnn/DNN/galaxy-tools/wine_quality commit e7c4e447552083db7eaecbdf139a7c359fe9becc
author siwaa
date Wed, 04 Dec 2024 15:25:26 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:143b15001522
1 #!/usr/bin/env python
2 # coding: utf-8
3
4 ######### ######### ######### ######### ######### ######### #########
5 #
6 # File created from fidlemore/model_wine_lightning/wine_quality_predict.py
7 #
8 # Modifications :
9 #
10 # - restore sanitized text for -data value
11 #
12 # - quality.txt file containing only quality value (extracted from report)
13 #
14 ######### ######### ######### ######### ######### ######### #########
15
16 ###############################################################################
17 # Module : model_wine_lightning
18 #
19 # This code has been extracted from 01-DNN-Wine-Regression-lightning.ipynb
20 # (fidle-tp/fidle-master-3.0.11/Wine.Lightning) then modified.
21 # Only last part kept :
22 # - Restore saved model from checkpoint
23 # - Evaluate the model not kept
24 # - Make some predictions 1 prediction
25 #
26 # Inputs :
27 #
28 # Data of wine for which quality is going to be predicted is given
29 # by -data_json_filepath or by -data (only one of both).
30 #
31 # -data_json_filepath : data file path (.json) containing data.
32 # -data : data (string format)
33 #
34 # Example of wine data : { "fixed acidity": 11.2,
35 # "volatile acidity": 0.28,
36 # "citric acid": 0.56,
37 # "residual sugar": 1.9,
38 # "chlorides": 0.075,
39 # "free sulfur dioxide": 17,
40 # "total sulfur dioxide": 60,
41 # "density": 0.998,
42 # "pH": 3.16,
43 # "sulphates": 0.58,
44 # "alcohol": 9.8 }
45 #
46 # -model_ckpt_filepath : checkpoint model file path (.ckpt) to be loaded.
47 #
48 # -norm_config_json_filepath : normalization configuration file (.json)
49 # containing information (norm_config) that has been returned by the model
50 # wine_quality_train_eval running.
51 #
52 # Outputs :
53 #
54 # Output files under "OUTPUTS" folder (must exist !!!)
55 #
56 # - Quality prediction value (float)
57 #
58 # - Report file (report_json_filepath) (.json) containing:
59 # - Quality prediction value
60 # - Wine data
61 # - error message, more message, warning message
62 #
63 # - Log files into Wine.Lightning/run/LWINE1/logs/reg_logs
64 #
65 # - Screen display containing running information
66 #
67 ###############################################################################
68
69 # <img width="800px" src="../fidle/img/header.svg"></img>
70 #
71 # # <!-- TITLE --> [LWINE1] - Wine quality prediction with a Dense Network (DNN)
72 # <!-- DESC --> Another example of regression, with a wine quality prediction, using PyTorch Lightning
73 # <!-- AUTHOR : Achille Mbogol Touye (EFFILIA-MIAI/SIMaP) -->
74 #
75 # ## Objectives :
76 # - Predict the **quality of wines**, based on their analysis
77 # - Understanding the principle and the architecture of a regression with a dense neural network with backup and restore of the trained model.
78 #
79 # The **[Wine Quality datasets](https://archive.ics.uci.edu/ml/datasets/wine+Quality)** are made up of analyses of a large number of wines, with an associated quality (between 0 and 10)
80 # This dataset is provide by :
81 # Paulo Cortez, University of Minho, GuimarĂ£es, Portugal, http://www3.dsi.uminho.pt/pcortez
82 # A. Cerdeira, F. Almeida, T. Matos and J. Reis, Viticulture Commission of the Vinho Verde Region(CVRVV), Porto, Portugal, @2009
83 # This dataset can be retreive at [University of California Irvine (UCI)](https://archive-beta.ics.uci.edu/ml/datasets/wine+quality)
84 #
85 #
86 # Due to privacy and logistic issues, only physicochemical and sensory variables are available
87 # There is no data about grape types, wine brand, wine selling price, etc.
88 #
89 # - fixed acidity
90 # - volatile acidity
91 # - citric acid
92 # - residual sugar
93 # - chlorides
94 # - free sulfur dioxide
95 # - total sulfur dioxide
96 # - density
97 # - pH
98 # - sulphates
99 # - alcohol
100 # - quality (score between 0 and 10)
101 #
102 # ## What we're going to do :
103 #
104 # - (Retrieve data)
105 # - (Preparing the data)
106 # - (Build a model)
107 # - Train and save the model
108 # - Restore saved model
109 # - Evaluate the model
110 # - Make some predictions
111 #
112 HEAD = "[wine_quality/wine_quality]"
113
114 # ## Step 1 - Import and init
115 print("\n"+HEAD,"# ## Step 1 - Import and init\n")
116
117 # Import some packages
118 import os
119 import pandas as pd
120 import torch
121 import torchvision.transforms as T
122 from model_wine_lightning.modules.data_load import NormalizeFeatures
123 from model_wine_lightning.modules.data_load import ToTensorFeatures
124 from model_wine_lightning.modules.model import LitRegression
125 import fidle
126 import json
127 import argparse
128 from pprint import pprint
129
130 OUTPUTS_PATH = "OUTPUTS" # must exit !!!
131
132 error_msg, warn_msg, more_msg = "", "", "" # default
133
134 report_json_filepath = os.path.join(OUTPUTS_PATH, "report.json")
135 quality_txt_filepath = os.path.join(OUTPUTS_PATH, "quality.txt")
136 report = dict() # init
137 pred = 99.99 # default
138
139 data_json_filepath, data = None, None
140 model_ckpt_filepath = None
141 norm_config_json_filepath, norm_config = None, None
142
143 try:
144 def is_not_given(argument):
145 return ((argument is None) or (argument == 'None'))
146 def is_given(argument):
147 return ((argument is not None) and (argument != 'None'))
148
149 if not os.path.exists(OUTPUTS_PATH): # cas isfile non traite
150 os.mkdir(OUTPUTS_PATH)
151 message = "Outputs folder '"+OUTPUTS_PATH+" does not exist => created."
152 warn_msg += message + " "
153 print(HEAD, "Warning :", message)
154
155 if not os.path.isdir(OUTPUTS_PATH):
156 message = "Outputs folder '" + OUTPUTS_PATH + "' must exist."
157 error_msg += message + " "
158 raise Exception(message)
159
160 # ## INPUTS
161 print("\n"+HEAD, "# ## INPUTS\n")
162
163 desc_text = "Predict Quality of a Wine"
164 parser = argparse.ArgumentParser(prog='wine_quality_predict',
165 description=desc_text)
166
167 help_required_keys = "\"fixed acidity\";\"volatile acidity\";\"citric acid\";\"residual sugar\";\"chlorides\";\"free sulfur dioxide\";\"total sulfur dioxide\";\"density\";\"pH\";\"sulphates\";\"alcohol\""
168 help_text = "data file path (.json), required keys:"+help_required_keys
169 parser.add_argument("-data_json_filepath", type=str, help=help_text)
170 help_text = "data in string format, required keys:"+help_required_keys
171 parser.add_argument("-data", type=str, help=help_text)
172
173 help_text = "checkpoint model file path (.ckpt)"
174 parser.add_argument("-model_ckpt_filepath", type=str, help=help_text)
175
176 help_text = "normalization configuration file path (.json), "
177 help_text += "information returned by wine_quality_train_eval runnig."
178 parser.add_argument("-norm_config_json_filepath", type=str, help=help_text)
179
180 args = parser.parse_args()
181 # 1 and only 1 among -data_json_filepath and -data
182 if is_given(args.data_json_filepath) and is_given(args.data):
183 message = "Both -data_json_filepath and -data given "
184 message += "(1 and only 1 of -data_json_filepath and -data expected) "
185 message += "=> STOP."
186 error_msg += message + " "
187 raise Exception(message)
188 if is_not_given(args.data_json_filepath) and is_not_given(args.data):
189 message = "NO data_json_filepath and NO -data given. "
190 message += "(1 and only 1 of -data_json_filepath and -data expected) "
191 message += "=> STOP."
192 error_msg += message + " "
193 raise Exception(message)
194
195 path = args.data_json_filepath
196 if is_given(path) :
197 if os.path.isfile(path) :
198 data_json_filepath = path
199 print(HEAD, "data_json_filepath used :", data_json_filepath)
200 try :
201 inputfile = open(data_json_filepath, 'r')
202 data = json.load(inputfile)
203 except:
204 message = "Failed to get json data from "
205 message += "'" + data_json_filepath+ "'" + " file."
206 error_msg += message + " "
207 raise Exception(message)
208 else :
209 message = path+ "data_json_filepath file not found => STOP."
210 error_msg += message + ""
211 raise Exception(message)
212
213 if is_given(args.data) :
214 data_text = args.data
215 try :
216
217 # restore sanitized text
218 MAPPING = {'>': '__gt__', '<': '__lt__', "'": '__sq__', '"': '__dq__', '[': '__ob__', ']': '__cb__', '{': '__oc__', '}': '__cc__', '@': '__at__', '\n': '__cn__', '\r': '__cr__', '\t': '__tc__', '#': '__pd__'}
219 for key, value in MAPPING.items():
220 data_text = data_text.replace(value, key)
221 data = json.loads(data_text) # get data
222
223 except:
224 message = "Failed to get json data from string '"+data_text+"'"
225 error_msg += message + " "
226 raise Exception(message)
227
228 path = args.model_ckpt_filepath
229 if is_given(path) :
230 if os.path.isfile(path) :
231 model_ckpt_filepath = path
232 print(HEAD, "model_ckpt_filepath used :", model_ckpt_filepath)
233 else :
234 message = path+ "model_ckpt_filepath file not found => STOP."
235 error_msg += message + " "
236 raise Exception(message)
237 else:
238 message = "NO model_ckpt_filepath given => STOP."
239 error_msg += message + ""
240 raise Exception(message)
241
242 path = args.norm_config_json_filepath
243 if is_given(path) :
244 if os.path.isfile(path) :
245 norm_config_json_filepath = path
246 print(HEAD, "norm_config_json_filepath used :",
247 norm_config_json_filepath)
248 try :
249 inputfile = open(norm_config_json_filepath, 'r')
250 norm_config = json.load(inputfile)
251 except:
252 message = "Failed to get json norm_config from "
253 message += "'" + norm_config_json_filepath+ "'" + "file."
254 error_msg += message + " "
255 raise Exception(message)
256 else :
257 message = path+ "norm_config_json_filepath file not found => STOP."
258 error_msg += message + " "
259 raise Exception(message)
260 else:
261 message = "NO norm_config_json_filepath given => STOP."
262 error_msg += message + ""
263 raise Exception(message)
264
265 print(HEAD, "INPUTS:")
266 print("- data:", data)
267 print("- model checkpoint:", model_ckpt_filepath)
268 print("- norm_config:", norm_config)
269
270 # Init Fidle environment
271 print("\n"+HEAD, "# Init Fidle environment\n")
272 run_id, run_dir, datasets_dir = fidle.init('LWINE1_predict')
273
274 # Verbosity during training :
275 # - 0 = silent
276 # - 1 = progress bar
277 # - 2 = one line per epoch
278 fit_verbosity = 1
279
280 # Override parameters (batch mode) - Just forget this cell
281 fidle.override('fit_verbosity')
282
283 # ## Step 7 - Restore model :
284 print("\n"+HEAD, "# ## Step 7 - Restore model :\n")
285
286 # ### 7.1 - Reload model
287 print("\n"+HEAD, "# ### 7.1 - Reload model\n")
288 loaded_model = LitRegression.load_from_checkpoint(model_ckpt_filepath)
289 print(HEAD, "Model loaded from checkpoint: ", model_ckpt_filepath)
290 print("Loaded:", loaded_model)
291
292 ## ### 7.2 - Evaluate model : Not kept
293
294 # ### 7.3 - Make a prediction
295 print("\n"+HEAD, "# ### 7.3 - Make a prediction\n")
296
297 mean_json = norm_config['mean_json']
298 std_json = norm_config['std_json']
299 min_json = norm_config['min_json']
300 max_json = norm_config['max_json']
301 print(HEAD, "Use Normalization mean: ", mean_json)
302 print(HEAD, "Use Normalization std: ", std_json)
303 print(HEAD, "Use Normalization min: ", min_json)
304 print(HEAD, "Use Normalization max: ", max_json)
305 NF = NormalizeFeatures(mean_json, std_json, min_json, max_json)
306 if not NF.is_in_domain(data):
307 message = "data values out of domain => no prediction."
308 error_msg += message + " "
309 raise Exception(message)
310 features = NF.get_features(data)
311 transform = T.Compose([NF, ToTensorFeatures()])
312 sample = transform(features)
313
314 # Sets the model in evaluation mode
315 loaded_model.eval()
316
317 # Perform inference using the loaded model
318 y_pred = loaded_model(sample)
319 pred = y_pred[0][0].item()
320 print(HEAD, ":")
321 print("Quality prediction :", f'{pred:.2f}', " , for wine data:")
322 pprint(data)
323
324 # ## OUTPUTS
325 print("\n"+HEAD, "# ## OUTPUTS\n")
326
327 # Report (json) :
328 # - quality prediction value
329 # - wine data
330 # - error message, more message, warning message
331 report["quality"] = pred
332 report["data"] = data
333 report["model_ckpt_filepath"] = model_ckpt_filepath
334 report["norm_config"] = norm_config
335
336 fidle.end()
337
338 except Exception as e :
339 error_msg += type(e).__name__ + str(e.args) + ". "
340
341 if error_msg != "": report["error"] = error_msg
342 if more_msg != "": report["more"] = more_msg
343 if warn_msg != "": report["warning"] = warn_msg
344
345 print("OUTPUT:", "Quality prediction :", pred)
346
347 print("OUTPUT:", "Report: ")
348 pprint(report)
349
350 ## Save Report as .json file
351 #try:
352 # with open(report_json_filepath, "w") as outfile:
353 # json.dump(report, outfile)
354 # print("OUTPUT:", "Report file (containing report) :", report_json_filepath)
355 #except :
356 # pass
357
358 # Save quality alone into .txt file
359 try:
360 with open(quality_txt_filepath, "w") as outfile:
361 outfile.write(str(pred))
362 print("OUTPUT:", "Quality file (containing quality value) :",
363 quality_txt_filepath)
364 except :
365 pass
366
367 # ---
368 # <img width="80px" src="../fidle/img/logo-paysage.svg"></img>
369