Post

PyTorch DataLoader에서 Progress Bar 나타내기





문제

PyTorch를 사용하여 한 번의 epoch에 대해 학습을 진행할 때 DataLoader 객체를 순회하게 됩니다. 그 때 tqdm을 이용하여 progress bar를 표시하는 경우가 있는데 간혹 제대로 표시가 안되는 경우가 있습니다. 아래 코드를 보겠습니다. 아래 코드는 tqdm 이슈란을 살펴보다가 가져오게 되었습니다. [Ref]

1
2
for i_batch, feed_dict in tqdm.tqdm(enumerate(dataloader)):
    sleep(0.01)

해결

굉장히 일반적인 방법인데 이렇게 tqdm으로 enumerate를 감싸는 경우 progress bar가 올바르게 표시가 안됩니다. 시간과 초당 반복 횟수만 표시가 됩니다. Progress bar를 제대로 표시하려면 tqdmenumerate를 사용하는 순서를 바꾸면 됩니다.

1
2
for i_batch, feed_dict in enumerate(tqdm.tqdm(dataloader)):
    sleep(0.01)

해당 이슈 게시물을 보면 tqdm의 문제가 아니고 단순히 enumerate의 기능상 특징 때문이라고 합니다. DataLoader 객체의 경우 __len__()가 있지만 enumerate__len__()을 갖지 않기 때문에 enumeratetqdm으로 감싸게 되면 전체 길이에 대한 정보가 없어 progress bar를 표시하지 못하게 됩니다.

더 나아가서

이런 문제는 zip()을 사용할 때도 비슷합니다. zip() 역시 __len__()이 없기 때문에 tqdm을 이용해서 progress bar를 표시하려면 tqdm을 먼저 사용한 다음 zip()으로 감싸줘야 합니다.

1
2
for e1, e2 in zip(tqdm.tqdm(iter1), iter2):
    ...


This post is licensed under CC BY 4.0 by the author.