ML OPS/Inference & Serving

[ML OPS] 파이썬으로 딥러닝 모델 서빙하기 (ft. flask)

seokhyun2 2022. 2. 8. 00:45

딥러닝 모델을 서빙하는 방식은 여러가지 방법이 있습니다.

여러가지 방법 중에서, 오늘은 flask를 활용하는 방법을 소개해보도록 하겠습니다.

 

https://github.com/hsh2438/MLops/tree/main/1_flask_rest_api

 

GitHub - hsh2438/MLops

Contribute to hsh2438/MLops development by creating an account on GitHub.

github.com

코드는 우선 위의 깃헙 레포지토리를 참고하시면 됩니다.

 

라이브러리 설치

필요한 라이브러리는 requirements.txt 파일에 저장해두었으므로 아래 명령을 실행하여 라이브러리를 설치해주시면 됩니다.

pip install -r requirements.txt

 

서버 구현 

flask로 rest api server를 구현한 코드는 app.py 입니다.

app.py 코드를 조금씩 보도록 하겠습니다.

 

# initializing
app = Flask(__name__)
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()

먼저 필요한 것들을 초기화하는 부분입니다.

app = Flask(__name__) 을 통해서 flask를 먼저 초기화 해주었습니다.

다음은 imagenet 데이터의 index와 label을 json 파일로부터 로딩해줍니다.

 

저희는 서빙을 중점으로 볼 예정으로 모델은 pretrain 모델을 가지고 왔습니다.

인퍼런스만 해줄 예정이므로 model.eval()을 실행하여 평가 모드로 모델을 변경합니다.

(eval을 실행해서 dropout이나 batchnorm과 같은 evaluation 과정에서는 사용하지 않아야 하는 layer를 off해줍니다.)

 

 

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

transform_image는 이미지를 전처리 해주는 함수입니다.

이미지 전처리 코드는 학습과 서빙에서 동일한 로직을 활용해야 학습에서의 성능을 서빙에서도 보장할 수 있습니다.

 

 

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

get_prediction 함수는 inference를 수행하고, json파일로 로딩해 온 정보를 활용하여 최종 라벨을 반환하는 함수입니다.

 

 

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

이제 flask를 활용해서 서버를 구성하는 부분입니다.

post만 받도록 구성하고, file을 입력받아서 get_prediction 함수를 수행한 결과를 반환하게만 해주었습니다.

이제 app.py 파일을 실행해보면 localhost:5000 주소로 서버가 실행됩니다.

 

테스트

import requests

response = requests.post("http://localhost:5000/predict", files={"file": open('cat.jpeg','rb')})
print(response.json())

서버가 실제로 잘 동작하는지 테스트를 해보기 위해서 테스트 코드도 준비했습니다.

requests 라이브러리를 활용하여 cat.jpeg 파일을 보내서 결과를 프린트 해주었습니다.

 

위의 코드를 실행해보면 아래와 같은 결과를 볼 수 있으면 성공입니다!