Skip to content

Commit

Permalink
Custom middleware function
Browse files Browse the repository at this point in the history
  • Loading branch information
Bslabe123 committed Oct 29, 2024
1 parent e06b6db commit 9a4a448
Showing 1 changed file with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
import requests
import time
from typing import AsyncGenerator, List, Optional, Tuple, Dict
from prometheus_client import start_http_server, Histogram
from prometheus_client import start_http_server, Histogram, Gauge

import google.auth
import google.auth.transport.requests

import aiohttp
from aiohttp_prometheus_exporter.trace import PrometheusTraceConfig
import numpy as np
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerBase
Expand All @@ -35,6 +34,18 @@
prompt_length_metric = Histogram("LatencyProfileGenerator:prompt_length", "Input prompt length", buckets=[2**i for i in range(1, 16)])
response_length_metric = Histogram("LatencyProfileGenerator:response_length", "Response length", buckets=[2**i for i in range(1, 16)])
tpot_metric = Histogram('LatencyProfileGenerator:time_per_output_token', 'Time per output token per request')
active_requests_metric = Gauge('LatencyProfileGenerator:active_requests', 'How many requests actively being processed')

# Add trace config for monitoring in flight requests
async def on_request_start(session, trace_config_ctx, params):
active_requests_metric.inc()

async def on_request_end(session, trace_config_ctx, params):
active_requests_metric.dec()

trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(on_request_start)
trace_config.on_request_end.append(on_request_end)

def sample_requests(
dataset_path: str,
Expand Down Expand Up @@ -209,7 +220,7 @@ async def send_request(

# Set client timeout to be 3 hrs.
timeout = aiohttp.ClientTimeout(total=CLIENT_TIMEOUT_SEC)
async with aiohttp.ClientSession(timeout=timeout,trust_env=True,trace_configs=[PrometheusTraceConfig()]) as session:
async with aiohttp.ClientSession(timeout=timeout,trust_env=True,trace_configs=[trace_config]) as session:
while True:
try:
async with session.post(api_url, headers=headers, json=pload, ssl=False) as response:
Expand Down

0 comments on commit 9a4a448

Please sign in to comment.