Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions scriptworker_client/src/scriptworker_client/github_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,58 @@ async def __aenter__(self):
async def __aexit__(self, *excinfo):
await self.close()

async def create_branch(self, branch_name: str, from_branch: Optional[str] = None, dry_run: bool = False) -> None:
"""Create a new branch in the repository.

Args:
branch_name (str): The name of the new branch to create.
from_branch (str): The branch to create the new branch from. Uses the
repository's default branch if unspecified (optional).
dry_run (bool): If it's a dry run
"""
# Get the repository ID and source OID in one query
source_branch = from_branch or "HEAD"
info_query = Template(
dedent("""
query getRepoInfo {
repository(owner: "$owner", name: "$repo") {
id
object(expression: "$branch") {
oid
}
}
}""")
)
str_info_query = info_query.substitute(owner=self.owner, repo=self.repo, branch=source_branch)
repo = (await self._client.execute(str_info_query))["repository"]

if repo.get("object") is None:
raise UnknownBranchError(f"branch '{source_branch}' not found in repo!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should also check if the branch already exists and bail early? I'm seeing the "branch creation passed but version bump got 500s" already...

Although there's an argument for saying that if the branch already exists it could've been created by mistake by a human and shouldn't be used and thus failing here is the right thing to do...


repo_id = repo["id"]
source_oid = repo["object"]["oid"]

create_branch_mutation = dedent("""
mutation ($input: CreateRefInput!) {
createRef(input: $input) {
ref {
name
}
}
}""")
variables = {
"input": {
"repositoryId": repo_id,
"name": f"refs/heads/{branch_name}",
"oid": source_oid,
}
}

verb = "Would create" if dry_run else "Creating"
log.debug(f"{verb} {branch_name} on repo {self.repo}[{repo_id}] at {source_branch}@{source_oid}")
if not dry_run:
await self._client.execute(create_branch_mutation, variables=variables)

async def commit(self, branch: str, message: str, additions: Optional[Dict[str, str]] = None, deletions: Optional[List[str]] = None) -> None:
"""Commit changes to the given repository and branch.

Expand Down
98 changes: 98 additions & 0 deletions scriptworker_client/tests/test_github_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,104 @@
from scriptworker_client.github_client import UnknownBranchError


@pytest.mark.asyncio
async def test_create_branch(aioresponses, github_client):
branch_name = "new-feature"
from_branch = "main"
repo_id = "R_abc123"
source_oid = "def456"

# First query gets repo ID and source OID
aioresponses.post(
GITHUB_GRAPHQL_ENDPOINT,
status=200,
payload={"data": {"repository": {"id": repo_id, "object": {"oid": source_oid}}}},
)
# Second query creates the branch
aioresponses.post(
GITHUB_GRAPHQL_ENDPOINT,
status=200,
payload={"data": {"createRef": {"ref": {"name": f"refs/heads/{branch_name}"}}}},
)

await github_client.create_branch(branch_name=branch_name, from_branch=from_branch)

aioresponses.assert_called()
key = ("POST", URL(GITHUB_GRAPHQL_ENDPOINT))
info_request = aioresponses.requests[key][-2][1]["json"]
create_request = aioresponses.requests[key][-1][1]["json"]

assert info_request == {
"query": dedent(
f"""
query getRepoInfo {{
repository(owner: "{github_client.owner}", name: "{github_client.repo}") {{
id
object(expression: "{from_branch}") {{
oid
}}
}}
}}"""
).strip(),
}
assert create_request == {
"query": dedent(
"""
mutation ($input: CreateRefInput!) {
createRef(input: $input) {
ref {
name
}
}
}"""
).strip(),
"variables": {
"input": {
"repositoryId": repo_id,
"name": f"refs/heads/{branch_name}",
"oid": source_oid,
}
},
}


@pytest.mark.asyncio
async def test_create_branch_dry_run(aioresponses, github_client):
branch_name = "new-feature"
from_branch = "main"
repo_id = "R_abc123"
source_oid = "def456"

# Only the info query should be made in dry_run mode
aioresponses.post(
GITHUB_GRAPHQL_ENDPOINT,
status=200,
payload={"data": {"repository": {"id": repo_id, "object": {"oid": source_oid}}}},
)

await github_client.create_branch(branch_name=branch_name, from_branch=from_branch, dry_run=True)

aioresponses.assert_called()
key = ("POST", URL(GITHUB_GRAPHQL_ENDPOINT))
# Only one request should be made (the info query, not the mutation)
assert len(aioresponses.requests[key]) == 1


@pytest.mark.asyncio
async def test_create_branch_unknown_source_branch(aioresponses, github_client):
branch_name = "new-feature"
from_branch = "nonexistent"

aioresponses.post(
GITHUB_GRAPHQL_ENDPOINT,
status=200,
payload={"data": {"repository": {"id": "R_abc123", "object": None}}},
)

with pytest.raises(UnknownBranchError, match=f"branch '{from_branch}' not found in repo!"):
await github_client.create_branch(branch_name=branch_name, from_branch=from_branch)


@pytest.mark.asyncio
async def test_commit(aioresponses, github_client):
branch = "main"
Expand Down
14 changes: 13 additions & 1 deletion treescript/src/treescript/data/treescript_task_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@
"files"
]
},
"create_branch_info": {
"type": "object",
"properties": {
"branch_name": {
"type": "string"
}
},
"required": [
"branch_name"
]
},
"android_l10n_import_info": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -343,7 +354,8 @@
"merge_day",
"android_l10n_import",
"android_l10n_sync",
"push"
"push",
"create_branch"
]
}
}
Expand Down
3 changes: 3 additions & 0 deletions treescript/src/treescript/github/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from scriptworker_client.github import extract_github_repo_owner_and_name
from scriptworker_client.github_client import GithubClient

from treescript.github.branch import create_branch
from treescript.github.versionmanip import bump_version
from treescript.util.task import get_source_repo, task_action_types

Expand All @@ -23,5 +24,7 @@ async def do_actions(config, task):
actions = task_action_types(config, task)

async with GithubClient(config["github_config"], owner, repo) as client:
if "create_branch" in actions:
await create_branch(client, task)
if "version_bump" in actions:
await bump_version(client, task)
41 changes: 41 additions & 0 deletions treescript/src/treescript/github/branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python
"""Treescript branch methods."""

from typing import Dict

from scriptworker_client.github_client import GithubClient

from treescript.exceptions import TreeScriptError
from treescript.util.task import get_branch, get_create_branch_info, should_push


def get_branch_name(task: Dict) -> str:
"""Get the branch_name from a task's create_branch_info.

Args:
task (Dict): The task definition containing create_branch_info.

Returns:
str: The name of the target branch.

Raises:
TreeScriptError: If branch_name is not specified in the task.
"""
create_branch_info = get_create_branch_info(task)
if "branch_name" not in create_branch_info:
raise TreeScriptError("branch_name is required in task")
return create_branch_info["branch_name"]


async def create_branch(client: GithubClient, task: Dict) -> None:
"""Create a new branch in the repository based on task configuration.

Args:
client (GithubClient): GithubClient instance for associated repo.
task (Dict): The task definition containing branch configuration.
"""
await client.create_branch(
branch_name=get_branch_name(task),
from_branch=get_branch(task),
dry_run=not should_push(task),
)
22 changes: 21 additions & 1 deletion treescript/src/treescript/util/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
VALID_ACTIONS = {
"comm": {"tag", "version_bump", "l10n_bump", "l10n_bump_github", "push", "merge_day"},
"gecko": {"tag", "version_bump", "l10n_bump", "l10n_bump_github", "push", "merge_day", "android_l10n_import", "android_l10n_sync"},
"mobile": {"version_bump"},
"mobile": {"version_bump", "create_branch"},
}

DONTBUILD_MSG = " DONTBUILD"
Expand Down Expand Up @@ -154,6 +154,26 @@ def get_version_bump_info(task):
return version_info


# get_create_branch_info {{{1
def get_create_branch_info(task):
"""Get the create branch information from the task metadata.

Args:
task: the task definition.

Returns:
object: the info structure as passed to the task payload.

Raises:
TaskVerificationError: If expected item missing from task definition.

"""
create_branch_info = task.get("payload", {}).get("create_branch_info")
if not create_branch_info:
raise TaskVerificationError("Requested create branch but no create_branch_info in payload")
return create_branch_info


# get_l10n_bump_info {{{1
def get_l10n_bump_info(task, raise_on_empty=True):
"""Get the l10n bump information from the task metadata.
Expand Down
75 changes: 75 additions & 0 deletions treescript/tests/test_github_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python
"""Tests for treescript github branch methods."""

import pytest

import treescript.github.branch as branch
from treescript.exceptions import TreeScriptError


def test_get_branch_name():
task = {
"payload": {
"create_branch_info": {
"branch_name": "release-v1.0",
}
}
}
assert branch.get_branch_name(task) == "release-v1.0"


def test_get_branch_name_not_specified():
task = {
"payload": {
"create_branch_info": {
"from_branch": "main",
}
}
}
with pytest.raises(TreeScriptError):
branch.get_branch_name(task)


@pytest.mark.asyncio
async def test_create_branch(mocker, github_client):
task = {
"payload": {
"branch": "main",
"create_branch_info": {
"branch_name": "release-v1.0",
},
}
}

mock_create_branch = mocker.patch.object(github_client, "create_branch")

await branch.create_branch(github_client, task)

mock_create_branch.assert_called_once_with(
branch_name="release-v1.0",
from_branch="main",
dry_run=False,
)


@pytest.mark.asyncio
async def test_create_branch_dry_run(mocker, github_client):
task = {
"payload": {
"branch": "main",
"create_branch_info": {
"branch_name": "release-v1.0",
},
"dry_run": True,
}
}

mock_create_branch = mocker.patch.object(github_client, "create_branch")

await branch.create_branch(github_client, task)

mock_create_branch.assert_called_once_with(
branch_name="release-v1.0",
from_branch="main",
dry_run=True,
)