import struct
import math


# This wrapper function keeps the file reading code tidy
def _struct_read(s, file):
    return s.unpack(file.read(s.size))


class ParseError(BaseException):
    pass


class Md3:
    # STRUCTS
    # A bunch of static objects to avoid repeated construction during loops
    header_struct = struct.Struct("<4sl64s9l")
    frame_struct = struct.Struct("<10f16s")
    tag_struct = struct.Struct("<64s12f")
    triangle_struct = struct.Struct("<3l")
    shader_struct = struct.Struct("<64sl")
    texcoord_struct = struct.Struct("<2f")
    vertex_struct = struct.Struct("<3hH")
    surface_struct = struct.Struct("<4s64s10l")

    MODEL_SCALE = 1 / 64

    # INIT
    # Fill in defaults everywhere to avoid AttributeError and create any lists
    def __init__(self):
        # initialise all the exported properties to minimal defaults
        self.ident = b"IDP3"
        self.version = 15
        self.name = b""
        self.flags = 0
        self.frames = []
        self.tags = []
        self.surfaces = []

    # HEADER
    # This is stored against the body of the Md3 object
    # a) because you can only have a single header per model and
    # b) because quantities like num_frames only make sense in terms of an
    # actual collection of frames belonging to the model
    # Note: private fields like _num_frames are only used to store the lengths
    # of field parsed from the header of a model file we are reading
    # Any code not reading from a binary file should get the size of collections
    # from len(self.frames) etc. as per the write_header function below
    def read_header(self, file):
        (ident, version,
         self.name, self.flags,
         self._num_frames,
         self._num_tags,
         self._num_surfaces,
         self._num_skins,
         self._ofs_frames,
         self._ofs_tags,
         self._ofs_surfaces,
         self._ofs_eof) = _struct_read(Md3.header_struct, file)
        # run some basic checks that we have been given a md3 file
        if ident != b"IDP3":
            raise ParseError("Ident does not match IDP3")
        if version != 15:
            raise ParseError("MD3 file is not version 15")
        return self

    def write_header(self, file, ofs_frames=0, ofs_tags=0,
                     ofs_surfaces=0, ofs_eof=0):
        file.write(Md3.header_struct.pack
                   (
                       b"IDP3", 15,
                       self.name, self.flags,
                       len(self.frames),
                       len(self.tags),
                       len(self.surfaces),
                       0,
                       ofs_frames,
                       ofs_tags,
                       ofs_surfaces,
                       ofs_eof
                   ))
        return self

    class Frame:
        def __init__(self):
            self.origin = (0.0, 0.0, 0.0)
            self.name = b"nameless"

        def read(self, file):
            data = _struct_read(Md3.frame_struct, file)
            self._mins = data[0:3]
            self._maxs = data[3:6]
            self.origin = data[6:9]
            self._radius = data[9]
            self.name = data[10]
            return self

        def write(self, file, mins=(0, 0, 0), maxs=(0, 0, 0)):
            file.write(Md3.frame_struct.pack(mins[0],
                                             mins[1],
                                             mins[2],
                                             maxs[0],
                                             maxs[1],
                                             maxs[2],
                                             self.origin[0],
                                             self.origin[1],
                                             self.origin[2],
                                             self._radius,
                                             self.name))
            return self

    def read_frames(self, file):
        file.seek(self._ofs_frames)
        for i in range(self._num_frames):
            self.frames.append(Md3.Frame().read(file))
        return self

    def write_frames(self, file):
        for index, frame in enumerate(self.frames):
            mins = tuple(min(z * Md3.MODEL_SCALE) for z in
                         zip((s.get_frame_mins(index) for s in self.surfaces)))
            maxs = tuple(max(z * Md3.MODEL_SCALE) for z in
                         zip((s.get_frame_maxs(index) for s in self.surfaces)))
            frame.write(file, mins, maxs)
        return self

    class Tag:
        def __init__(self):
            self.name = b""
            self.origin = (0.0, 0.0, 0.0)
            self.axis = ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0))

        def read(self, file):
            data = _struct_read(Md3.tag_struct, file)
            self.name = data[0]
            self.origin = data[1:4]
            self.axis = (data[4:7], data[7:10], data[10:13])
            return self

        def write(self, file):
            data = (self.name,) + self.origin
            data += self.axis[0] + self.axis[1] + self.axis[2]
            file.write(Md3.tag_struct.pack(*data))
            return self

    def read_tags(self, file):
        file.seek(self._ofs_tags)
        for i in range(self._num_tags):
            self.tags.append(Md3.Tag().read(file))
        return self

    def write_tags(self, file):
        for tag in self.tags:
            tag.write(file)
        return self

    class Triangle:
        def __init__(self):
            self.vertices = (0, 0, 0)

        def read(self, file):
            self.vertices = _struct_read(Md3.triangle_struct, file)
            return self

        def write(self, file):
            file.write(Md3.triangle_struct.pack(*self.vertices))
            return self

    class Shader:
        def __init__(self):
            self.name = b""
            self.index = 0

        def read(self, file):
            (self.name, self.index) = _struct_read(Md3.shader_struct, file)
            return self

        def write(self, file):
            file.write(Md3.shader_struct.pack(self.name, self.index))
            return self

    class Texcoord:
        def __init__(self):
            self.u, self.v = (0, 0)

        def read(self, file):
            self.u, self.v = _struct_read(Md3.texcoord_struct, file)
            return self

        def write(self, file):
            file.write(Md3.texcoord_struct.pack(self.u, self.v))
            return self

    class Vertex:
        def __init__(self):
            self.position = (0, 0, 0)
            self.normal = 0

        def read(self, file):
            data = _struct_read(Md3.vertex_struct, file)
            self.position = data[0:3]
            self.normal = data[3]
            return self

        def write(self, file):
            file.write(Md3.vertex_struct.pack(self.position[0],
                                              self.position[1],
                                              self.position[2],
                                              self.normal))
            return self

        def encode(self, vector):
            x, y, z = vector
            lng = math.atan2(y, x) * 255 / (2 * math.pi)
            lat = math.acos(z) * 255 / (2 * math.pi)
            self.normal = (lng & 255) | (lat & 255) << 8

        def decode(self):
            lng = self.normal & 255
            lat = (self.normal >> 8) & 255
            x = math.cos(lat) * math.sin(lng)
            y = math.sin(lat) * math.sin(lng)
            z = math.cos(lng)
            return x, y, z

    class Surface:
        def __init__(self):
            self.name = b""
            self.flags = 0
            self.shaders = []
            self.triangles = []
            self.texcoords = []
            self.vertices = []
            # Although all MD3s *should* have the same number of frames
            # in each surface, we keep track of the count per Surface
            # as it can't be reconstructed from the length of a list
            # FIXME more sensible to record number of vertices per frame?
            self._frame_count = 0

        def read(self, file):
            surface_start = file.tell()
            (ident, self.name, self.flags,
             self._frame_count, num_shaders, num_vertices, num_triangles,
             ofs_triangles, ofs_shaders, ofs_texcoords, ofs_vertices,
             ofs_end) = _struct_read(Md3.surface_struct, file)
            if ident != b"IDP3":
                raise ParseError("Surface ident does not match IDP3")
            # we must seek to the offset of each structure although if all
            # goes well these will be seeks of size 0 in most models
            file.seek(surface_start + ofs_shaders)  # usually first
            for i in range(num_shaders):
                self.shaders.append(Md3.Shader().read(file))

            file.seek(surface_start + ofs_triangles)
            for i in range(num_triangles):
                self.triangles.append(Md3.Triangle().read(file))

            file.seek(surface_start + ofs_texcoords)
            for i in range(num_vertices):
                self.texcoords.append(Md3.Texcoord().read(file))

            file.seek(surface_start + ofs_vertices)
            for i in range(num_vertices * self._frame_count):
                self.vertices.append(Md3.Vertex().read(file))

            # we must leave the file pointer at the end of structure given
            file.seek(surface_start + ofs_end)
            return self

        def write_header(self, file, ofs_scratch):
            num_frames = self._frame_count
            num_shaders = len(self.shaders)
            num_vertices = len(self.vertices) // self._frame_count
            num_triangles = len(self.triangles)
            ofs_shaders = ofs_scratch - file.tell()
            ofs_triangles = ofs_shaders + num_shaders * Md3.shader_struct.size
            ofs_texcoords = ofs_triangles + num_triangles * Md3.triangle_struct.size
            ofs_vertices = ofs_texcoords + num_vertices * Md3.texcoord_struct.size
            ofs_end = ofs_vertices + len(self.vertices) * Md3.vertex_struct.size
            file.write(Md3.surface_struct.pack(
                b"IDP3", self.name, self.flags,
                num_frames, num_shaders, num_vertices, num_triangles,
                ofs_triangles, ofs_shaders, ofs_texcoords, ofs_vertices, ofs_end))

        def write_content(self, file):
            for shader in self.shaders:
                shader.write(file)
            for triangle in self.triangles:
                triangle.write(file)
            for texcoord in self.texcoords:
                texcoord.write(file)
            for vertex in self.vertices:
                vertex.write(file)

        def get_frame_vertices(self, frame_index):
            frame_length = len(self.vertices) // self._frame_count
            frame_start = frame_index * frame_length
            return self.vertices[frame_start:(frame_start + frame_length)]

        def get_frame_mins(self, frame_index):
            return tuple(min(z) for z in zip((v.position for v in self.get_frame_vertices(frame_index))))

        def get_frame_maxs(self, frame_index):
            return tuple(max(z) for z in zip((v.position for v in self.get_frame_vertices(frame_index))))

    def read_surfaces(self, file):
        file.seek(self._ofs_surfaces)
        for i in range(self._num_surfaces):
            self.surfaces.append(Md3.Surface().read(file))
        return self

    def write_surfaces(self, file, ofs_surface_content):
        for surface in self.surfaces:
            surface.write_header(file, ofs_surface_content)
            ofs_surface_header = file.tell()
            file.seek(ofs_surface_content)
            surface.write_content(file)
            ofs_surface_content = file.tell()
            file.seek(ofs_surface_header)
        file.seek(ofs_surface_content)
        return self

    def read(self, file):
        self.read_header(file)
        self.read_frames(file)
        self.read_tags(file)
        self.read_surfaces(file)
        return self

    def write(self, file):
        ofs_header = file.tell()
        ofs_frames = ofs_header + Md3.header_struct.size
        ofs_tags = ofs_frames + Md3.frame_struct.size * len(self.frames)
        ofs_surfaces = ofs_tags + Md3.tag_struct.size * len(self.tags)
        ofs_scratch = ofs_surfaces + Md3.surface_struct.size * len(self.surfaces)
        self.write_header(file, ofs_frames, ofs_tags, ofs_surfaces)
        self.write_frames(file)
        self.write_tags(file)
        self.write_surfaces(file, ofs_scratch)
        ofs_eof = file.tell()
        file.seek(ofs_header)
        self.write_header(file, ofs_frames, ofs_tags, ofs_surfaces, ofs_eof)
        return self