#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
#           This file was automatically generated from src/transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py.
#               Do NOT edit this file manually as any edits will be overwritten by the generation of
#             the file from the modular. If any change should be done, please apply the change to the
#                          modular_pp_doclayout_v3.py file directly. One of our CI enforces this.
#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
import torch
import torchvision.transforms.v2.functional as tvF

from ...image_processing_utils_fast import BaseImageProcessorFast, BatchFeature
from ...image_transforms import group_images_by_shape, reorder_images
from ...image_utils import PILImageResampling, SizeDict
from ...utils import auto_docstring, is_cv2_available, requires_backends
from ...utils.generic import TensorType


if is_cv2_available():
    import cv2


@auto_docstring
class PPDocLayoutV3ImageProcessorFast(BaseImageProcessorFast):
    resample = PILImageResampling.BICUBIC
    image_mean = [0, 0, 0]
    image_std = [1, 1, 1]
    size = {"height": 800, "width": 800}
    do_resize = True
    do_rescale = True
    do_normalize = True

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    # We require `self.resize(..., antialias=False)` to approximate the output of `cv2.resize`
    def _preprocess(
        self,
        images: list["torch.Tensor"],
        do_resize: bool,
        size: SizeDict,
        interpolation: Optional["tvF.InterpolationMode"],
        do_center_crop: bool,
        crop_size: SizeDict,
        do_rescale: bool,
        rescale_factor: float,
        do_normalize: bool,
        image_mean: float | list[float] | None,
        image_std: float | list[float] | None,
        do_pad: bool | None,
        pad_size: SizeDict | None,
        disable_grouping: bool | None,
        return_tensors: str | TensorType | None,
        **kwargs,
    ) -> BatchFeature:
        # Group images by size for batched resizing
        grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
        resized_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            if do_resize:
                stacked_images = self.resize(
                    image=stacked_images, size=size, interpolation=interpolation, antialias=False
                )
            resized_images_grouped[shape] = stacked_images
        resized_images = reorder_images(resized_images_grouped, grouped_images_index)

        # Group images by size for further processing
        # Needed in case do_resize is False, or resize returns images with different sizes
        grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
        processed_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            if do_center_crop:
                stacked_images = self.center_crop(stacked_images, crop_size)
            # Fused rescale and normalize
            stacked_images = self.rescale_and_normalize(
                stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
            )
            processed_images_grouped[shape] = stacked_images
        processed_images = reorder_images(processed_images_grouped, grouped_images_index)

        if do_pad:
            processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)

        return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

    def _get_order_seqs(self, order_logits):
        """
        Computes the order sequences for a batch of inputs based on logits.

        This function takes in the order logits, calculates order scores using a sigmoid activation,
        and determines the order sequences by ranking the votes derived from the scores.

        Args:
            order_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_queries)`):
                Stacked order logits.

        Returns:
            torch.Tensor: A tensor of shape `(batch_size, num_queries)`:
                Containing the computed order sequences for each input in the batch. Each row represents the ranked order of elements for the corresponding input in the batch.
        """
        order_scores = torch.sigmoid(order_logits)
        batch_size, sequence_length, _ = order_scores.shape

        order_votes = order_scores.triu(diagonal=1).sum(dim=1) + (1.0 - order_scores.transpose(1, 2)).tril(
            diagonal=-1
        ).sum(dim=1)

        order_pointers = torch.argsort(order_votes, dim=1)
        order_seq = torch.empty_like(order_pointers)
        ranks = torch.arange(sequence_length, device=order_pointers.device, dtype=order_pointers.dtype).expand(
            batch_size, -1
        )
        order_seq.scatter_(1, order_pointers, ranks)

        return order_seq

    def extract_custom_vertices(self, polygon, sharp_angle_thresh=45):
        poly = np.array(polygon)
        n = len(poly)
        res = []
        i = 0
        while i < n:
            previous_point = poly[(i - 1) % n]
            current_point = poly[i]
            next_point = poly[(i + 1) % n]
            vector_1 = previous_point - current_point
            vector_2 = next_point - current_point
            cross_product_value = (vector_1[1] * vector_2[0]) - (vector_1[0] * vector_2[1])
            if cross_product_value < 0:
                angle_cos = np.clip(
                    (vector_1 @ vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2)), -1.0, 1.0
                )
                angle = np.degrees(np.arccos(angle_cos))
                if abs(angle - sharp_angle_thresh) < 1:
                    # Calculate the new point based on the direction of two vectors.
                    dir_vec = vector_1 / np.linalg.norm(vector_1) + vector_2 / np.linalg.norm(vector_2)
                    dir_vec = dir_vec / np.linalg.norm(dir_vec)
                    step_size = (np.linalg.norm(vector_1) + np.linalg.norm(vector_2)) / 2
                    new_point = current_point + dir_vec * step_size
                    res.append(tuple(new_point))
                else:
                    res.append(tuple(current_point))
            i += 1
        return res

    def _mask2polygon(self, mask, epsilon_ratio=0.004):
        """
        Postprocess mask by removing small noise.
        Args:
            mask (ndarray): The input mask of shape [H, W].
            epsilon_ratio (float): The ratio of epsilon.
        Returns:
            ndarray: The output mask after postprocessing.
        """
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        if not contours:
            return None

        contours = max(contours, key=cv2.contourArea)
        epsilon = epsilon_ratio * cv2.arcLength(contours, True)
        approx_contours = cv2.approxPolyDP(contours, epsilon, True)
        polygon_points = approx_contours.squeeze()
        polygon_points = np.atleast_2d(polygon_points)

        polygon_points = self.extract_custom_vertices(polygon_points)

        return polygon_points

    def _extract_polygon_points_by_masks(self, boxes, masks, scale_ratio):
        scale_width, scale_height = scale_ratio[0] / 4, scale_ratio[1] / 4
        mask_height, mask_width = masks.shape[1:]
        polygon_points = []

        for i in range(len(boxes)):
            x_min, y_min, x_max, y_max = boxes[i].astype(np.int32)
            box_w, box_h = x_max - x_min, y_max - y_min

            # default rect
            rect = np.array(
                [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]],
                dtype=np.float32,
            )

            if box_w <= 0 or box_h <= 0:
                polygon_points.append(rect)
                continue

            # crop mask
            x_coordinates = [int(round((x_min * scale_width).item())), int(round((x_max * scale_width).item()))]
            x_start, x_end = np.clip(x_coordinates, 0, mask_width)
            y_coordinates = [int(round((y_min * scale_height).item())), int(round((y_max * scale_height).item()))]
            y_start, y_end = np.clip(y_coordinates, 0, mask_height)
            cropped_mask = masks[i, y_start:y_end, x_start:x_end]

            # resize mask to match box size
            resized_mask = cv2.resize(cropped_mask.astype(np.uint8), (box_w, box_h), interpolation=cv2.INTER_NEAREST)

            polygon = self._mask2polygon(resized_mask)
            if polygon is not None and len(polygon) < 4:
                polygon_points.append(rect)
                continue
            if polygon is not None and len(polygon) > 0:
                polygon = polygon + np.array([x_min, y_min])

            polygon_points.append(polygon)

        return polygon_points

    def post_process_object_detection(
        self,
        outputs,
        threshold: float = 0.5,
        target_sizes: TensorType | list[tuple] | None = None,
    ):
        """
        Converts the raw output of [`PPDocLayoutV3ForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
        bottom_right_x, bottom_right_y) format. Only supports PyTorch.

        Args:
            outputs ([`DetrObjectDetectionOutput`]):
                Raw outputs of the model.
        Returns:
            `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and polygon_points for an image
            in the batch as predicted by the model.
        """
        requires_backends(self, ["torch", "cv2"])
        boxes = outputs.pred_boxes
        logits = outputs.logits
        order_logits = outputs.order_logits
        masks = outputs.out_masks

        order_seqs = self._get_order_seqs(order_logits)

        box_centers, box_dims = torch.split(boxes, 2, dim=-1)
        top_left_coords = box_centers - 0.5 * box_dims
        bottom_right_coords = box_centers + 0.5 * box_dims
        boxes = torch.cat([top_left_coords, bottom_right_coords], dim=-1)

        if target_sizes is not None:
            if len(logits) != len(target_sizes):
                raise ValueError(
                    "Make sure that you pass in as many target sizes as the batch dimension of the logits"
                )
            if isinstance(target_sizes, list):
                img_height, img_width = torch.as_tensor(target_sizes).unbind(1)
            else:
                img_height, img_width = target_sizes.unbind(1)
            scale_factor = torch.stack([img_width, img_height, img_width, img_height], dim=1).to(boxes.device)
            boxes = boxes * scale_factor[:, None, :]

        num_top_queries = logits.shape[1]
        num_classes = logits.shape[2]

        scores = torch.nn.functional.sigmoid(logits)
        scores, index = torch.topk(scores.flatten(1), num_top_queries, dim=-1)
        labels = index % num_classes
        index = index // num_classes
        boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
        masks = masks.gather(
            dim=1, index=index.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, masks.shape[-2], masks.shape[-1])
        )
        masks = (masks.sigmoid() > threshold).int()
        order_seqs = order_seqs.gather(dim=1, index=index)

        results = []
        for score, label, box, order_seq, target_size, mask in zip(
            scores, labels, boxes, order_seqs, target_sizes, masks
        ):
            order_seq = order_seq[score >= threshold]
            order_seq, indices = torch.sort(order_seq)
            polygon_points = self._extract_polygon_points_by_masks(
                box[score >= threshold][indices].detach().cpu().numpy(),
                mask[score >= threshold][indices].detach().cpu().numpy(),
                [self.size["width"] / target_size[1], self.size["height"] / target_size[0]],
            )
            results.append(
                {
                    "scores": score[score >= threshold][indices],
                    "labels": label[score >= threshold][indices],
                    "boxes": box[score >= threshold][indices],
                    "polygon_points": polygon_points,
                    "order_seq": order_seq,
                }
            )

        return results


__all__ = ["PPDocLayoutV3ImageProcessorFast"]
