Skip to content

Commit

Permalink
fix: perf
Browse files Browse the repository at this point in the history
  • Loading branch information
jayanth-kumar-morem authored and shivam-singhal committed Jan 8, 2025
1 parent cc07fce commit 4e64f9f
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 18 deletions.
63 changes: 59 additions & 4 deletions preswald/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import pandas as pd
import hashlib
import time

# Configure logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -217,13 +218,59 @@ def plotly(fig):
fig: A Plotly figure object.
"""
try:
import time
start_time = time.time()
logger.debug(f"[PLOTLY] Starting plotly render")

id = generate_id("plot")
logger.debug(f"Creating plot component with id {id}")
logger.debug(f"[PLOTLY] Created plot component with id {id}")

# Optimize the figure for web rendering
optimize_start = time.time()

# Reduce precision of numeric values
for trace in fig.data:
for attr in ['x', 'y', 'z', 'lat', 'lon']:
if hasattr(trace, attr):
values = getattr(trace, attr)
if isinstance(values, (list, np.ndarray)):
if np.issubdtype(np.array(values).dtype, np.floating):
setattr(trace, attr, np.round(values, decimals=4))

# Optimize marker sizes
if hasattr(trace, 'marker') and hasattr(trace.marker, 'size'):
if isinstance(trace.marker.size, (list, np.ndarray)):
# Scale marker sizes to a reasonable range
sizes = np.array(trace.marker.size)
if len(sizes) > 0:
min_size, max_size = 5, 20 # Reasonable size range for web rendering
normalized_sizes = (sizes - sizes.min()) / (sizes.max() - sizes.min())
scaled_sizes = min_size + normalized_sizes * (max_size - min_size)
trace.marker.size = scaled_sizes.tolist()

# Optimize layout
if hasattr(fig, 'layout'):
# Set reasonable margins
fig.update_layout(
margin=dict(l=50, r=50, t=50, b=50),
autosize=True
)

# Optimize font sizes
fig.update_layout(
font=dict(size=12),
title=dict(font=dict(size=14))
)

logger.debug(f"[PLOTLY] Figure optimization took {time.time() - optimize_start:.3f}s")

# Convert the figure to JSON-serializable format
fig_dict_start = time.time()
fig_dict = fig.to_dict()
logger.debug(f"[PLOTLY] Figure to dict conversion took {time.time() - fig_dict_start:.3f}s")

# Clean up any NaN values in the data
clean_start = time.time()
for trace in fig_dict.get('data', []):
if isinstance(trace.get('marker'), dict):
marker = trace['marker']
Expand All @@ -236,9 +283,12 @@ def plotly(fig):
trace[key] = [None if isinstance(x, (float, np.floating)) and np.isnan(x) else x for x in value]
elif isinstance(value, (float, np.floating)) and np.isnan(value):
trace[key] = None
logger.debug(f"[PLOTLY] NaN cleanup took {time.time() - clean_start:.3f}s")

# Convert to JSON-serializable format
serialize_start = time.time()
serializable_fig_dict = convert_to_serializable(fig_dict)
logger.debug(f"[PLOTLY] Serialization took {time.time() - serialize_start:.3f}s")

component = {
"type": "plot",
Expand All @@ -250,20 +300,25 @@ def plotly(fig):
"responsive": True,
"displayModeBar": True,
"modeBarButtonsToRemove": ["lasso2d", "select2d"],
"displaylogo": False
"displaylogo": False,
"scrollZoom": True, # Enable scroll zoom for better interaction
"showTips": False, # Disable hover tips for better performance
}
}
}

# Verify JSON serialization
json_start = time.time()
json.dumps(component)
logger.debug(f"[PLOTLY] JSON verification took {time.time() - json_start:.3f}s")

logger.debug(f"Plot data created successfully for id {id}")
logger.debug(f"[PLOTLY] Plot data created successfully for id {id}")
logger.debug(f"[PLOTLY] Total plotly render took {time.time() - start_time:.3f}s")
_rendered_html.append(component)
return component

except Exception as e:
logger.error(f"Error creating plot: {str(e)}", exc_info=True)
logger.error(f"[PLOTLY] Error creating plot: {str(e)}", exc_info=True)
error_component = {
"type": "plot",
"id": id,
Expand Down
86 changes: 72 additions & 14 deletions preswald/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import os
import toml
import time

# Configure logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,12 +127,16 @@ def connect(source, name=None):
"""
Connect to a data source such as a CSV, JSON, or database.
If source is a connection name from config.toml, it will use that configuration.
Otherwise, it will treat source as a direct file path or connection string.
Otherwise, it will treat source as a direct path or connection string.
Args:
source (str): Either a connection name from config.toml or a direct path/connection string
name (str, optional): A unique name for the connection
"""
import time
start_time = time.time()
logger.info(f"[CONNECT] Starting connection to {source}")

if name is None:
name = f"connection_{len(connections) + 1}"

Expand All @@ -155,16 +160,18 @@ def connect(source, name=None):

# Load main config
if os.path.exists(config_path):
config_start = time.time()
with open(config_path, 'r') as f:
config = toml.load(f)
logger.info(f"[CONNECT] Loaded config: {config}")
logger.info(f"[CONNECT] Loaded config in {time.time() - config_start:.3f}s")

# Load secrets if available
secrets = {}
if os.path.exists(secrets_path):
secrets_start = time.time()
with open(secrets_path, 'r') as f:
secrets = toml.load(f)
logger.info("[CONNECT] Loaded secrets file")
logger.info(f"[CONNECT] Loaded secrets in {time.time() - secrets_start:.3f}s")

# Handle nested data section
if source.startswith("data."):
Expand Down Expand Up @@ -221,6 +228,28 @@ def connect(source, name=None):
raise # Re-raise the exception to handle it properly

logger.info(f"[CONNECT] Final source path: {source}")

# Function to optimize large dataframes
def optimize_dataframe(df, max_rows=10000):
"""Optimize dataframe by sampling and converting types"""
if len(df) > max_rows:
logger.info(f"[CONNECT] Sampling large dataframe from {len(df)} to {max_rows} rows")
df = df.sample(n=max_rows, random_state=42)

# Optimize numeric columns
for col in df.select_dtypes(include=['float64']).columns:
if df[col].nunique() < 1000: # If column has few unique values
df[col] = df[col].astype('float32')

# Optimize integer columns
for col in df.select_dtypes(include=['int64']).columns:
if df[col].min() >= -32768 and df[col].max() <= 32767:
df[col] = df[col].astype('int16')
elif df[col].min() >= -2147483648 and df[col].max() <= 2147483647:
df[col] = df[col].astype('int32')

return df

# Process the source as a direct path/connection string
if source.endswith(".csv"):
logger.info(f"[CONNECT] Reading CSV from: {source}")
Expand All @@ -233,7 +262,9 @@ def connect(source, name=None):
# Create a StringIO object from the response content
from io import StringIO
csv_data = StringIO(response.text)
connections[name] = pd.read_csv(csv_data)
read_start = time.time()
df = pd.read_csv(csv_data)
logger.info(f"[CONNECT] CSV read from URL took {time.time() - read_start:.3f}s")
else:
# Handle local file path
script_path = get_script_path()
Expand All @@ -244,13 +275,36 @@ def connect(source, name=None):
else:
csv_path = source
logger.info(f"[CONNECT] Reading CSV from: {csv_path}")
connections[name] = pd.read_csv(csv_path)
read_start = time.time()
df = pd.read_csv(csv_path)
logger.info(f"[CONNECT] CSV read from file took {time.time() - read_start:.3f}s")

# Optimize the dataframe
optimize_start = time.time()
df = optimize_dataframe(df)
logger.info(f"[CONNECT] Dataframe optimization took {time.time() - optimize_start:.3f}s")
connections[name] = df

elif source.endswith(".json"):
logger.info(f"[CONNECT] Reading JSON from: {source}")
connections[name] = pd.read_json(source)
read_start = time.time()
df = pd.read_json(source)
logger.info(f"[CONNECT] JSON read took {time.time() - read_start:.3f}s")
optimize_start = time.time()
df = optimize_dataframe(df)
logger.info(f"[CONNECT] Dataframe optimization took {time.time() - optimize_start:.3f}s")
connections[name] = df

elif source.endswith(".parquet"):
logger.info(f"[CONNECT] Reading Parquet from: {source}")
connections[name] = pd.read_parquet(source)
read_start = time.time()
df = pd.read_parquet(source)
logger.info(f"[CONNECT] Parquet read took {time.time() - read_start:.3f}s")
optimize_start = time.time()
df = optimize_dataframe(df)
logger.info(f"[CONNECT] Dataframe optimization took {time.time() - optimize_start:.3f}s")
connections[name] = df

elif any(source.startswith(prefix) for prefix in ["postgresql://", "postgres://", "mysql://"]):
logger.info(f"[CONNECT] Creating database engine")
engine = create_engine(source)
Expand All @@ -267,7 +321,7 @@ def connect(source, name=None):

# Broadcast connection updates
asyncio.create_task(broadcast_connections())
logger.info(f"[CONNECT] Successfully created connection '{name}' to {source}")
logger.info(f"[CONNECT] Successfully created connection '{name}' to {source} in {time.time() - start_time:.3f}s")
except Exception as e:
# Clean up if connection failed
if name in connections:
Expand Down Expand Up @@ -393,7 +447,8 @@ def convert_to_serializable(obj):

def get_rendered_components():
"""Get all rendered components as JSON"""
logger.debug(f"[CORE] Getting rendered components, count: {len(_rendered_html)}")
start_time = time.time()
logger.debug(f"[RENDER] Getting rendered components, count: {len(_rendered_html)}")
components = []

# Create a set to track unique component IDs
Expand All @@ -403,7 +458,9 @@ def get_rendered_components():
try:
if isinstance(item, dict):
# Clean any NaN values in the component
clean_start = time.time()
cleaned_item = _clean_nan_values(item)
logger.debug(f"[RENDER] NaN cleanup took {time.time() - clean_start:.3f}s")

# Ensure component has current state
if 'id' in cleaned_item:
Expand All @@ -414,24 +471,25 @@ def get_rendered_components():
current_state = get_component_state(component_id)
if current_state is not None:
cleaned_item['value'] = _clean_nan_values(current_state)
logger.debug(f"[CORE] Updated component {component_id} with state: {current_state}")
logger.debug(f"[RENDER] Updated component {component_id} with state: {current_state}")
components.append(cleaned_item)
seen_ids.add(component_id)
logger.debug(f"[CORE] Added component with state: {cleaned_item}")
logger.debug(f"[RENDER] Added component with state: {cleaned_item}")
else:
# Components without IDs are added as-is
components.append(cleaned_item)
logger.debug(f"[CORE] Added component without ID: {cleaned_item}")
logger.debug(f"[RENDER] Added component without ID: {cleaned_item}")
else:
# Convert HTML string to component data
component = {
"type": "html",
"content": str(item)
}
components.append(component)
logger.debug(f"[CORE] Added HTML component: {component}")
logger.debug(f"[RENDER] Added HTML component: {component}")
except Exception as e:
logger.error(f"[CORE] Error processing component: {e}", exc_info=True)
logger.error(f"[RENDER] Error processing component: {e}", exc_info=True)
continue

logger.debug(f"[RENDER] Total rendering took {time.time() - start_time:.3f}s")
return components
8 changes: 8 additions & 0 deletions preswald/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import signal
import sys
import time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -359,6 +360,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: str):

async def handle_websocket_message(websocket: WebSocket, message: str):
"""Handle incoming WebSocket messages"""
start_time = time.time()
try:
data = json_loads(message)
logger.debug(f"[WebSocket] Received message: {data}")
Expand All @@ -377,6 +379,7 @@ async def handle_websocket_message(websocket: WebSocket, message: str):
return

logger.info("[Component Update] Processing updates:")
update_start = time.time()
for component_id, value in states.items():
try:
# Update component state
Expand All @@ -390,17 +393,22 @@ async def handle_websocket_message(websocket: WebSocket, message: str):
logger.error(f"Error updating component {component_id}: {e}")
await send_error(websocket, f"Failed to update component {component_id}")
continue
logger.info(f"[Component Update] State updates took {time.time() - update_start:.3f}s")

# Trigger script rerun with all states
logger.info(f"[Script Rerun] Triggering with states: {states}")
rerun_start = time.time()
await rerun_script(websocket, states)
logger.info(f"[Script Rerun] Script rerun took {time.time() - rerun_start:.3f}s")

except json.JSONDecodeError as e:
logger.error(f"[WebSocket] Error decoding message: {e}")
await send_error(websocket, "Invalid message format")
except Exception as e:
logger.error(f"[WebSocket] Error processing message: {e}")
await send_error(websocket, f"Error processing message: {str(e)}")
finally:
logger.info(f"[WebSocket] Total message handling took {time.time() - start_time:.3f}s")

async def send_error(websocket: WebSocket, message: str):
"""Send error message to client"""
Expand Down

0 comments on commit 4e64f9f

Please sign in to comment.