Skip to content

Commit

Permalink
💥 Simpler credentials passing around (#918)
Browse files Browse the repository at this point in the history
Fix #149

Should have backwards compatibilty, just marked as deprecated
  • Loading branch information
coyotte508 committed Sep 23, 2024
1 parent 68602cd commit 400ea89
Show file tree
Hide file tree
Showing 31 changed files with 466 additions and 452 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

await createRepo({
repo: {type: "model", name: "my-user/nlp-model"},
credentials: {accessToken: HF_TOKEN}
accessToken: HF_TOKEN
});

await uploadFile({
repo: "my-user/nlp-model",
credentials: {accessToken: HF_TOKEN},
accessToken: HF_TOKEN,
// Can work with native File in browsers
file: {
path: "pytorch_model.bin",
Expand Down Expand Up @@ -79,7 +79,7 @@ Then import the libraries in your code:
import { HfInference } from "@huggingface/inference";
import { HfAgent } from "@huggingface/agents";
import { createRepo, commit, deleteRepo, listFiles } from "@huggingface/hub";
import type { RepoId, Credentials } from "@huggingface/hub";
import type { RepoId } from "@huggingface/hub";
```

### From CDN or Static hosting
Expand Down Expand Up @@ -182,12 +182,12 @@ const HF_TOKEN = "hf_...";

await createRepo({
repo: "my-user/nlp-model", // or {type: "model", name: "my-user/nlp-test"},
credentials: {accessToken: HF_TOKEN}
accessToken: HF_TOKEN
});

await uploadFile({
repo: "my-user/nlp-model",
credentials: {accessToken: HF_TOKEN},
accessToken: HF_TOKEN,
// Can work with native File in browsers
file: {
path: "pytorch_model.bin",
Expand All @@ -197,7 +197,7 @@ await uploadFile({

await deleteFiles({
repo: {type: "space", name: "my-user/my-space"}, // or "spaces/my-user/my-space"
credentials: {accessToken: HF_TOKEN},
accessToken: HF_TOKEN,
paths: ["README.md", ".gitattributes"]
});
```
Expand Down
17 changes: 8 additions & 9 deletions packages/hub/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,21 @@ Learn how to find free models using the hub package in this [interactive tutoria

```ts
import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI } from "@huggingface/hub";
import type { RepoDesignation, Credentials } from "@huggingface/hub";
import type { RepoDesignation } from "@huggingface/hub";

const repo: RepoDesignation = { type: "model", name: "myname/some-model" };
const credentials: Credentials = { accessToken: "hf_..." };

const {name: username} = await whoAmI({credentials});
const {name: username} = await whoAmI({accessToken: "hf_..."});

for await (const model of listModels({search: {owner: username}, credentials})) {
for await (const model of listModels({search: {owner: username}, accessToken: "hf_..."})) {
console.log("My model:", model);
}

await createRepo({ repo, credentials, license: "mit" });
await createRepo({ repo, accessToken: "hf_...", license: "mit" });

await uploadFiles({
repo,
credentials,
accessToken: "hf_...",
files: [
// path + blob content
{
Expand All @@ -70,23 +69,23 @@ await uploadFiles({

for await (const progressEvent of await uploadFilesWithProgress({
repo,
credentials,
accessToken: "hf_...",
files: [
...
],
})) {
console.log(progressEvent);
}

await deleteFile({repo, credentials, path: "myfile.bin"});
await deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"});

await (await downloadFile({ repo, path: "README.md" })).text();

for await (const fileInfo of listFiles({repo})) {
console.log(fileInfo);
}

await deleteRepo({ repo, credentials });
await deleteRepo({ repo, accessToken: "hf_..." });
```

## OAuth Login
Expand Down
16 changes: 4 additions & 12 deletions packages/hub/src/lib/commit.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ describe("commit", () => {
};

await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo,
license: "mit",
Expand All @@ -50,9 +48,7 @@ describe("commit", () => {
await commit({
repo,
title: "Some commit",
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
operations: [
{
Expand Down Expand Up @@ -135,9 +131,7 @@ size ${lfsContent.length}
};

await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
repo,
hubUrl: TEST_HUB_URL,
});
Expand All @@ -163,9 +157,7 @@ size ${lfsContent.length}
);
await commit({
repo,
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
title: "upload model",
operations,
Expand Down
16 changes: 8 additions & 8 deletions packages/hub/src/lib/commit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type {
ApiPreuploadRequest,
ApiPreuploadResponse,
} from "../types/api/api-commit";
import type { Credentials, RepoDesignation } from "../types/public";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { chunk } from "../utils/chunk";
import { promisesQueue } from "../utils/promisesQueue";
Expand Down Expand Up @@ -54,12 +54,11 @@ type CommitBlob = Omit<CommitFile, "content"> & { content: Blob };
export type CommitOperation = CommitDeletedEntry | CommitFile /* | CommitRenameFile */;
type CommitBlobOperation = Exclude<CommitOperation, CommitFile> | CommitBlob;

export interface CommitParams {
export type CommitParams = {
title: string;
description?: string;
repo: RepoDesignation;
operations: CommitOperation[];
credentials?: Credentials;
/** @default "main" */
branch?: string;
/**
Expand All @@ -82,7 +81,8 @@ export interface CommitParams {
*/
fetch?: typeof fetch;
abortSignal?: AbortSignal;
}
// Credentials are optional due to custom fetch functions or cookie auth
} & Partial<CredentialsParams>;

export interface CommitOutput {
pullRequestUrl?: string;
Expand Down Expand Up @@ -121,7 +121,7 @@ export type CommitProgressEvent =
* Can be exposed later to offer fine-tuned progress info
*/
export async function* commitIter(params: CommitParams): AsyncGenerator<CommitProgressEvent, CommitOutput> {
checkCredentials(params.credentials);
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);
yield { event: "phase", phase: "preuploading" };

Expand Down Expand Up @@ -189,7 +189,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
Expand Down Expand Up @@ -263,7 +263,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
Accept: "application/vnd.git-lfs+json",
"Content-Type": "application/vnd.git-lfs+json",
},
Expand Down Expand Up @@ -468,7 +468,7 @@ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitPr
{
method: "POST",
headers: {
...(params.credentials && { Authorization: `Bearer ${params.credentials.accessToken}` }),
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
"Content-Type": "application/x-ndjson",
},
body: [
Expand Down
27 changes: 14 additions & 13 deletions packages/hub/src/lib/count-commits.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { Credentials, RepoDesignation } from "../types/public";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

export async function countCommits(params: {
credentials?: Credentials;
repo: RepoDesignation;
/**
* Revision to list commits from. Defaults to the default branch.
*/
revision?: string;
hubUrl?: string;
fetch?: typeof fetch;
}): Promise<number> {
checkCredentials(params.credentials);
export async function countCommits(
params: {
repo: RepoDesignation;
/**
* Revision to list commits from. Defaults to the default branch.
*/
revision?: string;
hubUrl?: string;
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<number> {
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);

// Could upgrade to 1000 commits per page
Expand All @@ -23,7 +24,7 @@ export async function countCommits(params: {
}?limit=1`;

const res: Response = await (params.fetch ?? fetch)(url, {
headers: params.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : {},
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
});

if (!res.ok) {
Expand Down
12 changes: 3 additions & 9 deletions packages/hub/src/lib/create-repo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ describe("createRepo", () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
repo: {
name: repoName,
type: "model",
Expand Down Expand Up @@ -62,9 +60,7 @@ describe("createRepo", () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo: repoName,
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
Expand All @@ -88,9 +84,7 @@ describe("createRepo", () => {
const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;

const result = await createRepo({
credentials: {
accessToken: TEST_ACCESS_TOKEN,
},
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo: repoName,
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
Expand Down
41 changes: 21 additions & 20 deletions packages/hub/src/lib/create-repo.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiCreateRepoPayload } from "../types/api/api-create-repo";
import type { Credentials, RepoDesignation, SpaceSdk } from "../types/public";
import type { CredentialsParams, RepoDesignation, SpaceSdk } from "../types/public";
import { base64FromBytes } from "../utils/base64FromBytes";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

export async function createRepo(params: {
repo: RepoDesignation;
credentials: Credentials;
private?: boolean;
license?: string;
/**
* Only a few lightweight files are supported at repo creation
*/
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): Promise<{ repoUrl: string }> {
checkCredentials(params.credentials);
export async function createRepo(
params: {
repo: RepoDesignation;
private?: boolean;
license?: string;
/**
* Only a few lightweight files are supported at repo creation
*/
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
/** @required for when {@link repo.type} === "space" */
sdk?: SpaceSdk;
hubUrl?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & CredentialsParams
): Promise<{ repoUrl: string }> {
const accessToken = checkCredentials(params);
const repoId = toRepoId(params.repo);
const [namespace, repoName] = repoId.name.split("/");

Expand Down Expand Up @@ -61,7 +62,7 @@ export async function createRepo(params: {
: undefined,
} satisfies ApiCreateRepoPayload),
headers: {
Authorization: `Bearer ${params.credentials.accessToken}`,
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
},
});
Expand Down
9 changes: 3 additions & 6 deletions packages/hub/src/lib/delete-file.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@ describe("deleteFile", () => {
it("should delete a file", async () => {
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
const repo = { type: "model", name: repoName } satisfies RepoId;
const credentials = {
accessToken: TEST_ACCESS_TOKEN,
};

try {
const result = await createRepo({
credentials,
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
repo,
files: [
Expand All @@ -39,7 +36,7 @@ describe("deleteFile", () => {

assert.strictEqual(await content?.text(), "file1");

await deleteFile({ path: "file1", repo, credentials, hubUrl: TEST_HUB_URL });
await deleteFile({ path: "file1", repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });

content = await downloadFile({
repo,
Expand All @@ -59,7 +56,7 @@ describe("deleteFile", () => {
} finally {
await deleteRepo({
repo,
credentials,
accessToken: TEST_ACCESS_TOKEN,
hubUrl: TEST_HUB_URL,
});
}
Expand Down
Loading

0 comments on commit 400ea89

Please sign in to comment.