Skip to content

Commit

Permalink
Use chatId from URL rather than from payload for chats (#700)
Browse files Browse the repository at this point in the history
### Motivation and Context
The verify access to a chat, we use HandleRequest() with the chatId
provided. Currently, we get this from the payload, which can differ from
the chatId from the URL, which opens us to a security problem where a
user could inject an arbitrary chatId in the payload, which doesn't
match what's in the URL.

### Description
- Use chatId from URL and only from URL
- Add integrations test to validate this

### Contribution Checklist
- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [Contribution
Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
glahaye authored Dec 9, 2023
1 parent 8e8e284 commit 6a31dd5
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 53 deletions.
50 changes: 50 additions & 0 deletions integration-tests/ChatTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text.Json;
using CopilotChat.WebApi.Models.Request;
using CopilotChat.WebApi.Models.Response;
using Xunit;
using static CopilotChat.WebApi.Models.Storage.CopilotChatMessage;

namespace ChatCopilotIntegrationTests;

public class ChatTests : ChatCopilotIntegrationTest
{
[Fact]
public async void ChatMessagePostSucceedsWithValidInput()
{
await this.SetUpAuth();

// Create chat session
var createChatParams = new CreateChatParameters() { Title = nameof(ChatMessagePostSucceedsWithValidInput) };
HttpResponseMessage response = await this._httpClient.PostAsJsonAsync("chats", createChatParams);
response.EnsureSuccessStatusCode();

var contentStream = await response.Content.ReadAsStreamAsync();
var createChatResponse = await JsonSerializer.DeserializeAsync<CreateChatResponse>(contentStream, new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
Assert.NotNull(createChatResponse);

// Ask something to the bot
var ask = new Ask
{
Input = "Who is Satya Nadella?",
Variables = new KeyValuePair<string, string>[] { new("MessageType", ChatMessageType.Message.ToString()) }
};
response = await this._httpClient.PostAsJsonAsync($"chats/{createChatResponse.ChatSession.Id}/messages", ask);
response.EnsureSuccessStatusCode();

contentStream = await response.Content.ReadAsStreamAsync();
var askResult = await JsonSerializer.DeserializeAsync<AskResult>(contentStream, new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
Assert.NotNull(askResult);
Assert.False(string.IsNullOrEmpty(askResult.Value));


// Clean up
response = await this._httpClient.DeleteAsync($"chats/{createChatResponse.ChatSession.Id}");
response.EnsureSuccessStatusCode();
}
}

28 changes: 22 additions & 6 deletions webapi/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ public async Task<IActionResult> ChatAsync(
[FromServices] IKernel kernel,
[FromServices] IHubContext<MessageRelayHub> messageRelayHubContext,
[FromServices] CopilotChatPlanner planner,
[FromServices] AskConverter askConverter,
[FromServices] ChatSessionRepository chatSessionRepository,
[FromServices] ChatParticipantRepository chatParticipantRepository,
[FromServices] IAuthInfo authInfo,
Expand All @@ -108,7 +107,7 @@ public async Task<IActionResult> ChatAsync(
{
this._logger.LogDebug("Chat message received.");

return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
return await this.HandleRequest(ChatFunctionName, kernel, messageRelayHubContext, planner, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
}

/// <summary>
Expand All @@ -135,7 +134,6 @@ public async Task<IActionResult> ProcessPlanAsync(
[FromServices] IKernel kernel,
[FromServices] IHubContext<MessageRelayHub> messageRelayHubContext,
[FromServices] CopilotChatPlanner planner,
[FromServices] AskConverter askConverter,
[FromServices] ChatSessionRepository chatSessionRepository,
[FromServices] ChatParticipantRepository chatParticipantRepository,
[FromServices] IAuthInfo authInfo,
Expand All @@ -144,7 +142,7 @@ public async Task<IActionResult> ProcessPlanAsync(
{
this._logger.LogDebug("plan request received.");

return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, askConverter, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
return await this.HandleRequest(ProcessPlanFunctionName, kernel, messageRelayHubContext, planner, chatSessionRepository, chatParticipantRepository, authInfo, ask, chatId.ToString());
}

/// <summary>
Expand All @@ -166,15 +164,14 @@ private async Task<IActionResult> HandleRequest(
IKernel kernel,
IHubContext<MessageRelayHub> messageRelayHubContext,
CopilotChatPlanner planner,
AskConverter askConverter,
ChatSessionRepository chatSessionRepository,
ChatParticipantRepository chatParticipantRepository,
IAuthInfo authInfo,
Ask ask,
string chatId)
{
// Put ask's variables in the context we will use.
var contextVariables = askConverter.GetContextVariables(ask);
var contextVariables = GetContextVariables(ask, authInfo, chatId);

// Verify that the chat exists and that the user has access to it.
ChatSession? chat = null;
Expand Down Expand Up @@ -415,6 +412,25 @@ await planner.Kernel.ImportOpenAIPluginFunctionsAsync(
return;
}

private static ContextVariables GetContextVariables(Ask ask, IAuthInfo authInfo, string chatId)
{
const string UserIdKey = "userId";
const string UserNameKey = "userName";
const string ChatIdKey = "chatId";

var contextVariables = new ContextVariables(ask.Input);
foreach (var variable in ask.Variables)
{
contextVariables.Set(variable.Key, variable.Value);
}

contextVariables.Set(UserIdKey, authInfo.UserId);
contextVariables.Set(UserNameKey, authInfo.Name);
contextVariables.Set(ChatIdKey, chatId);

return contextVariables;
}

/// <summary>
/// Dispose of the object.
/// </summary>
Expand Down
5 changes: 0 additions & 5 deletions webapi/Extensions/ServiceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ internal static void AddOptions<TOptions>(this IServiceCollection services, ICon
.PostConfigure(TrimStringProperties);
}

internal static IServiceCollection AddUtilities(this IServiceCollection services)
{
return services.AddScoped<AskConverter>();
}

internal static IServiceCollection AddPlugins(this IServiceCollection services, IConfiguration configuration)
{
var plugins = configuration.GetSection("Plugins").Get<List<Plugin>>() ?? new List<Plugin>();
Expand Down
1 change: 0 additions & 1 deletion webapi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ public static async Task Main(string[] args)
.AddOptions(builder.Configuration)
.AddPersistentChatStore()
.AddPlugins(builder.Configuration)
.AddUtilities()
.AddChatCopilotAuthentication(builder.Configuration)
.AddChatCopilotAuthorization();

Expand Down
41 changes: 0 additions & 41 deletions webapi/Utilities/AskConverter.cs

This file was deleted.

0 comments on commit 6a31dd5

Please sign in to comment.