diff --git a/README.md b/README.md index 8dba7c8..284d5ca 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,4 @@ -# VDebugger - -This repo is for **VDebugger: Harnessing Execution Feedback for Debugging Visual Programs** +# VDebugger: Harnessing Execution Feedback for Debugging Visual Programs [Paper](https://arxiv.org/abs/2406.13444), [Website](https://shirley-wu.github.io/vdebugger/index.html), [Models and Data](https://huggingface.co/VDebugger) @@ -10,6 +8,7 @@ This repo is for **VDebugger: Harnessing Execution Feedback for Debugging Visual - [Dataset Setup](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#dataset-setup) - [Generation and Execution of Visual Programs](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#generation-and-execution-of-visual-programs) - [Inference of VDebugger](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#inference-of-vdebugger) +- [Error Injection](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#error-injection) ## Environment Setup @@ -135,6 +134,13 @@ Then you can execute the programs in `critic-refine-infer.csv` as in step 2 of [ If you want to reproduce our training of VDebugger, please use `vdebugger/training_scripts/train_{critic, refiner}.sh`. You will need to install `deepspeed==0.14.0`. +## Error Injection + +To perform error injection and generate incorrect programs as described in Section 4 of our paper, you first need a `.csv` file containing the visual programs generated for the training set and their execution results. Then, please go to `vdebugger/` and run: +```bash +python error_injection.py YOUR_CSV_FILE --error_injection {greedy, mask-best} +``` + ## Citation Please cite our paper if this repository inspires your work. diff --git a/vdebugger/error_injection.py b/vdebugger/error_injection.py new file mode 100644 index 0000000..72bd40b --- /dev/null +++ b/vdebugger/error_injection.py @@ -0,0 +1,159 @@ +import argparse +import ast +import os +import random + +import numpy as np +import pandas as pd +import torch +from vllm import LLM, SamplingParams + +from my_datasets import datasets + + +def get_all_candidates(code, head: str): + code = head + code + try: + code = ast.unparse(ast.parse(code)) + except: + return None + + lines = code.splitlines() + tree = ast.parse(code) + + def get_text(nodes): + start = nodes[0] + end = nodes[-1] + end_lineno = getattr(end, 'end_lineno', end.lineno) + if start.lineno == end_lineno: + return lines[start.lineno - 1][start.col_offset: end.end_col_offset] + else: + return '\n'.join( + [lines[start.lineno - 1][start.col_offset:], ] + lines[start.lineno: end_lineno - 1] + \ + [lines[end_lineno - 1][:end.end_col_offset], ] + ) + + def mask_nodes(nodes): + start = nodes[0] + end = nodes[-1] + ret = lines[:start.lineno - 1] + [lines[start.lineno - 1][:start.col_offset] + '', ] + ret[-1] += lines[end.end_lineno - 1][end.end_col_offset:] + ret += lines[end.end_lineno:] + return '\n'.join(ret) + + def is_ImagePatch(x): + return isinstance(x, ast.Call) and get_text([x.func, ]) == 'ImagePatch' + + candidate_nodes_to_mask = [] + for node in ast.walk(tree): + if node is tree: + continue # don't add root-level nodes + if hasattr(node, 'body') and isinstance(node.body, list): + for i in range(len(node.body)): + for j in range(i + 1, min(i + 4, len(node.body) + 1)): + if not any((isinstance(x, ast.Assign) and is_ImagePatch(x.value)) for x in node.body[i:j]): + candidate_nodes_to_mask.append(node.body[i:j]) + if isinstance(node, ast.Assign) or isinstance(node, ast.Return): + if not is_ImagePatch(node.value) and node.value is not None: + candidate_nodes_to_mask.append([node.value, ]) + if isinstance(node, ast.If): + candidate_nodes_to_mask.append([node.test, ]) + if isinstance(node, ast.Call) and is_ImagePatch(node): + if not is_ImagePatch(node): + candidate_nodes_to_mask.append(node.args) + + return [(mask_nodes(nodes), get_text(nodes)) for nodes in candidate_nodes_to_mask] + + +def sample_one_masked_code(code, head): + candidates = get_all_candidates(code, head) + if candidates is None: + return None + return random.choice(candidates)[0] + + +def get_prompt(): + with open(os.path.join(os.path.dirname(__file__), '../viper/prompts/benchmarks/joint.py')) as f: + base_prompt = f.read().strip() + api_definition = base_prompt.replace("# INSERT_QUERY_HERE", "").strip() + return """[INST] I am writing code to handle visual question answering tasks by calling computer vision APIs. Some content from the code is masked (represented as "". Please recover the original code. + +My code: +```python +# {QUESTION_1} +{CODE} +``` + +Your code should be wrapped in ```python and ```. The code should be exactly the same as my code, except recovering the masked content. + +--- + +Below are the available APIs and some example usages: + +```python +""" + api_definition + """ +```[/INST] Here's the original code with the `` section replaced: +```python +# {QUESTION_1} +{QUESTION_2}""" + + +def main(args): + dataset = datasets[args.dataset] + + data = pd.read_csv(args.input) + data['acc'] = [dataset.accuracy([r, ], [a, ]) for r, a in zip(data['result'], data['answer'])] + + PROMPT = get_prompt() + llm = LLM(model=args.model_id, dtype='bfloat16', tensor_parallel_size=torch.cuda.device_count()) + if args.error_injection == 'mask-best': + llm.llm_engine.model_executor.driver_worker.model_runner.model.sampler = ErrorSampler() + + data['orig_code'] = data.pop('code') + data.insert(len(data.keys()), 'masked', '') + data.insert(len(data.keys()), 'generation', '') + data.insert(len(data.keys()), 'code', '[') + + prompt_inds = [] + prompts = [] + masked_codes = [] + for i, row in data.iterrows(): + if row['acc'] > 0: + question_1, question_2 = row['query'].splitlines() + masked = sample_one_masked_code(row['orig_code'], question_2) + if masked is not None: + prompt_inds.append(i) + masked_codes.append(masked) + prompt = PROMPT.format(QUESTION_1=question_1, QUESTION_2=question_2, CODE=masked) + prompts.append(prompt) + + prompt_inds = np.array(prompt_inds) + data['masked'][prompt_inds] = np.array(masked_codes) + + if args.error_injection == 'mask-best': + sampling_params = SamplingParams(max_tokens=512, temperature=1.0) + else: + assert args.error_injection == 'greedy' + sampling_params = SamplingParams(max_tokens=512, temperature=0.0) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + generation = [o.outputs[0].text for o in outputs] + new_code = [x.split("```")[0] for x in generation] + + data['generation'][prompt_inds] = np.array(generation) + data['code'][prompt_inds] = np.array(new_code) + + data.to_csv(args.output, escapechar='\\') + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('input') + parser.add_argument('--output') + parser.add_argument('--model_id', default='codellama/CodeLlama-7b-Instruct-hf') + parser.add_argument('--error_injection', default='mask-best', choices=['mask-best', 'greedy', ]) + args = parser.parse_args() + if args.output is None: + args.output = args.input.replace('.csv', '.error-injection-by-{}.csv'.format(args.error_injection)) + assert not os.path.exists(args.output), "Warning: will overwrite " + args.output + + main(args) diff --git a/vdebugger/error_sampler.py b/vdebugger/error_sampler.py new file mode 100644 index 0000000..f3211b6 --- /dev/null +++ b/vdebugger/error_sampler.py @@ -0,0 +1,95 @@ +from typing import Optional + +import torch +import torch.nn as nn +from vllm.model_executor.layers.sampler import ( + _apply_min_tokens_penalty, _apply_penalties, _apply_top_k_top_p, + _apply_min_p, _sample, _get_logprobs, _build_sampler_output +) +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors +from vllm.sequence import SamplerOutput + +MASK_TOP_TH = 0.9 +MASK_MAX_TIMES = 3 + + +class ErrorSampler(nn.Module): + def forward( + self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + assert logits is not None + _, vocab_size = logits.shape + + # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + # have not been generated yet + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Prepare sampling tensors with pinned memory to avoid blocking. + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + # Apply presence and frequency penalties. + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + # Use log_softmax to ensure numerical stability. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # ------ My custom code + # collect seq_inds + seq_inds = [] + for seq_group in sampling_metadata.seq_groups: + seq_ids_, _ = seq_group + seq_inds += seq_ids_ + # collect perturbed counters + perturbed = [] + for i in seq_inds: + if not hasattr(sampling_metadata.seq_data[i], 'perturbed'): + sampling_metadata.seq_data[i].perturbed = 0 + perturbed.append(sampling_metadata.seq_data[i].perturbed) + # help computation, cast to tensor + seq_inds = torch.LongTensor(seq_inds).to(logits.device) + perturbed = torch.LongTensor(perturbed).to(logits.device) + # mask which: probability threshold, and perturbed counter + top2 = logprobs.exp().topk(2, dim=1, sorted=True) + should_mask_top = torch.bitwise_and( + perturbed < MASK_MAX_TIMES, top2.values[:, 0] - top2.values[:, 1] < MASK_TOP_TH + ) + if should_mask_top.sum() > 0: + # actual perturb + logits[should_mask_top, top2.indices[should_mask_top, 0]] = -float("inf") + # re-compute softmax + probs[should_mask_top] = torch.softmax(logits[should_mask_top], dim=-1, dtype=torch.float) + logprobs[should_mask_top] = torch.log_softmax(logits[should_mask_top], dim=-1, dtype=torch.float) + for i in seq_inds[should_mask_top]: # allow at most 3 perturbations + sampling_metadata.seq_data[int(i)].perturbed += 1 + # ------ My custom code + + # Sample the next tokens. + sample_results = _sample(probs, logprobs, sampling_metadata, + sampling_tensors) + # Get the logprobs query results. + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + return _build_sampler_output(sample_results, sampling_metadata, + prompt_logprobs, sample_logprobs) diff --git a/vdebugger/infer_critic.py b/vdebugger/infer_critic.py index b284241..b338365 100644 --- a/vdebugger/infer_critic.py +++ b/vdebugger/infer_critic.py @@ -7,11 +7,10 @@ import pandas as pd import sklearn.metrics import torch -from dump_data import datasets from transformers import AutoTokenizer from vllm import LLM, SamplingParams -from my_datasets import process_result +from my_datasets import datasets, process_result def parse(code): diff --git a/vdebugger/my_datasets/__init__.py b/vdebugger/my_datasets/__init__.py index f440e36..80b1086 100644 --- a/vdebugger/my_datasets/__init__.py +++ b/vdebugger/my_datasets/__init__.py @@ -33,3 +33,7 @@ def __repr__(self): except: print("Weird or invalid ImagePatch:", x) return x + + +datasets = {'gqa': GQADataset(), 'vsr': VSRDataset(), 'tallyqa': TallyQADataset(), 'covr': COVRDataset(), + 'refcoco': RefCOCODataset(), 'nlvr': NLVRDataset(), }