Skip to content

Commit

Permalink
update error injection code
Browse files Browse the repository at this point in the history
  • Loading branch information
Xueqing Wu committed Jun 24, 2024
1 parent 74e6954 commit b3b54c8
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 5 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# VDebugger

This repo is for **VDebugger: Harnessing Execution Feedback for Debugging Visual Programs**
# VDebugger: Harnessing Execution Feedback for Debugging Visual Programs

[Paper](https://arxiv.org/abs/2406.13444), [Website](https://shirley-wu.github.io/vdebugger/index.html), [Models and Data](https://huggingface.co/VDebugger)

Expand All @@ -10,6 +8,7 @@ This repo is for **VDebugger: Harnessing Execution Feedback for Debugging Visual
- [Dataset Setup](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#dataset-setup)
- [Generation and Execution of Visual Programs](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#generation-and-execution-of-visual-programs)
- [Inference of VDebugger](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#inference-of-vdebugger)
- [Error Injection](https://github.com/shirley-wu/vdebugger/tree/main?tab=readme-ov-file#error-injection)

## Environment Setup

Expand Down Expand Up @@ -135,6 +134,13 @@ Then you can execute the programs in `critic-refine-infer.csv` as in step 2 of [

If you want to reproduce our training of VDebugger, please use `vdebugger/training_scripts/train_{critic, refiner}.sh`. You will need to install `deepspeed==0.14.0`.

## Error Injection

To perform error injection and generate incorrect programs as described in Section 4 of our paper, you first need a `.csv` file containing the visual programs generated for the training set and their execution results. Then, please go to `vdebugger/` and run:
```bash
python error_injection.py YOUR_CSV_FILE --error_injection {greedy, mask-best}
```

## Citation

Please cite our paper if this repository inspires your work.
Expand Down
159 changes: 159 additions & 0 deletions vdebugger/error_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse
import ast
import os
import random

import numpy as np
import pandas as pd
import torch
from vllm import LLM, SamplingParams

from my_datasets import datasets


def get_all_candidates(code, head: str):
code = head + code
try:
code = ast.unparse(ast.parse(code))
except:
return None

lines = code.splitlines()
tree = ast.parse(code)

def get_text(nodes):
start = nodes[0]
end = nodes[-1]
end_lineno = getattr(end, 'end_lineno', end.lineno)
if start.lineno == end_lineno:
return lines[start.lineno - 1][start.col_offset: end.end_col_offset]
else:
return '\n'.join(
[lines[start.lineno - 1][start.col_offset:], ] + lines[start.lineno: end_lineno - 1] + \
[lines[end_lineno - 1][:end.end_col_offset], ]
)

def mask_nodes(nodes):
start = nodes[0]
end = nodes[-1]
ret = lines[:start.lineno - 1] + [lines[start.lineno - 1][:start.col_offset] + '<MASKED>', ]
ret[-1] += lines[end.end_lineno - 1][end.end_col_offset:]
ret += lines[end.end_lineno:]
return '\n'.join(ret)

def is_ImagePatch(x):
return isinstance(x, ast.Call) and get_text([x.func, ]) == 'ImagePatch'

candidate_nodes_to_mask = []
for node in ast.walk(tree):
if node is tree:
continue # don't add root-level nodes
if hasattr(node, 'body') and isinstance(node.body, list):
for i in range(len(node.body)):
for j in range(i + 1, min(i + 4, len(node.body) + 1)):
if not any((isinstance(x, ast.Assign) and is_ImagePatch(x.value)) for x in node.body[i:j]):
candidate_nodes_to_mask.append(node.body[i:j])
if isinstance(node, ast.Assign) or isinstance(node, ast.Return):
if not is_ImagePatch(node.value) and node.value is not None:
candidate_nodes_to_mask.append([node.value, ])
if isinstance(node, ast.If):
candidate_nodes_to_mask.append([node.test, ])
if isinstance(node, ast.Call) and is_ImagePatch(node):
if not is_ImagePatch(node):
candidate_nodes_to_mask.append(node.args)

return [(mask_nodes(nodes), get_text(nodes)) for nodes in candidate_nodes_to_mask]


def sample_one_masked_code(code, head):
candidates = get_all_candidates(code, head)
if candidates is None:
return None
return random.choice(candidates)[0]


def get_prompt():
with open(os.path.join(os.path.dirname(__file__), '../viper/prompts/benchmarks/joint.py')) as f:
base_prompt = f.read().strip()
api_definition = base_prompt.replace("# INSERT_QUERY_HERE", "").strip()
return """[INST] I am writing code to handle visual question answering tasks by calling computer vision APIs. Some content from the code is masked (represented as "<MASKED>". Please recover the original code.
My code:
```python
# {QUESTION_1}
{CODE}
```
Your code should be wrapped in ```python and ```. The code should be exactly the same as my code, except recovering the masked content.
---
Below are the available APIs and some example usages:
```python
""" + api_definition + """
```[/INST] Here's the original code with the `<MASKED>` section replaced:
```python
# {QUESTION_1}
{QUESTION_2}"""


def main(args):
dataset = datasets[args.dataset]

data = pd.read_csv(args.input)
data['acc'] = [dataset.accuracy([r, ], [a, ]) for r, a in zip(data['result'], data['answer'])]

PROMPT = get_prompt()
llm = LLM(model=args.model_id, dtype='bfloat16', tensor_parallel_size=torch.cuda.device_count())
if args.error_injection == 'mask-best':
llm.llm_engine.model_executor.driver_worker.model_runner.model.sampler = ErrorSampler()

data['orig_code'] = data.pop('code')
data.insert(len(data.keys()), 'masked', '')
data.insert(len(data.keys()), 'generation', '')
data.insert(len(data.keys()), 'code', '[')

prompt_inds = []
prompts = []
masked_codes = []
for i, row in data.iterrows():
if row['acc'] > 0:
question_1, question_2 = row['query'].splitlines()
masked = sample_one_masked_code(row['orig_code'], question_2)
if masked is not None:
prompt_inds.append(i)
masked_codes.append(masked)
prompt = PROMPT.format(QUESTION_1=question_1, QUESTION_2=question_2, CODE=masked)
prompts.append(prompt)

prompt_inds = np.array(prompt_inds)
data['masked'][prompt_inds] = np.array(masked_codes)

if args.error_injection == 'mask-best':
sampling_params = SamplingParams(max_tokens=512, temperature=1.0)
else:
assert args.error_injection == 'greedy'
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
generation = [o.outputs[0].text for o in outputs]
new_code = [x.split("```")[0] for x in generation]

data['generation'][prompt_inds] = np.array(generation)
data['code'][prompt_inds] = np.array(new_code)

data.to_csv(args.output, escapechar='\\')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('input')
parser.add_argument('--output')
parser.add_argument('--model_id', default='codellama/CodeLlama-7b-Instruct-hf')
parser.add_argument('--error_injection', default='mask-best', choices=['mask-best', 'greedy', ])
args = parser.parse_args()
if args.output is None:
args.output = args.input.replace('.csv', '.error-injection-by-{}.csv'.format(args.error_injection))
assert not os.path.exists(args.output), "Warning: will overwrite " + args.output

main(args)
95 changes: 95 additions & 0 deletions vdebugger/error_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Optional

import torch
import torch.nn as nn
from vllm.model_executor.layers.sampler import (
_apply_min_tokens_penalty, _apply_penalties, _apply_top_k_top_p,
_apply_min_p, _sample, _get_logprobs, _build_sampler_output
)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sequence import SamplerOutput

MASK_TOP_TH = 0.9
MASK_MAX_TIMES = 3


class ErrorSampler(nn.Module):
def forward(
self, logits: torch.Tensor, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
assert logits is not None
_, vocab_size = logits.shape

# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)

# Apply presence and frequency penalties.
if do_penalties:
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
sampling_tensors.output_tokens,
sampling_tensors.presence_penalties,
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)

# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

# ------ My custom code
# collect seq_inds
seq_inds = []
for seq_group in sampling_metadata.seq_groups:
seq_ids_, _ = seq_group
seq_inds += seq_ids_
# collect perturbed counters
perturbed = []
for i in seq_inds:
if not hasattr(sampling_metadata.seq_data[i], 'perturbed'):
sampling_metadata.seq_data[i].perturbed = 0
perturbed.append(sampling_metadata.seq_data[i].perturbed)
# help computation, cast to tensor
seq_inds = torch.LongTensor(seq_inds).to(logits.device)
perturbed = torch.LongTensor(perturbed).to(logits.device)
# mask which: probability threshold, and perturbed counter
top2 = logprobs.exp().topk(2, dim=1, sorted=True)
should_mask_top = torch.bitwise_and(
perturbed < MASK_MAX_TIMES, top2.values[:, 0] - top2.values[:, 1] < MASK_TOP_TH
)
if should_mask_top.sum() > 0:
# actual perturb
logits[should_mask_top, top2.indices[should_mask_top, 0]] = -float("inf")
# re-compute softmax
probs[should_mask_top] = torch.softmax(logits[should_mask_top], dim=-1, dtype=torch.float)
logprobs[should_mask_top] = torch.log_softmax(logits[should_mask_top], dim=-1, dtype=torch.float)
for i in seq_inds[should_mask_top]: # allow at most 3 perturbations
sampling_metadata.seq_data[int(i)].perturbed += 1
# ------ My custom code

# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata,
sampling_tensors)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)
3 changes: 1 addition & 2 deletions vdebugger/infer_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import pandas as pd
import sklearn.metrics
import torch
from dump_data import datasets
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from my_datasets import process_result
from my_datasets import datasets, process_result


def parse(code):
Expand Down
4 changes: 4 additions & 0 deletions vdebugger/my_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ def __repr__(self):
except:
print("Weird or invalid ImagePatch:", x)
return x


datasets = {'gqa': GQADataset(), 'vsr': VSRDataset(), 'tallyqa': TallyQADataset(), 'covr': COVRDataset(),
'refcoco': RefCOCODataset(), 'nlvr': NLVRDataset(), }

0 comments on commit b3b54c8

Please sign in to comment.