Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code contribution #3

Open
falseywinchnet opened this issue Mar 4, 2025 · 0 comments
Open

code contribution #3

falseywinchnet opened this issue Mar 4, 2025 · 0 comments

Comments

@falseywinchnet
Copy link

I would like to volunteer a lightweight visualization for attention and tokens across the batch.
you will need a way to return the attention weights at each iteration.

you will also need the per token loss.


import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import ipywidgets as widgets
from IPython.display import display, clear_output
import time

# --- Configuration Constants ---
CHAR_WIDTH = 8  # Font size 8 for token rendering
CHAR_HEIGHT = 11
SEQ_LEN = 128
BATCH_SIZE = 16
LOSS_BAR_HEIGHT = 32
EWMA_HEIGHT = 32  # Increased to accommodate large text (previously 32)

# Full-resolution framebuffer dimensions
container_width = CHAR_WIDTH * SEQ_LEN  # 1024 pixels
container_height = CHAR_HEIGHT * BATCH_SIZE  # 176 pixels
total_height = container_height + LOSS_BAR_HEIGHT + EWMA_HEIGHT  # Adjusted for larger EWMA

# Final scaled-down dimensions
scaled_width = container_width   # 512 pixels
scaled_height = total_height  # 170 pixels

# Initialize framebuffer
framebuffer = np.zeros((total_height, container_width, 3), dtype=np.uint8)

# EWMA storage
ticker_history = np.zeros(SEQ_LEN, dtype=np.float32)  # Stock ticker moving buffer
loss_memory = 0.0
# Load font
try:
    font = ImageFont.truetype("DejaVuSansMono.ttf", 8)  # Monospaced font
    font_large = ImageFont.truetype("DejaVuSansMono.ttf", 64)  # Large EWMA display
except:
    font = ImageFont.load_default()
    font_large = font

# --- Color Mapping Functions ---
def get_flame_color(val):
    """Map a normalized value to a flame-like color."""
    return np.array([int(val * 255), int(val * 0.5 * 255), 0], dtype=np.uint8)

# --- IPython Display Setup ---
out = widgets.Output()
display(out)

def get_dynamic_color(attn_val, loss_val):
    """
    Compute a dynamic color transition between flame orange (uncertain) and phosphor green (confident).
    
    attn_val: Normalized attention value (0 to 1)
    loss_val: Normalized loss value (0 to 1, inverted as 1 - loss)
    
    Returns an RGB color as a NumPy array.
    colors late in training will often be red. this is suggested to swap out for get_flame_color
    but only on fine tuning on new data.
    """
    certainty = 1 - loss_val  # High certainty = low loss
    
    # Define RGB endpoints
    orange = np.array([attn_val * 255, attn_val * 0.5 * 255, 0], dtype=np.uint8)   # Uncertain (High Loss)
    green = np.array([attn_val * 0.5 * 255, attn_val * 255, attn_val * 0.25 * 255], dtype=np.uint8)  # Confident (Low Loss)
    
    # Interpolate based on certainty (0 = uncertain/orange, 1 = confident/green)
    color = (certainty * green) + ((1 - certainty) * orange)
    
    return color.astype(np.uint8)
# --- Framebuffer Update Function ---
def update_framebuffer(attn_weights, token_losses, current_loss, tokens,entropy):
    attn_weights =(attn_weights-attn_weights.min())/(np.ptp(attn_weights))
    token_losses =(token_losses-token_losses.min())/(np.ptp(token_losses))

    """Render the text grid with coloration based on attn * inverse loss."""
    global framebuffer, loss_history, ticker_history, loss_memory

    # Normalize to [0,1]

    # Create image buffer
    img = Image.new("RGB", (container_width, total_height), (0, 0, 0))
    draw = ImageDraw.Draw(img)

    # Render text with colored intensity
    char_positions = [
        (col * CHAR_WIDTH, row * CHAR_HEIGHT + EWMA_HEIGHT + LOSS_BAR_HEIGHT, tokens[row][col])
        for row in range(BATCH_SIZE) for col in range(SEQ_LEN)
    ]
    colors = [
        tuple(get_dynamic_color(attn_weights[row, col], attn_weights[row, col]))
        for row in range(BATCH_SIZE) for col in range(SEQ_LEN)
    ]
    for (x, y, char), color in zip(char_positions, colors):
        draw.text((x, y), char, font=font, fill=color)

                 
    etcerta = 0.367879441  # Constant used in update rule
    et = 1 - etcerta   
    update = loss_memory * et + np.minimum(12, np.maximum(current_loss , 0)) * etcerta
    loss_memory = loss_memory * et + update * etcerta
    # --- EWMA Display (LARGE FONT) ---
    ewma = loss_memory
    ewma_text = f"{ewma:.4f}"
    draw.text((container_width-128, 0), ewma_text, font_size=32, fill=(65,255, 125))
    ent_text = f"{entropy:.4f}"
    draw.text((10, 0), ent_text, font_size=32, fill=(255,125, 0))  

    # --- Moving Loss Ticker Graph ---
    ticker_history = np.roll(ticker_history, -1)  # Shift left
    ticker_history[-1] = current_loss  # Insert new loss on the right

    # Rescale ticker dynamically like a stock ticker (normalize to min-max range)
    min_loss = np.min(ticker_history)
    max_loss = np.max(ticker_history)
    range_loss = max_loss - min_loss if max_loss != min_loss else 1
    normalized_ticker = (ticker_history - min_loss) / range_loss

    # Draw ticker graph line
    # Optimized drawing loop (fewer function calls)
    y_vals = EWMA_HEIGHT + (1 - normalized_ticker) * LOSS_BAR_HEIGHT
    x_vals = np.arange(SEQ_LEN) * CHAR_WIDTH
    for i in range(SEQ_LEN - 1):
        draw.line([(x_vals[i], y_vals[i]), (x_vals[i + 1], y_vals[i + 1])], fill=(0, 255, 255), width=2)

    framebuffer = np.array(img)

# --- IPython Display Update Function ---
def update_display():
    """Show the framebuffer, scaled down by half using ipywidgets."""
    img = Image.fromarray(framebuffer)
    img_resized = img.resize((scaled_width, scaled_height), Image.LANCZOS)
    
    with out:
        clear_output(wait=True)
        display(img_resized)

loss_history = []

adjust batch size and seq len
update_framebuffer(attn_weights, token_losses, current_loss, tokens,entropy)
call this method in your training loop with all of the necessary products detached AFTER backprop
this minimizes the amount of gpu memory thrashing.
entropy is a relatively new idea but the implementation seems to need work.
https://arxiv.org/abs/2203.0919

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant