#include "CPPNHyperNEAT.h"
#include "../mesh/MeshGenerator.h"
#include <cmath>

CPPNHyperNEAT::CPPNHyperNEAT(int inputDimension, int outputDimension)
    : cppn({inputDimension, 12, 12, outputDimension}), inputDim(inputDimension), outputDim(outputDimension) {

    // Set tanh activation for better range
    cppn.setActivationFunction([](double x) { return std::tanh(x); });
}

double CPPNHyperNEAT::queryNetwork(double x, double y, double z) {
    // Calculate distance from origin for 4th input dimension
    double distance = std::sqrt(x*x + y*y + z*z);
    std::vector<double> input = {x, y, z, distance};
    auto output = cppn.forward(input);
    return output.empty() ? 0.0 : output[0];
}

Mesh CPPNHyperNEAT::generateMesh(int resolution) {
    Mesh mesh;
    float meshSize = 10.0f;  // Base mesh size
    float heightScale = 5.0f;

    // Generate a parametric surface using CPPN
    for (int i = 0; i < resolution; i++) {
        for (int j = 0; j < resolution; j++) {
            double u = (double)i / resolution * 2.0 - 1.0;
            double v = (double)j / resolution * 2.0 - 1.0;

            // Query network with radial coordinates
            double distance = std::sqrt(u*u + v*v);
            double angle = std::atan2(v, u);
            std::vector<double> input = {u, v, angle, distance};
            auto output = cppn.forward(input);

            if (output.size() >= 3) {
                // Use all 3 outputs for x, y, z displacement
                float x = u * meshSize * 0.5f + output[0] * heightScale * 0.2f;
                float y = output[1] * heightScale;
                float z = v * meshSize * 0.5f + output[2] * heightScale * 0.2f;

                Vec3 pos(x, y, z);

                Vertex vert;
                vert.position = pos;
                vert.normal = Vec3(0, 1, 0);

                // Color based on network outputs
                vert.color = Vec3(
                    (output[0] + 1.0f) * 0.5f,
                    (output[1] + 1.0f) * 0.5f,
                    (output[2] + 1.0f) * 0.5f
                );
                vert.u = u;
                vert.v = v;

                mesh.addVertex(vert);
            }
        }
    }

    // Create triangles
    for (int i = 0; i < resolution - 1; i++) {
        for (int j = 0; j < resolution - 1; j++) {
            int idx = i * resolution + j;
            mesh.addTriangle(
                (unsigned int)idx,
                (unsigned int)(idx + resolution),
                (unsigned int)(idx + resolution + 1)
            );
            mesh.addTriangle(
                (unsigned int)idx,
                (unsigned int)(idx + resolution + 1),
                (unsigned int)(idx + 1)
            );
        }
    }

    mesh.computeNormals();
    return mesh;
}

void CPPNHyperNEAT::evolve(int generations) {
    // Simplified: would normally use evolutionary algorithm
}
