Mercurial > repos > siwaa > wine_quality
comparison wine_quality.py @ 0:143b15001522 draft
planemo upload for repository https://forgemia.inra.fr/nathalie.rousse/use/-/tree/dnn/DNN/galaxy-tools/wine_quality commit e7c4e447552083db7eaecbdf139a7c359fe9becc
author | siwaa |
---|---|
date | Wed, 04 Dec 2024 15:25:26 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:143b15001522 |
---|---|
1 #!/usr/bin/env python | |
2 # coding: utf-8 | |
3 | |
4 ######### ######### ######### ######### ######### ######### ######### | |
5 # | |
6 # File created from fidlemore/model_wine_lightning/wine_quality_predict.py | |
7 # | |
8 # Modifications : | |
9 # | |
10 # - restore sanitized text for -data value | |
11 # | |
12 # - quality.txt file containing only quality value (extracted from report) | |
13 # | |
14 ######### ######### ######### ######### ######### ######### ######### | |
15 | |
16 ############################################################################### | |
17 # Module : model_wine_lightning | |
18 # | |
19 # This code has been extracted from 01-DNN-Wine-Regression-lightning.ipynb | |
20 # (fidle-tp/fidle-master-3.0.11/Wine.Lightning) then modified. | |
21 # Only last part kept : | |
22 # - Restore saved model from checkpoint | |
23 # - Evaluate the model not kept | |
24 # - Make some predictions 1 prediction | |
25 # | |
26 # Inputs : | |
27 # | |
28 # Data of wine for which quality is going to be predicted is given | |
29 # by -data_json_filepath or by -data (only one of both). | |
30 # | |
31 # -data_json_filepath : data file path (.json) containing data. | |
32 # -data : data (string format) | |
33 # | |
34 # Example of wine data : { "fixed acidity": 11.2, | |
35 # "volatile acidity": 0.28, | |
36 # "citric acid": 0.56, | |
37 # "residual sugar": 1.9, | |
38 # "chlorides": 0.075, | |
39 # "free sulfur dioxide": 17, | |
40 # "total sulfur dioxide": 60, | |
41 # "density": 0.998, | |
42 # "pH": 3.16, | |
43 # "sulphates": 0.58, | |
44 # "alcohol": 9.8 } | |
45 # | |
46 # -model_ckpt_filepath : checkpoint model file path (.ckpt) to be loaded. | |
47 # | |
48 # -norm_config_json_filepath : normalization configuration file (.json) | |
49 # containing information (norm_config) that has been returned by the model | |
50 # wine_quality_train_eval running. | |
51 # | |
52 # Outputs : | |
53 # | |
54 # Output files under "OUTPUTS" folder (must exist !!!) | |
55 # | |
56 # - Quality prediction value (float) | |
57 # | |
58 # - Report file (report_json_filepath) (.json) containing: | |
59 # - Quality prediction value | |
60 # - Wine data | |
61 # - error message, more message, warning message | |
62 # | |
63 # - Log files into Wine.Lightning/run/LWINE1/logs/reg_logs | |
64 # | |
65 # - Screen display containing running information | |
66 # | |
67 ############################################################################### | |
68 | |
69 # <img width="800px" src="../fidle/img/header.svg"></img> | |
70 # | |
71 # # <!-- TITLE --> [LWINE1] - Wine quality prediction with a Dense Network (DNN) | |
72 # <!-- DESC --> Another example of regression, with a wine quality prediction, using PyTorch Lightning | |
73 # <!-- AUTHOR : Achille Mbogol Touye (EFFILIA-MIAI/SIMaP) --> | |
74 # | |
75 # ## Objectives : | |
76 # - Predict the **quality of wines**, based on their analysis | |
77 # - Understanding the principle and the architecture of a regression with a dense neural network with backup and restore of the trained model. | |
78 # | |
79 # The **[Wine Quality datasets](https://archive.ics.uci.edu/ml/datasets/wine+Quality)** are made up of analyses of a large number of wines, with an associated quality (between 0 and 10) | |
80 # This dataset is provide by : | |
81 # Paulo Cortez, University of Minho, GuimarĂ£es, Portugal, http://www3.dsi.uminho.pt/pcortez | |
82 # A. Cerdeira, F. Almeida, T. Matos and J. Reis, Viticulture Commission of the Vinho Verde Region(CVRVV), Porto, Portugal, @2009 | |
83 # This dataset can be retreive at [University of California Irvine (UCI)](https://archive-beta.ics.uci.edu/ml/datasets/wine+quality) | |
84 # | |
85 # | |
86 # Due to privacy and logistic issues, only physicochemical and sensory variables are available | |
87 # There is no data about grape types, wine brand, wine selling price, etc. | |
88 # | |
89 # - fixed acidity | |
90 # - volatile acidity | |
91 # - citric acid | |
92 # - residual sugar | |
93 # - chlorides | |
94 # - free sulfur dioxide | |
95 # - total sulfur dioxide | |
96 # - density | |
97 # - pH | |
98 # - sulphates | |
99 # - alcohol | |
100 # - quality (score between 0 and 10) | |
101 # | |
102 # ## What we're going to do : | |
103 # | |
104 # - (Retrieve data) | |
105 # - (Preparing the data) | |
106 # - (Build a model) | |
107 # - Train and save the model | |
108 # - Restore saved model | |
109 # - Evaluate the model | |
110 # - Make some predictions | |
111 # | |
112 HEAD = "[wine_quality/wine_quality]" | |
113 | |
114 # ## Step 1 - Import and init | |
115 print("\n"+HEAD,"# ## Step 1 - Import and init\n") | |
116 | |
117 # Import some packages | |
118 import os | |
119 import pandas as pd | |
120 import torch | |
121 import torchvision.transforms as T | |
122 from model_wine_lightning.modules.data_load import NormalizeFeatures | |
123 from model_wine_lightning.modules.data_load import ToTensorFeatures | |
124 from model_wine_lightning.modules.model import LitRegression | |
125 import fidle | |
126 import json | |
127 import argparse | |
128 from pprint import pprint | |
129 | |
130 OUTPUTS_PATH = "OUTPUTS" # must exit !!! | |
131 | |
132 error_msg, warn_msg, more_msg = "", "", "" # default | |
133 | |
134 report_json_filepath = os.path.join(OUTPUTS_PATH, "report.json") | |
135 quality_txt_filepath = os.path.join(OUTPUTS_PATH, "quality.txt") | |
136 report = dict() # init | |
137 pred = 99.99 # default | |
138 | |
139 data_json_filepath, data = None, None | |
140 model_ckpt_filepath = None | |
141 norm_config_json_filepath, norm_config = None, None | |
142 | |
143 try: | |
144 def is_not_given(argument): | |
145 return ((argument is None) or (argument == 'None')) | |
146 def is_given(argument): | |
147 return ((argument is not None) and (argument != 'None')) | |
148 | |
149 if not os.path.exists(OUTPUTS_PATH): # cas isfile non traite | |
150 os.mkdir(OUTPUTS_PATH) | |
151 message = "Outputs folder '"+OUTPUTS_PATH+" does not exist => created." | |
152 warn_msg += message + " " | |
153 print(HEAD, "Warning :", message) | |
154 | |
155 if not os.path.isdir(OUTPUTS_PATH): | |
156 message = "Outputs folder '" + OUTPUTS_PATH + "' must exist." | |
157 error_msg += message + " " | |
158 raise Exception(message) | |
159 | |
160 # ## INPUTS | |
161 print("\n"+HEAD, "# ## INPUTS\n") | |
162 | |
163 desc_text = "Predict Quality of a Wine" | |
164 parser = argparse.ArgumentParser(prog='wine_quality_predict', | |
165 description=desc_text) | |
166 | |
167 help_required_keys = "\"fixed acidity\";\"volatile acidity\";\"citric acid\";\"residual sugar\";\"chlorides\";\"free sulfur dioxide\";\"total sulfur dioxide\";\"density\";\"pH\";\"sulphates\";\"alcohol\"" | |
168 help_text = "data file path (.json), required keys:"+help_required_keys | |
169 parser.add_argument("-data_json_filepath", type=str, help=help_text) | |
170 help_text = "data in string format, required keys:"+help_required_keys | |
171 parser.add_argument("-data", type=str, help=help_text) | |
172 | |
173 help_text = "checkpoint model file path (.ckpt)" | |
174 parser.add_argument("-model_ckpt_filepath", type=str, help=help_text) | |
175 | |
176 help_text = "normalization configuration file path (.json), " | |
177 help_text += "information returned by wine_quality_train_eval runnig." | |
178 parser.add_argument("-norm_config_json_filepath", type=str, help=help_text) | |
179 | |
180 args = parser.parse_args() | |
181 # 1 and only 1 among -data_json_filepath and -data | |
182 if is_given(args.data_json_filepath) and is_given(args.data): | |
183 message = "Both -data_json_filepath and -data given " | |
184 message += "(1 and only 1 of -data_json_filepath and -data expected) " | |
185 message += "=> STOP." | |
186 error_msg += message + " " | |
187 raise Exception(message) | |
188 if is_not_given(args.data_json_filepath) and is_not_given(args.data): | |
189 message = "NO data_json_filepath and NO -data given. " | |
190 message += "(1 and only 1 of -data_json_filepath and -data expected) " | |
191 message += "=> STOP." | |
192 error_msg += message + " " | |
193 raise Exception(message) | |
194 | |
195 path = args.data_json_filepath | |
196 if is_given(path) : | |
197 if os.path.isfile(path) : | |
198 data_json_filepath = path | |
199 print(HEAD, "data_json_filepath used :", data_json_filepath) | |
200 try : | |
201 inputfile = open(data_json_filepath, 'r') | |
202 data = json.load(inputfile) | |
203 except: | |
204 message = "Failed to get json data from " | |
205 message += "'" + data_json_filepath+ "'" + " file." | |
206 error_msg += message + " " | |
207 raise Exception(message) | |
208 else : | |
209 message = path+ "data_json_filepath file not found => STOP." | |
210 error_msg += message + "" | |
211 raise Exception(message) | |
212 | |
213 if is_given(args.data) : | |
214 data_text = args.data | |
215 try : | |
216 | |
217 # restore sanitized text | |
218 MAPPING = {'>': '__gt__', '<': '__lt__', "'": '__sq__', '"': '__dq__', '[': '__ob__', ']': '__cb__', '{': '__oc__', '}': '__cc__', '@': '__at__', '\n': '__cn__', '\r': '__cr__', '\t': '__tc__', '#': '__pd__'} | |
219 for key, value in MAPPING.items(): | |
220 data_text = data_text.replace(value, key) | |
221 data = json.loads(data_text) # get data | |
222 | |
223 except: | |
224 message = "Failed to get json data from string '"+data_text+"'" | |
225 error_msg += message + " " | |
226 raise Exception(message) | |
227 | |
228 path = args.model_ckpt_filepath | |
229 if is_given(path) : | |
230 if os.path.isfile(path) : | |
231 model_ckpt_filepath = path | |
232 print(HEAD, "model_ckpt_filepath used :", model_ckpt_filepath) | |
233 else : | |
234 message = path+ "model_ckpt_filepath file not found => STOP." | |
235 error_msg += message + " " | |
236 raise Exception(message) | |
237 else: | |
238 message = "NO model_ckpt_filepath given => STOP." | |
239 error_msg += message + "" | |
240 raise Exception(message) | |
241 | |
242 path = args.norm_config_json_filepath | |
243 if is_given(path) : | |
244 if os.path.isfile(path) : | |
245 norm_config_json_filepath = path | |
246 print(HEAD, "norm_config_json_filepath used :", | |
247 norm_config_json_filepath) | |
248 try : | |
249 inputfile = open(norm_config_json_filepath, 'r') | |
250 norm_config = json.load(inputfile) | |
251 except: | |
252 message = "Failed to get json norm_config from " | |
253 message += "'" + norm_config_json_filepath+ "'" + "file." | |
254 error_msg += message + " " | |
255 raise Exception(message) | |
256 else : | |
257 message = path+ "norm_config_json_filepath file not found => STOP." | |
258 error_msg += message + " " | |
259 raise Exception(message) | |
260 else: | |
261 message = "NO norm_config_json_filepath given => STOP." | |
262 error_msg += message + "" | |
263 raise Exception(message) | |
264 | |
265 print(HEAD, "INPUTS:") | |
266 print("- data:", data) | |
267 print("- model checkpoint:", model_ckpt_filepath) | |
268 print("- norm_config:", norm_config) | |
269 | |
270 # Init Fidle environment | |
271 print("\n"+HEAD, "# Init Fidle environment\n") | |
272 run_id, run_dir, datasets_dir = fidle.init('LWINE1_predict') | |
273 | |
274 # Verbosity during training : | |
275 # - 0 = silent | |
276 # - 1 = progress bar | |
277 # - 2 = one line per epoch | |
278 fit_verbosity = 1 | |
279 | |
280 # Override parameters (batch mode) - Just forget this cell | |
281 fidle.override('fit_verbosity') | |
282 | |
283 # ## Step 7 - Restore model : | |
284 print("\n"+HEAD, "# ## Step 7 - Restore model :\n") | |
285 | |
286 # ### 7.1 - Reload model | |
287 print("\n"+HEAD, "# ### 7.1 - Reload model\n") | |
288 loaded_model = LitRegression.load_from_checkpoint(model_ckpt_filepath) | |
289 print(HEAD, "Model loaded from checkpoint: ", model_ckpt_filepath) | |
290 print("Loaded:", loaded_model) | |
291 | |
292 ## ### 7.2 - Evaluate model : Not kept | |
293 | |
294 # ### 7.3 - Make a prediction | |
295 print("\n"+HEAD, "# ### 7.3 - Make a prediction\n") | |
296 | |
297 mean_json = norm_config['mean_json'] | |
298 std_json = norm_config['std_json'] | |
299 min_json = norm_config['min_json'] | |
300 max_json = norm_config['max_json'] | |
301 print(HEAD, "Use Normalization mean: ", mean_json) | |
302 print(HEAD, "Use Normalization std: ", std_json) | |
303 print(HEAD, "Use Normalization min: ", min_json) | |
304 print(HEAD, "Use Normalization max: ", max_json) | |
305 NF = NormalizeFeatures(mean_json, std_json, min_json, max_json) | |
306 if not NF.is_in_domain(data): | |
307 message = "data values out of domain => no prediction." | |
308 error_msg += message + " " | |
309 raise Exception(message) | |
310 features = NF.get_features(data) | |
311 transform = T.Compose([NF, ToTensorFeatures()]) | |
312 sample = transform(features) | |
313 | |
314 # Sets the model in evaluation mode | |
315 loaded_model.eval() | |
316 | |
317 # Perform inference using the loaded model | |
318 y_pred = loaded_model(sample) | |
319 pred = y_pred[0][0].item() | |
320 print(HEAD, ":") | |
321 print("Quality prediction :", f'{pred:.2f}', " , for wine data:") | |
322 pprint(data) | |
323 | |
324 # ## OUTPUTS | |
325 print("\n"+HEAD, "# ## OUTPUTS\n") | |
326 | |
327 # Report (json) : | |
328 # - quality prediction value | |
329 # - wine data | |
330 # - error message, more message, warning message | |
331 report["quality"] = pred | |
332 report["data"] = data | |
333 report["model_ckpt_filepath"] = model_ckpt_filepath | |
334 report["norm_config"] = norm_config | |
335 | |
336 fidle.end() | |
337 | |
338 except Exception as e : | |
339 error_msg += type(e).__name__ + str(e.args) + ". " | |
340 | |
341 if error_msg != "": report["error"] = error_msg | |
342 if more_msg != "": report["more"] = more_msg | |
343 if warn_msg != "": report["warning"] = warn_msg | |
344 | |
345 print("OUTPUT:", "Quality prediction :", pred) | |
346 | |
347 print("OUTPUT:", "Report: ") | |
348 pprint(report) | |
349 | |
350 ## Save Report as .json file | |
351 #try: | |
352 # with open(report_json_filepath, "w") as outfile: | |
353 # json.dump(report, outfile) | |
354 # print("OUTPUT:", "Report file (containing report) :", report_json_filepath) | |
355 #except : | |
356 # pass | |
357 | |
358 # Save quality alone into .txt file | |
359 try: | |
360 with open(quality_txt_filepath, "w") as outfile: | |
361 outfile.write(str(pred)) | |
362 print("OUTPUT:", "Quality file (containing quality value) :", | |
363 quality_txt_filepath) | |
364 except : | |
365 pass | |
366 | |
367 # --- | |
368 # <img width="80px" src="../fidle/img/logo-paysage.svg"></img> | |
369 |