Skip to content

Commit

Permalink
🎶 add methods ListModels and GetModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Jochen Kirstätter committed Feb 29, 2024
1 parent 337bff6 commit bb24743
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 5 deletions.
39 changes: 37 additions & 2 deletions src/GenerativeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
#endif
using System.Text.RegularExpressions;
using System.Linq;
using System.Net;

namespace Mscc.GenerativeAI
{
public class GenerativeModel
{
private readonly string urlGoogleAI = "https://generativelanguage.googleapis.com/{version}/models/{model}:{method}?key={apiKey}";
private readonly string endpointGoogleAI = "generativelanguage.googleapis.com";
private readonly string urlGoogleAI = "https://{endpointGoogleAI}/{version}/models/{model}:{method}?key={apiKey}";
private readonly string urlVertexAI = "https://{region}-aiplatform.googleapis.com/{version}/projects/{projectId}/locations/{region}/publishers/{publisher}/models/{model}:{method}";
private readonly string model;
private readonly string apiKey = default;
Expand Down Expand Up @@ -116,6 +118,38 @@ public GenerativeModel(string projectId, string region, string model, Generation
this.safetySettings = safetySettings;
}

/// <summary>
/// Get a list of available models and description.
/// </summary>
/// <returns></returns>
public async Task<List<ModelResponse>> ListModels()
{
if (string.IsNullOrEmpty(apiKey))
throw new NotSupportedException();

var url = $"https://{endpointGoogleAI}/{Version}/models?key={apiKey}";
var response = await Client.GetAsync(url);
response.EnsureSuccessStatusCode();
var models = await Deserialize<ListModelsResponse>(response);
return models?.Models!;
}

/// <summary>
/// Get information about the model, including default values.
/// </summary>
/// <param name="model">The model to query</param>
/// <returns></returns>
public async Task<ModelResponse> GetModel(string model = Model.GeminiPro)
{
if (string.IsNullOrEmpty(apiKey))
throw new NotSupportedException();

var url = $"https://{endpointGoogleAI}/{Version}/models/{model}?key={apiKey}";
var response = await Client.GetAsync(url);
response.EnsureSuccessStatusCode();
return await Deserialize<ModelResponse>(response);
}

/// <summary>
/// Produces a single request and response.
/// </summary>
Expand Down Expand Up @@ -264,7 +298,7 @@ public ChatSession StartChat(List<Content>? history = null, GenerationConfig? ge
/// <param name="url"></param>
/// <param name="method"></param>
/// <returns></returns>
private string ParseUrl(string url, string? method)
private string ParseUrl(string url, string? method = default)
{
var replacements = GetReplacements();
replacements.Add("method", method);
Expand All @@ -278,6 +312,7 @@ Dictionary<string, string> GetReplacements()
{
return new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase)
{
{ "endpointGoogleAI", endpointGoogleAI },
{ "version", Version },
{ "model", model },
{ "apikey", apiKey },
Expand Down
11 changes: 11 additions & 0 deletions src/ListModelsResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System.Collections.Generic;
#endif

namespace Mscc.GenerativeAI
{
internal class ListModelsResponse
{
public List<ModelResponse>? Models { get; set; }
}
}
23 changes: 23 additions & 0 deletions src/ModelResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#if NET472_OR_GREATER || NETSTANDARD2_0
using System.Collections.Generic;
#endif

using System.Diagnostics;

namespace Mscc.GenerativeAI
{
[DebuggerDisplay("{DisplayName} ({Name})")]
public class ModelResponse
{
public string? Name { get; set; } = default;
public string? Version { get; set; } = default;
public string? DisplayName { get; set; } = default;
public string? Description { get; set; } = default;
public int? InputTokenLimit { get; set; } = default;
public int? OutputTokenLimit { get; set; } = default;
public List<string>? SupportedGenerationMethods { get; set; }
public float? Temperature { get; set; } = default;
public float? TopP { get; set; } = default;
public int? TopK { get; set; } = default;
}
}
4 changes: 2 additions & 2 deletions src/Mscc.GenerativeAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
<RepositoryUrl>https://github.com/mscraftsman/generative-ai</RepositoryUrl>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>False</PackageRequireLicenseAcceptance>
<PackageReleaseNotes>- Initial Release</PackageReleaseNotes>
<Version>0.2.1</Version>
<PackageReleaseNotes>- Add methods ListModels and GetModel</PackageReleaseNotes>
<Version>0.3.1</Version>
</PropertyGroup>

<PropertyGroup Condition="$(TargetFramework.StartsWith('net6.0')) or $(TargetFramework.StartsWith('net7.0')) or $(TargetFramework.StartsWith('net8.0'))">
Expand Down
39 changes: 39 additions & 0 deletions tests/GoogleAi_GeminiPro_Should.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,45 @@ public void Initialize_Model()
model.Name().Should().Be(Model.GeminiPro);
}

[Fact]
public async void List_Models()
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey);

// Act
var sut = await model.ListModels();

// Assert
sut.Should().NotBeNull();
sut.Should().NotBeNull().And.HaveCountGreaterThanOrEqualTo(1);
sut.ForEach(x =>
{
output.WriteLine($"Model: {x.DisplayName} ({x.Name})");
x.SupportedGenerationMethods.ForEach(m => output.WriteLine($" Method: {m}"));
});
}

[Theory]
[InlineData(Model.GeminiPro)]
[InlineData(Model.GeminiProVision)]
[InlineData(Model.BisonText)]
[InlineData(Model.BisonChat)]
public async void Get_Model_Information(string modelName)
{
// Arrange
var model = new GenerativeModel(apiKey: fixture.ApiKey);

// Act
var sut = await model.GetModel(model: modelName);

// Assert
sut.Should().NotBeNull();
sut.Name.Should().Be($"models/{modelName}");
output.WriteLine($"Model: {sut.DisplayName} ({sut.Name})");
sut.SupportedGenerationMethods.ForEach(m => output.WriteLine($" Method: {m}"));
}

[Fact]
public async void Generate_Content()
{
Expand Down
29 changes: 29 additions & 0 deletions tests/VertexAi_GeminiPro_Should.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using FluentAssertions;
using Mscc.GenerativeAI;
using System;
using System.Collections.Generic;
using System.Linq;
using Xunit;
Expand Down Expand Up @@ -60,6 +61,34 @@ public void Initialize_Model()
model.Name().Should().Be(Model.Gemini10Pro);
}

[Fact]
public async void List_Models()
{
// Arrange
var vertex = new VertexAI(projectId: fixture.ProjectId, region: fixture.Region);
var model = vertex.GenerativeModel(model: this.model);
model.AccessToken = fixture.AccessToken;

// Act & Assert
await Assert.ThrowsAsync<NotSupportedException>(() => model.ListModels());
}

[Theory]
[InlineData(Model.GeminiPro)]
[InlineData(Model.GeminiProVision)]
[InlineData(Model.BisonText)]
[InlineData(Model.BisonChat)]
public async void Get_Model_Information(string modelName)
{
// Arrange
var vertex = new VertexAI(projectId: fixture.ProjectId, region: fixture.Region);
var model = vertex.GenerativeModel(model: this.model);
model.AccessToken = fixture.AccessToken;

// Act & Assert
await Assert.ThrowsAsync<NotSupportedException>(() => model.GetModel(model: modelName));
}

[Fact]
public async void Generate_Content()
{
Expand Down
2 changes: 1 addition & 1 deletion tests/appsettings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"api_key": "",
"project_id": "",
"region": "",
"region": "us-central1",
"access_token": ""
}

0 comments on commit bb24743

Please sign in to comment.