Skip to content

Commit 5b4d72c

Browse files
Add an alternative scenario to EoMT post_process_semantic_segmentation in case path_offsets is None (#42716)
* Add an alternative scenario in case patch_offsets is None * Fixup * Fix an error * Simplified the function --------- Co-authored-by: Yoni Gozlan <[email protected]>
1 parent 3f3cae7 commit 5b4d72c

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

src/transformers/models/eomt/image_processing_eomt.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,19 @@ def post_process_semantic_segmentation(
815815

816816
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
817817

818-
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
818+
if patch_offsets:
819+
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
820+
else:
821+
output_logits = []
822+
823+
for idx in range(len(segmentation_logits)):
824+
resized_logits = torch.nn.functional.interpolate(
825+
segmentation_logits[idx].unsqueeze(dim=0),
826+
size=target_sizes[idx],
827+
mode="bilinear",
828+
align_corners=False,
829+
)
830+
output_logits.append(resized_logits[0])
819831

820832
preds = [logit.argmax(dim=0) for logit in output_logits]
821833
return preds

src/transformers/models/eomt/image_processing_eomt_fast.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,19 @@ def post_process_semantic_segmentation(
385385

386386
segmentation_logits = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
387387

388-
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
388+
if patch_offsets:
389+
output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size)
390+
else:
391+
output_logits = []
392+
393+
for idx in range(len(segmentation_logits)):
394+
resized_logits = torch.nn.functional.interpolate(
395+
segmentation_logits[idx].unsqueeze(dim=0),
396+
size=target_sizes[idx],
397+
mode="bilinear",
398+
align_corners=False,
399+
)
400+
output_logits.append(resized_logits[0])
389401

390402
preds = [logit.argmax(dim=0) for logit in output_logits]
391403
return preds

0 commit comments

Comments
 (0)