from qmdl.mdl import Mdl, SEAM_FLAG
from qmdl.palette import palette
from collections import defaultdict


class Helper:
    """
    This class contains a number of methods to perform common tasks on a Mdl.
    

    """
    def __init__(self):
        self.mdl = Mdl()
        self.filename = "model.mdl"

    def load(self, filename):
        self.mdl = Mdl()
        with open(filename, "rb") as file:
            self.mdl.read(file)
        self.filename = filename
        return self

    def save(self, filename=None):
        #perform operations to make mdl data consistent
        self.mdl.recalculate_header()
        for f in self.mdl.frames:
            f.calculate_bounds()
        if filename is None:
            filename = self.filename
        with open(filename, "wb") as file:
            self.mdl.write(file)
        return self

    @staticmethod
    def rename_range(frames, start, name, length):
        """
        Sets sequential names on a run of the given frames.

        Parameters:
        frames -- a list of frame objects
        start -- the index of the first frame to rename
        name -- a string given the root name to which numbers are appended
        length -- the number of frames to rename
        

        """
        for i, frame in enumerate(frames[start:(start + length)], 1):
            frame.name = (name + str(i)).encode("ascii")

    def rename_frames(self, names):
        """
        Set sequential names for the frames in the model.

        The names parameter needs a particular format. It should be a list, and
        each entry in the list should be a tuple:
        (name, length)
        name -- the root name of this frame sequence
        length -- the number of frames in this sequence

        This method always starts from frame 0, and renames each sequence
        starting from the frame one past the last renamed sequence. The
        intended use is to rename all the sequences in the model at once, so
        the names parameter should have entries that cover every frame.
        
        
        """
        i = 0
        for name in names:
            Helper.rename_range(self.mdl.frames, i, *name)
            i += name[1]
        return self

    def group_frames(self, start, end, sequence_name, duration=0.1):
        """
        Takes the frames between start and end and makes them a framegroup.

        Parameters:
        start -- the index of the first frame to include in the framegroup
        end -- the index of the last frame to include in the framegroup
        sequence_name -- what to call the new framegroup
        duration -- how long each frame in the framegroup will last

        Note that this function modifies the frames list - the new framegroup
        replaces the frames that previously sat between start and end. Bear
        this in mind when calling the function multiple times in a row, all the
        frames after the first framegroup will have moved down the list and
        have new indices. One way to make this simpler is to start with the
        final framegroup in your model and work your way to the first one. This
        way you've already finished with the frames which move positions.
        
        
        """
        # python expects end to be the frame 1 past the frame we want
        end += 1
        head = self.mdl.frames[:start]
        ungrouped_frames = self.mdl.frames[start:end]
        tail = self.mdl.frames[end:]
        frame_group = Mdl.FrameGroup()
        frame_group.name = sequence_name.encode("ascii")
        frame_group.duration = [(1 + x) * duration for x in range(end - start)]
        frame_group.frames = ungrouped_frames
        head.append(frame_group)
        head.extend(tail)
        self.mdl.frames = head
        return self

    def remove_frames(self, removal_list):
        # output a reversed list with no duplicates
        removal_list = sorted(set(removal_list), reverse=True)
        for index in removal_list:
            del self.mdl.frames[index]
        
    def merge_vertices(self, ignore_normals=False):
        """
        Finds and merges all vertices which can be safely combined

        By default this function will only merge vertices which will not change
        the appearance of the model in any way. It will attempt to create seam
        vertices where there are corresponding left and right hand vertices. The
        criteria for mergeable vertices are:

        * They occupy the same position in every frame
        * They have mergeable UV coordinates
        * They have the same normal vector in each frame

        The ignore_normals parameter allows the function to skip the third
        condition. Merged vertices will take the average vertex normal. This
        will remove smoothing groups and erase some smoothing errors.

        Mergeable UV coordinates covers two possibilities. Firstly that the
        vertices have exactly the same coordinates on the skin. Secondly, they
        can be half a skin-width separated, so long as they are both exclusively
        connected to appropriately facing triangles, front faces on the left and
        back faces on the right.
        
        """

        # We build a set of normals to merge later on, to avoid giving extra
        # weight to any vertex through repeated averaging
        normals_to_average = defaultdict(set)
        
        def merge_vertex_normals(v_from, v_to):
            if v_from in normals_to_average:
                normals_to_average[v_to].update(normals_to_average[v_from])
                del normals_to_average[v_from]
            normals_to_average[v_to].add(v_from)

        def average_normals(frame):
            for v_to, v_from in normals_to_average.items():
                norm = [0, 0, 0]

                def add_to_normal(vec):
                    return map(sum, zip(vec, norm))
                for v_index in v_from:
                    norm = add_to_normal(frame.vertices[v_index].decode())
                norm = add_to_normal(frame.vertices[v_to].decode())
                frame.vertices[v_to].encode(tuple(norm))
        
        # The tri_map lets us go quickly from a vertex index to joined triangles
        tri_map = defaultdict(set)  # set helps for degenerate triangles, updates
        for triangle in self.mdl.triangles:
            for vertex in triangle.vertices:
                tri_map[vertex].add(triangle)
                
        # While we're working we only "merge" vertices by changing vertex
        # indices on the triangles. The result is with a valid model with a
        # bunch of vertices which are no longer joined to any triangles.
        # This set tracks these orphaned vertices, so at the end we can delete
        # entries from the vertex lists on the model and each frame, and
        # renumber the triangle indices.
        removed_vertices = set()
        
        def merge_vertices_on_triangles(merge_from, merge_to):
            for vert in merge_from:
                for tri in tri_map[vert]:
                    vertlist = list(tri.vertices)
                    vertlist = [merge_to if x == vert else x for x in vertlist]
                    tri.vertices = tuple(vertlist)
                tri_map[merge_to].update(tri_map[vert])
                merge_vertex_normals(vert, merge_to)
                del tri_map[vert]
            removed_vertices.update(merge_from)
                
        if ignore_normals:
            normflag = 0
        else:
            normflag = 1
        # We go frame by frame, each time creating a dictionary of lists.
        # Each list contains vertices which can be merged in all frames so far.
        # We build a key which encodes both the previous list the vertex belongs
        # to and its position in the current frame. We then add the vertex to
        # the list under that key. Vertices with the same key can still be
        # merged, and get added to the same list

        # We start with all vertices in one list: before we've looked at any
        # frames we've got no evidence that they can't be merged.
        current_merge_lists = {"only": range(len(self.mdl.vertices))}
        for frame in self.mdl.basic_frames():
            prior_merge_lists = current_merge_lists
            current_merge_lists = defaultdict(list)
            i = 0
            for merge_list in prior_merge_lists.values():
                for v_index in merge_list:
                    coord = frame.vertices[v_index] 
                    key = (coord.position, coord.normal * normflag, i)
                    current_merge_lists[key].append(v_index)
                i += 1

        # To create new seam vertices, we need all triangles to correctly
        # report if they are backfacing or not.
        self.mdl.calculate_facings()

        half_width = self.mdl.skinwidth // 2
        
        for merge_key, merge_list in current_merge_lists.items():
            uv_matches = defaultdict(set)
            # Group mergeable vertices by their UV coordinates
            for v_index in merge_list:
                vertex = self.mdl.vertices[v_index]
                # In the first pass we only merge onseam with onseam and
                # offseam with offseam, so we add the seam status to the key.
                key = (vertex.u, vertex.v, bool(vertex.onseam))
                uv_matches[key].add(v_index)
            # Merge each of the groups down to 1 vertex
            for match_key, match_set in uv_matches.items():
                if len(match_set) == 1:
                    # Unpack the lone vertex from its set
                    uv_matches[match_key] = match_set.pop()
                    continue
                keep = min(match_set)
                cull = match_set.difference([keep])
                merge_vertices_on_triangles(cull, keep)
                # Store the lone remaining vertex directly
                uv_matches[match_key] = keep
            # At this stage each key in uv_matches holds exactly 1 vertex.
            # Now we try to create seam vertices from pairs of offseam vertices.
            for match_key in uv_matches.keys():
                # Look for right-hand, offseam vertices.
                if match_key[2] or match_key[0] < half_width:
                    continue
                # Build the key for the corresponding left-hand vertex.
                left_key = (match_key[0] - half_width, match_key[1], False)
                if left_key not in uv_matches:
                    continue
                # Any triangle containing the vertex in left_key must be a
                # front-facing triangle, because mixed triangles are allocated
                # FACE_FRONT by calculate_facings. So we need only check that
                # the triangles joined to the right-hand vertex are FACE_BACK
                right_vertex = uv_matches[match_key]
                tri_set = tri_map[right_vertex]
                # Remember backface is non-zero for front-facing triangles
                if any(t.backface for t in tri_set):
                    continue
                left_vertex = uv_matches[left_key]
                # We merge the right_vertex into the left_vertex because it
                # already has the correct UV coordinates
                merge_vertices_on_triangles([right_vertex], left_vertex)
                self.mdl.vertices[left_vertex].onseam = SEAM_FLAG
                # We can now try to merge this with the coordinate at the same
                # position which was already onseam, and so not considered for
                # merging earlier.
                onseam_key = (left_key[0], left_key[1], True)
                if onseam_key not in uv_matches:
                    continue
                merge_vertices_on_triangles([left_vertex], uv_matches[onseam_key])

        # Now we tidy up the orphaned vertices, and update the triangle
        # vertex indices to match up to the reduced numbers
        # We begin by creating a list containing the new index (or None) for
        # each of the old vertices in order        
        initial_vertex_count = len(self.mdl.vertices)
        vertex_remap = []
        i = 0
        for vert in range(initial_vertex_count):
            if vert in removed_vertices:
                vertex_remap.append(None)
            else:
                vertex_remap.append(i)
                i += 1
        
        newlist = [vert for i, vert in enumerate(self.mdl.vertices)
                   if i not in removed_vertices]
        self.mdl.vertices = newlist

        for frame in self.mdl.basic_frames():
            average_normals(frame)
            newlist = [vert for i, vert in enumerate(frame.vertices)
                       if i not in removed_vertices]
            frame.vertices = newlist

        for tri in self.mdl.triangles:
            tri.vertices = (vertex_remap[tri.vertices[0]],
                            vertex_remap[tri.vertices[1]],
                            vertex_remap[tri.vertices[2]])
        return self
        
    def append_skin(self, filename):
        """
        Import an image from the given filename and append it as a skin.

        This method uses the Python Image Library, and will open any file it
        supports. If the loaded image uses an indexed palette, the function
        assumes it is in the Quake palette and performs no conversion. If the
        file is in RGB format it will be dithered to the Quake palette. There
        is no guarantee that the conversion will be pretty, it is advised to
        perform conversion externally for best results.
        
        
        """
        from PIL import Image
        
        im = Image.open(filename)
        #update the model's skinwidth and height if this is the first skin
        if not self.mdl.skins:
            self.mdl.skinwidth, self.mdl.skinheight = im.size

        if im.mode != "P":
            im = im.convert("RGB")
            pim = Image.new("P", (self.mdl.skinwidth, self.mdl.skinheight))
            pim.putpalette(palette)
            pim = im.quantize(palette=pim)
        else:
            pim = im
        sk = Mdl.Skin(self.mdl.skinwidth, self.mdl.skinheight)
        sk.pixels = pim.tostring()
        self.mdl.skins.append(sk)
        return self