1. Exportieren von Modellen nach ONNX
import torch
import torch.nn as nn
# Einfaches Modell
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
# Dummy-Eingabe für das Modell
dummy_input = torch.randn(1, 10)
# Exportieren des Modells nach ONNX
torch.onnx.export(model, dummy_input, "simple_model.onnx",
input_names=["input"], output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})
print("Modell erfolgreich nach ONNX exportiert")
Output:
Modell erfolgreich nach ONNX exportiert
2. Bereitstellung von PyTorch-Modellen mit Flask oder FastAPI
Flask und FastAPI sind zwei beliebte Python-Webframeworks, die es ermöglichen, ML-Modelle als Web-APIs bereitzustellen. Flask ist bekannt für seine Einfachheit und Flexibilität, während FastAPI für seine hohe Leistung und Unterstützung von asynchronen Operationen geschätzt wird.
Bereitstellung mit Flask
Beispiel: Bereitstellung eines PyTorch-Modells mit Flask:
from flask import Flask, request, jsonify
import torch
app = Flask(__name__)
# Einfaches Modell
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# Laden des Modells
model = SimpleModel()
model.load_state_dict(torch.load("simple_model.pth"))
model.eval()
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json(force=True)
input_tensor = torch.tensor(data['input'], dtype=torch.float32)
with torch.no_grad():
output = model(input_tensor)
return jsonify({'output': output.tolist()})
if __name__ == '__main__':
app.run(debug=True)
# Output:
# * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
Bereitstellung mit FastAPI
Beispiel: Bereitstellung eines PyTorch-Modells mit FastAPI:
from fastapi import FastAPI
from pydantic import BaseModel
import torch
app = FastAPI()
# Einfaches Modell
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# Laden des Modells
model = SimpleModel()
model.load_state_dict(torch.load("simple_model.pth"))
model.eval()
class InputData(BaseModel):
input: list
@app.post('/predict')
def predict(data: InputData):
input_tensor = torch.tensor(data.input, dtype=torch.float32)
with torch.no_grad():
output = model(input_tensor)
return {'output': output.tolist()}
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='127.0.0.1', port=8000)
# Output:
# INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
3. Bereitstellung von PyTorch-Modellen mit Django
Django ist ein leistungsfähiges Webframework für Python, das für die schnelle Entwicklung von Webanwendungen verwendet wird. Mit Django können Sie komplexe Webanwendungen erstellen und ML-Modelle nahtlos integrieren.
3.1 Beispiel: Bereitstellung eines PyTorch-Modells mit Django
Schritt 1: Erstellen eines neuen Django-Projekts und einer App:
django-admin startproject myproject
cd myproject
django-admin startapp myapp
Schritt 2: Modifizieren der views.py
Datei:
# myapp/views.py
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
import json
import torch
# Einfaches Modell
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# Laden des Modells
model = SimpleModel()
model.load_state_dict(torch.load("simple_model.pth"))
model.eval()
@csrf_exempt
def predict(request):
if request.method == 'POST':
data = json.loads(request.body)
input_tensor = torch.tensor(data['input'], dtype=torch.float32)
with torch.no_grad():
output = model(input_tensor)
return JsonResponse({'output': output.tolist()})
Schritt 3: Hinzufügen der URL zu urls.py
# myapp/urls.py
from django.urls import path
from . import views
urlpatterns = [
path('predict/', views.predict, name='predict'),
]
Schritt 4: Einfügen der App-URLs in das Haupt-URLs-Dateisystem
# myproject/urls.py
from django.contrib import admin
from django.urls import path, include
urlpatterns = [
path('admin/', admin.site.urls),
path('', include('myapp.urls')),
]
Schritt 5: Starten des Django-Servers
python manage.py runserver
# Output:
# Starting development server at http://127.0.0.1:8000/
# Quit the server with CONTROL-C.
3.2 Bereitstellung von PyTorch-Modellen mit Django und FastAPI
Schritt 1: Installieren von django-ninja
pip install django-ninja
Schritt 2: Erstellen eines neuen Ninja-API-Endpunkts
Schritt 2: Erstellen einer neuen API-Datei:
touch myapp/api.py
Schritt 3: Konfigurieren der API-Datei
# myapp/api.py
from ninja import NinjaAPI
import torch
from pydantic import BaseModel
api = NinjaAPI()
# Einfaches Modell
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# Laden des Modells
model = SimpleModel()
model.load_state_dict(torch.load("simple_model.pth"))
model.eval()
class InputData(BaseModel):
input: list
@api.post("/predict")
def predict(request, data: InputData):
input_tensor = torch.tensor(data.input, dtype=torch.float32)
with torch.no_grad():
output = model(input_tensor)
return {'output': output.tolist()}
Schritt 4: Hinzufügen der API-URLs zu urls.py
# myapp/urls.py
from django.urls import path
from .api import api
urlpatterns = [
path('api/', api.urls),
]
Schritt 5: Starten des Django-Servers
python manage.py runserver
# Output:
# Starting development server at http://127.0.0.1:8000/
# Quit the server with CONTROL-C.
4. Verwendung von PyTorch mit TorchServe
TorchServe ist ein flexibles und einfach zu verwendendes Tool zur Bereitstellung von PyTorch-Modellen. Es bietet Funktionen wie Multi-Model-Serving, Model Versioning und Metrics.
Installation und Konfiguration von TorchServe
Installation:
pip install torchserve torch-model-archiver
Modellarchivierung:
torch-model-archiver --model-name simple_model --version 1.0 --model-file simple_model.py --serialized-file simple_model.pth --handler torchserve_handler.py --export-path model_store
Erstellen einer benutzerdefinierten Handler-Datei (torchserve_handler.py):
import torch
import torch.nn as nn
from ts.torch_handler.base_handler import BaseHandler
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
class SimpleModelHandler(BaseHandler):
def __init__(self):
super(SimpleModelHandler, self).__init__()
self.model = SimpleModel()
def initialize(self, context):
self.manifest = context.manifest
self.model_pt_path = self.manifest['model']['serializedFile']
self.model.load_state_dict(torch.load(self.model_pt_path))
self.model.eval()
def handle(self, data, context):
input_tensor = torch.tensor(data[0]['body'], dtype=torch.float32)
with torch.no_grad():
output = self.model(input_tensor)
return output.tolist()
Fazit
Die Bereitstellung von PyTorch-Modellen kann auf verschiedene Weisen erfolgen, je nach Anwendungsfall und Anforderungen. ONNX ermöglicht eine breite Kompatibilität mit anderen Frameworks, während Flask, FastAPI, Django und Django FastAPI schnelle und einfache Möglichkeiten bieten, Modelle als Web-APIs bereitzustellen. TorchServe bietet eine robuste und skalierbare Lösung für die Bereitstellung von PyTorch-Modellen in der Produktion. Nutzen Sie diese Techniken, um Ihre Machine-Learning-Modelle effektiv und effizient bereitzustellen.