Skip to content

Commit

Permalink
Ensure that TensorFlow session is cleared even if an exception occurs (
Browse files Browse the repository at this point in the history
  • Loading branch information
dimasciput authored Jan 26, 2024
1 parent 127310b commit a99091e
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions django_project/monitor/observation_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,30 +127,26 @@ def classify_image(image):
if not model:
return {'error': 'Cannot load model'}
try:

# Convert the image file to a PIL Image
pil_image = Image.open(image)

# Resize the image to the target size
img = pil_image.resize((224, 224))

img_array = tf.keras.utils.img_to_array(img)

img_array = tf.expand_dims(img_array, 0)

predictions = model.predict(img_array)

score = tf.nn.softmax(predictions[0])
predicted_class = classes[np.argmax(score)]
confidence = 100 * np.max(score)

# Clear TensorFlow session to release resources
clear_tensorflow_session()

return {'class': predicted_class, 'confidence': confidence}
with Image.open(image) as pil_image:
img = pil_image.resize((224, 224))
img_array = tf.keras.utils.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)

predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])
predicted_class = classes[np.argmax(score)]
confidence = 100 * np.max(score)
return {
'class': predicted_class,
'confidence': confidence
}
except Exception as e:
print(f"Error during image classification: {e}")
return {'error': str(e)}
finally:
# Clear TensorFlow session to release resources
clear_tensorflow_session()


# end of ai score calculation section

Expand All @@ -162,7 +158,7 @@ def convert_to_int(value, default=0):
return int(value.strip().replace('"', ''))
except (ValueError, TypeError):
return default

@csrf_exempt
@login_required
def upload_pest_image(request):
Expand Down Expand Up @@ -194,7 +190,7 @@ def upload_pest_image(request):
user = request.user
site_id = convert_to_int(site_id)
observation_id = convert_to_int(observation_id)

try:
site = Sites.objects.get(gid=site_id)
except Sites.DoesNotExist:
Expand Down Expand Up @@ -348,7 +344,7 @@ def create_observations(request):
selectedSite = 0
if site_id_str.lower() == 'undefined':
selectedSite = int(datainput.get('selectedSite', 0))

observation_id_str = str(request.POST.get('observationId', '0'))

# Remove leading and trailing whitespaces, and replace double quotes
Expand Down

0 comments on commit a99091e

Please sign in to comment.