DeepLearning/Service

pytorch와 flask를 활용한 딥러닝 모델 서빙하기

seokhyun2 2020. 2. 4. 10:07

tensorflow 2.0을 활용해서 어떻게 서빙하는지 다뤄봤었는데, 요즘엔 pytorch를 사용하시는 분들도 많으니까 이번엔 pytorch를 서빙하는 방법에 대해서 설명드리려고 합니다!

 

이전 글들과 똑같이 mnist를 준비했고, 학습은 pytorch 공식 예제 참조하여 학습을 수행하였습니다. 

아래 링크 참조하셔서 학습 진행해보시길 추천드려요!

https://github.com/pytorch/examples/tree/master/mnist

 

pytorch/examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc. - pytorch/examples

github.com

 

이전 포스팅에서 tensorflow 예제를 다룰 때는 pixel을 255로 나누어줬었는데, pytorch 예제를 보시면  0.1307과 0.3081이란 숫자를 활용해서 정규화를 해주는 것을 보실 수 있습니다.

이 부분은 mnist 데이터에서 전체 평균과 표준편차를 구하여 그 값을 활용하여 정규화를 수행하도록 한 것입니다.

pixel을 단순히 255로 나누는 것보다 평균과 표준편차를 활용하여 -1 ~ 1 사이의 값으로 정규화 해주는 것이 더 좋다고 하네요. 

서빙이랑은 상관이 없으니 이쯤에서 넘어가도록 하겠습니다.

 

이제 pytorch 모델을 서빙하는 소스코드를 보도록 하겠습니다.

# flask_server.py

import torch
import numpy as np
from torchvision import transforms
from flask import Flask, jsonify, request

from model import CNN


model = CNN()
model.load_state_dict(torch.load('mnist_model.pt'), strict=False)
model.eval()

normalize = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

app = Flask(__name__)
@app.route('/inference', methods=['POST'])
def inference():
    data = request.json
    _, result = model.forward(normalize(np.array(data['images'], dtype=np.uint8)).unsqueeze(0)).max(1)
    return str(result.item())


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=2431, threaded=False)

매우 간단하죠?

모델을 불러오고 정규화를 미리 정의해두고 그 후에는 요청이 들어올 때 마다 결과를 출력해서 반환하도록 구현하였습니다.

 

서버에 요청하는 소스코드는 아래와 같습니다.

# flask_test.py

import json
import requests
import numpy as np
from PIL import Image

image = Image.open('test_image.jpg')
pixels = np.array(image)

headers = {'Content-Type':'application/json'}
address = "http://127.0.0.1:2431/inference"
data = {'images':pixels.tolist()}

result = requests.post(address, data=json.dumps(data), headers=headers)

print(str(result.content, encoding='utf-8'))

이미지를 불러와서 픽셀을 담아서 보내주면 끝!

 

이렇게 pytorch 모델까지 어떻게 서빙을 할 수 있는지 간단하게 살펴보았습니다.

하지만, 이렇게만 서빙한다고 서비스를 할 수 있는 건 절대 아니겠죠?

제일 먼저 병렬처리에 대해서 궁금해 하실 것 같네요.

 

flask_server.py에서 맨 밑에 줄에 threaded=True 옵션을 주면, 각 요청들이 각각의 쓰레드로 동작하면서 병렬처리가 가능하도록 flask에서 제공은 하고 있지만 pytorch에서는 그 기능을 사용할 경우에는 내부에서 데이터가 꼬이는 현상이 발생하게 됩니다. (Tensorflow에서는 문제가 없어서 threaded=True 옵션으로도 병렬처리가 가능은 합니다.)

그래서 쓰레드 방식보다는 프로세스를 여러개 띄우는 방식을 사용해야만 해요!

 

그런 부분에 대해서는 다음 포스팅에서 쓰레드와 프로세스의 차이, 그리고 파이썬에서의 병렬처리에 대해서도 간단하게 정리하고 어떤 방법을 활용해서 병렬처리를 해주면 좋을지 다음 포스팅에서 정리해보도록 하겠습니다.

 

이번 포스팅에서 활용된 전체 소스코드는 아래의 깃헙 주소로 가시면 모두 보실 수 있습니다!

오늘도 즐거운 딥러닝하세요!

https://github.com/hsh2438/mnist_serving_pytorch_flask.git