#include "PatternGenerator.h"
#include "../mesh/MeshGenerator.h"
#include <random>
#include <algorithm>
#include <limits>

Mesh PatternGenerator::createVoronoi2D(int cellCount, float size, float height) {
    Mesh mesh;
    std::mt19937 gen(42);
    std::uniform_real_distribution<float> dis(-size/2, size/2);

    // Generate random seed points
    std::vector<Vec3> seeds;
    for (int i = 0; i < cellCount; i++) {
        seeds.push_back(Vec3(dis(gen), 0, dis(gen)));
    }

    // Create a grid and visualize Voronoi cells
    int resolution = 100;
    float cellSize = size / resolution;

    for (int x = 0; x < resolution; x++) {
        for (int z = 0; z < resolution; z++) {
            float px = -size/2 + x * cellSize;
            float pz = -size/2 + z * cellSize;
            Vec3 p(px, 0, pz);

            // Find nearest seed
            float minDist = std::numeric_limits<float>::max();
            int nearestSeed = 0;
            for (size_t i = 0; i < seeds.size(); i++) {
                float dist = (p - seeds[i]).length();
                if (dist < minDist) {
                    minDist = dist;
                    nearestSeed = i;
                }
            }

            // Create cell with height based on seed index (creates colored cells)
            float h = height * ((nearestSeed % 3) / 3.0f);

            int baseIdx = mesh.getVertexCount();

            Vertex v0, v1, v2, v3;
            v0.position = Vec3(px, h, pz);
            v1.position = Vec3(px + cellSize, h, pz);
            v2.position = Vec3(px + cellSize, h, pz + cellSize);
            v3.position = Vec3(px, h, pz + cellSize);

            // Color based on which seed
            float colorVal = (nearestSeed % 7) / 7.0f;
            Vec3 color(colorVal, 1.0f - colorVal, 0.5f);
            v0.color = color;
            v1.color = color;
            v2.color = color;
            v3.color = color;

            mesh.addVertex(v0);
            mesh.addVertex(v1);
            mesh.addVertex(v2);
            mesh.addVertex(v3);

            mesh.addTriangle(baseIdx, baseIdx + 1, baseIdx + 2);
            mesh.addTriangle(baseIdx, baseIdx + 2, baseIdx + 3);
        }
    }

    // Add seed points as small spheres for visualization
    for (const Vec3& seed : seeds) {
        Mesh sphere = MeshGenerator::createSphere(0.1f, 8, 8);
        sphere.transform(seed + Vec3(0, height, 0), Vec3(1, 1, 1), Vec3(0, 0, 0));
        mesh.merge(sphere);
    }

    mesh.computeNormals();
    return mesh;
}

Mesh PatternGenerator::createDelaunayTriangulation(int pointCount, float size) {
    Mesh mesh;
    std::mt19937 gen(42);
    std::uniform_real_distribution<float> dis(-size/2, size/2);

    // Generate random points
    std::vector<Vec3> points;
    for (int i = 0; i < pointCount; i++) {
        points.push_back(Vec3(dis(gen), 0, dis(gen)));
    }

    // Simple Delaunay triangulation using Bowyer-Watson algorithm
    struct Triangle {
        int a, b, c;
        Vec3 circumcenter;
        float circumradius;
    };

    auto calcCircumcircle = [&](const Vec3& p1, const Vec3& p2, const Vec3& p3, Vec3& center, float& radius) {
        // Calculate circumcircle for 2D points (using x and z coordinates)
        float ax = p1.x, ay = p1.z;
        float bx = p2.x, by = p2.z;
        float cx = p3.x, cy = p3.z;

        float d = 2.0f * (ax * (by - cy) + bx * (cy - ay) + cx * (ay - by));
        if (std::abs(d) < 0.0001f) {
            center = (p1 + p2 + p3) * (1.0f/3.0f);
            radius = 1000000.0f;
            return;
        }

        float ux = ((ax*ax + ay*ay) * (by - cy) + (bx*bx + by*by) * (cy - ay) + (cx*cx + cy*cy) * (ay - by)) / d;
        float uy = ((ax*ax + ay*ay) * (cx - bx) + (bx*bx + by*by) * (ax - cx) + (cx*cx + cy*cy) * (bx - ax)) / d;

        center = Vec3(ux, 0, uy);
        radius = (Vec3(ax, 0, ay) - center).length();
    };

    std::vector<Triangle> triangles;

    // Create super triangle
    float big = size * 10.0f;
    Vec3 st1(-big, 0, -big);
    Vec3 st2(big, 0, -big);
    Vec3 st3(0, 0, big * 2.0f);

    std::vector<Vec3> allPoints;
    allPoints.push_back(st1);
    allPoints.push_back(st2);
    allPoints.push_back(st3);
    allPoints.insert(allPoints.end(), points.begin(), points.end());

    Triangle superTri;
    superTri.a = 0;
    superTri.b = 1;
    superTri.c = 2;
    calcCircumcircle(st1, st2, st3, superTri.circumcenter, superTri.circumradius);
    triangles.push_back(superTri);

    // Add points one by one
    for (size_t pi = 3; pi < allPoints.size(); pi++) {
        Vec3 point = allPoints[pi];
        std::vector<Triangle> badTriangles;

        // Find triangles whose circumcircle contains the point
        for (const Triangle& tri : triangles) {
            float dist = (point - tri.circumcenter).length();
            if (dist < tri.circumradius + 0.001f) {
                badTriangles.push_back(tri);
            }
        }

        // Find the boundary of the polygonal hole
        struct Edge {
            int a, b;
            bool operator==(const Edge& other) const {
                return (a == other.a && b == other.b) || (a == other.b && b == other.a);
            }
        };

        std::vector<Edge> polygon;
        for (const Triangle& tri : badTriangles) {
            Edge edges[3] = {{tri.a, tri.b}, {tri.b, tri.c}, {tri.c, tri.a}};
            for (const Edge& edge : edges) {
                bool shared = false;
                for (const Triangle& other : badTriangles) {
                    if (&tri == &other) continue;
                    if ((edge.a == other.a && edge.b == other.b) ||
                        (edge.a == other.b && edge.b == other.c) ||
                        (edge.a == other.c && edge.b == other.a) ||
                        (edge.a == other.a && edge.b == other.c) ||
                        (edge.a == other.b && edge.b == other.a) ||
                        (edge.a == other.c && edge.b == other.b)) {
                        shared = true;
                        break;
                    }
                }
                if (!shared) {
                    polygon.push_back(edge);
                }
            }
        }

        // Remove bad triangles properly
        for (const Triangle& bad : badTriangles) {
            for (auto it = triangles.begin(); it != triangles.end(); ) {
                if (it->a == bad.a && it->b == bad.b && it->c == bad.c) {
                    it = triangles.erase(it);
                } else {
                    ++it;
                }
            }
        }

        // Add new triangles from point to polygon edges
        for (const Edge& edge : polygon) {
            Triangle newTri;
            newTri.a = edge.a;
            newTri.b = edge.b;
            newTri.c = pi;
            calcCircumcircle(allPoints[edge.a], allPoints[edge.b], allPoints[pi],
                           newTri.circumcenter, newTri.circumradius);
            triangles.push_back(newTri);
        }
    }

    // Remove triangles that share vertices with super triangle
    std::vector<Triangle> finalTriangles;
    for (const Triangle& tri : triangles) {
        if (tri.a >= 3 && tri.b >= 3 && tri.c >= 3) {
            finalTriangles.push_back({tri.a - 3, tri.b - 3, tri.c - 3});
        }
    }

    // Build mesh
    for (const Vec3& p : points) {
        mesh.addVertex(Vertex(p));
    }

    for (const Triangle& tri : finalTriangles) {
        mesh.addTriangle(tri.a, tri.b, tri.c);
    }

    // Add small spheres at points for visualization
    for (const Vec3& p : points) {
        Mesh sphere = MeshGenerator::createSphere(0.05f, 6, 6);
        sphere.transform(p + Vec3(0, 0.1f, 0), Vec3(1, 1, 1), Vec3(0, 0, 0));
        mesh.merge(sphere);
    }

    mesh.computeNormals();
    return mesh;
}

Mesh PatternGenerator::createPerlinNoiseSurface(int resolution, float scale, float amplitude) {
    Mesh mesh;

    // Simple Perlin-like noise implementation
    auto fade = [](float t) { return t * t * t * (t * (t * 6 - 15) + 10); };
    auto lerp = [](float t, float a, float b) { return a + t * (b - a); };

    // Create permutation table
    std::vector<int> p(512);
    for (int i = 0; i < 256; i++) p[i] = i;
    std::mt19937 gen(42);
    std::shuffle(p.begin(), p.begin() + 256, gen);
    for (int i = 0; i < 256; i++) p[256 + i] = p[i];

    auto grad = [](int hash, float x, float y) {
        int h = hash & 3;
        float u = h < 2 ? x : y;
        float v = h < 2 ? y : x;
        return ((h & 1) ? -u : u) + ((h & 2) ? -v : v);
    };

    auto noise = [&](float x, float y) {
        int X = (int)std::floor(x) & 255;
        int Y = (int)std::floor(y) & 255;

        x -= std::floor(x);
        y -= std::floor(y);

        float u = fade(x);
        float v = fade(y);

        int a = p[X] + Y;
        int aa = p[a];
        int ab = p[a + 1];
        int b = p[X + 1] + Y;
        int ba = p[b];
        int bb = p[b + 1];

        return lerp(v,
                   lerp(u, grad(p[aa], x, y), grad(p[ba], x - 1, y)),
                   lerp(u, grad(p[ab], x, y - 1), grad(p[bb], x - 1, y - 1)));
    };

    // Generate terrain with octaves
    auto fbm = [&](float x, float y) {
        float total = 0;
        float frequency = 1.0f;
        float amplitude = 1.0f;
        float maxValue = 0;
        int octaves = 4;

        for (int i = 0; i < octaves; i++) {
            total += noise(x * frequency, y * frequency) * amplitude;
            maxValue += amplitude;
            amplitude *= 0.5f;
            frequency *= 2.0f;
        }

        return total / maxValue;
    };

    float terrainSize = 10.0f;

    for (int y = 0; y <= resolution; y++) {
        for (int x = 0; x <= resolution; x++) {
            float px = (float)x / resolution * terrainSize - terrainSize / 2;
            float pz = (float)y / resolution * terrainSize - terrainSize / 2;

            // Sample Perlin noise
            float noiseValue = fbm(x * scale, y * scale);
            float py = noiseValue * amplitude;

            Vertex v;
            v.position = Vec3(px, py, pz);

            // Color based on height
            float t = (py / amplitude + 1.0f) * 0.5f;
            v.color = Vec3(0.2f + t * 0.3f, 0.4f + t * 0.4f, 0.1f + t * 0.2f);

            mesh.addVertex(v);
        }
    }

    for (int y = 0; y < resolution; y++) {
        for (int x = 0; x < resolution; x++) {
            int i0 = y * (resolution + 1) + x;
            int i1 = i0 + 1;
            int i2 = i0 + resolution + 1;
            int i3 = i2 + 1;

            mesh.addTriangle(i0, i2, i1);
            mesh.addTriangle(i1, i2, i3);
        }
    }

    mesh.computeNormals();
    return mesh;
}

Mesh PatternGenerator::createWavefunctionCollapse(int gridSize, float tileSize) {
    return MeshGenerator::createGrid(gridSize, gridSize, tileSize);
}
