Skip to content

Commit

Permalink
re-connect on NoHostAvailale and better logging
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Aug 27, 2024
1 parent 7a412aa commit 35badac
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 29 deletions.
22 changes: 12 additions & 10 deletions impl/astra_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ def connect(self, retry=False):
raise

except DriverException as e:
logger.error(e)
logger.error(f"dbid: {self.dbid}, error: {e}")
raise HTTPException(400, f"Failed to connect to cluster - the database may be hibernated")

except Exception as e:
logger.warning(f"Failed to connect to AstraDB: {e}")
logger.warning(f"Failed to connect to AstraDB: {e}, dbid: {self.dbid}")
# sleep and retry
time.sleep(5)
if retry:
Expand Down Expand Up @@ -1390,6 +1390,8 @@ def upsert_table_from_dict(self, table_name : str, obj : Dict):
)
logger.debug(f"upserted {table_name}: {obj}")
except Exception as e:
if isinstance(e, NoHostAvailable):
self.session = self.connect()
logger.error(f"failed to upsert {table_name}: {obj}, {query_string}")
raise e

Expand Down Expand Up @@ -1611,10 +1613,10 @@ def upsert_chunks_content_only(self, file_id, content, created_at):
),
)
except Exception as e:
logger.warning(f"Exception inserting into table: {e}")
logger.warning(f"Exception inserting into table: {e}, dbid: {self.dbid}")
raise

def queue_up_chunks(self, statements_and_params: [PreparedStatement], json: dict[str, Any], embedding_model,
def queue_up_chunks(self, statements_and_params: List[PreparedStatement], json: dict[str, Any], embedding_model,
**litellm_kwargs):
statement = self.make_chunks_statement(embedding_model, json, litellm_kwargs)
statements_and_params.append(
Expand All @@ -1631,13 +1633,13 @@ def queue_up_chunks(self, statements_and_params: [PreparedStatement], json: dict
)
return statements_and_params

def upsert_chunks_concurrently(self, statements_and_params: [SimpleStatement]):
def upsert_chunks_concurrently(self, statements_and_params: List[SimpleStatement]):
results = execute_concurrent(
self.session, statements_and_params, concurrency=100, results_generator=True)

for (success, result) in results:
if not success:
logger.warning(f"Exception inserting into table: {result}")
logger.warning(f"Exception inserting into table: {result}, dbid: {self.dbid}")
raise

def do_upsert_chunks(self, json: dict[str, Any], embedding_model, **litellm_kwargs):
Expand All @@ -1655,7 +1657,7 @@ def do_upsert_chunks(self, json: dict[str, Any], embedding_model, **litellm_kwar
),
)
except Exception as e:
logger.warning(f"Exception inserting into table: {e}")
logger.warning(f"Exception inserting into table: {e}, dbid: {self.dbid}")
raise

def make_chunks_statement(self, embedding_model, json, litellm_kwargs):
Expand Down Expand Up @@ -1838,7 +1840,7 @@ def handle_multiple_partitions(self, embeddings, limit, queryString, vector_inde
json_rows = []
for (success, result) in rows:
if not success:
logger.error(f"problem with async query: {result}") # result will be an Exception
logger.error(f"problem with async query: {result}, dbid: {self.dbid}") # result will be an Exception
else:
for row in result:
json_rows.append(dict(row))
Expand Down Expand Up @@ -1884,8 +1886,8 @@ def execute_and_get_json(self, boundStatement, vector_index_column, tries=0):
return json_rows
except Exception as e:
if tries < 3:
logger.warning(f"Exception during query (retrying): {e}")
logger.warning(f"Exception during query (retrying): {e}, dbid: {self.dbid}")
time.sleep(1)
return self.execute_and_get_json(boundStatement, vector_index_column, tries + 1)
else:
raise HTTPException(status_code=500, detail=f"Exception during recall")
raise HTTPException(status_code=500, detail=f"Exception during recall, dbid: {self.dbid}")
11 changes: 7 additions & 4 deletions impl/background.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import logging

from impl.astra_vector import CassandraClient

logger = logging.getLogger(__name__)
background_task_set = set()

Expand All @@ -15,7 +17,7 @@ async def add_background_task(function, run_id, thread_id, astradb):
task.add_done_callback(lambda t: on_task_completion(t, astradb=astradb, run_id=run_id, thread_id=thread_id))


def on_task_completion(task, astradb, run_id, thread_id):
def on_task_completion(task, astradb: CassandraClient, run_id, thread_id):
background_task_set.remove(task)
logger.debug(f"Task stopped for run_id: {run_id} and thread_id: {thread_id}")

Expand All @@ -26,13 +28,14 @@ def on_task_completion(task, astradb, run_id, thread_id):
try:
exception = task.exception()
if exception is not None:
logger.warning(f"Task raised an exception, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
logger.error(exception)
logger.error(f"Task raised an exception, setting status to failed for run_id: "
f"{run_id} and thread_id: {thread_id} and dbid: {astradb.dbid}"
f"\nException:\n{exception}")
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
raise exception
else:
logger.debug(f"Task completed successfully for run_id: {run_id} and thread_id: {thread_id}")
except asyncio.CancelledError:
logger.warning(f"why wasn't this caught in task.cancelled()")
logger.debug(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id}")
logger.debug(f"Task cancelled, setting status to failed for run_id: {run_id} and thread_id: {thread_id} and dbid: {astradb.dbid}")
astradb.update_run_status(id=run_id, thread_id=thread_id, status="failed");
8 changes: 4 additions & 4 deletions impl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def dispatch(self, request: Request, call_next):
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"Error: {e}, dbid: {request.state.dbid}")
print(e)
raise e

Expand Down Expand Up @@ -233,7 +233,7 @@ async def shutdown_event():
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception):
# Log the error
logger.error(f"Unexpected error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}")
logger.error(f"Unexpected error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}, dbid: {request.state.dbid}")

if isinstance(exc, HTTPException):
raise exec
Expand All @@ -245,9 +245,9 @@ async def generic_exception_handler(request: Request, exc: Exception):

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
logging.error(f"Validation error for request: {request.url}")
logging.error(f"Validation error for request: {request.url}, dbid: {request.state.dbid}")
logging.error(f"Body: {exc.body}")
logger.error(f"Validation error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}")
logging.error(f"Validation error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}")
logging.error(f"Errors: {exc.errors()}")
return JSONResponse(
status_code=422,
Expand Down
18 changes: 9 additions & 9 deletions impl/routes_v2/threads_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ async def run_event_stream(run, message_id, astradb):
if e.status_code == 404:
run_step = None
else:
logger.error(e)
logger.error(f"Error reading run step, dbid: {astradb.dbid}, error: {e}")
raise e

#run_step = astradb.get_run_step(run_id=run.id, id=run_step_id)
Expand Down Expand Up @@ -579,7 +579,7 @@ async def run_event_stream(run, message_id, astradb):
except Exception as e:
# This usually means the client is broken
# TODO: cancel the run.
logger.error(e)
logger.error(f"Error in run event stream, dbid: {astradb.dbid}, error: {e}")


async def stream_message_events(astradb, thread_id, limit, order, after, before, run):
Expand Down Expand Up @@ -671,7 +671,7 @@ async def stream_message_events(astradb, thread_id, limit, order, after, before,
last_message.content = message.content
break
except Exception as e:
logger.error(e)
logger.error(f"Error in stream message events, dbid: {astradb.dbid}, error: {e}")
# TODO - cancel run, mark message incomplete
# yield f"data: []"

Expand Down Expand Up @@ -885,7 +885,7 @@ async def create_run(
try:
message = await get_chat_completion(messages=message_content, model=model, **litellm_kwargs[0])
except Exception as e:
logger.error(f"error: {e}, tenant {astradb.dbid}, model {model}, messages.data {messages.data}, create_run_request {create_run_request}")
logger.error(f"error: {e}, dbid: {astradb.dbid}, model {model}, messages.data {messages.data}, create_run_request {create_run_request}")
raise HTTPException(status_code=500, detail=f"Error processing message, {e}")

logger.info(f"tool_call message: {message}")
Expand Down Expand Up @@ -1232,13 +1232,13 @@ async def process_rag(
**litellm_kwargs[0],
)
except asyncio.CancelledError as e:
logger.error(e)
logger.error(f"process_rag cancelled, dbid: {astradb.dbid}, error: {e}")
# TODO maybe do a cancelled run step with more details?
await update_run_status(thread_id=thread_id, id=run_id, status="failed", astradb=astradb)
logger.error("process_rag cancelled")
raise RuntimeError("process_rag cancelled")
except Exception as e:
logger.error(e)
logger.error(f"process_rag failed, dbid: {astradb.dbid}, error: {e}")
# TODO maybe do a cancelled run step with more details?
await update_run_status(thread_id=thread_id, id=run_id, status="failed", astradb=astradb)
logger.error("process_rag cancelled")
Expand Down Expand Up @@ -1278,7 +1278,7 @@ async def process_rag(
except Exception as e:
await update_run_status(thread_id=thread_id, id=run_id, status="failed", astradb=astradb)
logger.error(traceback.format_exc())
logger.error(e)
logger.error(f"Error in process_rag, dbid: {astradb.dbid}, error: {e}")
raise e
except asyncio.CancelledError:
logger.error("process_rag cancelled")
Expand Down Expand Up @@ -1642,7 +1642,7 @@ async def submit_tool_ouputs_to_run(
media_type="text/event-stream")

except Exception as e:
logger.info(e)
logger.info(f"Error in submit_tool_ouputs_to_run, dbid: {astradb.dbid}, error: {e}")
await update_run_status(thread_id=thread_id, id=run_id, status="failed", astradb=astradb)
raise

Expand Down Expand Up @@ -1797,7 +1797,7 @@ async def message_delta_streamer(message_id, created_at, response, run, astradb)
logger.info(f"completed run_id {run.id} thread_id {run.thread_id} with tool submission")

except Exception as e:
logger.info(e)
logger.info(f"Error in message_delta_streamer, dbid: {astradb.dbid}, error: {e}")
await update_run_status(thread_id=run.thread_id, id=run.id, status="failed", astradb=astradb)
raise

Expand Down
5 changes: 3 additions & 2 deletions impl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def store_object(astradb: CassandraClient, obj: BaseModel, target_class: T
astradb.upsert_table_from_dict(table_name=table_name, obj=obj_dict)
return combined_obj
except Exception as e:
logger.error(f"store_object failed {e} for table {table_name} and object {obj}")
logger.error(f"store_object failed {e} for table {table_name} and object {obj}, dbid: {astradb.dbid}")
raise HTTPException(status_code=500, detail=f"Error reading {table_name}: {e}")


Expand All @@ -88,11 +88,12 @@ def read_object(astradb: CassandraClient, target_class: Type[BaseModel], table_n
try:
objs = read_objects(astradb, target_class, table_name, partition_keys, args)
except Exception as e:
logger.error(f"read_object failed {e} for table {table_name}")
logger.error(f"read_object failed {e} for table {table_name}, dbid: {astradb.dbid}")
logger.error(f"trace: {traceback.format_exc()}")
raise HTTPException(status_code=404, detail=f"{target_class.__name__} not found.")
if len(objs) == 0:
# Maybe pass down name
logger.warn(f"did not find partition_keys {partition_keys} and args {args} for {target_class.__name__} in table {table_name} for dbid: {astradb.dbid}")
raise HTTPException(status_code=404, detail=f"{target_class.__name__} not found.")
return objs[0]

Expand Down

0 comments on commit 35badac

Please sign in to comment.