Skip to content

Commit

Permalink
add gemini_1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
KainingYing committed Apr 28, 2024
1 parent d60b78f commit 96c2254
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
11 changes: 10 additions & 1 deletion vlmeval/api/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
headers = 'Content-Type: application/json'


MODEL_DICT = {
"1": "gemini-pro-vision",
"1.5": "gemini-1.5-pro-latest"
}


class GeminiWrapper(BaseAPI):

is_api: bool = True

def __init__(self,
version: str = "1",
retry: int = 5,
wait: int = 5,
key: str = None,
Expand All @@ -18,6 +25,8 @@ def __init__(self,
max_tokens: int = 1024,
proxy: str = None,
**kwargs):
self.version = version
assert self.version in ["1", "1.5"]

self.fail_msg = 'Failed to obtain answer via API. '
self.max_tokens = max_tokens
Expand All @@ -44,7 +53,7 @@ def generate_inner(self, inputs, **kwargs) -> str:
assert isinstance(inputs, list)
pure_text = np.all([x['type'] == 'text' for x in inputs])
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel('gemini-pro-vision')
model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel(MODEL_DICT[self.version])
messages = self.build_msgs(inputs)
gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
gen_config.update(kwargs)
Expand Down
3 changes: 2 additions & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
'GPT4V_HIGH': partial(GPT4V, model='gpt-4-1106-vision-preview', temperature=0, img_size=-1, img_detail='high', retry=10),
'GPT4V_20240409': partial(GPT4V, model='gpt-4-turbo-2024-04-09', temperature=0, img_size=512, img_detail='low', retry=10),
'GPT4V_20240409_HIGH': partial(GPT4V, model='gpt-4-turbo-2024-04-09', temperature=0, img_size=-1, img_detail='high', retry=10),
'GeminiProVision': partial(GeminiProVision, temperature=0, retry=10),
'GeminiProVision': partial(GeminiProVision, version="1", temperature=0, retry=10),
'GeminiProVision_v1.5': partial(GeminiProVision, version="1.5", temperature=0, retry=10),
'QwenVLPlus': partial(QwenVLAPI, model='qwen-vl-plus', temperature=0, retry=10),
'QwenVLMax': partial(QwenVLAPI, model='qwen-vl-max', temperature=0, retry=10),
# Internal Only
Expand Down

0 comments on commit 96c2254

Please sign in to comment.