import { useState, useMemo, useEffect } from 'react';
import './app.scss';

function TailwindPlayground() {
    const [metadata, setMetadata] = useState<any>(null);
    const [file, setFile] = useState<File | null>(null);
    const [error, setError] = useState<any>(null);
    const [selectedTensor, setSelectedTensor] = useState<string | null>(null);

    const onFilesSelected = (files: File[]) => {
        setFile(files[0]);

        playWithFile(files[0]).then(v => {
            setError(null);
            setMetadata(v);
        }).catch(e => {
            console.error(e);
            setError(e);
        });
    };

    const download = (range: any, filename: string) => {
        downloadBlobSlice(file! as File, range[0], range[1], filename);
    };

    const selectTensor = (obj: any) => {
        setSelectedTensor(obj?.path ?? null);
    };

    return (
        <div className="viewer-root">
            <div className="view-wrapper">
                {(file && metadata) ? <FloatExplorer selected={selectedTensor} metadata={metadata} file={file} /> : null}

                <DropZone onFilesSelected={onFilesSelected} className={(file && metadata) ? "has-file" : ""}>
                    <div>
                        {!file && <div className="dropzone-info">
                            <p>Drop a <b>.safetensors</b> file here.</p>
                            <p className="footnote-quiet">The file won't be uploaded, nor leave this page.</p>
                        </div>}

                        {file && metadata && <div className="main-text-container">
                            <DataExplorer data={metadata.metadata} filename={file.name} download={download} selectTensor={selectTensor} selectedTensor={selectedTensor} />
                        </div>}

                        {file && error && <div className="parsing-error">
                            <p><b>Error</b>: {error.message}</p>
                        </div>}
                    </div>
                </DropZone>
            </div>
        </div>
    );
}

const FloatExplorer = (props: any) => {
    const { selected = null, metadata } = props;
    const file = props.file as File;

    const tensor = useMemo(() => {
        if (selected) {
            return metadata.metadata[selected];
        }

        return null;
    }, [selected]);

    return (
        <div className="value-viewer dropzone">
            {selected ? <FloatArray tensor={tensor} file={file} />
                : <span className="float-placeholder">Select a tensor from the navigator to see its contents.</span>}
        </div>
    );
};

async function readTensorData(file: File, tensor: any, index: number, max: number = -1) {
    // console.log(tensor, file);
    const startTime = performance.now();
    const [start, end] = tensor.data_offsets;
    const bytes = await file.slice(start + (index * 4), end).arrayBuffer();
    const result = new Float32Array(bytes);
    console.log(performance.now() - startTime);
    return result;
}

const FloatArray = (props: any) => {
    const { tensor } = props;
    const file = props.file as File;

    const [values, setValues] = useState<Float32Array | null>(null);
    const [offset, setOffset] = useState(0);
    const [pageSize, setPageSize] = useState(145);

    useEffect(() => {
        if (tensor) {
            readTensorData(file, tensor, 0, 32).then(setValues);
        }
    }, [tensor]);

    const selection = useMemo(() => {
        return values ? (values as Float32Array).subarray(offset, offset + pageSize) : null;
    }, [values, offset, pageSize]);

    const lengthStr = useMemo(() => `${values?.length ?? 0}`, [values]);
    const offsetPadded = useMemo(() => `${offset}`.padStart(lengthStr.length), [offset, lengthStr]);

    // console.log(selection);

    const onScroll = ev => {
        if (values) {
            const delta = (Math.max(1, Math.abs(ev.deltaY / 5)) * Math.sign(ev.deltaY)) | 0;
            setOffset(off => Math.max(0, Math.min((off + delta) | 0, values.length)) | 0);
        }
        ev.stopPropagation();
    };

    return (<div className="float-data-wrapper">
        <div className="float-nav">
            Offset: [-] {offsetPadded} / {lengthStr} [+]
        </div>

        <div className="float-grid" onWheel={onScroll}>
            {selection ? Array.from(selection).map((f: any, i: number) => <div className="float-value" key={i}><div>{f.toFixed(10)}</div></div>) : null}
        </div>
    </div>);
};

const DropZone = (props: any) => {
    const { onFilesSelected, children, className, ...rest } = props;

    const [isDragging, setIsDragging] = useState(false);

    const onDragEnter = (e: React.DragEvent<HTMLDivElement>) => {
        e.preventDefault();
        e.stopPropagation();

        setIsDragging(true);
    };

    const onDragLeave = (e: React.DragEvent<HTMLDivElement>) => {
        if (e.currentTarget.contains(e.relatedTarget as Node)) return;
        e.preventDefault();
        e.stopPropagation();

        setIsDragging(false);
    };

    const onDrop = (e: React.DragEvent<HTMLDivElement>) => {
        e.preventDefault();
        e.stopPropagation();

        const files = e.dataTransfer.files;
        if (files.length === 0) {
            return;
        }

        onFilesSelected(files);

        setIsDragging(false);
        return false;
    };

    return (
        <div className={"dropzone " + className} {...{ onDragEnter, onDragLeave, onDrop }} onDragOver={e => e.preventDefault()}>
            {isDragging && <div className="drag-callout-wrapper">
                <div className="drag-centered">
                    <div className="dropzone-callout">
                        <p className="text-center">Drop it like it's hot!</p>
                    </div>
                </div>
            </div>}

            {children}
        </div>
    );
};

const DataExplorer = ({ data, filename, download, selectTensor, selectedTensor }: { data: Record<string, any>, filename: string, download: any, selectTensor: any, selectedTensor: any; }) => {
    const tree = useMemo(() => {
        let tree = {};

        Object.keys(data).forEach(key => {
            // Keys are separated by dots. We restore the tree structure.
            const value = data[key]; // Ex.: add_embedding.linear_1.bias

            const parts = key.split('.');

            let current: any = tree;

            for (let i = 0; i < parts.length; i++) {
                const part = parts[i];
                const fullPath = parts.slice(0, i + 1).join('.');

                // Check if this is the last part
                if (i === parts.length - 1) {
                    // Last part, assign the value
                    const size = sizeOf(value);
                    current[part] = { type: 'value', value, totalSize: size, path: fullPath };

                    // Add the size of the value to the parent folder, walking up the tree
                    let parent: any = tree;
                    for (let i2 = 0; i2 < parts.length; i2++) {
                        const part = parts[i2];
                        parent.totalSize += size;

                        parent = parent[part];
                    }
                } else {
                    // Ensure the path exists
                    current[part] ??= { type: 'folder', totalSize: 0, path: fullPath };

                    // Move to the next level
                    current = current[part];
                }
            }
        });

        return tree;
    }, [data]);

    return (<>
        <p className="file-title">
            {filename} <button className="download-label" onClick={() => exportJson(data, filename.replace(".safetensors", "") + ".metadata.json")}>export .json tree</button>
        </p>

        <div className="data-explorer-wrapper">
            <DataExplorerNode data={tree} depth={0} download={download} selectTensor={selectTensor} selectedTensor={selectedTensor} />
        </div>
    </>
    );
};

const DataExplorerNode = ({ data, depth, download, selectTensor, selectedTensor }: { data: Record<string, any>, depth: number, download: any, selectTensor: any, selectedTensor: any; }) => {
    return (<ul className="data-explorer-node">
        {Object.keys(data).map(key => {
            const value = data[key];

            if (value.type === 'folder') {
                return <DataExplorerFolder key={key} name={key} depth={depth} data={value} download={download} selectTensor={selectTensor} selectedTensor={selectedTensor} />;
            } else if (value.type === 'value') {
                return <DataExplorerValue key={key} name={key} depth={depth} obj={value} download={download} selectTensor={selectTensor} selectedTensor={selectedTensor} />;
            }
        })}
    </ul>);
};

const DTYPE_TO_BYTES = {
    'F32': 4,
    'F64': 8,
    'I32': 4,
    'I64': 8,
    'I8': 1,
    'U8': 1,
    'U16': 2,
    'U32': 4,
    'U64': 8,
    'Bool': 1,
    'BF16': 2,
    'F16': 2,
};

function sizeOf(value: any) {
    const numElements = value.shape ? value.shape.reduce((acc: number, v: number) => acc * v, 1) : 0;
    const byteSize = numElements * (DTYPE_TO_BYTES[value.dtype] ?? 0);

    return byteSize;
}

function formatSize(size: number) {
    let unit = 'bytes';

    if (size > 1024) {
        size /= 1024;
        unit = 'KB';
    }

    if (size > 1024) {
        size /= 1024;
        unit = 'MB';
    }

    if (size > 1024) {
        size /= 1024;
        unit = 'GB';
    }

    size = Math.round(size * 100) / 100;

    return { unit, size };
}

const DataExplorerValue = ({ name, depth, obj, download, selectTensor, selectedTensor }: { name: string, depth: number, obj: any, download: any, selectTensor: any, selectedTensor: any; }) => {
    if (!obj.value || !obj.value.shape || !obj.value.dtype) {
        return null;
    }

    const selected = selectedTensor === obj.path;

    const value = obj.value;
    const numElements = value.shape ? value.shape.reduce((acc: number, v: number) => acc * v, 1) : 0;

    let { unit, size } = formatSize(obj.totalSize);

    const shapeForFilename = value.shape ? ('--' + value.shape.join('x')) : '';
    const filename = `${obj.path}${shapeForFilename}.${value.dtype.toLowerCase()}.bin`;

    return (
        <li>
            <div className={`tree-label-wrapper ${selected && 'selected'}`} onClick={() => selectTensor(obj)}>
                <span className="text-thin-mono"> - {name}:</span>
                <span className="text-mono">{value.dtype} {'{'}{value.shape && value.shape.join(',')}{'}'}</span>
                <span className='text-thin'> ({size} {unit}, {numElements} elem.)</span>

                {value.data_offsets && <button className='download-label' title={filename} onClick={() => download(value.data_offsets, filename)}> extract</button>}
            </div>
        </li>
    );
};

const DataExplorerFolder = ({ name, depth, data, download, selectTensor, selectedTensor }: { name: string, depth: number, data: any, download: any, selectTensor: any, selectedTensor: any; }) => {
    const [expanded, setExpanded] = useState(false);

    const { unit, size } = formatSize(data.totalSize);

    return (
        <li>
            <div className="data-explorer-folder">
                <button onClick={() => setExpanded(v => !v)} className="btn-folder-expand">
                    <div className="btn-folder-expand-text">
                        <span>{expanded ? '-' : '+'}</span>
                    </div>
                </button>

                <span className={"data-explorer-filename"}>{name}</span>

                <span className='text-thin'> ({size} {unit})</span>
            </div>

            {expanded && <DataExplorerNode data={data} depth={depth + 1} download={download} selectTensor={selectTensor} selectedTensor={selectedTensor} />}
        </li>
    );
};

async function playWithFile(file: File): Promise<any> {
    if (file.size < 8) {
        throw new Error('File is too small; safetensors header needs at least 8 bytes.');
    }

    const buffer = await file.slice(0, 8).arrayBuffer();

    // Read 8-byte unsigned long
    const view = new DataView(buffer);
    const size = view.getUint32(0, true);

    let meta: ArrayBuffer;

    try {
        meta = await file.slice(8, size + 8).arrayBuffer();
    } catch (e: any) {
        throw new Error('Failed to read metadata: ' + e.message);
    }

    if (meta.byteLength !== size) {
        throw new Error(`Metadata size mismatch: expected ${size} bytes, got ${meta.byteLength} bytes.`);
    }

    try {
        const metaString = new TextDecoder().decode(meta);
        const metadata: Record<string, any> = JSON.parse(metaString);

        for (const v of Object.values(metadata)) {
            if (v.hasOwnProperty("data_offsets")) {
                v.data_offsets = v.data_offsets.map(off => off + size);
            }
        }

        return { metadata, size };
    } catch (e: any) {
        throw new Error('Failed to parse metadata JSON: ' + e.message);
    }
}

function downloadBlobSlice(file: File, start, end, filename) {
    const blob = file.slice(start, end);
    const url = URL.createObjectURL(blob);
    const a = document.createElement("a");
    a.href = url;
    a.download = filename;
    a.click();
    window.URL.revokeObjectURL(url);
}

function downloadBlob(blob: Blob, filename) {
    const url = URL.createObjectURL(blob);
    const a = document.createElement("a");
    a.href = url;
    a.download = filename;
    a.click();
    window.URL.revokeObjectURL(url);
}

function exportJson(json, filename: string) {
    const text = JSON.stringify(json, null, 4);
    const blob = new Blob([text]);
    downloadBlob(blob, filename);
}

function debugBuffer(x: Uint8Array) {
    // Ascii print
    const ascii = String.fromCharCode(...x.map((v) => v & 0x7f).map((v) => (v | 0) < 32 ? '?'.charCodeAt(0) : v));
    console.log(ascii);

    // Hex editor print, like xxd
    let hexDigits: any[] = [];
    let stringRepresenation: any[] = [];

    for (let i = 0; i < x.length; i++) {
        const hex = x[i].toString(16).padStart(2, '0');
        hexDigits.push(hex);

        const char = x[i] < 32 ? '.' : String.fromCharCode(x[i]);
        stringRepresenation.push(char);
    }

    console.log('');
    console.log('          00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f   0123456789abcdef');
    for (let i = 0; i < hexDigits.length; i += 16) {
        const addr = i.toString(16).padStart(8, '0');
        console.log(addr + "  " + hexDigits.slice(i, i + 16).join(' '), ' ', stringRepresenation.slice(i, i + 16).join(''));
    }
    console.log('');
}

export default TailwindPlayground;
