딥러닝 모델을 서빙하는 방식은 여러가지 방법이 있습니다.
여러가지 방법 중에서, 오늘은 flask를 활용하는 방법을 소개해보도록 하겠습니다.
https://github.com/hsh2438/MLops/tree/main/1_flask_rest_api
코드는 우선 위의 깃헙 레포지토리를 참고하시면 됩니다.
라이브러리 설치
필요한 라이브러리는 requirements.txt 파일에 저장해두었으므로 아래 명령을 실행하여 라이브러리를 설치해주시면 됩니다.
pip install -r requirements.txt
서버 구현
flask로 rest api server를 구현한 코드는 app.py 입니다.
app.py 코드를 조금씩 보도록 하겠습니다.
# initializing
app = Flask(__name__)
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
먼저 필요한 것들을 초기화하는 부분입니다.
app = Flask(__name__) 을 통해서 flask를 먼저 초기화 해주었습니다.
다음은 imagenet 데이터의 index와 label을 json 파일로부터 로딩해줍니다.
저희는 서빙을 중점으로 볼 예정으로 모델은 pretrain 모델을 가지고 왔습니다.
인퍼런스만 해줄 예정이므로 model.eval()을 실행하여 평가 모드로 모델을 변경합니다.
(eval을 실행해서 dropout이나 batchnorm과 같은 evaluation 과정에서는 사용하지 않아야 하는 layer를 off해줍니다.)
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
transform_image는 이미지를 전처리 해주는 함수입니다.
이미지 전처리 코드는 학습과 서빙에서 동일한 로직을 활용해야 학습에서의 성능을 서빙에서도 보장할 수 있습니다.
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
get_prediction 함수는 inference를 수행하고, json파일로 로딩해 온 정보를 활용하여 최종 라벨을 반환하는 함수입니다.
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
이제 flask를 활용해서 서버를 구성하는 부분입니다.
post만 받도록 구성하고, file을 입력받아서 get_prediction 함수를 수행한 결과를 반환하게만 해주었습니다.
이제 app.py 파일을 실행해보면 localhost:5000 주소로 서버가 실행됩니다.
테스트
import requests
response = requests.post("http://localhost:5000/predict", files={"file": open('cat.jpeg','rb')})
print(response.json())
서버가 실제로 잘 동작하는지 테스트를 해보기 위해서 테스트 코드도 준비했습니다.
requests 라이브러리를 활용하여 cat.jpeg 파일을 보내서 결과를 프린트 해주었습니다.
위의 코드를 실행해보면 아래와 같은 결과를 볼 수 있으면 성공입니다!
'ML OPS > Inference & Serving' 카테고리의 다른 글
[MLOps, LLMOps] In-flight batching (0) | 2023.12.24 |
---|---|
[MLOps] BentoML - Adaptive Batching (0) | 2023.03.12 |
[ML OPS] quantization을 활용한 인퍼런스 최적화 (ft. ONNX, TensorRT) (0) | 2022.07.23 |
[ML OPS] transformers inference (ft. colab, onnx, gpu) (0) | 2022.07.10 |
[ML OPS] ONNX로 pytorch 모델 변환하기 (0) | 2022.02.15 |