#pragma once
#include "Vec3.h"
#include "Matrix4.h"
#include <cmath>

class Quaternion {
public:
    float w, x, y, z;

    Quaternion() : w(1), x(0), y(0), z(0) {}
    Quaternion(float w, float x, float y, float z) : w(w), x(x), y(y), z(z) {}

    static Quaternion fromAxisAngle(const Vec3& axis, float angle) {
        float halfAngle = angle * 0.5f;
        float s = std::sin(halfAngle);
        return Quaternion(std::cos(halfAngle), axis.x * s, axis.y * s, axis.z * s);
    }

    static Quaternion fromEuler(float pitch, float yaw, float roll) {
        float cy = std::cos(yaw * 0.5f);
        float sy = std::sin(yaw * 0.5f);
        float cp = std::cos(pitch * 0.5f);
        float sp = std::sin(pitch * 0.5f);
        float cr = std::cos(roll * 0.5f);
        float sr = std::sin(roll * 0.5f);

        return Quaternion(
            cr * cp * cy + sr * sp * sy,
            sr * cp * cy - cr * sp * sy,
            cr * sp * cy + sr * cp * sy,
            cr * cp * sy - sr * sp * cy
        );
    }

    Quaternion operator*(const Quaternion& q) const {
        return Quaternion(
            w * q.w - x * q.x - y * q.y - z * q.z,
            w * q.x + x * q.w + y * q.z - z * q.y,
            w * q.y - x * q.z + y * q.w + z * q.x,
            w * q.z + x * q.y - y * q.x + z * q.w
        );
    }

    float length() const {
        return std::sqrt(w * w + x * x + y * y + z * z);
    }

    Quaternion normalized() const {
        float len = length();
        return Quaternion(w / len, x / len, y / len, z / len);
    }

    static Quaternion slerp(const Quaternion& a, const Quaternion& b, float t) {
        float dot = a.w * b.w + a.x * b.x + a.y * b.y + a.z * b.z;

        Quaternion b2 = b;
        if (dot < 0.0f) {
            b2 = Quaternion(-b.w, -b.x, -b.y, -b.z);
            dot = -dot;
        }

        if (dot > 0.9995f) {
            return Quaternion(
                a.w + t * (b2.w - a.w),
                a.x + t * (b2.x - a.x),
                a.y + t * (b2.y - a.y),
                a.z + t * (b2.z - a.z)
            ).normalized();
        }

        float theta = std::acos(dot);
        float sinTheta = std::sin(theta);
        float wa = std::sin((1 - t) * theta) / sinTheta;
        float wb = std::sin(t * theta) / sinTheta;

        return Quaternion(
            a.w * wa + b2.w * wb,
            a.x * wa + b2.x * wb,
            a.y * wa + b2.y * wb,
            a.z * wa + b2.z * wb
        );
    }

    Matrix4 toMatrix() const {
        Matrix4 mat;
        float xx = x * x, yy = y * y, zz = z * z;
        float xy = x * y, xz = x * z, yz = y * z;
        float wx = w * x, wy = w * y, wz = w * z;

        mat.m[0] = 1 - 2 * (yy + zz);
        mat.m[1] = 2 * (xy + wz);
        mat.m[2] = 2 * (xz - wy);
        mat.m[4] = 2 * (xy - wz);
        mat.m[5] = 1 - 2 * (xx + zz);
        mat.m[6] = 2 * (yz + wx);
        mat.m[8] = 2 * (xz + wy);
        mat.m[9] = 2 * (yz - wx);
        mat.m[10] = 1 - 2 * (xx + yy);

        return mat;
    }
};
