Déployer un modèle de classification d'images avec TensorFlow et FastAPI 🚀
Apprenez à transformer votre modèle de classification d'images TensorFlow en une API REST scalable avec FastAPI. Cet article vous guide pas à pas : préparation des données, entraînement, export du modèle, création d’une API, tests et bonnes pratiques de mise en production. Vous découvrirez des extraits de code clairs, des conseils de performance et des astuces de sécurisation pour offrir une solution IA fiable et prête à l’emploi. Idéal pour les développeurs souhaitant passer du prototype à la mise en service rapidement.
Introduction
Le machine learning ne suffit pas : pour exploiter un modèle en production, il faut le rendre accessible via une API. Nous allons combiner TensorFlow pour l'entraînement et FastAPI pour l'exposition du modèle. Le résultat : une micro‑service capable de classer des images en temps réel, avec une documentation interactive grâce à Swagger UI.
🛠️ Prérequis
- Python 3.9+
- TensorFlow 2.13+
- FastAPI 0.110+
- Uvicorn (serveur ASGI)
- Docker (optionnel pour le déploiement)
1️⃣ Entraîner le modèle
Chargement et pré‑traitement des données
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
# Répertoire contenant les sous‑dossiers class1, class2, ...
train_ds = image_dataset_from_directory(
"./data/train",
validation_split=0.2,
subset="training",
seed=123,
image_size=(224, 224),
batch_size=32,
)
val_ds = image_dataset_from_directory(
"./data/train",
validation_split=0.2,
subset="validation",
seed=123,
image_size=(224, 224),
batch_size=32,
)
Définition du modèle
model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet'
)
model.trainable = False # Fine‑tuning léger
inputs = tf.keras.Input(shape=(224, 224, 3))
x = model(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(train_ds.num_classes, activation='softmax')(x)
clf = tf.keras.Model(inputs, outputs)
clf.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
clf.fit(train_ds, validation_data=val_ds, epochs=5)
Export du modèle
# Sauvegarde au format SavedModel
clf.save('saved_model/')
Cette version sera chargée par notre API FastAPI.
2️⃣ Créer l'API avec FastAPI
Structure du projet
project_root/
├─ app/
│ ├─ main.py
│ ├─ model.py
│ └─ utils.py
├─ saved_model/
└─ requirements.txt
Chargement du modèle (model.py)
import tensorflow as tf
class ImageClassifier:
def __init__(self, model_path: str = "../saved_model"):
self.model = tf.keras.models.load_model(model_path)
self.class_names = ["cat", "dog", "bird"] # À adapter
def predict(self, image_bytes: bytes) -> dict:
import numpy as np
from PIL import Image
img = Image.open(io.BytesIO(image_bytes)).resize((224, 224))
img_array = np.expand_dims(np.array(img) / 255.0, axis=0)
preds = self.model.predict(img_array)[0]
top_idx = np.argmax(preds)
return {"label": self.class_names[top_idx], "confidence": float(preds[top_idx])}
Endpoint FastAPI (main.py)
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from model import ImageClassifier
import io
app = FastAPI(title="API de classification d'images", version="1.0.0")
classifier = ImageClassifier()
@app.post("/predict", summary="Classifie une image", description="Envoie une image (JPEG/PNG) et récupère le label prédit.")
async def predict(file: UploadFile = File(...)):
if file.content_type not in ["image/jpeg", "image/png"]:
raise HTTPException(status_code=400, detail="Format d'image non supporté")
content = await file.read()
try:
result = classifier.predict(content)
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", summary="Vérifie la santé du service")
async def health():
return {"status": "ok"}
Lancement du serveur
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
FastAPI expose automatiquement la documentation interactive à http://localhost:8000/docs 🎉.
3️⃣ Bonnes pratiques de mise en production
- Versionnage du modèle : stockez chaque version dans un bucket (S3, GCS) et chargez‑la via une variable d’environnement.
- Gestion des dépendances : utilisez un
requirements.txtoupoetry.lockpour garantir la reproductibilité. - Conteneurisation : créez un
Dockerfileminimal (python‑slim) pour isoler l’environnement.FROM python:3.11-slim WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt COPY . . EXPOSE 8000 CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] - Surveillance : intégrez Prometheus + Grafana pour monitorer latence, taux d’erreur et utilisation GPU/CPU.
- Sécurité : activez le middleware
HTTPSRedirectMiddlewareet limitez la taille des uploads (ex.max_length=2*1024*1024).
4️⃣ Tests et validation
Test unitaire du prédicteur
import pytest
from model import ImageClassifier
@pytest.fixture
def classifier():
return ImageClassifier()
def test_predict_cat(classifier, sample_cat_image):
result = classifier.predict(sample_cat_image.read())
assert result["label"] == "cat"
assert result["confidence"] > 0.7
Test d’intégration avec HTTPX
import httpx
import pytest
@pytest.mark.asyncio
async def test_api_predict():
async with httpx.AsyncClient(base_url="http://localhost:8000") as client:
files = {"file": ("cat.jpg", open("tests/fixtures/cat.jpg", "rb"), "image/jpeg")}
resp = await client.post("/predict", files=files)
assert resp.status_code == 200
json = resp.json()
assert json["label"] == "cat"
🔚 Conclusion
En suivant ce guide, vous avez transformé un modèle TensorFlow en une API FastAPI prête à être déployée en production. Vous avez également découvert les meilleures pratiques de versionnage, de conteneurisation et de monitoring qui garantissent la fiabilité et la scalabilité de votre service IA. N’hésitez pas à explorer le scaling horizontal avec Kubernetes ou à ajouter du caching (Redis) pour réduire la latence des prédictions fréquentes.
