#include <pybind11/pybind11.h>
#include <iostream>
#include <cstring>
#include <arpa/inet.h>
#include <bzlib.h>
#include <vector>
#include <cmath>
#include <string>

#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"

namespace py = pybind11;

inline float ntohf(float f) {
    uint32_t i; std::memcpy(&i, &f, sizeof(float)); i = ntohl(i);                       
    float out; std::memcpy(&out, &i, sizeof(float)); return out;
}

#pragma pack(push, 1)
struct VolumeHeader {
    char tape_filename[9]; char extension[3];
    uint32_t julian_date; uint32_t ms_since_midnight; char icao[4];
};
struct MessageHeader {
    uint16_t message_size; uint8_t rda_channel; uint8_t message_type;
    uint16_t id_sequence; uint16_t julian_date; uint32_t ms_since_midnight;
    uint16_t num_message_segments; uint16_t message_segment_num;
};
struct Msg31Header {
    char radar_id[4]; uint32_t radial_time; uint16_t julian_date;
    uint16_t azimuth_number; float azimuth_angle; uint8_t compression;
    uint8_t spare; uint16_t radial_length; uint8_t azimuth_resolution;
    uint8_t radial_status; uint8_t elevation_number; uint8_t cut_sector;
    float elevation_angle; uint8_t spot_blanking; uint8_t azimuth_indexing; 
    uint16_t data_block_count; uint32_t block_pointers[9]; 
};
struct DataBlockHeader { char block_name[4]; };
struct ReflectivityBlock {
    char block_name[4]; uint32_t reserved; uint16_t num_gates;     
    uint16_t first_gate_range; uint16_t gate_spacing; int16_t threshold;
    int16_t snr_threshold; uint8_t control_flags; uint8_t data_size;     
    float scale; float offset;
};
#pragma pack(pop)

void set_pixel(std::vector<uint8_t>& img, int x, int y, float dbz) {
    int idx = (y * 1200 + x) * 4;
    uint8_t r = 0, g = 0, b = 0, a = 255; 
    
    if (dbz < -20) { a = 0; } 
    else if (dbz < 5)  { r = 100; g = 120; b = 150; a = 100; } 
    else if (dbz < 15) { r = 0;   g = 255; b = 255; } 
    else if (dbz < 30) { r = 0;   g = 255; b = 0;   } 
    else if (dbz < 40) { r = 255; g = 255; b = 0;   } 
    else if (dbz < 50) { r = 255; g = 127; b = 0;   } 
    else if (dbz < 60) { r = 255; g = 0;   b = 0;   } 
    else               { r = 255; g = 0;   b = 255; } 
    
    img[idx] = r; img[idx+1] = g; img[idx+2] = b; img[idx+3] = a;
}

void process_chunk(py::buffer b, size_t chunk_size, const std::string& output_path) {
    py::buffer_info info = b.request();
    uint8_t* data = static_cast<uint8_t*>(info.ptr);
    py::gil_scoped_release release;

    uint8_t* active_data = data;
    size_t remaining_size = chunk_size;

    if (active_data[0] == 'A' && active_data[1] == 'R' && active_data[2] == '2') {
        active_data += sizeof(VolumeHeader);
        remaining_size -= sizeof(VolumeHeader);
    }

    std::vector<std::vector<float>> radar_map(3600, std::vector<float>(600, -999.0f));
    
    while (remaining_size > 10) {
        size_t bz_offset = 0; bool found = false;
        for (size_t i = 0; i < 100 && i < remaining_size - 3; ++i) {
            if (active_data[i] == 'B' && active_data[i+1] == 'Z' && active_data[i+2] == 'h') {
                bz_offset = i; found = true; break;
            }
        }
        if (!found) break;

        active_data += bz_offset; remaining_size -= bz_offset;

        bz_stream strm; strm.bzalloc = NULL; strm.bzfree = NULL; strm.opaque = NULL;
        BZ2_bzDecompressInit(&strm, 0, 0);
        strm.next_in = reinterpret_cast<char*>(active_data); strm.avail_in = remaining_size;

        unsigned int destLen = 5 * 1024 * 1024; 
        std::vector<char> uncompressed_buffer(destLen);
        strm.next_out = uncompressed_buffer.data(); strm.avail_out = destLen;

        BZ2_bzDecompress(&strm);
        size_t bytes_consumed = strm.total_in_lo32; size_t bytes_produced = strm.total_out_lo32;
        BZ2_bzDecompressEnd(&strm);

        if (bytes_produced > 0) {
            uint8_t* uncomp_ptr = reinterpret_cast<uint8_t*>(uncompressed_buffer.data());
            size_t uncomp_remaining = bytes_produced;

            while (uncomp_remaining >= sizeof(MessageHeader) + 12) {
                uint32_t ctm_offset = 0;
                MessageHeader* m_hdr = reinterpret_cast<MessageHeader*>(uncomp_ptr + 12);
                
                if (m_hdr->message_type == 0 || m_hdr->message_type > 31) {
                    m_hdr = reinterpret_cast<MessageHeader*>(uncomp_ptr);
                    ctm_offset = 0;
                } else {
                    ctm_offset = 12;
                }

                uint16_t msg_size_words = ntohs(m_hdr->message_size);
                uint32_t msg_size_bytes = msg_size_words * 2;
                uint32_t total_msg_size = ctm_offset + msg_size_bytes;

                if (msg_size_bytes == 0 || total_msg_size > uncomp_remaining) break;

                uint8_t* msg_start = uncomp_ptr + ctm_offset;

                if (m_hdr->message_type == 31) {
                    uint8_t* payload_start = msg_start + sizeof(MessageHeader);
                    Msg31Header* msg31 = reinterpret_cast<Msg31Header*>(payload_start);
                    
                    if (msg31->elevation_number <= 2) {
                        uint16_t block_count = ntohs(msg31->data_block_count);
                        float real_azimuth = ntohf(msg31->azimuth_angle);
                        
                        for (int i = 0; i < block_count && i < 9; ++i) {
                            uint32_t pointer_offset = ntohl(msg31->block_pointers[i]);
                            if (pointer_offset == 0) continue;
                            
                            DataBlockHeader* block_hdr = reinterpret_cast<DataBlockHeader*>(msg_start + pointer_offset);
                            
                            // THE FIX: NOAA block name starts at byte [1], because byte [0] is the Block Type Integer
                            if (block_hdr->block_name[1] == 'R' && block_hdr->block_name[2] == 'E' && block_hdr->block_name[3] == 'F') {
                                ReflectivityBlock* ref_block = reinterpret_cast<ReflectivityBlock*>(msg_start + pointer_offset);
                                uint16_t num_gates = ntohs(ref_block->num_gates);
                                float scale = ntohf(ref_block->scale); 
                                float offset = ntohf(ref_block->offset);
                                uint8_t* raw_data = reinterpret_cast<uint8_t*>(ref_block) + sizeof(ReflectivityBlock);
                                
                                int base_az = (int)(real_azimuth * 10.0f);
                                for (int a = -3; a <= 3; ++a) { 
                                    int az_idx = (base_az + a + 3600) % 3600;
                                    for (int g = 0; g < num_gates && g < 600; ++g) {
                                        uint8_t raw_val = raw_data[g];
                                        if (raw_val > 1) { 
                                            radar_map[az_idx][g] = (raw_val - offset) / scale; 
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
                uncomp_ptr += total_msg_size;
                uncomp_remaining -= total_msg_size;
            }
        }
        if (bytes_consumed == 0) break;
        active_data += bytes_consumed; remaining_size -= bytes_consumed;
    }

    std::vector<uint8_t> pixels(1200 * 1200 * 4, 0); 
    for (int y = 0; y < 1200; ++y) {
        for (int x = 0; x < 1200; ++x) {
            float dx = x - 600.0f; float dy = 600.0f - y; 
            float r = std::sqrt(dx*dx + dy*dy);
            if (r >= 600.0f) continue; 
            
            float az = std::atan2(dx, dy) * 180.0f / M_PI;
            if (az < 0) az += 360.0f;
            int az_idx = (int)(az * 10.0f) % 3600;
            int gate = (int)r;
            
            float dbz = radar_map[az_idx][gate];
            if (dbz > -900.0f) { 
                set_pixel(pixels, x, y, dbz); 
            } else {
                int idx = (y * 1200 + x) * 4;
                pixels[idx] = 15;     
                pixels[idx+1] = 23;   
                pixels[idx+2] = 42;   
                pixels[idx+3] = 40;   
            }
        }
    }

    stbi_write_png(output_path.c_str(), 1200, 1200, 4, pixels.data(), 1200 * 4);
}

PYBIND11_MODULE(radar_core, m) {
    m.def("process_chunk", &process_chunk, "Render AWS Weather Data to PNG");
}
