GitOrigin-RevId: 6370f6ea785709295b6abcf9c60717cacf3ac432
This commit is contained in:
commit
8157b39ea4
308 changed files with 29248 additions and 0 deletions
33
.forgejo/workflows/ci.yml
Normal file
33
.forgejo/workflows/ci.yml
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push: {}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: "docker"
|
||||
container:
|
||||
image: forgejo.csbx.dev/acmcarther/coder-dev-base-image:4
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install bazelisk
|
||||
run: |
|
||||
curl -fLO "https://bin-cache.csbx.dev/bazelisk/v1.27.0/bazelisk-linux-amd64"
|
||||
mkdir -p "${GITHUB_WORKSPACE}/bin/"
|
||||
mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
chmod +x "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
|
||||
- name: Test
|
||||
env:
|
||||
BAZELISK_BASE_URL: "https://bin-cache.csbx.dev/bazel"
|
||||
run: |
|
||||
${GITHUB_WORKSPACE}/bin/bazel test --config=ci --java_runtime_version=remotejdk_21 //...
|
||||
|
||||
- name: Upload test logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: bazel-test-logs
|
||||
path: bazel-testlogs/
|
||||
46
.forgejo/workflows/publish-container-images.yml
Normal file
46
.forgejo/workflows/publish-container-images.yml
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
name: Publish Container Images
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * *' # Run at 3 AM
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: docker
|
||||
container:
|
||||
image: forgejo.csbx.dev/acmcarther/coder-dev-base-image:4
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install bazelisk
|
||||
run: |
|
||||
curl -fLO "http://bin-cache-http.dev.svc.cluster.local/bazelisk/v1.27.0/bazelisk-linux-amd64"
|
||||
mkdir -p "${GITHUB_WORKSPACE}/bin/"
|
||||
mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
chmod +x "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
echo "${GITHUB_WORKSPACE}/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Login to Forgejo Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: forgejo.csbx.dev
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.YESOD_PACKAGE_TOKEN }}
|
||||
|
||||
- name: Publish Coder Dev Base Image
|
||||
env:
|
||||
BAZELISK_BASE_URL: "http://bin-cache-http.dev.svc.cluster.local/bazel"
|
||||
# rules_oci respects DOCKER_CONFIG or looks in ~/.docker/config.json
|
||||
# The login action typically sets up ~/.docker/config.json
|
||||
run: |
|
||||
# Ensure DOCKER_CONFIG is set to where login-action writes (default home)
|
||||
export DOCKER_CONFIG=$HOME/.docker
|
||||
|
||||
SHORT_SHA=$(git rev-parse --short HEAD)
|
||||
TAG="5-${SHORT_SHA}"
|
||||
|
||||
echo "Pushing image with tag: ${TAG}"
|
||||
bazel run --config=remote //k8s/container/coder-dev-base-image:push -- --tag ${TAG}
|
||||
157
.forgejo/workflows/publish-homebrew-packages.yml
Normal file
157
.forgejo/workflows/publish-homebrew-packages.yml
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
name: Publish Homebrew Packages
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 2 * * *' # Run at 2 AM
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: docker
|
||||
container:
|
||||
image: forgejo.csbx.dev/acmcarther/coder-dev-base-image:4
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- package: tts-client
|
||||
target: //experimental/users/acmcarther/llm/tts_grpc:tts_client_go
|
||||
binary_name: tts-client-darwin-arm64
|
||||
- package: litellm-client
|
||||
target: //experimental/users/acmcarther/llm/litellm_grpc:client_go
|
||||
binary_name: litellm-client-darwin-arm64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install bazelisk
|
||||
run: |
|
||||
curl -fLO "http://bin-cache-http.dev.svc.cluster.local/bazelisk/v1.27.0/bazelisk-linux-amd64"
|
||||
mkdir -p "${GITHUB_WORKSPACE}/bin/"
|
||||
mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
chmod +x "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
echo "${GITHUB_WORKSPACE}/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Build Binary (Darwin ARM64)
|
||||
env:
|
||||
BAZELISK_BASE_URL: "http://bin-cache-http.dev.svc.cluster.local/bazel"
|
||||
run: |
|
||||
bazel build --config=remote --platforms=@rules_go//go/toolchain:darwin_arm64 ${{ matrix.target }}
|
||||
|
||||
- name: Publish Artifact
|
||||
env:
|
||||
YESOD_PACKAGE_TOKEN: ${{ secrets.YESOD_PACKAGE_TOKEN }}
|
||||
PACKAGE_NAME: ${{ matrix.package }}
|
||||
BINARY_NAME: ${{ matrix.binary_name }}
|
||||
TARGET: ${{ matrix.target }}
|
||||
run: |
|
||||
# Calculate Version
|
||||
SHORT_SHA=$(git rev-parse --short HEAD)
|
||||
VERSION="0.0.2-${SHORT_SHA}"
|
||||
PACKAGE_URL="https://forgejo.csbx.dev/api/packages/acmcarther/generic/${PACKAGE_NAME}/${VERSION}/${BINARY_NAME}"
|
||||
|
||||
echo "Publishing ${PACKAGE_NAME} version ${VERSION}"
|
||||
|
||||
# Locate the binary using cquery for precision
|
||||
BINARY_PATH=$(bazel cquery --config=remote --platforms=@rules_go//go/toolchain:darwin_arm64 --output=files ${TARGET} 2>/dev/null)
|
||||
|
||||
if [ -z "$BINARY_PATH" ]; then
|
||||
echo "cquery failed to find the binary path for ${TARGET}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload
|
||||
echo "Uploading to ${PACKAGE_URL}..."
|
||||
curl -v --fail \
|
||||
-H "Authorization: token $YESOD_PACKAGE_TOKEN" \
|
||||
-X PUT \
|
||||
"${PACKAGE_URL}" \
|
||||
-T "$BINARY_PATH"
|
||||
|
||||
update-homebrew:
|
||||
needs: build-and-publish
|
||||
runs-on: docker
|
||||
container:
|
||||
image: forgejo.csbx.dev/acmcarther/coder-dev-base-image:4
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install bazelisk
|
||||
run: |
|
||||
curl -fLO "http://bin-cache-http.dev.svc.cluster.local/bazelisk/v1.27.0/bazelisk-linux-amd64"
|
||||
mkdir -p "${GITHUB_WORKSPACE}/bin/"
|
||||
mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
chmod +x "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
echo "${GITHUB_WORKSPACE}/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Generate and Update Formulas
|
||||
env:
|
||||
BAZELISK_BASE_URL: "http://bin-cache-http.dev.svc.cluster.local/bazel"
|
||||
PACKAGES: "tts-client litellm-client"
|
||||
run: |
|
||||
SHORT_SHA=$(git rev-parse --short HEAD)
|
||||
VERSION="0.0.2-${SHORT_SHA}"
|
||||
|
||||
for PACKAGE in $PACKAGES; do
|
||||
echo "Processing $PACKAGE..."
|
||||
|
||||
# Convert hyphen to underscore for bazel target name convention
|
||||
TARGET_NAME="generate_${PACKAGE//-/_}_rb"
|
||||
BAZEL_TARGET="//homebrew:${TARGET_NAME}"
|
||||
BINARY_NAME="${PACKAGE}-darwin-arm64"
|
||||
PACKAGE_URL="https://forgejo.csbx.dev/api/packages/acmcarther/generic/${PACKAGE}/${VERSION}/${BINARY_NAME}"
|
||||
|
||||
echo "Building formula target: $BAZEL_TARGET"
|
||||
bazel build --config=remote --platforms=@rules_go//go/toolchain:darwin_arm64 $BAZEL_TARGET
|
||||
|
||||
FORMULA_PATH=$(bazel cquery --config=remote --platforms=@rules_go//go/toolchain:darwin_arm64 --output=files $BAZEL_TARGET 2>/dev/null)
|
||||
|
||||
if [ -z "$FORMULA_PATH" ]; then
|
||||
echo "Failed to find generated formula for $PACKAGE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Generated formula at: $FORMULA_PATH"
|
||||
|
||||
# Inject Version and URL
|
||||
sed -e "s|{VERSION}|${VERSION}|g" \
|
||||
-e "s|{URL}|${PACKAGE_URL}|g" \
|
||||
"$FORMULA_PATH" > homebrew/${PACKAGE}.rb
|
||||
|
||||
echo "Updated homebrew/${PACKAGE}.rb"
|
||||
done
|
||||
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config --global user.name "Copybara"
|
||||
git config --global user.email "copybara@csbx.dev"
|
||||
git config --global url."https://acmcarther:${{ secrets.HOMEBREW_REPO_TOKEN }}@forgejo.csbx.dev/acmcarther/yesod-homebrew-tools.git".insteadOf "https://forgejo.csbx.dev/acmcarther/yesod-homebrew-tools.git"
|
||||
|
||||
- name: Run Copybara
|
||||
env:
|
||||
BAZELISK_BASE_URL: "http://bin-cache-http.dev.svc.cluster.local/bazel"
|
||||
run: |
|
||||
# Stage the new formulas
|
||||
git add homebrew/*.rb
|
||||
|
||||
# Check if there are changes
|
||||
if git diff --cached --quiet; then
|
||||
echo "No changes to commit."
|
||||
else
|
||||
SHORT_SHA=$(git rev-parse --short HEAD)
|
||||
VERSION="0.0.2-${SHORT_SHA}"
|
||||
git commit -m "Update homebrew formulas to ${VERSION} [skip ci]"
|
||||
|
||||
# Patch copy.bara.sky to use local origin
|
||||
sed -i "s|sourceUrl = .*|sourceUrl = \"file://${GITHUB_WORKSPACE}\"|" "${GITHUB_WORKSPACE}/tools/copybara/homebrew/copy.bara.sky"
|
||||
|
||||
# Run Copybara
|
||||
bazel run //tools/copybara:copybara --config=remote --java_runtime_version=remotejdk_21 -- \
|
||||
migrate \
|
||||
"${GITHUB_WORKSPACE}/tools/copybara/homebrew/copy.bara.sky" \
|
||||
--force
|
||||
fi
|
||||
43
.forgejo/workflows/publish-yesod-mirror.yml
Normal file
43
.forgejo/workflows/publish-yesod-mirror.yml
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
name: Publish Yesod Mirror
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 1 * * *' # Run at 1 AM
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
update-yesod-mirror:
|
||||
runs-on: docker
|
||||
container:
|
||||
image: forgejo.csbx.dev/acmcarther/coder-dev-base-image:4
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install bazelisk
|
||||
run: |
|
||||
curl -fLO "http://bin-cache-http.dev.svc.cluster.local/bazelisk/v1.27.0/bazelisk-linux-amd64"
|
||||
mkdir -p "${GITHUB_WORKSPACE}/bin/"
|
||||
mv bazelisk-linux-amd64 "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
chmod +x "${GITHUB_WORKSPACE}/bin/bazel"
|
||||
echo "${GITHUB_WORKSPACE}/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config --global user.name "Copybara"
|
||||
git config --global user.email "copybara@csbx.dev"
|
||||
git config --global url."https://acmcarther:${{ secrets.YESOD_MIRROR_TOKEN }}@forgejo.csbx.dev/acmcarther/yesod-mirror.git".insteadOf "https://forgejo.csbx.dev/acmcarther/yesod-mirror.git"
|
||||
git config --global url."https://acmcarther:${{ secrets.YESOD_MIRROR_TOKEN }}@forgejo.csbx.dev/acmcarther/yesod.git".insteadOf "https://forgejo.csbx.dev/acmcarther/yesod.git"
|
||||
|
||||
- name: Run Copybara
|
||||
env:
|
||||
BAZELISK_BASE_URL: "http://bin-cache-http.dev.svc.cluster.local/bazel"
|
||||
run: |
|
||||
# Run Copybara
|
||||
bazel run //tools/copybara:copybara --config=remote --java_runtime_version=remotejdk_21 -- \
|
||||
migrate \
|
||||
"${GITHUB_WORKSPACE}/tools/copybara/yesod-mirror/copy.bara.sky" \
|
||||
--force
|
||||
25
.gemini/settings.json
Normal file
25
.gemini/settings.json
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
"context": {
|
||||
"includeDirectories": [
|
||||
"~/.gemini/tmp/",
|
||||
"~/.cache/bazel/"
|
||||
]
|
||||
},
|
||||
"general": {
|
||||
"preferredEditor": "vim"
|
||||
},
|
||||
"mcpServers": {
|
||||
"playwright": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"@playwright/mcp@latest"
|
||||
]
|
||||
}
|
||||
},
|
||||
"telemetry": {
|
||||
"enabled": true,
|
||||
"target": "local",
|
||||
"otlpEndpoint": "",
|
||||
"logPrompts": true
|
||||
}
|
||||
}
|
||||
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
11
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
### Link to Task
|
||||
|
||||
<!-- Provide a link to the project checkpoint or task file that initiated this work. -->
|
||||
|
||||
### Summary of Changes
|
||||
|
||||
<!-- Describe the modifications made in this pull request. -->
|
||||
|
||||
### Verification Steps
|
||||
|
||||
<!-- Outline the specific commands run and tests passed to validate the changes. -->
|
||||
0
experimental/users/acmcarther/BUILD.bazel
Normal file
0
experimental/users/acmcarther/BUILD.bazel
Normal file
23
experimental/users/acmcarther/build_defs.bzl
Normal file
23
experimental/users/acmcarther/build_defs.bzl
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
load("@rules_pkg//pkg:tar.bzl", "pkg_tar")
|
||||
|
||||
# Old stuff moved out of scripts/
|
||||
def copy_to_dist(name):
|
||||
native.genrule(
|
||||
name = name + "_publish",
|
||||
srcs = [":" + name],
|
||||
outs = ["dist/scripts/" + name],
|
||||
cmd = "cp $(execpath :" + name + ") $@",
|
||||
)
|
||||
|
||||
# Old stuff moved out of scripts/
|
||||
def package_script(name):
|
||||
"""Packages a py_binary script into a tarball.
|
||||
|
||||
Args:
|
||||
name: The name of the py_binary rule.
|
||||
"""
|
||||
pkg_tar(
|
||||
name = name + "_tar",
|
||||
srcs = [":" + name],
|
||||
)
|
||||
106
experimental/users/acmcarther/examples/grpc_example/BUILD
Normal file
106
experimental/users/acmcarther/examples/grpc_example/BUILD
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary", "py_unpacked_wheel")
|
||||
load("@build_stack_rules_proto//rules:proto_compile.bzl", "proto_compile")
|
||||
load("@build_stack_rules_proto//rules/py:grpc_py_library.bzl", "grpc_py_library")
|
||||
load("@build_stack_rules_proto//rules/py:proto_py_library.bzl", "proto_py_library")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
load("@rules_go//go:def.bzl", "go_library", "go_test")
|
||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||
|
||||
# gazelle:resolve go forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example //experimental/users/acmcarther/examples/grpc_example:example_go_proto
|
||||
|
||||
py_binary(
|
||||
name = "example_client",
|
||||
srcs = ["example_client.py"],
|
||||
deps = [
|
||||
":example_grpc_py_library",
|
||||
":example_py_library",
|
||||
requirement("grpcio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "example_server",
|
||||
srcs = ["example_server.py"],
|
||||
deps = [
|
||||
":example_grpc_py_library",
|
||||
":example_py_library",
|
||||
requirement("grpcio"),
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "example_proto",
|
||||
srcs = ["example.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
proto_compile(
|
||||
name = "example_go_grpc_compile",
|
||||
output_mappings = [
|
||||
"example.pb.go=forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example/example.pb.go",
|
||||
"example_grpc.pb.go=forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example/example_grpc.pb.go",
|
||||
],
|
||||
outputs = [
|
||||
"example.pb.go",
|
||||
"example_grpc.pb.go",
|
||||
],
|
||||
plugins = [
|
||||
"@build_stack_rules_proto//plugin/golang/protobuf:protoc-gen-go",
|
||||
"@build_stack_rules_proto//plugin/grpc/grpc-go:protoc-gen-go-grpc",
|
||||
],
|
||||
proto = "example_proto",
|
||||
)
|
||||
|
||||
grpc_py_library(
|
||||
name = "example_grpc_py_library",
|
||||
srcs = ["example_pb2_grpc.py"],
|
||||
deps = [
|
||||
":example_py_library",
|
||||
"@pip_third_party//grpcio:pkg",
|
||||
],
|
||||
)
|
||||
|
||||
proto_compile(
|
||||
name = "example_python_grpc_compile",
|
||||
outputs = [
|
||||
"example_pb2.py",
|
||||
"example_pb2.pyi",
|
||||
"example_pb2_grpc.py",
|
||||
],
|
||||
plugins = [
|
||||
"@build_stack_rules_proto//plugin/builtin:pyi",
|
||||
"@build_stack_rules_proto//plugin/builtin:python",
|
||||
"@build_stack_rules_proto//plugin/grpc/grpc:protoc-gen-grpc-python",
|
||||
],
|
||||
proto = "example_proto",
|
||||
)
|
||||
|
||||
proto_py_library(
|
||||
name = "example_py_library",
|
||||
srcs = ["example_pb2.py"],
|
||||
deps = ["@com_google_protobuf//:protobuf_python"],
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "example_go_proto",
|
||||
srcs = [":example_go_grpc_compile"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//codes",
|
||||
"@org_golang_google_grpc//status",
|
||||
"@org_golang_google_protobuf//reflect/protoreflect",
|
||||
"@org_golang_google_protobuf//runtime/protoimpl",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "grpc_example_test",
|
||||
srcs = ["example_test.go"],
|
||||
deps = [
|
||||
":example_go_proto",
|
||||
"@org_golang_google_grpc//:grpc",
|
||||
"@org_golang_google_grpc//credentials/insecure",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
load("@rules_go//go:def.bzl", "go_binary", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "client_lib",
|
||||
srcs = ["main.go"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/grpc_example/client",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//experimental/users/acmcarther/examples/grpc_example:example_go_proto",
|
||||
"@org_golang_google_grpc//:grpc",
|
||||
"@org_golang_google_grpc//credentials/insecure",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "client",
|
||||
embed = [":client_lib"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
pb "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultName = "world"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
name = flag.String("name", defaultName, "Name to greet")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
// Set up a connection to the server.
|
||||
conn, err := grpc.Dial(*addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewExampleClient(conn)
|
||||
|
||||
// Contact the server and print out its response.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
r, err := c.SayHello(ctx, &pb.HelloRequest{Name: name})
|
||||
if err != nil {
|
||||
log.Fatalf("could not greet: %v", err)
|
||||
}
|
||||
log.Printf("Greeting: %s", r.GetMessage())
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
package experimental.users.acmcarther.examples.grpc_example;
|
||||
|
||||
option go_package = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example";
|
||||
|
||||
service Example {
|
||||
rpc SayHello(HelloRequest) returns (HelloReply) {}
|
||||
}
|
||||
|
||||
// The request message containing the user's name.
|
||||
message HelloRequest {
|
||||
optional string name = 1;
|
||||
}
|
||||
|
||||
// The response message containing the greetings
|
||||
message HelloReply {
|
||||
optional string message = 1;
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from experimental.users.acmcarther.examples.grpc_example import example_pb2_grpc, example_pb2
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
|
||||
def main():
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = example_pb2_grpc.ExampleStub(channel)
|
||||
response = stub.SayHello(example_pb2.HelloRequest(name="you"))
|
||||
print("Greeter client received: " + response.message)
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from experimental.users.acmcarther.examples.grpc_example import example_pb2_grpc, example_pb2
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
|
||||
class ExampleService(example_pb2_grpc.ExampleServicer):
|
||||
def SayHello(self, request, context):
|
||||
response_message = f"Hello, {request.name}!"
|
||||
return example_pb2.HelloReply(message=response_message)
|
||||
|
||||
|
||||
def main():
|
||||
port = 50051
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
example_pb2_grpc.add_ExampleServicer_to_server(ExampleService(), server)
|
||||
server.add_insecure_port(f'[::]:{port}')
|
||||
server.start()
|
||||
print(f"gRPC server is running on port {port}...")
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
package main_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
pb "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example"
|
||||
)
|
||||
|
||||
// server is used to implement helloworld.GreeterServer.
|
||||
type server struct {
|
||||
pb.UnimplementedExampleServer
|
||||
}
|
||||
|
||||
// SayHello implements helloworld.GreeterServer
|
||||
func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
log.Printf("Received: %v", in.GetName())
|
||||
msg := "Hello " + in.GetName()
|
||||
return &pb.HelloReply{Message: &msg}, nil
|
||||
}
|
||||
|
||||
func TestSayHello(t *testing.T) {
|
||||
// Start server
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
pb.RegisterExampleServer(s, &server{})
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
log.Printf("server exited with error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer s.Stop()
|
||||
|
||||
// Connect client
|
||||
conn, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
t.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewExampleClient(conn)
|
||||
|
||||
// Call method
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
name := "TestUser"
|
||||
r, err := c.SayHello(ctx, &pb.HelloRequest{Name: &name})
|
||||
if err != nil {
|
||||
t.Fatalf("could not greet: %v", err)
|
||||
}
|
||||
if r.GetMessage() != "Hello "+name {
|
||||
t.Errorf("got %s, want Hello %s", r.GetMessage(), name)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
load("@rules_go//go:def.bzl", "go_binary", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "server_lib",
|
||||
srcs = ["main.go"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example/server",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//experimental/users/acmcarther/examples/grpc_example:example_go_proto",
|
||||
"@org_golang_google_grpc//:grpc",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "server",
|
||||
embed = [":server_lib"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
pb "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/grpc_example"
|
||||
)
|
||||
|
||||
var (
|
||||
port = flag.Int("port", 50051, "The server port")
|
||||
)
|
||||
|
||||
// server is used to implement helloworld.GreeterServer.
|
||||
type server struct {
|
||||
pb.UnimplementedExampleServer
|
||||
}
|
||||
|
||||
// SayHello implements helloworld.GreeterServer
|
||||
func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
log.Printf("Received: %v", in.GetName())
|
||||
msg := "Hello " + in.GetName()
|
||||
return &pb.HelloReply{Message: &msg}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
pb.RegisterExampleServer(s, &server{})
|
||||
log.Printf("server listening at %v", lis.Addr())
|
||||
if err := s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}
|
||||
25
experimental/users/acmcarther/examples/jsonnet/BUILD.bazel
Normal file
25
experimental/users/acmcarther/examples/jsonnet/BUILD.bazel
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
load("@rules_jsonnet//jsonnet:jsonnet.bzl", "jsonnet_library", "jsonnet_to_json", "jsonnet_to_json_test")
|
||||
|
||||
jsonnet_library(
|
||||
name = "base_lib",
|
||||
srcs = ["lib/base.libsonnet"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//third_party/jsonnet:k8s_libsonnet",
|
||||
],
|
||||
)
|
||||
|
||||
jsonnet_library(
|
||||
name = "dev_env_lib",
|
||||
srcs = ["environments/dev/main.jsonnet"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":base_lib"],
|
||||
)
|
||||
|
||||
jsonnet_to_json(
|
||||
name = "dev_env",
|
||||
src = "environments/dev/main.jsonnet",
|
||||
outs = ["dev_env.json"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":base_lib"],
|
||||
)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
local base = import "experimental/users/acmcarther/examples/jsonnet/lib/base.libsonnet";
|
||||
|
||||
base {
|
||||
other_field: 10,
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
local k = import "external/+jsonnet_deps+github_com_jsonnet_libs_k8s_libsonnet_1_29/1.29/main.libsonnet";
|
||||
|
||||
{
|
||||
myNamespace: k.core.v1.namespace.new("example-namespace")
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary", "py_unpacked_wheel")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
|
||||
py_binary(
|
||||
name = "hello_fastapi",
|
||||
srcs = ["hello_fastapi.py"],
|
||||
deps = [
|
||||
requirement("fastapi"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "hello_socketio",
|
||||
srcs = ["hello_socketio.py"],
|
||||
deps = [
|
||||
requirement("python-socketio"),
|
||||
requirement("asyncio"),
|
||||
requirement("aiohttp"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "hello_requests",
|
||||
srcs = ["hello_requests.py"],
|
||||
deps = [
|
||||
requirement("requests"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "hello_numpy",
|
||||
srcs = ["hello_numpy.py"],
|
||||
deps = [
|
||||
requirement("numpy"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "hello_yaml",
|
||||
srcs = ["hello_yaml.py"],
|
||||
deps = [
|
||||
requirement("pyyaml"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "hello_pandas",
|
||||
srcs = ["hello_pandas.py"],
|
||||
deps = [
|
||||
requirement("pandas"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "hello_beautifulsoup",
|
||||
srcs = ["hello_beautifulsoup.py"],
|
||||
deps = [
|
||||
requirement("beautifulsoup4"),
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_library", "py_test")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
|
||||
py_library(
|
||||
name = "example",
|
||||
srcs = ["example.py"],
|
||||
deps = [
|
||||
requirement("absl-py"),
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "example_test",
|
||||
srcs = ["example_test.py"],
|
||||
deps = [
|
||||
":example",
|
||||
requirement("absl-py"),
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from absl.testing import absltest
|
||||
|
||||
class SampleTest(absltest.TestCase):
|
||||
|
||||
def test_subtest(self):
|
||||
for i in (1, 2):
|
||||
with self.subTest(i=i):
|
||||
self.assertEqual(i, i)
|
||||
print('msg_for_test')
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from bs4 import BeautifulSoup
|
||||
|
||||
def main():
|
||||
html_doc = """
|
||||
<html><head><title>The Dormouse's story</title></head>
|
||||
<body>
|
||||
<p class="title"><b>The Dormouse's story</b></p>
|
||||
<p class="story">Once upon a time...</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
soup = BeautifulSoup(html_doc, 'html.parser')
|
||||
print("Successfully parsed HTML:")
|
||||
print(soup.title.string)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1 @@
|
|||
print("hello fastapi")
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import numpy as np
|
||||
|
||||
def main():
|
||||
arr = np.array([1, 2, 3])
|
||||
print(f"Numpy array: {arr}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
def main():
|
||||
s = pd.Series([1, 3, 5, np.nan, 6, 8])
|
||||
print("Successfully created pandas Series:")
|
||||
print(s)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import requests
|
||||
|
||||
def main():
|
||||
response = requests.get("https://www.google.com")
|
||||
print(f"Status Code: {response.status_code}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
import asyncio
|
||||
import socketio
|
||||
from aiohttp import web
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server setup
|
||||
# ---------------------------------------------------------------------------
|
||||
sio = socketio.AsyncServer(async_mode="aiohttp")
|
||||
app = web.Application()
|
||||
sio.attach(app)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(sid, environ):
|
||||
print(f"Client connected: {sid}")
|
||||
# Send a greeting to the newly connected client
|
||||
await sio.emit("greeting", {"msg": "Hello from Socket.IO server!"}, to=sid)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(sid):
|
||||
print(f"Client disconnected: {sid}")
|
||||
|
||||
|
||||
@sio.on("reply")
|
||||
async def on_reply(sid, data):
|
||||
print(f"Received reply from client {sid}: {data}")
|
||||
|
||||
|
||||
async def start_server():
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "localhost", 5000)
|
||||
await site.start()
|
||||
print("Socket.IO server listening on http://localhost:5000")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Client implementation (runs after a short delay to ensure the server is up)
|
||||
# ---------------------------------------------------------------------------
|
||||
async def start_client():
|
||||
# Wait briefly for the server to be ready
|
||||
await asyncio.sleep(1)
|
||||
client = socketio.AsyncClient()
|
||||
|
||||
@client.event
|
||||
async def connect():
|
||||
print("Client connected to server")
|
||||
|
||||
@client.on("greeting")
|
||||
async def on_greeting(data):
|
||||
print(f"Server says: {data['msg']}")
|
||||
# Slight delay before replying to ensure the namespace is fully ready
|
||||
await asyncio.sleep(0.1)
|
||||
await client.emit("reply", {"response": "Hello from client!"})
|
||||
|
||||
@client.event
|
||||
async def disconnect():
|
||||
print("Client disconnected")
|
||||
|
||||
await client.connect("http://localhost:5000")
|
||||
# Keep the client alive for a short while to exchange messages
|
||||
await asyncio.sleep(5)
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
async def main():
|
||||
# Run server and client concurrently
|
||||
await asyncio.gather(start_server(), start_client())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
def main():
|
||||
print("Hello, Bazel!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import yaml
|
||||
|
||||
def main():
|
||||
doc = """
|
||||
a: 1
|
||||
b:
|
||||
- c: 2
|
||||
- d: 3
|
||||
"""
|
||||
data = yaml.safe_load(doc)
|
||||
print("Successfully loaded YAML:")
|
||||
print(data)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
experimental/users/acmcarther/examples/tanka/BUILD
Normal file
19
experimental/users/acmcarther/examples/tanka/BUILD
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
load("@rules_go//go:def.bzl", "go_binary", "go_library")
|
||||
|
||||
go_binary(
|
||||
name = "tanka",
|
||||
embed = [":tanka_lib"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "tanka_lib",
|
||||
srcs = ["main.go"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/examples/tanka",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@com_github_grafana_tanka//pkg/kubernetes",
|
||||
"@com_github_grafana_tanka//pkg/process",
|
||||
"@com_github_grafana_tanka//pkg/spec/v1alpha1",
|
||||
],
|
||||
)
|
||||
14
experimental/users/acmcarther/examples/tanka/dummy_spec.json
Normal file
14
experimental/users/acmcarther/examples/tanka/dummy_spec.json
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"apiVersion": "tanka.dev/v1alpha1",
|
||||
"kind": "Environment",
|
||||
"metadata": {
|
||||
"name": "experimental-env",
|
||||
"namespace": "default"
|
||||
},
|
||||
"spec": {
|
||||
"apiServer": "https://0.0.0.0:6443",
|
||||
"namespace": "default",
|
||||
"resourceDefaults": {},
|
||||
"expectVersions": {}
|
||||
}
|
||||
}
|
||||
97
experimental/users/acmcarther/examples/tanka/main.go
Normal file
97
experimental/users/acmcarther/examples/tanka/main.go
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/grafana/tanka/pkg/kubernetes"
|
||||
"github.com/grafana/tanka/pkg/process"
|
||||
"github.com/grafana/tanka/pkg/spec/v1alpha1"
|
||||
)
|
||||
|
||||
func main() {
|
||||
specPath := flag.String("spec", "", "Path to spec.json")
|
||||
mainPath := flag.String("main", "", "Path to main.json")
|
||||
action := flag.String("action", "show", "Action to perform: show, diff, apply")
|
||||
flag.Parse()
|
||||
|
||||
if *specPath == "" || *mainPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "Usage: tanka --spec <spec.json> --main <main.json> [--action <action>]")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 1. Load Spec
|
||||
specData, err := os.ReadFile(*specPath)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("reading spec: %w", err))
|
||||
}
|
||||
|
||||
var env v1alpha1.Environment
|
||||
if err := json.Unmarshal(specData, &env); err != nil {
|
||||
panic(fmt.Errorf("unmarshaling spec: %w", err))
|
||||
}
|
||||
|
||||
// 2. Load Main (Data)
|
||||
mainData, err := os.ReadFile(*mainPath)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("reading main: %w", err))
|
||||
}
|
||||
|
||||
var rawData interface{}
|
||||
if err := json.Unmarshal(mainData, &rawData); err != nil {
|
||||
panic(fmt.Errorf("unmarshaling main: %w", err))
|
||||
}
|
||||
env.Data = rawData
|
||||
|
||||
// 3. Process (Extract, Label, Filter)
|
||||
// We use empty matchers for now
|
||||
list, err := process.Process(env, process.Matchers{})
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("processing manifests: %w", err))
|
||||
}
|
||||
|
||||
fmt.Printf("Processed %d manifests for env %s (namespace: %s)\n", len(list), env.Metadata.Name, env.Spec.Namespace)
|
||||
|
||||
if *action == "show" {
|
||||
for _, m := range list {
|
||||
fmt.Printf("- %s: %s\n", m.Kind(), m.Metadata().Name())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Initialize Kubernetes Client
|
||||
// This will fail if no valid kubeconfig/context is found matching spec.json
|
||||
kube, err := kubernetes.New(env)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to initialize Kubernetes client (expected if no cluster context): %v\n", err)
|
||||
return
|
||||
}
|
||||
defer kube.Close()
|
||||
|
||||
// 5. Perform Action
|
||||
switch *action {
|
||||
case "diff":
|
||||
fmt.Println("Running Diff...")
|
||||
diff, err := kube.Diff(context.Background(), list, kubernetes.DiffOpts{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if diff != nil {
|
||||
fmt.Println(*diff)
|
||||
} else {
|
||||
fmt.Println("No changes.")
|
||||
}
|
||||
case "apply":
|
||||
fmt.Println("Running Apply...")
|
||||
err := kube.Apply(list, kubernetes.ApplyOpts{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("Apply finished.")
|
||||
default:
|
||||
fmt.Printf("Unknown action: %s\n", *action)
|
||||
}
|
||||
}
|
||||
7
experimental/users/acmcarther/git-eradicate.sh
Executable file
7
experimental/users/acmcarther/git-eradicate.sh
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
git filter-branch -f --index-filter \
|
||||
'git rm --force --cached --ignore-unmatch kubectl' \
|
||||
-- --all
|
||||
rm -Rf .git/refs/original && \
|
||||
git reflog expire --expire=now --all && \
|
||||
git gc --aggressive && \
|
||||
git prune
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
load("@rules_jsonnet//jsonnet:jsonnet.bzl", "jsonnet_to_json")
|
||||
load("//tools:tanka.bzl", "tanka_environment")
|
||||
|
||||
jsonnet_to_json(
|
||||
name = "main",
|
||||
src = "main.jsonnet",
|
||||
outs = ["main.json"],
|
||||
data = [
|
||||
"@helm_crossplane_crossplane//:chart",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//k8s/configs/templates",
|
||||
"//experimental/users/acmcarther/k8s/configs/templates",
|
||||
],
|
||||
)
|
||||
|
||||
tanka_environment(
|
||||
name = "crossplane",
|
||||
main = ":main",
|
||||
spec = "spec.json",
|
||||
)
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
local base = import "k8s/configs/base.libsonnet";
|
||||
local crossplane = import "experimental/users/acmcarther/k8s/configs/templates/crossplane.libsonnet";
|
||||
|
||||
local namespace = "crossplane-system";
|
||||
local ctx = base.NewContext(base.helm);
|
||||
|
||||
{
|
||||
namespace: {
|
||||
apiVersion: "v1",
|
||||
kind: "Namespace",
|
||||
metadata: {
|
||||
name: namespace,
|
||||
},
|
||||
},
|
||||
apps: {
|
||||
crossplane: crossplane.App(crossplane.Params {
|
||||
namespace: namespace,
|
||||
name: "crossplane",
|
||||
context: ctx,
|
||||
values: {
|
||||
# Add any specific values here
|
||||
},
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"apiVersion": "tanka.dev/v1alpha1",
|
||||
"kind": "Environment",
|
||||
"metadata": {
|
||||
"name": "environments/crossplane",
|
||||
"namespace": "environments/crossplane/main.jsonnet"
|
||||
},
|
||||
"spec": {
|
||||
"apiServer": "https://k8s.dominion.lan:6443",
|
||||
"namespace": "crossplane-system",
|
||||
"resourceDefaults": {},
|
||||
"expectVersions": {},
|
||||
"injectLabels": true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
load("@rules_jsonnet//jsonnet:jsonnet.bzl", "jsonnet_library", "jsonnet_to_json", "jsonnet_to_json_test")
|
||||
load("//tools:tanka.bzl", "tanka_environment")
|
||||
load("//tools:sops.bzl", "sops_decrypt")
|
||||
|
||||
sops_decrypt(
|
||||
name = "secrets",
|
||||
src = "secrets.sops.yaml",
|
||||
out = "secrets.json",
|
||||
)
|
||||
|
||||
jsonnet_library(
|
||||
name = "secrets_lib",
|
||||
srcs = [":secrets"],
|
||||
)
|
||||
|
||||
jsonnet_to_json(
|
||||
name = "main",
|
||||
src = "main.jsonnet",
|
||||
outs = ["main.json"],
|
||||
data = [
|
||||
"@helm_jetstack_cert_manager//:chart",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":secrets_lib",
|
||||
"//k8s/configs/templates",
|
||||
"//experimental/users/acmcarther/k8s/configs/templates",
|
||||
],
|
||||
)
|
||||
|
||||
tanka_environment(
|
||||
name = "dominion",
|
||||
main = ":main",
|
||||
spec = "spec.json",
|
||||
)
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
local base = import "k8s/configs/base.libsonnet";
|
||||
local secrets = import "experimental/users/acmcarther/k8s/configs/environments/dominion/secrets.json";
|
||||
|
||||
local freshrss = import "k8s/configs/templates/personal/media/freshrss.libsonnet";
|
||||
local monica = import "k8s/configs/templates/personal/home/monica.libsonnet";
|
||||
local jellyfin = import "k8s/configs/templates/personal/media/jellyfin.libsonnet";
|
||||
local transmission = import "k8s/configs/templates/personal/media/transmission.libsonnet";
|
||||
|
||||
local lanraragi = import "experimental/users/acmcarther/k8s/configs/templates/lanraragi.libsonnet";
|
||||
|
||||
local nginxIngress = import "k8s/configs/templates/core/network/nginx-ingress.libsonnet";
|
||||
local mariadb = import "k8s/configs/templates/core/storage/mariadb.libsonnet";
|
||||
|
||||
local namespace = "dominion";
|
||||
{
|
||||
namespace: {
|
||||
apiVersion: "v1",
|
||||
kind: "Namespace",
|
||||
metadata: {
|
||||
name: namespace,
|
||||
},
|
||||
},
|
||||
secrets: {
|
||||
monica: mariadb.Secret(mariadb.SecretParams{
|
||||
name: "monica",
|
||||
namespace: "dominion",
|
||||
rootPassword: secrets.monica_mariadb_root_db_pwd,
|
||||
password: secrets.monica_mariadb_db_pwd,
|
||||
}),
|
||||
},
|
||||
apps: {
|
||||
/*
|
||||
jellyfin: {
|
||||
app: jellyfin.App(jellyfin.Params {
|
||||
namespace: namespace,
|
||||
name: "jellyfin",
|
||||
filePath: std.thisFile,
|
||||
// Defined in "dominion"
|
||||
configClaimName: "jellyfin-config",
|
||||
// Defined in "dominion"
|
||||
serialClaimName: "serial-lake",
|
||||
// Defined in "dominion"
|
||||
filmClaimName: "film-lake",
|
||||
// Defined in "dominion"
|
||||
transcodeClaimName: "jellyfin-transcode",
|
||||
}),
|
||||
ingress: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "jellyfin-ion",
|
||||
hosts: [
|
||||
"ion.cheapassbox.com",
|
||||
],
|
||||
serviceName: "jellyfin-vui",
|
||||
}),
|
||||
pvcs: {
|
||||
pvcJellyfinConfig: kube.RecoverableSimpleManyPvc(namespace, "jellyfin-config", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-287055fe-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsPath: "/volume3/fs/dominion-jellyfin-config-pvc-287055fe-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
pvcJellyfinTranscode: kube.RecoverableSimpleManyPvc(namespace, "jellyfin-transcode", "nfs-client", "200Gi", {
|
||||
volumeName: "pvc-2871f840-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsPath: "/volume3/fs/dominion-jellyfin-transcode-pvc-2871f840-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
// NOTE: These are different!
|
||||
pvcSerialLake: kube.RecoverableSimpleManyPvc(namespace, "serial-lake", "nfs-bulk", "160Gi", {
|
||||
volumeName: "pvc-2873b76a-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsPath: "/volume4/fs-bulk/dominion-serial-lake-pvc-2873b76a-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsServer: "apollo2.dominion.lan",
|
||||
}),
|
||||
pvcFilmLake: kube.RecoverableSimpleManyPvc(namespace, "film-lake", "nfs-bulk", "80Gi", {
|
||||
volumeName: "pvc-286ce6ea-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsPath: "/volume4/fs-bulk/dominion-film-lake-pvc-286ce6ea-b436-11e9-bad8-b8aeed7dc356",
|
||||
nfsServer: "apollo2.dominion.lan",
|
||||
}),
|
||||
},
|
||||
},
|
||||
*/
|
||||
freshrss: {
|
||||
configPvc: base.RecoverableSimplePvc(namespace, "freshrss-config", "nfs-client", "32Gi", {
|
||||
volumeName: "pvc-26b893fc-c3bf-11e9-8ccb-b8aeed7dc356",
|
||||
nfsPath: "/volume3/fs/dominion-freshrss-config-pvc-26b893fc-c3bf-11e9-8ccb-b8aeed7dc356",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
app: freshrss.App(freshrss.Params {
|
||||
namespace: namespace,
|
||||
name: "freshrss",
|
||||
filePath: std.thisFile,
|
||||
// Defined in "dominion"
|
||||
configClaimName: "freshrss-config",
|
||||
}),
|
||||
ingress: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "freshrss",
|
||||
hosts: [
|
||||
"rss.cheapassbox.com",
|
||||
],
|
||||
serviceName: "freshrss-ui",
|
||||
annotations: nginxIngress.KubeOauthProxyAnnotations,
|
||||
}),
|
||||
ingress2: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "freshrss-csbx",
|
||||
hosts: [
|
||||
"rss.csbx.dev",
|
||||
],
|
||||
serviceName: "freshrss-ui",
|
||||
annotations: nginxIngress.KubeCsbxOauthProxyAnnotations,
|
||||
}),
|
||||
},
|
||||
transmission2: {
|
||||
configPvc: base.RecoverableSimpleManyPvc(namespace, "transmission-config", "nfs-client", "50Mi", {
|
||||
volumeName: "pvc-3d93c19b-c177-11e9-8ccb-b8aeed7dc356",
|
||||
nfsPath: "/volume3/fs/dominion-transmission-config-pvc-3d93c19b-c177-11e9-8ccb-b8aeed7dc356",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
torrentFilesPvc: base.RecoverableSimpleManyPvc(namespace, "torrent-files", "nfs-client", "100Mi", {
|
||||
volumeName: "pvc-73528d8b-c177-11e9-8ccb-b8aeed7dc356",
|
||||
nfsPath: "/volume3/fs/dominion-torrent-files-pvc-73528d8b-c177-11e9-8ccb-b8aeed7dc356",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
incompleteDownloadsPvc: base.RecoverableSimpleManyPvc(namespace, "transmission-incomplete-downloads", "nfs-bulk", "100Gi", {
|
||||
volumeName: "pvc-1c1a00ff-b9a8-4f92-b3a7-70f81752141d",
|
||||
nfsPath: "/volume4/fs-bulk/dominion-transmission-incomplete-downloads-pvc-1c1a00ff-b9a8-4f92-b3a7-70f81752141d",
|
||||
nfsServer: "apollo2.dominion.lan",
|
||||
}),
|
||||
app: transmission.App(transmission.Params {
|
||||
namespace: namespace,
|
||||
name: "transmission2",
|
||||
filePath: std.thisFile,
|
||||
configClaimName: "transmission-config",
|
||||
incompleteDownloadsClaimName: "transmission-incomplete-downloads",
|
||||
downloadsClaimName: "lanraragi-content",
|
||||
torrentFilesClaimName: "torrent-files",
|
||||
// TODO(acmcarther): Import from central location
|
||||
dataNodePort: 32701,
|
||||
}),
|
||||
ingress: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "transmission",
|
||||
hosts: [
|
||||
"ex-transmission.cheapassbox.com",
|
||||
],
|
||||
serviceName: "transmission2-ui",
|
||||
annotations: nginxIngress.DominionOauthProxyAnnotations,
|
||||
}),
|
||||
ingress2: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "transmission-csbx",
|
||||
hosts: [
|
||||
"ex-transmission.csbx.dev",
|
||||
],
|
||||
serviceName: "transmission2-ui",
|
||||
annotations: nginxIngress.DominionCsbxOauthProxyAnnotations,
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"apiVersion": "tanka.dev/v1alpha1",
|
||||
"kind": "Environment",
|
||||
"metadata": {
|
||||
"name": "environments/dominion",
|
||||
"namespace": "environments/dominion/main.jsonnet"
|
||||
},
|
||||
"spec": {
|
||||
"apiServer": "https://k8s.dominion.lan:6443",
|
||||
"namespace": "dominion",
|
||||
"resourceDefaults": {},
|
||||
"expectVersions": {},
|
||||
"injectLabels": true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
load("@rules_jsonnet//jsonnet:jsonnet.bzl", "jsonnet_library", "jsonnet_to_json", "jsonnet_to_json_test")
|
||||
load("//tools:tanka.bzl", "tanka_environment")
|
||||
|
||||
jsonnet_to_json(
|
||||
name = "main",
|
||||
src = "main.jsonnet",
|
||||
outs = ["main.json"],
|
||||
data = [
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//k8s/configs/templates",
|
||||
"//experimental/users/acmcarther/k8s/configs/templates",
|
||||
],
|
||||
)
|
||||
|
||||
tanka_environment(
|
||||
name = "semantic-search",
|
||||
main = ":main",
|
||||
spec = "spec.json",
|
||||
)
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
local base = import "k8s/configs/base.libsonnet";
|
||||
local semanticSearch = import "experimental/users/acmcarther/k8s/configs/templates/semantic-search.libsonnet";
|
||||
local nginxIngress = import "k8s/configs/templates/core/network/nginx-ingress.libsonnet";
|
||||
|
||||
local namespace = "semantic-search";
|
||||
local appName = "semantic-search-server";
|
||||
|
||||
{
|
||||
namespace: {
|
||||
apiVersion: "v1",
|
||||
kind: "Namespace",
|
||||
metadata: {
|
||||
name: namespace,
|
||||
},
|
||||
},
|
||||
pvc: base.RecoverableSimpleManyPvc(namespace, appName + "-data", "nfs-client", "2Gi", {
|
||||
volumeName: "pvc-a10eadb8-b2a3-45b2-a50b-83ab11ae7f39",
|
||||
nfsPath: "/volume3/fs/semantic-search-semantic-search-server-data-pvc-a10eadb8-b2a3-45b2-a50b-83ab11ae7f39",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
apps: {
|
||||
server: semanticSearch.App(semanticSearch.Params {
|
||||
namespace: namespace,
|
||||
name: appName,
|
||||
filePath: std.thisFile,
|
||||
dataClaimName: appName + "-data",
|
||||
}),
|
||||
ingress: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: appName,
|
||||
hosts: [
|
||||
"search.csbx.dev",
|
||||
],
|
||||
serviceName: appName + "-ui",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"apiVersion": "tanka.dev/v1alpha1",
|
||||
"kind": "Environment",
|
||||
"metadata": {
|
||||
"name": "environments/semantic-search"
|
||||
},
|
||||
"spec": {
|
||||
"apiServer": "https://k8s.dominion.lan:6443",
|
||||
"namespace": "semantic-search",
|
||||
"resourceDefaults": {},
|
||||
"expectVersions": {},
|
||||
"injectLabels": true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
load("@rules_jsonnet//jsonnet:jsonnet.bzl", "jsonnet_library", "jsonnet_to_json", "jsonnet_to_json_test")
|
||||
load("//tools:tanka.bzl", "tanka_environment")
|
||||
|
||||
jsonnet_to_json(
|
||||
name = "main",
|
||||
src = "main.jsonnet",
|
||||
outs = ["main.json"],
|
||||
data = [
|
||||
"@helm_hashicorp_vault//:chart",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//k8s/configs/templates",
|
||||
],
|
||||
)
|
||||
|
||||
tanka_environment(
|
||||
name = "vault",
|
||||
main = ":main",
|
||||
spec = "spec.json",
|
||||
)
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
local base = import "k8s/configs/base.libsonnet";
|
||||
local nginxIngress = import "k8s/configs/templates/core/network/nginx-ingress.libsonnet";
|
||||
local vault = import "k8s/configs/templates/core/security/vault.libsonnet";
|
||||
|
||||
local namespace = "vault";
|
||||
local ctx = base.NewContext(base.helm);
|
||||
{
|
||||
namespace: {
|
||||
apiVersion: "v1",
|
||||
kind: "Namespace",
|
||||
metadata: {
|
||||
name: namespace,
|
||||
},
|
||||
},
|
||||
apps: {
|
||||
/*
|
||||
consul: consul.App(consul.Params {
|
||||
namespace: namespace,
|
||||
context: ctx,
|
||||
bootstrapTokenSecretName: "consul-bootstrap-acl-token",
|
||||
}),
|
||||
*/
|
||||
vault: vault.App(vault.Params {
|
||||
namespace: namespace,
|
||||
context: ctx,
|
||||
}),
|
||||
/*
|
||||
vaultIngress1: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "vault",
|
||||
hosts: [
|
||||
"vault.cheapassbox.com",
|
||||
],
|
||||
serviceName: "vault", # TODO
|
||||
annotations: nginxIngress.KubeOauthProxyAnnotations,
|
||||
}),
|
||||
*/
|
||||
vaultIngress2: nginxIngress.Ingress(nginxIngress.IngressParams {
|
||||
namespace: namespace,
|
||||
name: "vault-csbx",
|
||||
hosts: [
|
||||
"vault.csbx.dev",
|
||||
],
|
||||
serviceName: "vault-ui", # TODO
|
||||
servicePort: 8200,
|
||||
annotations: nginxIngress.KubeCsbxOauthProxyAnnotations,
|
||||
}),
|
||||
},
|
||||
volumes: {
|
||||
data0: base.RecoverableSimplePvc(namespace, "data-vault-0", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-0aa9f845-baef-476b-971f-8cd30932b874",
|
||||
nfsPath: "/volume3/fs/vault-data-vault-0-pvc-0aa9f845-baef-476b-971f-8cd30932b874",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
data1: base.RecoverableSimplePvc(namespace, "data-vault-1", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-90241eff-1ed4-49e0-87bb-8485cd0f6aca",
|
||||
nfsPath: "/volume3/fs/vault-data-vault-1-pvc-90241eff-1ed4-49e0-87bb-8485cd0f6aca",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
data2: base.RecoverableSimplePvc(namespace, "data-vault-2", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-5c23b9b5-3fbf-4898-9784-83d9bbef185c",
|
||||
nfsPath: "/volume3/fs/vault-data-vault-2-pvc-5c23b9b5-3fbf-4898-9784-83d9bbef185c",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
audit0: base.RecoverableSimplePvc(namespace, "audit-vault-0", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-1d037ee0-836c-4079-a96f-f61ed13c9626",
|
||||
nfsPath: "/volume3/fs/vault-audit-vault-0-pvc-1d037ee0-836c-4079-a96f-f61ed13c9626",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
audit1: base.RecoverableSimplePvc(namespace, "audit-vault-1", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-6f63b89d-b007-440a-adea-b503b885b914",
|
||||
nfsPath: "/volume3/fs/vault-audit-vault-1-pvc-6f63b89d-b007-440a-adea-b503b885b914",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
audit2: base.RecoverableSimplePvc(namespace, "audit-vault-2", "nfs-client", "10Gi", {
|
||||
volumeName: "pvc-44121280-3a8c-4252-abe2-95e177e78efc",
|
||||
nfsPath: "/volume3/fs/vault-audit-vault-2-pvc-44121280-3a8c-4252-abe2-95e177e78efc",
|
||||
nfsServer: "apollo1.dominion.lan",
|
||||
}),
|
||||
|
||||
},
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"apiVersion": "tanka.dev/v1alpha1",
|
||||
"kind": "Environment",
|
||||
"metadata": {
|
||||
"name": "environments/vault",
|
||||
"namespace": "environments/vault/main.jsonnet"
|
||||
},
|
||||
"spec": {
|
||||
"apiServer": "https://k8s.dominion.lan:6443",
|
||||
"namespace": "vault",
|
||||
"resourceDefaults": {},
|
||||
"expectVersions": {},
|
||||
"injectLabels": true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
load("@rules_jsonnet//jsonnet:jsonnet.bzl", "jsonnet_library", "jsonnet_to_json", "jsonnet_to_json_test")
|
||||
|
||||
jsonnet_library(
|
||||
name = "templates",
|
||||
srcs = glob(include = ["**/*.libsonnet"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//k8s/configs:base",
|
||||
"//k8s/configs:images",
|
||||
"//k8s/configs/templates",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
local base = import "k8s/configs/base.libsonnet";
|
||||
|
||||
local Params = base.SimpleFieldStruct([
|
||||
"namespace",
|
||||
"name",
|
||||
"context",
|
||||
"values",
|
||||
]);
|
||||
|
||||
local App(params) = {
|
||||
# The chart is provided by the @helm_crossplane_crossplane repository.
|
||||
# Note: The path construction might need adjustment depending on how helm_deps handles the repo name.
|
||||
# In chartfile.yaml, repo name is 'crossplane'.
|
||||
local chartPath = "../../external/+helm_deps+helm_crossplane_crossplane",
|
||||
|
||||
app: params.context.helm.template(params.name, chartPath, {
|
||||
namespace: params.namespace,
|
||||
values: params.values,
|
||||
# Crossplane often needs includeCRDs: true or similar if it's not default in values.
|
||||
# But for helm template, it's usually handled by includeCRDs option in the helm function if supported
|
||||
# or just let helm handle it. Tanka's helm.template usually passes args to `helm template`.
|
||||
includeCRDs: true,
|
||||
})
|
||||
};
|
||||
|
||||
{
|
||||
Params: Params,
|
||||
App: App,
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
local kube = import "k8s/configs/base.libsonnet";
|
||||
local images = import "k8s/configs/images.libsonnet";
|
||||
local templates = import "k8s/configs/templates/templates.libsonnet";
|
||||
|
||||
local WebPort = 3000;
|
||||
|
||||
local Params = kube.simpleFieldStruct([
|
||||
"namespace",
|
||||
"name",
|
||||
"contentClaimName",
|
||||
"databaseClaimName",
|
||||
"thumbClaimName",
|
||||
"filePath",
|
||||
]) {
|
||||
image: images.Prod["difegue/lanraragi"],
|
||||
webPort: WebPort,
|
||||
gatekeeperSidecar: null,
|
||||
resources: {
|
||||
requests: {
|
||||
cpu: "1000m",
|
||||
memory: "1000Mi",
|
||||
},
|
||||
limits: {
|
||||
cpu: "2000m",
|
||||
memory: "2000Mi",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
local App(params) = {
|
||||
local nskube = kube.UsingNamespace(params.namespace),
|
||||
local selector = {
|
||||
name: params.name,
|
||||
phase: "prod",
|
||||
},
|
||||
local selectorMixin = {
|
||||
selector: selector
|
||||
},
|
||||
service: nskube.Service(params.name + '-ui') {
|
||||
spec+: kube.SvcUtil.BasicHttpClusterIpSpec(WebPort) {
|
||||
selector: selector
|
||||
}
|
||||
},
|
||||
deployment: nskube.Deployment(params.name) {
|
||||
metadata+: {
|
||||
annotations: templates.annotations(params.filePath, std.thisFile),
|
||||
},
|
||||
spec+: {
|
||||
strategy: kube.DeployUtil.SimpleRollingUpdate(),
|
||||
replicas: 1,
|
||||
selector: {
|
||||
matchLabels: selector,
|
||||
},
|
||||
template: {
|
||||
metadata: {
|
||||
labels: selector,
|
||||
annotations: templates.annotations(params.filePath, std.thisFile),
|
||||
},
|
||||
spec+: {
|
||||
imagePullSecrets: [
|
||||
{
|
||||
name: "docker-auth",
|
||||
}
|
||||
],
|
||||
containers: [
|
||||
{
|
||||
image: params.image,
|
||||
name: "lanraragi",
|
||||
ports: [
|
||||
kube.DeployUtil.ContainerPort("http", params.webPort),
|
||||
],
|
||||
resources: params.resources,
|
||||
readinessProbe: {
|
||||
httpGet: {
|
||||
path: "/",
|
||||
port: params.webPort,
|
||||
},
|
||||
initialDelaySeconds: 30,
|
||||
},
|
||||
|
||||
livenessProbe: {
|
||||
httpGet: {
|
||||
path: "/",
|
||||
port: params.webPort,
|
||||
},
|
||||
initialDelaySeconds: 30,
|
||||
periodSeconds: 15,
|
||||
failureThreshold: 10
|
||||
},
|
||||
args: [],
|
||||
volumeMounts: [
|
||||
kube.DeployUtil.VolumeMount("content", "/home/koyomi/lanraragi/content"),
|
||||
kube.DeployUtil.VolumeMount("database", "/home/koyomi/lanraragi/database"),
|
||||
kube.DeployUtil.VolumeMount("thumb", "/home/koyomi/lanraragi/thumb"),
|
||||
]
|
||||
},
|
||||
],
|
||||
volumes: [
|
||||
kube.DeployUtil.VolumeClaimRef("content", params.contentClaimName),
|
||||
kube.DeployUtil.VolumeClaimRef("database", params.databaseClaimName),
|
||||
kube.DeployUtil.VolumeClaimRef("thumb", params.thumbClaimName),
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
WebPort: WebPort,
|
||||
Params: Params,
|
||||
App(params): App(params),
|
||||
}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
local kube = import "k8s/configs/base.libsonnet";
|
||||
local linuxserver = import "k8s/configs/templates/core/linuxserver.libsonnet";
|
||||
local images = import "k8s/configs/images.libsonnet";
|
||||
|
||||
local probe(delaySeconds) = {
|
||||
initialDelaySeconds: delaySeconds,
|
||||
periodSeconds: 20,
|
||||
tcpSocket: {
|
||||
port: "http",
|
||||
},
|
||||
};
|
||||
|
||||
local WebPort = 7860;
|
||||
|
||||
local Params = kube.simpleFieldStruct([
|
||||
"namespace",
|
||||
"name",
|
||||
"filePath",
|
||||
"storageClaimName",
|
||||
"outputClaimName",
|
||||
//"ingressHost",
|
||||
]) {
|
||||
labels: {},
|
||||
gatekeeperSidecar: null,
|
||||
lsParams: linuxserver.AppParams {
|
||||
name: $.name,
|
||||
namespace: $.namespace,
|
||||
filePath: $.filePath,
|
||||
templatePath: std.thisFile,
|
||||
baseAppName: "naifu",
|
||||
imageName: "naifu2",
|
||||
imagePullSecrets: ["regcred"],
|
||||
labels+: $.labels,
|
||||
gatekeeperSidecar: $.gatekeeperSidecar,
|
||||
isPrivileged: true,
|
||||
services: [
|
||||
linuxserver.Service {
|
||||
suffix: "ui",
|
||||
spec: {
|
||||
type: "ClusterIP",
|
||||
ports: [
|
||||
kube.SvcUtil.TCPServicePort("http", 80) {
|
||||
targetPort: WebPort
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
nodeSelector: {
|
||||
"gpu": "nvidia"
|
||||
},
|
||||
ports: [
|
||||
kube.DeployUtil.ContainerPort("http", WebPort),
|
||||
],
|
||||
env: linuxserver.Env {
|
||||
others: [
|
||||
kube.NameVal("CLI_ARGS", "--allow-code --ui-config-file /stable-diffusion-webui/models/Stable-diffusion/ui-config.json --styles-file /stable-diffusion-webui/models/Stable-diffusion/styles.csv --deepdanbooru"),
|
||||
kube.NameVal("NVIDIA_VISIBLE_DEVICES", "all"),
|
||||
//kube.NameVal("CLI_FLAGS", "--extra-models-cpu --optimized-turbo"),
|
||||
//--precision full --no-half
|
||||
//kube.NameVal("CLI_FLAGS", "--no-half"),
|
||||
//kube.NameVal("CUDA_VISIBLE_DEVICES", "0"),
|
||||
#kube.NameVal("TOKEN", "example-token"),
|
||||
]
|
||||
},
|
||||
args: [
|
||||
],
|
||||
pvcs: [
|
||||
linuxserver.Pvc{
|
||||
name: "naifu-storage",
|
||||
mountPath: "/data",
|
||||
bindName: $.storageClaimName,
|
||||
},
|
||||
linuxserver.Pvc{
|
||||
name: "naifu-output",
|
||||
mountPath: "/output",
|
||||
bindName: $.outputClaimName,
|
||||
},
|
||||
|
||||
],
|
||||
hostPaths: [
|
||||
linuxserver.HostPath{
|
||||
name: "nvidia-nvidia-uvm",
|
||||
hostPath: "/dev/nvidia-uvm",
|
||||
mountPath: "/dev/nvidia-uvm",
|
||||
},
|
||||
linuxserver.HostPath{
|
||||
name: "nvidia-nvidia0",
|
||||
hostPath: "/dev/nvidia0",
|
||||
mountPath: "/dev/nvidia0",
|
||||
},
|
||||
linuxserver.HostPath{
|
||||
name: "nvidia-nvidiactrl",
|
||||
hostPath: "/dev/nvidiactrl",
|
||||
mountPath: "/dev/nvidiactrl",
|
||||
},
|
||||
linuxserver.HostPath{
|
||||
name: "nvidia-drivers",
|
||||
hostPath: "/opt/drivers/nvidia",
|
||||
mountPath: "/usr/local/nvidia",
|
||||
},
|
||||
|
||||
],
|
||||
resources: {
|
||||
requests: {
|
||||
cpu: "1000m",
|
||||
memory: "12000Mi",
|
||||
},
|
||||
limits: {
|
||||
cpu: "4000m",
|
||||
memory: "24000Mi",
|
||||
},
|
||||
},
|
||||
//livenessProbe: probe(/*delaySeconds=*/60),
|
||||
//readinessProbe: probe(/*delaySeconds=*/60),
|
||||
},
|
||||
};
|
||||
|
||||
local App(params) = linuxserver.App(params.lsParams) {
|
||||
};
|
||||
|
||||
{
|
||||
WebPort: WebPort,
|
||||
Params: Params,
|
||||
App(params): App(params),
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
local kube = import "k8s/configs/base.libsonnet";
|
||||
local linuxserver = import "k8s/configs/templates/core/linuxserver.libsonnet";
|
||||
local images = import "k8s/configs/images.libsonnet";
|
||||
|
||||
local searchProbe(delaySeconds) = {
|
||||
initialDelaySeconds: delaySeconds,
|
||||
periodSeconds: 30,
|
||||
tcpSocket: {
|
||||
port: "http",
|
||||
},
|
||||
};
|
||||
|
||||
local WebPort = 8000;
|
||||
local DataDir = "/app/ai/data/vectordb";
|
||||
local ModelCacheDir = DataDir + "/models";
|
||||
|
||||
local Params = kube.simpleFieldStruct([
|
||||
"namespace",
|
||||
"name",
|
||||
"filePath",
|
||||
"dataClaimName",
|
||||
]) {
|
||||
labels: {},
|
||||
gatekeeperSidecar: null,
|
||||
lsParams: linuxserver.AppParams {
|
||||
name: $.name,
|
||||
namespace: $.namespace,
|
||||
filePath: $.filePath,
|
||||
templatePath: std.thisFile,
|
||||
baseAppName: "semantic-search",
|
||||
imageName: "semantic-search-server",
|
||||
labels+: $.labels,
|
||||
gatekeeperSidecar: $.gatekeeperSidecar,
|
||||
env+: linuxserver.Env {
|
||||
others: [
|
||||
kube.NameVal("TRANSFORMERS_CACHE", ModelCacheDir),
|
||||
],
|
||||
},
|
||||
services: [
|
||||
linuxserver.Service {
|
||||
suffix: "ui",
|
||||
spec: kube.SvcUtil.BasicHttpClusterIpSpec(WebPort)
|
||||
},
|
||||
],
|
||||
ports: [ kube.DeployUtil.ContainerPort("http", WebPort), ],
|
||||
pvcs: [
|
||||
linuxserver.Pvc {
|
||||
name: "data",
|
||||
mountPath: DataDir,
|
||||
bindName: $.dataClaimName,
|
||||
},
|
||||
],
|
||||
resources: {
|
||||
requests: {
|
||||
cpu: "100m",
|
||||
memory: "512Mi",
|
||||
},
|
||||
limits: {
|
||||
cpu: "500m",
|
||||
memory: "2Gi",
|
||||
},
|
||||
},
|
||||
livenessProbe: searchProbe(/*delaySeconds=*/60),
|
||||
readinessProbe: searchProbe(/*delaySeconds=*/60),
|
||||
},
|
||||
};
|
||||
|
||||
local App(params) =
|
||||
local baseApp = linuxserver.App(params.lsParams);
|
||||
baseApp {
|
||||
deployment+: {
|
||||
spec+: {
|
||||
template+: {
|
||||
spec+: {
|
||||
containers: [
|
||||
c { imagePullPolicy: "Always" }
|
||||
for c in super.containers
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
{
|
||||
Params: Params,
|
||||
WebPort: WebPort,
|
||||
App(params): App(params),
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
// A template for deploying a generic static website with Nginx.
|
||||
local kube = import "k8s/configs/base.libsonnet";
|
||||
local linuxserver = import "k8s/configs/templates/core/linuxserver.libsonnet";
|
||||
local images = import "k8s/configs/images.libsonnet";
|
||||
|
||||
local WebPort = 80;
|
||||
|
||||
local Params = kube.simpleFieldStruct([
|
||||
"namespace",
|
||||
"name",
|
||||
"filePath",
|
||||
"dataClaimName",
|
||||
]) {
|
||||
labels: {},
|
||||
gatekeeperSidecar: null,
|
||||
lsParams: linuxserver.AppParams {
|
||||
name: $.name,
|
||||
namespace: $.namespace,
|
||||
filePath: $.filePath,
|
||||
templatePath: std.thisFile,
|
||||
baseAppName: "static-site",
|
||||
imageName: "nginx:1.29.1-alpine",
|
||||
labels+: $.labels,
|
||||
gatekeeperSidecar: $.gatekeeperSidecar,
|
||||
services: [
|
||||
linuxserver.Service {
|
||||
suffix: "ui",
|
||||
spec: kube.SvcUtil.BasicHttpClusterIpSpec(WebPort)
|
||||
},
|
||||
],
|
||||
ports: [ kube.DeployUtil.ContainerPort("http", WebPort), ],
|
||||
pvcs: [
|
||||
linuxserver.Pvc{
|
||||
name: "static-content",
|
||||
mountPath: "/usr/share/nginx/html",
|
||||
bindName: $.dataClaimName,
|
||||
},
|
||||
],
|
||||
resources: {
|
||||
requests: {
|
||||
cpu: "10m",
|
||||
memory: "32Mi",
|
||||
},
|
||||
limits: {
|
||||
cpu: "50m",
|
||||
memory: "64Mi",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
local App(params) = linuxserver.App(params.lsParams);
|
||||
|
||||
{
|
||||
Params: Params,
|
||||
WebPort: WebPort,
|
||||
App(params): App(params),
|
||||
}
|
||||
97
experimental/users/acmcarther/llm/litellm/BUILD.bazel
Normal file
97
experimental/users/acmcarther/llm/litellm/BUILD.bazel
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
|
||||
py_binary(
|
||||
name = "litellm_agent",
|
||||
srcs = ["litellm_agent.py"],
|
||||
main = "litellm_agent.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
requirement("typer"),
|
||||
],
|
||||
)
|
||||
|
||||
py_pex_binary(
|
||||
name = "litellm_agent_pex",
|
||||
binary = ":litellm_agent",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_01_basic_connectivity",
|
||||
srcs = ["test_01_basic_connectivity.py"],
|
||||
main = "test_01_basic_connectivity.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_02_system_prompt",
|
||||
srcs = ["test_02_system_prompt.py"],
|
||||
main = "test_02_system_prompt.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_03_multi_turn",
|
||||
srcs = ["test_03_multi_turn.py"],
|
||||
main = "test_03_multi_turn.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_04_function_calling",
|
||||
srcs = ["test_04_function_calling.py"],
|
||||
main = "test_04_function_calling.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_04_debug_function_calling",
|
||||
srcs = ["test_04_debug_function_calling.py"],
|
||||
main = "test_04_debug_function_calling.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_05_error_handling",
|
||||
srcs = ["test_05_error_handling.py"],
|
||||
main = "test_05_error_handling.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "test_integration_comprehensive",
|
||||
srcs = ["test_integration_comprehensive.py"],
|
||||
main = "test_integration_comprehensive.py",
|
||||
visibility = ["//scripts:__pkg__"],
|
||||
deps = [
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
# Local Model Integration Findings & Requirements
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document summarizes the comprehensive testing of LiteLLM integration with local models (specifically Qwen3-Coder-30B) and provides requirements for retrofitting the agent harness to support local model execution.
|
||||
|
||||
## Test Results Overview
|
||||
|
||||
| Test | Status | Key Findings |
|
||||
|------|--------|--------------|
|
||||
| 1. Basic Connectivity | ✅ PASS | Stable connection, reliable responses |
|
||||
| 2. System Prompts | ✅ PASS | Excellent persona compliance and context adherence |
|
||||
| 3. Multi-Turn Conversation | ✅ PASS | Strong context retention and evolution capabilities |
|
||||
| 4. Function Calling | ⚠️ LIMITED | No native support, but manual JSON parsing works |
|
||||
| 5. Error Handling | ✅ PASS | Robust error recovery and graceful failure handling |
|
||||
|
||||
## Critical Technical Findings
|
||||
|
||||
### 1. Function Calling Limitation
|
||||
|
||||
**Issue**: The local model does not support native OpenAI-style function calling via the `tools` parameter.
|
||||
|
||||
**Impact**: This requires a fundamental change in how agent tools are invoked and processed.
|
||||
|
||||
**Solution**: Implement manual tool parsing with explicit JSON instruction prompts.
|
||||
|
||||
### 2. Response Quality
|
||||
|
||||
**Strengths**:
|
||||
- Excellent system prompt compliance
|
||||
- Strong conversational context management
|
||||
- High-quality code generation and analysis
|
||||
- Robust error handling
|
||||
|
||||
**Considerations**:
|
||||
- Slightly slower response times compared to cloud APIs (expected)
|
||||
- No rate limiting issues observed
|
||||
- Consistent behavior across multiple test runs
|
||||
|
||||
## Integration Requirements
|
||||
|
||||
### A. Harness Retrofit Requirements
|
||||
|
||||
#### 1. Tool System Redesign
|
||||
|
||||
**Current Implementation**:
|
||||
```python
|
||||
# Uses native function calling
|
||||
response = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tool_definitions # Native OpenAI format
|
||||
)
|
||||
```
|
||||
|
||||
**Required Implementation**:
|
||||
```python
|
||||
# Manual tool parsing
|
||||
tool_prompt = f"""
|
||||
Available tools: {json.dumps(tool_definitions)}
|
||||
When you need to use a tool, respond with JSON:
|
||||
{{"tool": "tool_name", "parameters": {{...}}}}
|
||||
"""
|
||||
|
||||
messages.append({"role": "system", "content": tool_prompt})
|
||||
response = litellm.completion(model=model, messages=messages)
|
||||
|
||||
# Manual parsing of tool calls
|
||||
if is_tool_call(response.content):
|
||||
tool_call = json.loads(response.content)
|
||||
result = execute_tool(tool_call)
|
||||
```
|
||||
|
||||
#### 2. Session Management Updates
|
||||
|
||||
**File**: `ai/harness/session.py`
|
||||
|
||||
**Required Changes**:
|
||||
- Replace `qwen` CLI calls with LiteLLM direct API calls
|
||||
- Update environment variable handling for local model configuration
|
||||
- Implement tool call parsing and execution loop
|
||||
- Add retry logic for failed tool calls
|
||||
|
||||
#### 3. Configuration Management
|
||||
|
||||
**New Environment Variables Required**:
|
||||
```bash
|
||||
OPENAI_API_BASE=http://192.168.0.236:1234/v1
|
||||
OPENAI_API_KEY=lm-studio
|
||||
LOCAL_MODEL_NAME=openai/qwen3-coder-30b-a3b-instruct-mlx
|
||||
```
|
||||
|
||||
### B. Agent Framework Updates
|
||||
|
||||
#### 1. Tool Suite Integration
|
||||
|
||||
**Current**: Direct function calls via multitool framework
|
||||
**Required**: JSON-based tool invocation with response parsing
|
||||
|
||||
#### 2. Persona System
|
||||
|
||||
**Status**: ✅ Fully compatible
|
||||
- No changes needed to persona files
|
||||
- System prompts work excellently
|
||||
- Agent behavior remains consistent
|
||||
|
||||
#### 3. Context Management
|
||||
|
||||
**Status**: ✅ Enhanced performance
|
||||
- Better context retention than some cloud models
|
||||
- Improved conversation flow
|
||||
- Strong memory across turns
|
||||
|
||||
## Implementation Roadmap
|
||||
|
||||
### Phase 1: Core Integration (High Priority)
|
||||
1. **Update SessionManager** to use LiteLLM instead of qwen CLI
|
||||
2. **Implement Tool Parser** for manual function call handling
|
||||
3. **Add Configuration Layer** for local model settings
|
||||
4. **Create Fallback Mechanism** to cloud models if local unavailable
|
||||
|
||||
### Phase 2: Tool System Migration (High Priority)
|
||||
1. **Refactor Agent Multitool** to support JSON-based invocation
|
||||
2. **Update Tool Definitions** for consistent JSON schema
|
||||
3. **Implement Tool Response Processing** loop
|
||||
4. **Add Error Recovery** for failed tool executions
|
||||
|
||||
### Phase 3: Optimization (Medium Priority)
|
||||
1. **Implement Response Caching** for common queries
|
||||
2. **Add Request Batching** for improved performance
|
||||
3. **Create Monitoring Dashboard** for local model health
|
||||
4. **Optimize Prompt Engineering** for local model characteristics
|
||||
|
||||
### Phase 4: Advanced Features (Low Priority)
|
||||
1. **Model Switching** capabilities
|
||||
2. **Load Balancing** across multiple local instances
|
||||
3. **Performance Analytics** and optimization
|
||||
4. **Custom Tool Development** for local model strengths
|
||||
|
||||
## Risk Assessment
|
||||
|
||||
### High Risks
|
||||
1. **Tool Compatibility**: Manual parsing may not cover all edge cases
|
||||
2. **Performance**: Local model may be slower for complex queries
|
||||
3. **Maintenance**: Additional complexity in tool management
|
||||
|
||||
### Medium Risks
|
||||
1. **Reliability**: Local model availability depends on local infrastructure
|
||||
2. **Consistency**: Response patterns may differ from cloud models
|
||||
3. **Debugging**: More complex error scenarios to handle
|
||||
|
||||
### Low Risks
|
||||
1. **Quality**: Response quality is excellent and consistent
|
||||
2. **Integration**: LiteLLM provides stable abstraction layer
|
||||
3. **Scalability**: Can be scaled with additional hardware
|
||||
|
||||
## Testing Recommendations
|
||||
|
||||
### Continuous Testing
|
||||
1. **Automated Test Suite**: Run all 5 test scripts on each deployment
|
||||
2. **Performance Benchmarks**: Track response times and quality metrics
|
||||
3. **Load Testing**: Validate behavior under concurrent usage
|
||||
4. **Failover Testing**: Ensure fallback mechanisms work correctly
|
||||
|
||||
### Monitoring Requirements
|
||||
1. **Response Time Monitoring**: Alert on performance degradation
|
||||
2. **Error Rate Tracking**: Monitor for increased failure rates
|
||||
3. **Resource Usage**: Track CPU/memory usage of local model
|
||||
4. **Quality Metrics**: Periodic evaluation of response quality
|
||||
|
||||
## Conclusion
|
||||
|
||||
The local model integration is **highly viable** with excellent response quality and robust error handling. The primary challenge is the lack of native function calling support, which requires a well-designed manual parsing system.
|
||||
|
||||
**Recommendation**: Proceed with Phase 1 implementation immediately, as the core functionality is proven to work reliably. The tool system redesign (Phase 2) should be prioritized as it's critical for agent functionality.
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Review and approve** this integration plan
|
||||
2. **Begin Phase 1 implementation** with SessionManager updates
|
||||
3. **Develop tool parsing prototype** based on test findings
|
||||
4. **Create integration test suite** combining all 5 test scenarios
|
||||
5. **Plan gradual migration** strategy for existing agents
|
||||
|
||||
---
|
||||
|
||||
*Document generated: 2025-10-22*
|
||||
*Test suite location: `scripts/test_*.py`*
|
||||
*Integration target: `ai/harness/` framework*
|
||||
76
experimental/users/acmcarther/llm/litellm/TEMP_MEMORY.md
Normal file
76
experimental/users/acmcarther/llm/litellm/TEMP_MEMORY.md
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
# Session Summary: LiteLLM Integration
|
||||
|
||||
## Overview
|
||||
Successfully integrated LiteLLM as an alternative AI provider to Google GenAI in the agent_multitool, with LiteLLM now set as the default provider.
|
||||
|
||||
## Key Changes Made
|
||||
|
||||
### 1. Updated scripts/litellm_agent.py
|
||||
- **Purpose**: Standalone LiteLLM client for testing and direct usage
|
||||
- **Key Changes**:
|
||||
- Updated default model to `openai/qwen3-coder-30b-a3b-instruct-mlx`
|
||||
- Updated default API base to `http://192.168.0.236:1234/v1`
|
||||
- Updated default API key to `lm-studio`
|
||||
- Added `typer` dependency to BUILD.bazel
|
||||
- **Status**: Working and tested successfully
|
||||
- **Build Target**: `//scripts:litellm_agent` (py_binary) and `//scripts:litellm_agent_pex` (pex)
|
||||
|
||||
### 2. Updated ai/harness/tool_suites/agent_multitool/
|
||||
- **Purpose**: Main agent multitool with provider selection
|
||||
- **Files Modified**:
|
||||
- `BUILD.bazel`: Added `litellm` dependency
|
||||
- `commands/microagent.py`: Added LiteLLM integration
|
||||
|
||||
#### microagent.py Changes:
|
||||
- Added `import litellm`
|
||||
- Created `invoke_agent_with_litellm()` function (mirrors Google GenAI function)
|
||||
- Added CLI options:
|
||||
- `--provider`: Choose between "google" or "litellm" (default: "litellm")
|
||||
- `--api-base`: API base URL for LiteLLM (default: "http://192.168.0.236:1234/v1")
|
||||
- `--api-key`: API key for LiteLLM (default: "lm-studio")
|
||||
- Updated default model to `openai/qwen3-coder-30b-a3b-instruct-mlx`
|
||||
- Updated default provider to `litellm`
|
||||
- Modified invoke command logic to route to appropriate provider
|
||||
|
||||
## Build Targets
|
||||
- `//scripts:litellm_agent` - Standalone LiteLLM client (working)
|
||||
- `//scripts:litellm_agent_pex` - PEX version (has Python 3.12 constraint issues)
|
||||
- `//ai/harness/tool_suites/agent_multitool:main` - Main multitool with LiteLLM integration (working)
|
||||
|
||||
## Testing Results
|
||||
- ✅ Standalone litellm_agent works: `./bazel-bin/scripts/litellm_agent noop "Please return OK"` → "OK"
|
||||
- ✅ Agent multitool with LiteLLM works: `./bazel-bin/ai/harness/tool_suites/agent_multitool/main microagent invoke noop "Please return OK"` → "OK"
|
||||
- ✅ Provider selection works: `--provider litellm` and `--provider google` options functional
|
||||
- ❌ Google provider needs API key (expected)
|
||||
- ❌ PEX version has Python 3.12 constraint vs system 3.13
|
||||
|
||||
## Configuration Details
|
||||
- **LM Studio Server**: `http://192.168.0.236:1234/v1`
|
||||
- **Model**: `qwen3-coder-30b-a3b-instruct-mlx` (requires `openai/` prefix for LiteLLM)
|
||||
- **API Key**: `lm-studio`
|
||||
- **Default Provider**: LiteLLM (changed from Google)
|
||||
|
||||
## Next Steps / Future Work
|
||||
1. Fix PEX Python version constraint or use alternative packaging
|
||||
2. Consider adding environment variable support for API configuration
|
||||
3. Potentially add more LiteLLM provider options (different models/endpoints)
|
||||
4. Test with actual Google API key when available
|
||||
5. Consider making LiteLLM the only provider if Google GenAI is deprecated
|
||||
|
||||
## File Locations
|
||||
- Standalone script: `/Users/acmcarther/Projects/yesod/scripts/litellm_agent.py`
|
||||
- Multitool integration: `/Users/acmcarther/Projects/yesod/ai/harness/tool_suites/agent_multitool/commands/microagent.py`
|
||||
- BUILD files: Both locations have updated dependencies
|
||||
|
||||
## Dependencies Added
|
||||
- `litellm` - Main LiteLLM library
|
||||
- `typer` - CLI framework (was missing from scripts BUILD.bazel)
|
||||
|
||||
## Architecture Notes
|
||||
- Both implementations share similar function signatures for consistency
|
||||
- LiteLLM uses OpenAI-compatible message format (system/user roles)
|
||||
- Google GenAI uses concatenated prompt approach
|
||||
- Provider selection is handled at CLI level with routing logic
|
||||
|
||||
---
|
||||
*Session completed successfully - LiteLLM integration is functional and default*
|
||||
148
experimental/users/acmcarther/llm/litellm/litellm_agent.py
Normal file
148
experimental/users/acmcarther/llm/litellm/litellm_agent.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
import typer
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
import os
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
# --- Configuration ---
|
||||
PROJECT_ROOT = Path(os.getcwd())
|
||||
MICROAGENTS_DIR = PROJECT_ROOT / "ai" / "process" / "microagents"
|
||||
|
||||
def invoke_agent_with_litellm(
|
||||
agent_name: str,
|
||||
user_prompt: str,
|
||||
context_files: list[str] = [],
|
||||
model_name: str = "openai/gpt-3.5-turbo",
|
||||
api_base: str | None = None,
|
||||
api_key: str | None = None,
|
||||
system_prompt_path: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Invokes a microagent using the litellm library with OpenAI-compatible providers.
|
||||
"""
|
||||
# 1. Load API Key and Base URL
|
||||
dotenv_path = PROJECT_ROOT / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Use provided parameters or fall back to environment variables
|
||||
api_key = api_key or os.getenv('OPENAI_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY not found in environment or .env file.")
|
||||
|
||||
api_base = api_base or os.getenv('OPENAI_API_BASE')
|
||||
if not api_base:
|
||||
raise ValueError("OPENAI_API_BASE not found in environment or .env file.")
|
||||
|
||||
# 2. Load System Prompt
|
||||
if system_prompt_path:
|
||||
system_prompt_file = Path(system_prompt_path)
|
||||
else:
|
||||
system_prompt_file = MICROAGENTS_DIR / f"{agent_name.lower()}.md"
|
||||
|
||||
if not system_prompt_file.exists():
|
||||
raise FileNotFoundError(f"System prompt not found for agent '{agent_name}' at {system_prompt_file}")
|
||||
system_prompt = system_prompt_file.read_text()
|
||||
|
||||
# 3. Construct Full User Prompt
|
||||
full_user_prompt = ""
|
||||
if context_files:
|
||||
for file_path in context_files:
|
||||
try:
|
||||
p = Path(file_path)
|
||||
if not p.is_absolute():
|
||||
p = PROJECT_ROOT / p
|
||||
full_user_prompt += f"--- CONTEXT FILE: {p.name} ---\n"
|
||||
try:
|
||||
full_user_prompt += p.read_text() + "\n\n"
|
||||
except UnicodeDecodeError:
|
||||
full_user_prompt += "[Binary file - content not displayed]\n\n"
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Context file not found: {file_path}")
|
||||
except Exception as e:
|
||||
raise IOError(f"Error reading context file {file_path}: {e}")
|
||||
|
||||
full_user_prompt += "--- USER PROMPT ---\n"
|
||||
full_user_prompt += user_prompt
|
||||
|
||||
# 4. Construct Messages for litellm
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": full_user_prompt}
|
||||
]
|
||||
|
||||
# 5. Invoke Model using litellm
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
|
||||
# Extract the response content
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
return response.choices[0].message.content
|
||||
else:
|
||||
return f"Unexpected response format: {response}"
|
||||
|
||||
except Exception as e:
|
||||
return f"Error generating response: {e}\n\nFull response object:\n{response if 'response' in locals() else 'No response generated'}"
|
||||
|
||||
|
||||
@app.command()
|
||||
def invoke(
|
||||
agent_name: str = typer.Argument(..., help="The name of the agent to invoke (e.g., 'librarian')."),
|
||||
user_prompt: Optional[str] = typer.Argument(None, help="The user's prompt for the agent. Required if --prompt-file is not used."),
|
||||
prompt_file: Optional[Path] = typer.Option(None, "--prompt-file", "-p", help="Path to a file containing the user's prompt."),
|
||||
context_file: Optional[List[Path]] = typer.Option(None, "--context-file", "-c", help="Path to a context file to prepend to the prompt. Can be specified multiple times."),
|
||||
# TODO: acmcarther@ - Disabled to test summarization performance.
|
||||
#model: str = typer.Option("openai/qwen3-coder-30b-a3b-instruct-mlx", help="The name of the model to use (e.g., 'openai/gpt-4', 'openai/claude-3-sonnet')."),
|
||||
model: str = typer.Option("openai/gpt-oss-120b", help="The name of the model to use (e.g., 'openai/gpt-4', 'openai/claude-3-sonnet')."),
|
||||
api_base: Optional[str] = typer.Option("http://192.168.0.235:1234/v1", "--api-base", help="The API base URL for the OpenAI-compatible provider. Defaults to OPENAI_API_BASE env var."),
|
||||
api_key: Optional[str] = typer.Option("lm-studio", "--api-key", help="The API key for the provider. Defaults to OPENAI_API_KEY env var."),
|
||||
):
|
||||
"""
|
||||
Invokes a specialized, single-purpose 'microagent' using litellm with OpenAI-compatible providers.
|
||||
"""
|
||||
if not user_prompt and not prompt_file:
|
||||
print("Error: Either a user prompt or a prompt file must be provided.")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if prompt_file:
|
||||
try:
|
||||
prompt_text = prompt_file.read_text()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Prompt file not found at {prompt_file}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
print(f"Error reading prompt file: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
elif user_prompt:
|
||||
prompt_text = user_prompt
|
||||
else:
|
||||
return
|
||||
|
||||
context_paths = [str(p) for p in context_file] if context_file else []
|
||||
|
||||
try:
|
||||
response = invoke_agent_with_litellm(
|
||||
agent_name=agent_name,
|
||||
user_prompt=prompt_text,
|
||||
context_files=context_paths,
|
||||
model_name=model,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
print(response)
|
||||
except (ValueError, FileNotFoundError, IOError) as e:
|
||||
print(f"Error: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 1: Basic API Connectivity Test
|
||||
Validates that we can establish a connection to the local model and get a response.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def test_basic_connectivity():
|
||||
"""
|
||||
Test basic connectivity to the local model using LiteLLM.
|
||||
This is the simplest possible test - just send a ping and get a pong.
|
||||
"""
|
||||
print("=== Test 1: Basic API Connectivity ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
print(f"API Base: {api_base}")
|
||||
print(f"Model: {model_name}")
|
||||
|
||||
# Simple test message
|
||||
try:
|
||||
print("\nSending simple ping...")
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": "Respond with exactly: 'Connection successful'"}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=50
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"Response: {result}")
|
||||
|
||||
# Basic validation
|
||||
if "successful" in result.lower():
|
||||
print("✅ Basic connectivity test PASSED")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Response received but unexpected content")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Connection failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_basic_connectivity()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 2: Single-Turn Conversation with System Prompt
|
||||
Validates that we can properly set system prompts and get contextually appropriate responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def test_system_prompt():
|
||||
"""
|
||||
Test that system prompts are properly respected and the model responds
|
||||
in character according to the system prompt.
|
||||
"""
|
||||
print("=== Test 2: System Prompt Compliance ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
print(f"API Base: {api_base}")
|
||||
print(f"Model: {model_name}")
|
||||
|
||||
# Test with a specific persona
|
||||
system_prompt = """
|
||||
You are a medieval alchemist named Magnus. You always speak in a formal, archaic tone
|
||||
and refer to modern concepts as if they were mystical alchemical processes.
|
||||
You believe programming is a form of transmutation and code is the philosopher's stone.
|
||||
Keep responses brief but in character.
|
||||
"""
|
||||
|
||||
user_prompt = "Explain what a function is in programming."
|
||||
|
||||
try:
|
||||
print("\nTesting system prompt compliance...")
|
||||
print(f"System: {system_prompt.strip()}")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nMagnus: {result}")
|
||||
|
||||
# Validate persona compliance
|
||||
alchemy_keywords = ['transmut', 'philosoph', 'stone', 'mystic', 'arcane', 'alchem']
|
||||
formal_keywords = ['indeed', 'verily', 'hark', 'pray', 'thus']
|
||||
|
||||
has_alchemy = any(keyword in result.lower() for keyword in alchemy_keywords)
|
||||
is_formal = len(result.split()) > 10 and not any(word in result.lower() for word in ['lol', 'yeah', 'ok'])
|
||||
|
||||
if has_alchemy or is_formal:
|
||||
print("✅ System prompt test PASSED - Model responded in character")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Response received but may not fully comply with system prompt")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ System prompt test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_agent_persona_simulation():
|
||||
"""
|
||||
Test using an actual agent persona file from our microagents directory.
|
||||
This simulates how we'd use the system in practice.
|
||||
"""
|
||||
print("\n=== Test 2b: Agent Persona Simulation ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
# Try to load an actual agent persona
|
||||
microagents_dir = project_root / "ai" / "process" / "microagents"
|
||||
persona_file = microagents_dir / "librarian.md"
|
||||
|
||||
if not persona_file.exists():
|
||||
print(f"⚠️ Persona file not found at {persona_file}, using fallback")
|
||||
system_prompt = "You are a helpful librarian who organizes information and provides structured responses."
|
||||
else:
|
||||
system_prompt = persona_file.read_text()
|
||||
print(f"Loaded persona from: {persona_file}")
|
||||
|
||||
user_prompt = "How should I organize documentation for a software project?"
|
||||
|
||||
try:
|
||||
print("\nTesting agent persona simulation...")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=300
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nLibrarian: {result}")
|
||||
|
||||
# Basic validation - should be structured and helpful
|
||||
if len(result) > 50 and ('organiz' in result.lower() or 'document' in result.lower()):
|
||||
print("✅ Agent persona simulation PASSED")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Response received but may not be fully in character")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Agent persona simulation failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success1 = test_system_prompt()
|
||||
success2 = test_agent_persona_simulation()
|
||||
|
||||
overall_success = success1 and success2
|
||||
print(f"\n=== Test 2 Summary ===")
|
||||
print(f"System Prompt Compliance: {'✅ PASS' if success1 else '❌ FAIL'}")
|
||||
print(f"Agent Persona Simulation: {'✅ PASS' if success2 else '❌ FAIL'}")
|
||||
print(f"Overall: {'✅ PASS' if overall_success else '❌ FAIL'}")
|
||||
|
||||
sys.exit(0 if overall_success else 1)
|
||||
184
experimental/users/acmcarther/llm/litellm/test_03_multi_turn.py
Normal file
184
experimental/users/acmcarther/llm/litellm/test_03_multi_turn.py
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 3: Multi-Turn Conversation with Context Management
|
||||
Validates that the model can maintain context across multiple conversation turns.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def test_context_retention():
|
||||
"""
|
||||
Test that the model can remember information from previous turns
|
||||
and reference it appropriately in subsequent responses.
|
||||
"""
|
||||
print("=== Test 3: Multi-Turn Context Management ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
print(f"API Base: {api_base}")
|
||||
print(f"Model: {model_name}")
|
||||
|
||||
# Multi-turn conversation
|
||||
system_prompt = "You are a helpful programming assistant. Remember details from our conversation and reference them when relevant."
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I'm working on a Python project called 'DataAnalyzer' that processes CSV files."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I understand you're working on a Python project called 'DataAnalyzer' for processing CSV files. That sounds like a data processing application. What specific features are you planning to implement?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to add data validation and export functionality. Can you suggest the best approach?"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
print("\nTesting multi-turn context retention...")
|
||||
|
||||
# Build the message list with system prompt
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.extend(conversation)
|
||||
|
||||
print("\nConversation so far:")
|
||||
for i, msg in enumerate(conversation):
|
||||
role = msg["role"].upper()
|
||||
content = msg["content"][:100] + "..." if len(msg["content"]) > 100 else msg["content"]
|
||||
print(f"{role}: {content}")
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=300
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nAssistant: {result}")
|
||||
|
||||
# Validate context retention
|
||||
context_keywords = ['dataanalyzer', 'csv', 'python', 'validation', 'export']
|
||||
context_retained = any(keyword.lower() in result.lower() for keyword in context_keywords)
|
||||
|
||||
if context_retained:
|
||||
print("✅ Context retention test PASSED - Model referenced previous conversation")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Response received but may not fully retain context")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Context retention test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_context_evolution():
|
||||
"""
|
||||
Test that the model can handle evolving context and update its understanding
|
||||
as new information is provided.
|
||||
"""
|
||||
print("\n=== Test 3b: Context Evolution ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
system_prompt = "You are a software architect. Track project requirements and update your recommendations as the scope evolves."
|
||||
|
||||
# Evolving conversation where requirements change
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I need to build a simple todo list web app with basic CRUD operations."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "For a simple todo list web app with CRUD operations, I'd recommend: 1) Frontend: React or Vue.js for the UI, 2) Backend: Node.js/Express or Python Flask, 3) Database: SQLite for simplicity. This gives you a clean separation of concerns and easy deployment options."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Actually, I need to scale this to support 10,000 concurrent users with real-time collaboration features."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "That's a significant scale increase! For 10,000 concurrent users with real-time collaboration, we need to rethink the architecture: 1) Frontend: React with WebSocket connections, 2) Backend: Node.js with Socket.IO for real-time features, 3) Database: PostgreSQL with connection pooling, 4) Caching layer: Redis for session management and real-time data sync. We'll also need load balancing and horizontal scaling capabilities."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the most critical component I should implement first?"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
print("\nTesting context evolution...")
|
||||
|
||||
# Build the message list with system prompt
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.extend(conversation)
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=250
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nArchitect: {result}")
|
||||
|
||||
# Validate that the response considers the scaled requirements
|
||||
scale_keywords = ['scal', 'concurrent', 'real-time', 'websocket', 'redis', 'load balanc']
|
||||
evolution_context = any(keyword.lower() in result.lower() for keyword in scale_keywords)
|
||||
|
||||
if evolution_context:
|
||||
print("✅ Context evolution test PASSED - Model adapted to evolved requirements")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Response received but may not fully account for evolved context")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Context evolution test failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success1 = test_context_retention()
|
||||
success2 = test_context_evolution()
|
||||
|
||||
overall_success = success1 and success2
|
||||
print(f"\n=== Test 3 Summary ===")
|
||||
print(f"Context Retention: {'✅ PASS' if success1 else '❌ FAIL'}")
|
||||
print(f"Context Evolution: {'✅ PASS' if success2 else '❌ FAIL'}")
|
||||
print(f"Overall: {'✅ PASS' if overall_success else '❌ FAIL'}")
|
||||
|
||||
sys.exit(0 if overall_success else 1)
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Debug version to understand what the model actually returns for function calling prompts.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def debug_function_calling():
|
||||
"""
|
||||
Debug what the model actually returns when prompted for function calls.
|
||||
"""
|
||||
print("=== Debug: Function Calling Response Analysis ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {"type": "string", "description": "The absolute path to the file to read"}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Test 1: Standard function calling prompt
|
||||
print("\n--- Test 1: Standard Function Calling ---")
|
||||
system_prompt = """
|
||||
You are an AI assistant with access to tools. When you need to use a tool, respond with ONLY a JSON object
|
||||
containing the function call. Do not include any explanatory text.
|
||||
"""
|
||||
|
||||
user_prompt = "Please read the file at /Users/acmcarther/Projects/infra2/README.md"
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"Raw response: '{result}'")
|
||||
print(f"Response type: {type(result)}")
|
||||
print(f"Response length: {len(result) if result else 0}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Test 2: More explicit instruction
|
||||
print("\n--- Test 2: Explicit JSON Instruction ---")
|
||||
system_prompt2 = """
|
||||
You are an AI assistant. When asked to perform actions, respond with a JSON object in this exact format:
|
||||
{"tool": "tool_name", "parameters": {"param": "value"}}
|
||||
|
||||
Available tools: read_file, list_directory
|
||||
|
||||
Example response: {"tool": "read_file", "parameters": {"file_path": "/path/to/file"}}
|
||||
|
||||
Do not include any text other than the JSON.
|
||||
"""
|
||||
|
||||
try:
|
||||
response2 = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt2},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response2, 'choices') and len(response2.choices) > 0:
|
||||
result2 = response2.choices[0].message.content
|
||||
print(f"Raw response: '{result2}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
# Test 3: Check if model understands tools at all
|
||||
print("\n--- Test 3: Tool Awareness Check ---")
|
||||
user_prompt2 = "What tools do you have access to? Please list them."
|
||||
|
||||
try:
|
||||
response3 = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an AI assistant with access to tools."},
|
||||
{"role": "user", "content": user_prompt2}
|
||||
],
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response3, 'choices') and len(response3.choices) > 0:
|
||||
result3 = response3.choices[0].message.content
|
||||
print(f"Tool awareness response: '{result3}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
debug_function_calling()
|
||||
|
|
@ -0,0 +1,329 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 4: Function Calling / Tool Usage Simulation
|
||||
Validates that the model can understand and generate function calls in the expected format.
|
||||
This is critical for agent tool integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def test_function_calling_basic():
|
||||
"""
|
||||
Test basic function calling capability - can the model generate proper
|
||||
function call JSON when prompted with available tools.
|
||||
"""
|
||||
print("=== Test 4: Function Calling / Tool Usage ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
print(f"API Base: {api_base}")
|
||||
print(f"Model: {model_name}")
|
||||
|
||||
# Define available tools (similar to our agent framework)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read the contents of a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the file to read"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "list_directory",
|
||||
"description": "List the contents of a directory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The absolute path to the directory"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
system_prompt = """
|
||||
You are an AI assistant with access to tools. When you need to use a tool, respond with ONLY a JSON object
|
||||
containing the function call. Do not include any explanatory text. The format should be:
|
||||
{"name": "function_name", "arguments": {"param": "value"}}
|
||||
|
||||
If you don't need to use a tool, respond normally.
|
||||
"""
|
||||
|
||||
user_prompt = "Please read the file at /Users/acmcarther/Projects/yesod/README.md"
|
||||
|
||||
try:
|
||||
print("\nTesting basic function calling...")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nAssistant: {result}")
|
||||
|
||||
# Try to parse as JSON function call
|
||||
try:
|
||||
function_call = json.loads(result)
|
||||
if "name" in function_call and "arguments" in function_call:
|
||||
print(f"✅ Function call generated: {function_call['name']}")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ JSON response but not in expected function call format")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
print("⚠️ Response is not valid JSON - model may not support function calling")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Function calling test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_tool_selection():
|
||||
"""
|
||||
Test that the model can select appropriate tools based on user requests.
|
||||
"""
|
||||
print("\n=== Test 4b: Tool Selection ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_files",
|
||||
"description": "Search for files matching a pattern",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Search pattern"},
|
||||
"directory": {"type": "string", "description": "Directory to search in"}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_command",
|
||||
"description": "Execute a shell command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "Command to execute"},
|
||||
"directory": {"type": "string", "description": "Working directory"}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
system_prompt = """
|
||||
You are an AI assistant with access to tools. When you need to use a tool, respond with ONLY a JSON object
|
||||
containing the function call. Do not include any explanatory text.
|
||||
"""
|
||||
|
||||
user_prompt = "Find all Python files in the current directory and then run git status"
|
||||
|
||||
try:
|
||||
print("\nTesting tool selection...")
|
||||
print(f"User: {user_prompt}")
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nAssistant: {result}")
|
||||
|
||||
# Check if it's trying to use the right tool
|
||||
try:
|
||||
function_call = json.loads(result)
|
||||
if "name" in function_call:
|
||||
tool_name = function_call["name"]
|
||||
if tool_name in ["search_files", "run_command"]:
|
||||
print(f"✅ Appropriate tool selected: {tool_name}")
|
||||
return True
|
||||
else:
|
||||
print(f"⚠️ Tool selected but may not be optimal: {tool_name}")
|
||||
return False
|
||||
else:
|
||||
print("⚠️ JSON response but missing 'name' field")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
print("⚠️ Response is not valid JSON")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Tool selection test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_function_calling_with_context():
|
||||
"""
|
||||
Test function calling in a conversational context where the model
|
||||
should use tools based on previous conversation context.
|
||||
"""
|
||||
print("\n=== Test 4c: Function Calling with Context ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
# Configuration
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {"type": "string", "description": "Path to the file"},
|
||||
"content": {"type": "string", "description": "Content to write"}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
system_prompt = """
|
||||
You are an AI assistant with access to tools. When you need to use a tool, respond with ONLY a JSON object
|
||||
containing the function call. Do not include any explanatory text.
|
||||
"""
|
||||
|
||||
conversation = [
|
||||
{"role": "user", "content": "I need to create a simple Python script that prints 'Hello World'"},
|
||||
{"role": "assistant", "content": "I'll help you create a Python script that prints 'Hello World'. Let me write it to a file for you."},
|
||||
{"role": "user", "content": "Please save it as hello.py in the current directory"}
|
||||
]
|
||||
|
||||
try:
|
||||
print("\nTesting function calling with context...")
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.extend(conversation)
|
||||
|
||||
response = litellm.completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
print(f"\nAssistant: {result}")
|
||||
|
||||
try:
|
||||
function_call = json.loads(result)
|
||||
if "name" in function_call and function_call["name"] == "write_file":
|
||||
args = function_call.get("arguments", {})
|
||||
if "hello.py" in args.get("file_path", ""):
|
||||
print("✅ Context-aware function call generated correctly")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Function call generated but with incorrect arguments")
|
||||
return False
|
||||
else:
|
||||
print("⚠️ Function call generated but not the expected tool")
|
||||
return False
|
||||
except json.JSONDecodeError:
|
||||
print("⚠️ Response is not valid JSON")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected response format: {response}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Context-aware function calling test failed: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success1 = test_function_calling_basic()
|
||||
success2 = test_tool_selection()
|
||||
success3 = test_function_calling_with_context()
|
||||
|
||||
overall_success = success1 and success2 and success3
|
||||
print(f"\n=== Test 4 Summary ===")
|
||||
print(f"Basic Function Calling: {'✅ PASS' if success1 else '❌ FAIL'}")
|
||||
print(f"Tool Selection: {'✅ PASS' if success2 else '❌ FAIL'}")
|
||||
print(f"Context-Aware Function Calling: {'✅ PASS' if success3 else '❌ FAIL'}")
|
||||
print(f"Overall: {'✅ PASS' if overall_success else '❌ FAIL'}")
|
||||
|
||||
if not overall_success:
|
||||
print("\n⚠️ NOTE: Function calling support may be limited in this model.")
|
||||
print("This could impact agent tool integration capabilities.")
|
||||
|
||||
sys.exit(0 if overall_success else 1)
|
||||
|
|
@ -0,0 +1,240 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test 5: Error Handling and Recovery
|
||||
Validates that the integration can handle various error conditions gracefully
|
||||
and provides meaningful feedback for debugging.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def test_invalid_api_endpoint():
|
||||
"""
|
||||
Test behavior when the API endpoint is unreachable.
|
||||
"""
|
||||
print("=== Test 5: Error Handling and Recovery ===")
|
||||
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
print("\n--- Test 5a: Invalid API Endpoint ---")
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="openai/qwen3-coder-30b-a3b-instruct-mlx",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
api_key="test-key",
|
||||
api_base="http://invalid-endpoint:1234/v1",
|
||||
max_tokens=50
|
||||
)
|
||||
print("❌ Expected connection error but got response")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✅ Correctly caught error: {type(e).__name__}: {str(e)[:100]}...")
|
||||
return True
|
||||
|
||||
def test_invalid_model_name():
|
||||
"""
|
||||
Test behavior with an invalid model name.
|
||||
"""
|
||||
print("\n--- Test 5b: Invalid Model Name ---")
|
||||
|
||||
# Use valid endpoint but invalid model
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="invalid-model-name",
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=50
|
||||
)
|
||||
print("❌ Expected model error but got response")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✅ Correctly caught error: {type(e).__name__}: {str(e)[:100]}...")
|
||||
return True
|
||||
|
||||
def test_timeout_handling():
|
||||
"""
|
||||
Test behavior with very short timeout.
|
||||
"""
|
||||
print("\n--- Test 5c: Timeout Handling ---")
|
||||
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = litellm.completion(
|
||||
model="openai/qwen3-coder-30b-a3b-instruct-mlx",
|
||||
messages=[
|
||||
{"role": "user", "content": "Write a very long detailed story"}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=1000,
|
||||
timeout=1 # Very short timeout
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
print(f"⚠️ Request completed in {elapsed:.2f}s (timeout may not be supported)")
|
||||
return True # Pass if timeout isn't supported
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
if "timeout" in str(e).lower():
|
||||
print(f"✅ Correctly timed out after {elapsed:.2f}s")
|
||||
else:
|
||||
print(f"⚠️ Got different error (may be expected): {type(e).__name__}")
|
||||
return True
|
||||
|
||||
def test_malformed_request():
|
||||
"""
|
||||
Test behavior with malformed request data.
|
||||
"""
|
||||
print("\n--- Test 5d: Malformed Request ---")
|
||||
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
|
||||
try:
|
||||
# Test with empty messages
|
||||
response = litellm.completion(
|
||||
model="openai/qwen3-coder-30b-a3b-instruct-mlx",
|
||||
messages=[], # Empty messages
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=50
|
||||
)
|
||||
print("❌ Expected validation error but got response")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"✅ Correctly caught error: {type(e).__name__}: {str(e)[:100]}...")
|
||||
return True
|
||||
|
||||
def test_recovery_after_error():
|
||||
"""
|
||||
Test that the system can recover after an error and continue working.
|
||||
"""
|
||||
print("\n--- Test 5e: Recovery After Error ---")
|
||||
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
|
||||
# First, make a request that will fail
|
||||
try:
|
||||
litellm.completion(
|
||||
model="invalid-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except:
|
||||
pass # Expected to fail
|
||||
|
||||
# Now try a valid request
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="openai/qwen3-coder-30b-a3b-instruct-mlx",
|
||||
messages=[
|
||||
{"role": "user", "content": "Respond with: 'Recovery successful'"}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=50
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
if "recovery" in result.lower() or "successful" in result.lower():
|
||||
print("✅ Recovery test PASSED - System recovered after error")
|
||||
return True
|
||||
else:
|
||||
print(f"⚠️ Got response but unexpected content: {result}")
|
||||
return False
|
||||
else:
|
||||
print("❌ Unexpected response format")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Recovery failed: {e}")
|
||||
return False
|
||||
|
||||
def test_rate_limiting():
|
||||
"""
|
||||
Test behavior with multiple rapid requests.
|
||||
"""
|
||||
print("\n--- Test 5f: Rate Limiting ---")
|
||||
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
# Make 5 rapid requests
|
||||
for i in range(5):
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model="openai/qwen3-coder-30b-a3b-instruct-mlx",
|
||||
messages=[
|
||||
{"role": "user", "content": f"Request {i+1}: Respond with 'OK'"}
|
||||
],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_tokens=10
|
||||
)
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"Request {i+1} failed: {type(e).__name__}")
|
||||
|
||||
print(f"Rapid requests: {success_count} successful, {error_count} failed")
|
||||
|
||||
if success_count >= 3: # At least some should succeed
|
||||
print("✅ Rate limiting test PASSED - System handled rapid requests")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Too many failures - may indicate rate limiting issues")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [
|
||||
test_invalid_api_endpoint,
|
||||
test_invalid_model_name,
|
||||
test_timeout_handling,
|
||||
test_malformed_request,
|
||||
test_recovery_after_error,
|
||||
test_rate_limiting
|
||||
]
|
||||
|
||||
results = []
|
||||
for test in tests:
|
||||
try:
|
||||
result = test()
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"❌ Test {test.__name__} crashed: {e}")
|
||||
results.append(False)
|
||||
|
||||
overall_success = sum(results) >= len(results) * 0.7 # 70% pass rate
|
||||
|
||||
print(f"\n=== Test 5 Summary ===")
|
||||
for i, (test, result) in enumerate(zip(tests, results)):
|
||||
status = "✅ PASS" if result else "❌ FAIL"
|
||||
print(f"{test.__name__}: {status}")
|
||||
|
||||
print(f"Overall: {'✅ PASS' if overall_success else '❌ FAIL'}")
|
||||
print(f"Passed: {sum(results)}/{len(results)} tests")
|
||||
|
||||
sys.exit(0 if overall_success else 1)
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Integration Test
|
||||
Combines all validated capabilities into a single end-to-end test that simulates
|
||||
real agent usage patterns with the local model.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
class LocalModelAgent:
|
||||
"""
|
||||
Simulates the proposed agent harness integration with local model.
|
||||
This demonstrates how the retrofitted system would work.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Load environment
|
||||
project_root = Path(__file__).parent.parent
|
||||
dotenv_path = project_root / '.env'
|
||||
load_dotenv(dotenv_path=dotenv_path)
|
||||
|
||||
self.api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
self.api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
self.model_name = "openai/qwen3-coder-30b-a3b-instruct-mlx"
|
||||
|
||||
# Define available tools (manual function calling)
|
||||
self.tools = {
|
||||
"read_file": {
|
||||
"description": "Read the contents of a file",
|
||||
"parameters": {"file_path": "string (absolute path)"}
|
||||
},
|
||||
"list_directory": {
|
||||
"description": "List contents of a directory",
|
||||
"parameters": {"path": "string (absolute path)"}
|
||||
},
|
||||
"write_file": {
|
||||
"description": "Write content to a file",
|
||||
"parameters": {"file_path": "string", "content": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
def _create_tool_prompt(self):
|
||||
"""
|
||||
Creates the system prompt for manual tool usage.
|
||||
"""
|
||||
tools_json = json.dumps(self.tools, indent=2)
|
||||
return f"""
|
||||
You are an AI assistant with access to the following tools:
|
||||
|
||||
{tools_json}
|
||||
|
||||
When you need to use a tool, respond with ONLY a JSON object in this format:
|
||||
{{"tool": "tool_name", "parameters": {{"param_name": "value"}}}}
|
||||
|
||||
Do not include any explanatory text. If you don't need to use a tool, respond normally.
|
||||
"""
|
||||
|
||||
def _is_tool_call(self, response):
|
||||
"""
|
||||
Checks if a response is a tool call.
|
||||
"""
|
||||
try:
|
||||
parsed = json.loads(response.strip())
|
||||
return isinstance(parsed, dict) and "tool" in parsed
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
|
||||
def _execute_tool(self, tool_call):
|
||||
"""
|
||||
Simulates tool execution (in real implementation, this would call actual tools).
|
||||
"""
|
||||
tool_name = tool_call["tool"]
|
||||
parameters = tool_call.get("parameters", {})
|
||||
|
||||
if tool_name == "read_file":
|
||||
file_path = parameters.get("file_path", "")
|
||||
return f"[SIMULATED] Read file: {file_path} - Content would be displayed here"
|
||||
|
||||
elif tool_name == "list_directory":
|
||||
path = parameters.get("path", "")
|
||||
return f"[SIMULATED] Directory listing for: {path} - Files would be listed here"
|
||||
|
||||
elif tool_name == "write_file":
|
||||
file_path = parameters.get("file_path", "")
|
||||
content = parameters.get("content", "")[:50] + "..."
|
||||
return f"[SIMULATED] Wrote to {file_path}: {content}"
|
||||
|
||||
else:
|
||||
return f"[ERROR] Unknown tool: {tool_name}"
|
||||
|
||||
def converse(self, user_message, conversation_history=None):
|
||||
"""
|
||||
Handles a conversation turn with manual tool calling support.
|
||||
"""
|
||||
if conversation_history is None:
|
||||
conversation_history = []
|
||||
|
||||
# Build messages
|
||||
messages = [
|
||||
{"role": "system", "content": self._create_tool_prompt()},
|
||||
*conversation_history,
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
max_tokens=300
|
||||
)
|
||||
|
||||
if hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
result = response.choices[0].message.content
|
||||
|
||||
# Check if this is a tool call
|
||||
if self._is_tool_call(result):
|
||||
tool_call = json.loads(result.strip())
|
||||
tool_result = self._execute_tool(tool_call)
|
||||
|
||||
# Continue conversation with tool result
|
||||
messages.append({"role": "assistant", "content": result})
|
||||
messages.append({"role": "user", "content": f"Tool result: {tool_result}"})
|
||||
|
||||
# Get final response
|
||||
final_response = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
max_tokens=300
|
||||
)
|
||||
|
||||
if hasattr(final_response, 'choices') and len(final_response.choices) > 0:
|
||||
return final_response.choices[0].message.content, True # True = tool was used
|
||||
|
||||
return result, False # False = no tool used
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: {e}", False
|
||||
|
||||
def test_comprehensive_integration():
|
||||
"""
|
||||
Runs a comprehensive integration test simulating real agent usage.
|
||||
"""
|
||||
print("=== Comprehensive Integration Test ===")
|
||||
print("Simulating retrofitted agent harness with local model...\n")
|
||||
|
||||
agent = LocalModelAgent()
|
||||
conversation_history = []
|
||||
tools_used = 0
|
||||
|
||||
# Test scenarios
|
||||
test_scenarios = [
|
||||
{
|
||||
"name": "System Prompt Compliance",
|
||||
"message": "Hi, I need help organizing my project documentation. Can you act as a documentation specialist and give me advice?"
|
||||
},
|
||||
{
|
||||
"name": "Tool Usage - File Reading",
|
||||
"message": "Please read the file at /Users/acmcarther/Projects/yesod/README.md to understand the project structure."
|
||||
},
|
||||
{
|
||||
"name": "Context Retention",
|
||||
"message": "Based on what you just read, what do you think the main purpose of this project is?"
|
||||
},
|
||||
{
|
||||
"name": "Tool Usage - Directory Listing",
|
||||
"message": "Now list the contents of the /Users/acmcarther/Projects/yesod/scripts directory to see what test files we have."
|
||||
},
|
||||
{
|
||||
"name": "Complex Task with Multiple Tools",
|
||||
"message": "Create a summary document called 'project_summary.md' that includes the project purpose and the test files available."
|
||||
}
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for i, scenario in enumerate(test_scenarios, 1):
|
||||
print(f"--- Test {i}: {scenario['name']} ---")
|
||||
print(f"User: {scenario['message']}")
|
||||
|
||||
response, used_tool = agent.converse(scenario['message'], conversation_history)
|
||||
print(f"Agent: {response[:200]}{'...' if len(response) > 200 else ''}")
|
||||
|
||||
if used_tool:
|
||||
tools_used += 1
|
||||
print("🔧 Tool was used in this response")
|
||||
|
||||
# Add to conversation history
|
||||
conversation_history.append({"role": "user", "content": scenario['message']})
|
||||
conversation_history.append({"role": "assistant", "content": response})
|
||||
|
||||
# Simple validation
|
||||
if len(response) > 20 and not response.startswith("Error:"):
|
||||
print("✅ Test passed")
|
||||
results.append(True)
|
||||
else:
|
||||
print("❌ Test failed")
|
||||
results.append(False)
|
||||
|
||||
print()
|
||||
|
||||
# Summary
|
||||
success_rate = sum(results) / len(results)
|
||||
print(f"=== Integration Test Summary ===")
|
||||
print(f"Tests passed: {sum(results)}/{len(results)} ({success_rate:.1%})")
|
||||
print(f"Tools used: {tools_used}")
|
||||
print(f"Conversation turns: {len(conversation_history) // 2}")
|
||||
|
||||
if success_rate >= 0.8:
|
||||
print("✅ Comprehensive integration test PASSED")
|
||||
print("\nThe local model integration is ready for production retrofitting.")
|
||||
return True
|
||||
else:
|
||||
print("❌ Comprehensive integration test FAILED")
|
||||
print("\nAdditional refinement needed before production deployment.")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_comprehensive_integration()
|
||||
sys.exit(0 if success else 1)
|
||||
99
experimental/users/acmcarther/llm/litellm_grpc/BUILD.bazel
Normal file
99
experimental/users/acmcarther/llm/litellm_grpc/BUILD.bazel
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary")
|
||||
load("@rules_go//go:def.bzl", "go_binary", "go_library")
|
||||
load("@build_stack_rules_proto//rules:proto_compile.bzl", "proto_compile")
|
||||
load("@build_stack_rules_proto//rules/py:grpc_py_library.bzl", "grpc_py_library")
|
||||
load("@build_stack_rules_proto//rules/py:proto_py_library.bzl", "proto_py_library")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "litellm_proto",
|
||||
srcs = ["litellm.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
proto_compile(
|
||||
name = "litellm_python_compile",
|
||||
outputs = [
|
||||
"litellm_pb2.py",
|
||||
"litellm_pb2.pyi",
|
||||
"litellm_pb2_grpc.py",
|
||||
],
|
||||
plugins = [
|
||||
"@build_stack_rules_proto//plugin/builtin:pyi",
|
||||
"@build_stack_rules_proto//plugin/builtin:python",
|
||||
"@build_stack_rules_proto//plugin/grpc/grpc:protoc-gen-grpc-python",
|
||||
],
|
||||
proto = ":litellm_proto",
|
||||
)
|
||||
|
||||
proto_py_library(
|
||||
name = "litellm_proto_py_lib",
|
||||
srcs = ["litellm_pb2.py"],
|
||||
deps = ["@com_google_protobuf//:protobuf_python"],
|
||||
)
|
||||
|
||||
grpc_py_library(
|
||||
name = "litellm_grpc_py_library",
|
||||
srcs = ["litellm_pb2_grpc.py"],
|
||||
deps = [
|
||||
requirement("grpcio"),
|
||||
":litellm_proto_py_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "server_main",
|
||||
srcs = ["server_main.py"],
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
deps = [
|
||||
":litellm_grpc_py_library",
|
||||
":litellm_proto_py_lib",
|
||||
requirement("grpcio"),
|
||||
requirement("litellm"),
|
||||
requirement("python-dotenv"),
|
||||
requirement("asyncio"),
|
||||
],
|
||||
)
|
||||
|
||||
proto_compile(
|
||||
name = "litellm_go_compile",
|
||||
output_mappings = [
|
||||
"litellm.pb.go=forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/litellm_grpc/litellm.pb.go",
|
||||
"litellm_grpc.pb.go=forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/litellm_grpc/litellm_grpc.pb.go",
|
||||
],
|
||||
outputs = [
|
||||
"litellm.pb.go",
|
||||
"litellm_grpc.pb.go",
|
||||
],
|
||||
plugins = [
|
||||
"@build_stack_rules_proto//plugin/golang/protobuf:protoc-gen-go",
|
||||
"@build_stack_rules_proto//plugin/grpc/grpc-go:protoc-gen-go-grpc",
|
||||
],
|
||||
proto = ":litellm_proto",
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "litellm_go_proto",
|
||||
srcs = [":litellm_go_compile"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/litellm_grpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//codes",
|
||||
"@org_golang_google_grpc//status",
|
||||
"@org_golang_google_protobuf//reflect/protoreflect",
|
||||
"@org_golang_google_protobuf//runtime/protoimpl",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "client_go",
|
||||
srcs = ["main.go"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":litellm_go_proto",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//credentials/insecure",
|
||||
],
|
||||
)
|
||||
66
experimental/users/acmcarther/llm/litellm_grpc/README.md
Normal file
66
experimental/users/acmcarther/llm/litellm_grpc/README.md
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
# LiteLLM gRPC Example
|
||||
|
||||
This directory contains an example of wrapping [LiteLLM](https://github.com/BerriAI/litellm) in a gRPC service and consuming it with a Go client.
|
||||
|
||||
## Components
|
||||
|
||||
1. **`litellm.proto`**: Defines the `LiteLLMService` with `Chat` and `StreamChat` methods.
|
||||
2. **`server_main.py`**: A Python gRPC server that implements the service using `litellm`.
|
||||
3. **`main.go`**: A Go client that calls the service.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Bazel
|
||||
- A `.env` file with your LLM API keys/base (optional, but recommended if you aren't using default OpenAI).
|
||||
|
||||
## Setup
|
||||
|
||||
Create a `.env` file in this directory (or anywhere `python-dotenv` can find it, though running via Bazel requires care with file placement).
|
||||
Alternatively, you can export environment variables before running the server (but Bazel sanitizes envs).
|
||||
For local development, it's often easiest to run the python binary directly from `bazel-bin` or use `bazel run` with `--action_env`.
|
||||
|
||||
Example `.env`:
|
||||
```
|
||||
OPENAI_API_KEY=your_key
|
||||
OPENAI_API_BASE=http://localhost:1234/v1
|
||||
```
|
||||
|
||||
## Running the Server
|
||||
|
||||
```bash
|
||||
bazel run //experimental/users/acmcarther/litellm_grpc:server_main
|
||||
```
|
||||
|
||||
To pass environment variables:
|
||||
|
||||
```bash
|
||||
bazel run --action_env=OPENAI_API_KEY=$OPENAI_API_KEY --action_env=OPENAI_API_BASE=$OPENAI_API_BASE //experimental/users/acmcarther/litellm_grpc:server_main
|
||||
```
|
||||
|
||||
## Running the Client
|
||||
|
||||
Open a new terminal.
|
||||
|
||||
To run a basic chat:
|
||||
|
||||
```bash
|
||||
bazel run //experimental/users/acmcarther/litellm_grpc:client_go -- --prompt "Tell me a joke" --model "gpt-3.5-turbo"
|
||||
```
|
||||
|
||||
To run with streaming:
|
||||
|
||||
```bash
|
||||
bazel run //experimental/users/acmcarther/litellm_grpc:client_go -- --prompt "Write a poem" --stream
|
||||
```
|
||||
|
||||
To run embeddings:
|
||||
|
||||
```bash
|
||||
bazel run //experimental/users/acmcarther/litellm_grpc:client_go -- --embedding_input "This is a test sentence." --model "openai/qwen3-embedding-8b-dwq"
|
||||
```
|
||||
|
||||
## Customizing
|
||||
|
||||
- Edit `litellm.proto` to add more fields (e.g., `top_p`, `presence_penalty`).
|
||||
- Update `server_main.py` to pass these fields to `litellm.completion`.
|
||||
- Update `main.go` to support setting these fields via flags.
|
||||
48
experimental/users/acmcarther/llm/litellm_grpc/litellm.proto
Normal file
48
experimental/users/acmcarther/llm/litellm_grpc/litellm.proto
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package experimental.users.acmcarther.llm.litellm_grpc;
|
||||
|
||||
option go_package = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/litellm_grpc";
|
||||
|
||||
message Message {
|
||||
string role = 1;
|
||||
string content = 2;
|
||||
}
|
||||
|
||||
message ChatRequest {
|
||||
string model = 1;
|
||||
repeated Message messages = 2;
|
||||
float temperature = 3;
|
||||
int32 max_tokens = 4;
|
||||
}
|
||||
|
||||
message ChatResponse {
|
||||
string content = 1;
|
||||
string role = 2;
|
||||
string finish_reason = 3;
|
||||
}
|
||||
|
||||
message EmbeddingRequest {
|
||||
string model = 1;
|
||||
repeated string inputs = 2;
|
||||
}
|
||||
|
||||
message EmbeddingResponse {
|
||||
message Embedding {
|
||||
repeated float values = 1;
|
||||
int32 index = 2;
|
||||
}
|
||||
repeated Embedding embeddings = 1;
|
||||
string model = 2;
|
||||
message Usage {
|
||||
int32 prompt_tokens = 1;
|
||||
int32 total_tokens = 2;
|
||||
}
|
||||
Usage usage = 3;
|
||||
}
|
||||
|
||||
service LiteLLMService {
|
||||
rpc Chat(ChatRequest) returns (ChatResponse) {}
|
||||
rpc StreamChat(ChatRequest) returns (stream ChatResponse) {}
|
||||
rpc Embed(EmbeddingRequest) returns (EmbeddingResponse) {}
|
||||
}
|
||||
110
experimental/users/acmcarther/llm/litellm_grpc/main.go
Normal file
110
experimental/users/acmcarther/llm/litellm_grpc/main.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
pb "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/litellm_grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
prompt = flag.String("prompt", "Hello, how are you?", "The prompt to send to the LLM")
|
||||
model = flag.String("model", "gpt-3.5-turbo", "The model to use")
|
||||
stream = flag.Bool("stream", false, "Use streaming API")
|
||||
maxTokens = flag.Int("max_tokens", 100, "Max tokens to generate")
|
||||
temperature = flag.Float64("temperature", 0.7, "Temperature")
|
||||
embeddingInput = flag.String("embedding_input", "", "Text to embed (if set, runs embedding instead of chat)")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
// Set up a connection to the server.
|
||||
conn, err := grpc.Dial(*addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewLiteLLMServiceClient(conn)
|
||||
|
||||
// Contact the server and print out its response.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if *embeddingInput != "" {
|
||||
runEmbedding(ctx, c)
|
||||
return
|
||||
}
|
||||
|
||||
req := &pb.ChatRequest{
|
||||
Model: *model,
|
||||
Messages: []*pb.Message{
|
||||
{Role: "user", Content: *prompt},
|
||||
},
|
||||
MaxTokens: int32(*maxTokens),
|
||||
Temperature: float32(*temperature),
|
||||
}
|
||||
|
||||
log.Printf("Sending request for model: %s, prompt: %q", *model, *prompt)
|
||||
|
||||
if *stream {
|
||||
runStream(ctx, c, req)
|
||||
} else {
|
||||
runChat(ctx, c, req)
|
||||
}
|
||||
}
|
||||
|
||||
func runChat(ctx context.Context, c pb.LiteLLMServiceClient, req *pb.ChatRequest) {
|
||||
resp, err := c.Chat(ctx, req)
|
||||
if err != nil {
|
||||
log.Fatalf("could not get chat response: %v", err)
|
||||
}
|
||||
log.Printf("Response: %s", resp.GetContent())
|
||||
}
|
||||
|
||||
func runStream(ctx context.Context, c pb.LiteLLMServiceClient, req *pb.ChatRequest) {
|
||||
stream, err := c.StreamChat(ctx, req)
|
||||
if err != nil {
|
||||
log.Fatalf("could not start stream: %v", err)
|
||||
}
|
||||
|
||||
log.Println("Stream started:")
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("streaming error: %v", err)
|
||||
}
|
||||
// Print content as it comes
|
||||
if content := resp.GetContent(); content != "" {
|
||||
fmt.Print(content)
|
||||
}
|
||||
}
|
||||
fmt.Println() // Ensure valid newline at the end
|
||||
log.Println("Stream finished.")
|
||||
}
|
||||
|
||||
func runEmbedding(ctx context.Context, c pb.LiteLLMServiceClient) {
|
||||
log.Printf("Sending embedding request for model: %s, input: %q", *model, *embeddingInput)
|
||||
resp, err := c.Embed(ctx, &pb.EmbeddingRequest{
|
||||
Model: *model,
|
||||
Inputs: []string{*embeddingInput},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("could not get embedding response: %v", err)
|
||||
}
|
||||
|
||||
for _, emb := range resp.GetEmbeddings() {
|
||||
log.Printf("Embedding %d (size: %d): [%.4f, %.4f, ...]", emb.GetIndex(), len(emb.GetValues()), emb.GetValues()[0], emb.GetValues()[1])
|
||||
}
|
||||
log.Printf("Usage: Prompt Tokens: %d, Total Tokens: %d", resp.GetUsage().GetPromptTokens(), resp.GetUsage().GetTotalTokens())
|
||||
}
|
||||
146
experimental/users/acmcarther/llm/litellm_grpc/server_main.py
Normal file
146
experimental/users/acmcarther/llm/litellm_grpc/server_main.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
from typing import AsyncIterable
|
||||
|
||||
import grpc
|
||||
import litellm
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from experimental.users.acmcarther.llm.litellm_grpc import litellm_pb2, litellm_pb2_grpc
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
api_base = os.getenv('OPENAI_API_BASE', 'http://192.168.0.235:1234/v1')
|
||||
api_key = os.getenv('OPENAI_API_KEY', 'lm-studio')
|
||||
|
||||
def get_safe(obj, attr, default=None):
|
||||
"""Safely get an attribute from an object or key from a dict."""
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(attr, default)
|
||||
return getattr(obj, attr, default)
|
||||
|
||||
class LiteLLMService(litellm_pb2_grpc.LiteLLMServiceServicer):
|
||||
def __init__(self):
|
||||
# Optional: Set defaults from env if needed
|
||||
pass
|
||||
|
||||
async def Chat(self, request, context):
|
||||
logger.info(f"Received Chat request for model: {request.model}")
|
||||
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
temperature=request.temperature if request.temperature else None,
|
||||
max_tokens=request.max_tokens if request.max_tokens > 0 else None,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
choice = response.choices[0]
|
||||
return litellm_pb2.ChatResponse(
|
||||
content=choice.message.content,
|
||||
role=choice.message.role,
|
||||
finish_reason=choice.finish_reason if hasattr(choice, 'finish_reason') else ""
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Chat: {e}")
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return litellm_pb2.ChatResponse()
|
||||
|
||||
async def StreamChat(self, request, context):
|
||||
logger.info(f"Received StreamChat request for model: {request.model}")
|
||||
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model=request.model,
|
||||
messages=messages,
|
||||
temperature=request.temperature if request.temperature else None,
|
||||
max_tokens=request.max_tokens if request.max_tokens > 0 else None,
|
||||
stream=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
if len(chunk.choices) > 0:
|
||||
delta = chunk.choices[0].delta
|
||||
content = delta.content if hasattr(delta, 'content') and delta.content else ""
|
||||
role = delta.role if hasattr(delta, 'role') and delta.role else ""
|
||||
# finish_reason might be on the choice, not delta
|
||||
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], 'finish_reason') else ""
|
||||
|
||||
if content or role or finish_reason:
|
||||
yield litellm_pb2.ChatResponse(
|
||||
content=content,
|
||||
role=role,
|
||||
finish_reason=finish_reason or ""
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in StreamChat: {e}")
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
|
||||
async def Embed(self, request, context):
|
||||
logger.info(f"Received Embed request for model: {request.model}")
|
||||
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model=request.model,
|
||||
input=request.inputs,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
logger.debug(f"Embedding response: {response}")
|
||||
|
||||
embeddings = []
|
||||
for data in response.data:
|
||||
embeddings.append(litellm_pb2.EmbeddingResponse.Embedding(
|
||||
values=get_safe(data, 'embedding'),
|
||||
index=get_safe(data, 'index')
|
||||
))
|
||||
|
||||
usage_obj = get_safe(response, 'usage', {})
|
||||
usage = litellm_pb2.EmbeddingResponse.Usage(
|
||||
prompt_tokens=get_safe(usage_obj, 'prompt_tokens', 0),
|
||||
total_tokens=get_safe(usage_obj, 'total_tokens', 0)
|
||||
)
|
||||
|
||||
return litellm_pb2.EmbeddingResponse(
|
||||
embeddings=embeddings,
|
||||
model=get_safe(response, 'model', ""),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Embed: {e}")
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(str(e))
|
||||
return litellm_pb2.EmbeddingResponse()
|
||||
|
||||
|
||||
async def serve():
|
||||
port = os.getenv("PORT", "50051")
|
||||
server = grpc.aio.server()
|
||||
litellm_pb2_grpc.add_LiteLLMServiceServicer_to_server(LiteLLMService(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
logger.info(f"Starting gRPC server on port {port}...")
|
||||
await server.start()
|
||||
await server.wait_for_termination()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(serve())
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
load("@rules_go//go:def.bzl", "go_binary", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "mcp_server_prototype_lib",
|
||||
srcs = ["main.go"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/mcp_server_prototype",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"@com_github_modelcontextprotocol_go_sdk//mcp",
|
||||
"@in_gopkg_yaml_v3//:yaml_v3",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "mcp_server_prototype",
|
||||
embed = [":mcp_server_prototype_lib"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "mcp_server_prototype_test",
|
||||
srcs = ["main_test.go"],
|
||||
data = [":mcp_server_prototype"],
|
||||
embed = [":mcp_server_prototype_lib"],
|
||||
tags = ["manual"],
|
||||
deps = ["@com_github_modelcontextprotocol_go_sdk//mcp"],
|
||||
)
|
||||
170
experimental/users/acmcarther/llm/mcp_server_prototype/main.go
Normal file
170
experimental/users/acmcarther/llm/mcp_server_prototype/main.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
projectRoot = flag.String("project-root", ".", "The absolute path to the project root directory.")
|
||||
listenAddr = flag.String("listen-addr", "", "If non-empty, listen on this HTTP address instead of using stdio.")
|
||||
)
|
||||
|
||||
// Arguments for the hello_world tool
|
||||
type HelloWorldArgs struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Handler for the hello_world tool
|
||||
func handleHelloWorld(ctx context.Context, session *mcp.ServerSession, req *mcp.CallToolParamsFor[HelloWorldArgs]) (*mcp.CallToolResultFor[any], error) {
|
||||
log.Println("Executing handleHelloWorld")
|
||||
return &mcp.CallToolResultFor[any]{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Hello, " + req.Arguments.Name},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Arguments for the invoke_agent tool
|
||||
type InvokeAgentArgs struct {
|
||||
TaskFile string `json:"task_file"`
|
||||
}
|
||||
|
||||
// TaskYAML defines the structure of the agent task file.
|
||||
type TaskYAML struct {
|
||||
Agent string `yaml:"agent"`
|
||||
Model string `yaml:"model"`
|
||||
Prompt string `yaml:"prompt"`
|
||||
ContextFiles []string `yaml:"context_files"`
|
||||
}
|
||||
|
||||
// Handler for the invoke_agent tool
|
||||
func handleInvokeAgent(ctx context.Context, session *mcp.ServerSession, req *mcp.CallToolParamsFor[InvokeAgentArgs]) (*mcp.CallToolResultFor[any], error) {
|
||||
log.Println("Executing handleInvokeAgent")
|
||||
// Ensure the task file path is absolute
|
||||
taskFile := req.Arguments.TaskFile
|
||||
if !filepath.IsAbs(taskFile) {
|
||||
taskFile = filepath.Join(*projectRoot, taskFile)
|
||||
}
|
||||
log.Printf("Invoking agent with task file: %s", taskFile)
|
||||
|
||||
// 1. Read and parse the task file
|
||||
log.Println("Reading and parsing task file...")
|
||||
yamlFile, err := os.ReadFile(taskFile)
|
||||
if err != nil {
|
||||
log.Printf("Error reading task file %s: %v", taskFile, err)
|
||||
return &mcp.CallToolResultFor[any]{
|
||||
Content: []mcp.Content{&mcp.TextContent{Text: "Error: " + err.Error()}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var task TaskYAML
|
||||
err = yaml.Unmarshal(yamlFile, &task)
|
||||
if err != nil {
|
||||
log.Printf("Error unmarshaling YAML from %s: %v", taskFile, err)
|
||||
return &mcp.CallToolResultFor[any]{
|
||||
Content: []mcp.Content{&mcp.TextContent{Text: "Error: " + err.Error()}},
|
||||
}, nil
|
||||
}
|
||||
log.Println("Task file parsed successfully.")
|
||||
|
||||
// 2. Construct the command to invoke the microagent, ensuring paths are absolute
|
||||
log.Println("Constructing agent invocation command...")
|
||||
scriptPath := filepath.Join(*projectRoot, "scripts/invoke_microagent.sh")
|
||||
args := []string{
|
||||
task.Agent,
|
||||
task.Prompt,
|
||||
"--model",
|
||||
task.Model,
|
||||
}
|
||||
for _, file := range task.ContextFiles {
|
||||
absFile := file
|
||||
if !filepath.IsAbs(absFile) {
|
||||
absFile = filepath.Join(*projectRoot, file)
|
||||
}
|
||||
args = append(args, "--context-file", absFile)
|
||||
}
|
||||
|
||||
log.Printf("Executing command: %s %v", scriptPath, args)
|
||||
cmd := exec.Command(scriptPath, args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Printf("Error invoking agent script: %v", err)
|
||||
log.Printf("Agent script output: %s", string(output))
|
||||
return &mcp.CallToolResultFor[any]{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: "Error invoking agent script: " + err.Error() + "\nOutput:\n" + string(output)},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Printf("Agent script output: %s", string(output))
|
||||
|
||||
// 3. Wrap the output in a "context firewall" to prevent identity bleed.
|
||||
firewalledOutput := fmt.Sprintf(
|
||||
"--- BEGIN SUB-AGENT OUTPUT (%s) ---\n%s\n--- END SUB-AGENT OUTPUT ---",
|
||||
task.Agent,
|
||||
string(output),
|
||||
)
|
||||
|
||||
return &mcp.CallToolResultFor[any]{
|
||||
Content: []mcp.Content{
|
||||
&mcp.TextContent{Text: firewalledOutput},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
log.Println("--- MCP Server Starting ---")
|
||||
log.Printf("Project Root: %s", *projectRoot)
|
||||
|
||||
log.Println("Creating new MCP server implementation...")
|
||||
serverImpl := &mcp.Implementation{Name: "prototype-server"}
|
||||
log.Println("Server implementation created.")
|
||||
|
||||
log.Println("Creating new MCP server...")
|
||||
server := mcp.NewServer(serverImpl, nil)
|
||||
log.Println("MCP server created.")
|
||||
|
||||
log.Println("Adding 'hello_world' tool...")
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "hello_world",
|
||||
Description: "A simple hello world tool for testing.",
|
||||
}, handleHelloWorld)
|
||||
log.Println("'hello_world' tool added.")
|
||||
|
||||
log.Println("Adding 'invoke_agent' tool...")
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "invoke_agent",
|
||||
Description: "Invokes an asynchronous agent task.",
|
||||
}, handleInvokeAgent)
|
||||
log.Println("'invoke_agent' tool added.")
|
||||
|
||||
if *listenAddr != "" {
|
||||
handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||
return server
|
||||
}, nil)
|
||||
log.Printf("MCP handler listening at %s", *listenAddr)
|
||||
if err := http.ListenAndServe(*listenAddr, handler); err != nil {
|
||||
log.Fatalf("HTTP server failed: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Println("Starting MCP server with stdio transport...")
|
||||
transport := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr)
|
||||
if err := server.Run(context.Background(), transport); err != nil {
|
||||
log.Fatalf("Server run loop failed: %v", err)
|
||||
}
|
||||
log.Println("Server run loop finished.")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Local definitions of MCP JSON-RPC structures for robust testing,
|
||||
// avoiding dependency on a specific SDK version which has proven unstable.
|
||||
|
||||
type Request struct {
|
||||
Version string `json:"jsonrpc"`
|
||||
ID string `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// Error defines the structure for a JSON-RPC error object.
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Version string `json:"jsonrpc"`
|
||||
ID string `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ListToolsResult struct {
|
||||
Tools []*mcp.Tool `json:"tools"`
|
||||
}
|
||||
|
||||
type CallToolResult struct {
|
||||
Content []json.RawMessage `json:"content"`
|
||||
}
|
||||
|
||||
type TextContent struct {
|
||||
Text string `json:"text"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// startTestServer starts the compiled server binary as a subprocess for testing.
|
||||
func startTestServer(t *testing.T) (*exec.Cmd, *bufio.Writer, *bufio.Reader) {
|
||||
t.Helper()
|
||||
path, ok := os.LookupEnv("TEST_SRCDIR")
|
||||
if !ok {
|
||||
t.Fatal("TEST_SRCDIR not set")
|
||||
}
|
||||
cmdPath := filepath.Join(path, "_main/experimental/mcp_server_prototype/mcp_server_prototype_/mcp_server_prototype")
|
||||
cmd := exec.Command(cmdPath)
|
||||
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get stdin pipe: %v", err)
|
||||
}
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get stdout pipe: %v", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Failed to start command: %v", err)
|
||||
}
|
||||
return cmd, bufio.NewWriter(stdinPipe), bufio.NewReader(stdoutPipe)
|
||||
}
|
||||
|
||||
// sendRequest marshals and sends a request object to the server's stdin.
|
||||
func sendRequest(t *testing.T, stdin *bufio.Writer, reqID, method string, params any) {
|
||||
t.Helper()
|
||||
req := Request{
|
||||
Version: "2.0",
|
||||
ID: reqID,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal request: %v", err)
|
||||
}
|
||||
if _, err := stdin.Write(append(reqBytes, '\n')); err != nil {
|
||||
t.Fatalf("Failed to write to stdin: %v", err)
|
||||
}
|
||||
if err := stdin.Flush(); err != nil {
|
||||
t.Fatalf("Failed to flush stdin: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// readResponse reads and unmarshals a response object from the server's stdout.
|
||||
func readResponse(t *testing.T, stdout *bufio.Reader) *Response {
|
||||
t.Helper()
|
||||
line, err := stdout.ReadBytes('\n')
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read from stdout: %v", err)
|
||||
}
|
||||
var resp Response
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %s\nError: %v", string(line), err)
|
||||
}
|
||||
return &resp
|
||||
}
|
||||
|
||||
func TestMCPWithSubprocess(t *testing.T) {
|
||||
cmd, stdin, stdout := startTestServer(t)
|
||||
defer cmd.Process.Kill()
|
||||
|
||||
// Step 1: Initialize the session
|
||||
initID := "init-1"
|
||||
sendRequest(t, stdin, initID, "mcp.initialize", &mcp.InitializeParams{})
|
||||
initResp := readResponse(t, stdout)
|
||||
if initResp.ID != initID {
|
||||
t.Fatalf("Expected init response ID '%s', got '%s'", initID, initResp.ID)
|
||||
}
|
||||
if initResp.Error != nil {
|
||||
t.Fatalf("Initialize failed: %v", initResp.Error)
|
||||
}
|
||||
|
||||
// Step 2: ListTools
|
||||
t.Run("ListTools", func(t *testing.T) {
|
||||
reqID := "list-tools-1"
|
||||
sendRequest(t, stdin, reqID, "mcp.listTools", &mcp.ListToolsParams{})
|
||||
|
||||
resp := readResponse(t, stdout)
|
||||
if resp.ID != reqID {
|
||||
t.Errorf("Expected response ID '%s', got '%s'", reqID, resp.ID)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("Received unexpected error: %v", resp.Error)
|
||||
}
|
||||
|
||||
var result ListToolsResult
|
||||
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", len(result.Tools))
|
||||
}
|
||||
if result.Tools[0].Name != "hello_world" {
|
||||
t.Errorf("Expected tool 0 to be 'hello_world', got '%s'", result.Tools[0].Name)
|
||||
}
|
||||
})
|
||||
|
||||
// Step 3: CallTool 'invoke_agent'
|
||||
t.Run("CallTool_InvokeAgent", func(t *testing.T) {
|
||||
reqID := "call-tool-1"
|
||||
taskFile := "/path/to/task.yaml"
|
||||
args := InvokeAgentArgs{TaskFile: taskFile}
|
||||
|
||||
sendRequest(t, stdin, reqID, "mcp.callTool", &mcp.CallToolParams{
|
||||
Name: "invoke_agent",
|
||||
Arguments: &args,
|
||||
})
|
||||
|
||||
resp := readResponse(t, stdout)
|
||||
if resp.ID != reqID {
|
||||
t.Errorf("Expected response ID '%s', got '%s'", reqID, resp.ID)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
t.Fatalf("Received unexpected error: %v", resp.Error)
|
||||
}
|
||||
|
||||
var result CallToolResult
|
||||
if err := json.Unmarshal(resp.Result, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Content) != 1 {
|
||||
t.Fatalf("Expected 1 content block, got %d", len(result.Content))
|
||||
}
|
||||
|
||||
var textContent TextContent
|
||||
if err := json.Unmarshal(result.Content[0], &textContent); err != nil {
|
||||
t.Fatalf("Failed to unmarshal text content: %v", err)
|
||||
}
|
||||
|
||||
expected := "Successfully received task_file: " + taskFile
|
||||
if textContent.Text != expected {
|
||||
t.Errorf("Expected output '%s', got '%s'", expected, textContent.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
33
experimental/users/acmcarther/llm/mlx/BUILD.bazel
Normal file
33
experimental/users/acmcarther/llm/mlx/BUILD.bazel
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary", "py_unpacked_wheel")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
|
||||
py_unpacked_wheel(
|
||||
name = "en_core_web_sm",
|
||||
src = "@spacy_en_core_web_sm//file",
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "mlx_testing_main",
|
||||
srcs = ["mlx_testing_main.py"],
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
deps = [
|
||||
":en_core_web_sm",
|
||||
requirement("mlx-audio"),
|
||||
requirement("mlx"),
|
||||
# Transitive dep of mlx-audio?
|
||||
requirement("soundfile"),
|
||||
requirement("sounddevice"),
|
||||
requirement("scipy"),
|
||||
requirement("loguru"),
|
||||
requirement("misaki"),
|
||||
requirement("num2words"),
|
||||
requirement("spacy"),
|
||||
requirement("huggingface_hub"),
|
||||
],
|
||||
)
|
||||
|
||||
py_pex_binary(
|
||||
name = "mlx_testing",
|
||||
binary = ":mlx_testing_main",
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
)
|
||||
22
experimental/users/acmcarther/llm/mlx/mlx_testing_main.py
Normal file
22
experimental/users/acmcarther/llm/mlx/mlx_testing_main.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from mlx_audio.tts.generate import generate_audio
|
||||
|
||||
def main():
|
||||
generate_audio(
|
||||
text=("White sparks cascaded onto the trembling wick. It was as if there were shooting stars in his hands, like the stars at the bottom of the grave to which Silk and Hyacinth had driven Orpine’s body in a dream he recalled with uncanny clarity. Here we dig holes in the ground for our dead, he thought, to bring them nearer the Outsider; and on Blue we do the same because we did it here, though it takes them away from him."),
|
||||
model_path="prince-canuma/Kokoro-82M",
|
||||
#voice="af_heart",
|
||||
voice="am_santa",
|
||||
#voice="am_echo",
|
||||
speed=1.2,
|
||||
lang_code="a", # Kokoro: (a)f_heart, or comment out for auto
|
||||
file_prefix="audiobook_chapter1",
|
||||
audio_format="wav",
|
||||
sample_rate=24000,
|
||||
join_audio=True,
|
||||
verbose=True # Set to False to disable print messages
|
||||
)
|
||||
print("Audiobook chapter successfully generated!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
155
experimental/users/acmcarther/llm/prompts/dependency_research.md
Normal file
155
experimental/users/acmcarther/llm/prompts/dependency_research.md
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Dependency Research Agent Prompt
|
||||
|
||||
## Task Overview
|
||||
You are tasked with analyzing the contents of a `package.json` file and conducting comprehensive research on each direct dependency and devDependency listed within it. Your goal is to produce a detailed report explaining what each package does, its purpose in web development/Node.js ecosystems, and provide credible references for all information presented.
|
||||
|
||||
## Input Format
|
||||
You will be provided with a `package.json` file. Extract and analyze:
|
||||
- All dependencies listed under the `"dependencies"` field
|
||||
- All devDependencies listed under the `"devDependencies"` field
|
||||
|
||||
## Research Requirements
|
||||
|
||||
For each dependency, you must research and provide:
|
||||
|
||||
### 1. Package Overview
|
||||
- **Name**: The exact package name from the manifest
|
||||
- **Version**: Current specified version (if provided)
|
||||
- **Primary Functionality**: What the package does in plain language
|
||||
- **Category**: Type of tool (e.g., build tool, testing framework, utility library, etc.)
|
||||
|
||||
### 2. Technical Details
|
||||
- **Purpose**: What specific problem does this package solve?
|
||||
- **Key Features**: Main capabilities and functionalities
|
||||
- **Ecosystem Role**: How it fits into the Node.js/web development ecosystem
|
||||
- **Usage Contexts**: Common scenarios where this package is typically used
|
||||
|
||||
### 3. Impact Assessment
|
||||
- **Industry Adoption**: How widely used is this package?
|
||||
- **Development Benefits**: What advantages does it provide to developers?
|
||||
- **Potential Concerns**: Any known issues, security considerations, or maintenance status
|
||||
- **Alternatives**: Brief mention of similar packages (if relevant)
|
||||
|
||||
### 4. References and Verification
|
||||
- **Official Documentation**: Link to official docs, GitHub repository, or website
|
||||
- **NPM Registry Information**: Official npm page and statistics
|
||||
- **Author/Maintainer Information**: Who maintains the package
|
||||
- **Community Resources**: Stack Overflow discussions, tutorials, or articles that explain its usage
|
||||
|
||||
## Output Format Requirements
|
||||
|
||||
### Report Structure
|
||||
```
|
||||
# Dependency Research Report - [Project Name]
|
||||
|
||||
## Executive Summary
|
||||
[Brief overview of the total dependencies analyzed, notable patterns, or key findings]
|
||||
|
||||
## Detailed Dependency Analysis
|
||||
|
||||
### Dependencies (Production)
|
||||
|
||||
#### [Package Name 1]
|
||||
- **Version**: [version]
|
||||
- **Primary Functionality**: [clear description]
|
||||
- **Technical Purpose**: [detailed explanation of what it does technically]
|
||||
- **Key Features**:
|
||||
- Feature 1
|
||||
- Feature 2
|
||||
- Feature 3
|
||||
- **Ecosystem Role**: [how it fits into the ecosystem]
|
||||
- **Development Impact**: [benefits and considerations]
|
||||
- **References**:
|
||||
- Official: [URL to docs/repo]
|
||||
- NPM: [npm page URL]
|
||||
- Additional: [relevant URLs]
|
||||
|
||||
#### [Package Name 2]
|
||||
[Repeat structure for each dependency]
|
||||
|
||||
### Dev Dependencies
|
||||
|
||||
#### [Dev Package Name 1]
|
||||
[Same structure as above]
|
||||
|
||||
## Summary Analysis
|
||||
- Total dependencies analyzed: [number]
|
||||
- Most common dependency types: [analysis of patterns]
|
||||
- Security considerations: [overall assessment]
|
||||
- Recommendations: [if applicable]
|
||||
```
|
||||
|
||||
## Quality Standards
|
||||
|
||||
### Research Depth
|
||||
1. **Minimum Viable Information**: Every package must have basic functionality explained clearly
|
||||
2. **Credible Sources Only**: Use official documentation, npm registry, GitHub repositories, established tech blogs
|
||||
3. **Technical Accuracy**: Ensure technical details are correct and up-to-date
|
||||
4. **Context Awareness**: Consider how the package fits into modern web development practices
|
||||
|
||||
### Reference Requirements
|
||||
- At least 2 credible sources per package (official documentation counts as one)
|
||||
- Prefer primary sources (official docs, GitHub) over secondary sources
|
||||
- Include both official resources and community validation when possible
|
||||
- All claims should be verifiable through provided references
|
||||
|
||||
### Writing Quality
|
||||
- Use clear, jargon-free language where possible
|
||||
- Maintain consistency in terminology throughout the report
|
||||
- Provide sufficient technical detail without being overwhelming
|
||||
- Include both high-level purpose and specific implementation details
|
||||
|
||||
## Research Process Guidelines
|
||||
|
||||
### Step 1: Initial Assessment
|
||||
- Start with the official npm page for basic information
|
||||
- Check the package's GitHub repository (if available) for documentation and README
|
||||
- Review recent commit activity to gauge maintenance status
|
||||
|
||||
### Step 2: Deep Research
|
||||
- Explore official documentation thoroughly
|
||||
- Look for usage examples and real-world implementations
|
||||
- Check for any security advisories or known issues
|
||||
- Assess the package's reputation and community adoption
|
||||
|
||||
### Step 3: Verification and Cross-referencing
|
||||
- Confirm information across multiple sources
|
||||
- Look for recent articles or discussions about the package's current relevance
|
||||
- Verify version compatibility and ecosystem position
|
||||
|
||||
### Step 4: Analysis Synthesis
|
||||
- Synthesize findings into coherent explanations
|
||||
- Identify patterns across dependencies where relevant
|
||||
- Provide context for why certain types of packages are commonly used together
|
||||
|
||||
## Special Considerations
|
||||
|
||||
### Popular vs. Niche Packages
|
||||
- **Popular Packages**: Focus on real-world impact, ecosystem integration, and community adoption metrics
|
||||
- **Niche/Specialized Packages**: Emphasize specific use cases and technical capabilities
|
||||
- **Security-Critical Packages**: Pay special attention to maintenance status, security track record, and alternatives
|
||||
|
||||
### Legacy vs. Modern Packages
|
||||
- **Legacy Packages**: Address maintenance status, deprecation risks, and modern alternatives
|
||||
- **Modern/Active Packages**: Emphasize current relevance and ongoing development
|
||||
|
||||
### Framework-Specific vs. Utility Packages
|
||||
- **Framework-Specific**: Explain integration with specific frameworks (React, Vue, Angular, etc.)
|
||||
- **Utility Packages**: Focus on general-purpose functionality and cross-framework applicability
|
||||
|
||||
## Completion Checklist
|
||||
Before finalizing your report, ensure:
|
||||
- [ ] Every dependency from both dependencies and devDependencies has been analyzed
|
||||
- [ ] All required sections are completed for each package
|
||||
- [ ] At least 2 credible references per dependency are provided
|
||||
- [ ] Technical explanations are accurate and up-to-date
|
||||
- [ ] Report structure follows the specified format exactly
|
||||
- [ ] Summary analysis provides overall insights about the dependency landscape
|
||||
- [ ] All links and references are functional and relevant
|
||||
|
||||
## Expected Deliverables
|
||||
1. **Complete Dependency Report**: Structured report covering all specified elements
|
||||
2. **Reference Compilation**: Separate list of all sources cited for verification purposes (if requested)
|
||||
3. **Key Insights Summary**: High-level observations about the project's dependency ecosystem
|
||||
|
||||
Remember: The goal is to provide actionable intelligence that helps developers understand not just what their dependencies are, but why they exist and what impact they have on the project
|
||||
20
experimental/users/acmcarther/llm/stt/BUILD.bazel
Normal file
20
experimental/users/acmcarther/llm/stt/BUILD.bazel
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary", "py_unpacked_wheel")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
|
||||
py_binary(
|
||||
name = "basic_recorder_main",
|
||||
srcs = ["basic_recorder.py"],
|
||||
deps = [
|
||||
requirement("sounddevice"),
|
||||
requirement("pyqt6"),
|
||||
requirement("pyqt6-qt6"),
|
||||
requirement("pyqt6-sip"),
|
||||
requirement("numpy"),
|
||||
requirement("scipy"),
|
||||
],
|
||||
)
|
||||
|
||||
py_pex_binary(
|
||||
name = "basic_recorder",
|
||||
binary = ":basic_recorder_main",
|
||||
)
|
||||
109
experimental/users/acmcarther/llm/stt/basic_recorder.py
Normal file
109
experimental/users/acmcarther/llm/stt/basic_recorder.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple PyQt6 microphone recorder with start/stop functionality.
|
||||
Requirements: pip install PyQt6 sounddevice numpy scipy
|
||||
"""
|
||||
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
from PyQt6. QtWidgets import (
|
||||
QApplication, QMainWindow, QVBoxLayout,
|
||||
QWidget, QPushButton, QLabel
|
||||
)
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
||||
|
||||
class MicrophoneRecorder(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Simple Mic Recorder")
|
||||
self.setGeometry(100, 100, 400, 200)
|
||||
|
||||
# Audio state
|
||||
self.is_recording = False
|
||||
self.audio_data = []
|
||||
|
||||
# Setup UI
|
||||
self.setup_ui()
|
||||
|
||||
def setup_ui(self):
|
||||
central_widget = QWidget()
|
||||
self.setCentralWidget(central_widget)
|
||||
layout = QVBoxLayout(central_widget)
|
||||
|
||||
# Status label
|
||||
self.status_label = QLabel("Ready to record")
|
||||
layout.addWidget(self.status_label)
|
||||
|
||||
# Record button
|
||||
self.record_button = QPushButton("Start Recording")
|
||||
self.record_button.clicked.connect(self.toggle_recording)
|
||||
layout.addWidget(self.record_button)
|
||||
|
||||
def toggle_recording(self):
|
||||
if not self.is_recording:
|
||||
self.start_recording()
|
||||
else:
|
||||
self.stop_recording()
|
||||
|
||||
def start_recording(self):
|
||||
"""Start audio stream in background thread"""
|
||||
self.is_recording = True
|
||||
self.audio_data = []
|
||||
|
||||
# Update UI
|
||||
self.record_button.setText("Stop Recording")
|
||||
self.status_label.setText("Recording... (click Stop when finished)")
|
||||
|
||||
def record_audio():
|
||||
with sd.InputStream(samplerate=44100, channels=1, callback=self.audio_callback):
|
||||
while self.is_recording:
|
||||
sd.sleep(100) # Keep thread alive
|
||||
|
||||
threading.Thread(target=record_audio, daemon=True).start()
|
||||
|
||||
def stop_recording(self):
|
||||
"""Stop recording"""
|
||||
self.is_recording = False
|
||||
self.record_button.setText("Start Recording")
|
||||
self.status_label.setText(f"Recording stopped ({len(self.audio_data)/44100:.1f}s)")
|
||||
|
||||
def audio_callback(self, indata, frames, time, status):
|
||||
"""Called by sounddevice for each audio chunk"""
|
||||
if self.is_recording:
|
||||
self.audio_data.extend(indata[:, 0]) # Mono channel
|
||||
|
||||
def save_recording(self):
|
||||
"""Save recorded audio to WAV file"""
|
||||
if not self.audio_data:
|
||||
return
|
||||
|
||||
# Convert and save as 16-bit WAV
|
||||
audio_array = np.array(self.audio_data, dtype=np.float32)
|
||||
audio_16bit = (audio_array * 32767).astype(np.int16)
|
||||
|
||||
filename = f"recording_{len(self.audio_data)}.wav"
|
||||
wavfile.write(filename, 44100, audio_16bit)
|
||||
|
||||
self.status_label.setText(f"Saved: {filename}")
|
||||
|
||||
|
||||
def main():
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
window = MicrophoneRecorder()
|
||||
window.show()
|
||||
|
||||
print("🎤 PyQt6 Mic Recorder")
|
||||
print("- Click 'Start Recording' to begin")
|
||||
print("- Click 'Stop Recording' when finished")
|
||||
|
||||
sys.exit(app.exec())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
63
experimental/users/acmcarther/llm/tts/BUILD.bazel
Normal file
63
experimental/users/acmcarther/llm/tts/BUILD.bazel
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary", "py_unpacked_wheel")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
|
||||
py_binary(
|
||||
name = "spacy_demo_main",
|
||||
srcs = ["spacy_demo.py"],
|
||||
deps = [
|
||||
"//third_party/python/spacy:en_core_web_sm",
|
||||
requirement("spacy"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "spacy_tts_pipeline_main",
|
||||
srcs = ["spacy_tts_pipeline.py"],
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
deps = [
|
||||
"//third_party/python/spacy:en_core_web_sm",
|
||||
requirement("spacy"),
|
||||
requirement("mlx"),
|
||||
requirement("mlx-audio"),
|
||||
# Transitive dep of mlx-audio?
|
||||
requirement("soundfile"),
|
||||
requirement("sounddevice"),
|
||||
requirement("scipy"),
|
||||
requirement("loguru"),
|
||||
requirement("misaki"),
|
||||
requirement("num2words"),
|
||||
requirement("huggingface_hub"),
|
||||
],
|
||||
)
|
||||
|
||||
py_pex_binary(
|
||||
name = "spacy_tts_pipeline",
|
||||
binary = ":spacy_tts_pipeline_main",
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "enhanced_tts_pipeline_main",
|
||||
srcs = ["enhanced_tts_pipeline.py"],
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
deps = [
|
||||
"//third_party/python/spacy:en_core_web_sm",
|
||||
requirement("spacy"),
|
||||
requirement("mlx"),
|
||||
requirement("mlx-audio"),
|
||||
# Transitive dep of mlx-audio?
|
||||
requirement("soundfile"),
|
||||
requirement("sounddevice"),
|
||||
requirement("scipy"),
|
||||
requirement("loguru"),
|
||||
requirement("misaki"),
|
||||
requirement("num2words"),
|
||||
requirement("huggingface_hub"),
|
||||
],
|
||||
)
|
||||
|
||||
py_pex_binary(
|
||||
name = "enhanced_tts_pipeline",
|
||||
binary = ":enhanced_tts_pipeline_main",
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
)
|
||||
297
experimental/users/acmcarther/llm/tts/enhanced_tts_pipeline.py
Normal file
297
experimental/users/acmcarther/llm/tts/enhanced_tts_pipeline.py
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Enhanced TTS Pipeline with SpaCy Emphasis Detection
|
||||
Integration with your existing Kokoro TTS pipeline
|
||||
"""
|
||||
|
||||
import spacy
|
||||
import re
|
||||
from typing import cast, List, Dict, Tuple
|
||||
from mlx_audio.tts.generate import generate_audio
|
||||
|
||||
class EnhancedTTSPipeline:
|
||||
"""Enhanced TTS pipeline with automatic emphasis detection and annotation"""
|
||||
|
||||
def __init__(self):
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
def enhance_text_with_emphasis(self, raw_text: str) -> str:
|
||||
"""Add emphasis annotations to text for improved TTS delivery"""
|
||||
|
||||
# 1. Run your existing SpaCy preprocessing
|
||||
doc = self.nlp(raw_text)
|
||||
|
||||
# 2. Detect emphasis points
|
||||
emphasis_points = self._detect_emphasis_annotations(raw_text)
|
||||
|
||||
# 3. Apply SSML emphasis annotations
|
||||
enhanced_text = self._apply_emphasis_to_text(raw_text, emphasis_points)
|
||||
|
||||
# 4. Return both versions for comparison
|
||||
return enhanced_text
|
||||
|
||||
def _detect_emphasis_annotations(self, text: str) -> List[Dict]:
|
||||
"""Detect all emphasis points in the text"""
|
||||
|
||||
doc = self.nlp(text)
|
||||
annotations = []
|
||||
|
||||
# KEY EMPHASIS PATTERNS FROM YOUR LITERARY TEXT:
|
||||
|
||||
# Pattern 1: Sensory Imagery - High Priority
|
||||
sensory_patterns = {
|
||||
'white sparks cascaded': {'type': 'visual_emphasis', 'level': 0.9},
|
||||
'trembling wick': {'type': 'tactile_emphasis', 'level': 0.8},
|
||||
'shooting stars in his hands': {'type': 'visual_cosmic', 'level': 0.95}
|
||||
}
|
||||
|
||||
for pattern, config in sensory_patterns.items():
|
||||
if self._pattern_exists_in_text(text.lower(), pattern):
|
||||
annotations.append({
|
||||
'text': pattern,
|
||||
'type': config['type'],
|
||||
'emphasis_level': config['level'],
|
||||
'ssml_tag': self._get_ssml_emphasis_tag(cast(float, config['level'])),
|
||||
'reason': f"sensory imagery: {config['type']}"
|
||||
})
|
||||
|
||||
# Pattern 2: Spiritual/Religious Content - High Priority
|
||||
spiritual_terms = ['grave', 'dead', 'outsider']
|
||||
|
||||
for token in doc:
|
||||
if any(term.lower() == token.lemma_.lower() for term in spiritual_terms):
|
||||
emphasis_level = 0.9 if token.text.lower().capitalize() == 'Outsider' else 0.8
|
||||
|
||||
annotations.append({
|
||||
'text': token.text,
|
||||
'type': 'spiritual_content',
|
||||
'emphasis_level': emphasis_level,
|
||||
'ssml_tag': self._get_ssml_emphasis_tag(emphasis_level),
|
||||
'reason': 'spiritual/metaphysical content requires emphasis'
|
||||
})
|
||||
|
||||
# Pattern 3: Metaphorical Language - High Priority
|
||||
metaphor_patterns = [
|
||||
'stars at the bottom of the grave',
|
||||
'dream he recalled with uncanny clarity',
|
||||
'like the stars at the bottom of the grave'
|
||||
]
|
||||
|
||||
for metaphor in metaphor_patterns:
|
||||
if self._pattern_exists_in_text(text.lower(), metaphor):
|
||||
annotations.append({
|
||||
'text': metaphor,
|
||||
'type': 'metaphorical_emphasis',
|
||||
'emphasis_level': 0.85,
|
||||
'ssml_tag': '<emphasis level="strong">',
|
||||
'reason': 'metaphorical/literary device'
|
||||
})
|
||||
|
||||
# Pattern 4: Complex Syntax - Medium Priority
|
||||
for sent in doc.sents:
|
||||
# Detect complex relative clauses
|
||||
has_relative_clause = any('which' in str(token) or 'that' in str(token)
|
||||
for token in sent if hasattr(token, '__str__'))
|
||||
|
||||
if has_relative_clause and len(sent.text) > 100:
|
||||
annotations.append({
|
||||
'text': sent.text.strip()[:50] + "...",
|
||||
'type': 'syntactic_emphasis',
|
||||
'emphasis_level': 0.7,
|
||||
'ssml_tag': '<pause time="500ms"/>',
|
||||
'reason': 'complex syntax - relative clause pause needed'
|
||||
})
|
||||
|
||||
return annotations
|
||||
|
||||
def _pattern_exists_in_text(self, text: str, pattern: str) -> bool:
|
||||
"""Check if pattern exists in text with some flexibility"""
|
||||
|
||||
# Direct match
|
||||
if pattern.lower() in text:
|
||||
return True
|
||||
|
||||
# Partial matches for complex patterns
|
||||
words = pattern.lower().split()
|
||||
if len(words) >= 2:
|
||||
return all(word in text for word in words[:2]) # First two words
|
||||
|
||||
return False
|
||||
|
||||
def _get_ssml_emphasis_tag(self, level: float) -> str:
|
||||
"""Convert emphasis level to SSML tag"""
|
||||
|
||||
if level >= 0.9:
|
||||
return '<emphasis level="strong">'
|
||||
elif level >= 0.8:
|
||||
return '<emphasis>'
|
||||
else:
|
||||
return '<!-- light emphasis -->' # Don't mark very low priority
|
||||
|
||||
def _apply_emphasis_to_text(self, original: str, annotations: List[Dict]) -> str:
|
||||
"""Apply emphasis annotations to create SSML-enhanced text"""
|
||||
|
||||
# Start with original text
|
||||
enhanced = original
|
||||
|
||||
# Apply annotations in reverse order (avoid index shifting)
|
||||
for ann in sorted(annotations, key=lambda x: original.lower().find(x['text'].lower()), reverse=True):
|
||||
|
||||
text_to_replace = ann['text']
|
||||
|
||||
if enhanced.lower().find(text_to_replace.lower()) != -1:
|
||||
# Find the position (case-insensitive)
|
||||
import re
|
||||
|
||||
pattern = re.escape(text_to_replace)
|
||||
match = re.search(pattern, enhanced, re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
start_idx = match.start()
|
||||
|
||||
# Insert SSML tags around the matched text
|
||||
before_text = enhanced[:start_idx]
|
||||
highlighted_text = ann['ssml_tag'] + match.group()
|
||||
|
||||
# Close tag - need to find where text actually ends in enhanced
|
||||
if ann['emphasis_level'] >= 0.8:
|
||||
close_tag = '</emphasis>'
|
||||
|
||||
# Add to highlighted text
|
||||
highlighted_text += close_tag
|
||||
|
||||
after_text = enhanced[start_idx + len(match.group()):]
|
||||
enhanced = before_text + highlighted_text + after_text
|
||||
|
||||
return f"<speak>{enhanced}</speak>"
|
||||
|
||||
def run_enhanced_pipeline(self, raw_text: str, output_prefix: str = "enhanced") -> None:
|
||||
"""Run complete enhanced TTS pipeline with emphasis detection"""
|
||||
|
||||
print("🚀 **ENHANCED TTS PIPELINE**")
|
||||
print(f"📝 Processing: {raw_text[:100]}...")
|
||||
|
||||
# Create both versions for comparison
|
||||
original_processed = self._preprocess_for_tts(raw_text)
|
||||
enhanced_with_emphasis = self.enhance_text_with_emphasis(original_processed)
|
||||
|
||||
print("\n🔄 **PROCESSING RESULTS:**\n")
|
||||
|
||||
# Show original (your existing pipeline)
|
||||
print("**ORIGINAL TTS TEXT:**")
|
||||
print(original_processed)
|
||||
|
||||
# Generate audio for original
|
||||
generate_audio(
|
||||
text=original_processed,
|
||||
model_path="prince-canuma/Kokoro-82M",
|
||||
voice="bm_george",
|
||||
speed=1.0,
|
||||
lang_code="b",
|
||||
file_prefix=f"{output_prefix}_original",
|
||||
audio_format="wav",
|
||||
sample_rate=24000,
|
||||
join_audio=True,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
print("\n**ENHANCED TTS WITH EMPHASIS:**")
|
||||
# Remove SSML tags for display but keep them in audio
|
||||
display_text = enhanced_with_emphasis.replace('<speak>', '').replace('</speak>', '')
|
||||
print(display_text[:200] + "..." if len(display_text) > 200 else display_text)
|
||||
|
||||
# Extract clean text for TTS (SSML tags might not be supported by Kokoro)
|
||||
import re
|
||||
clean_text = self._extract_clean_tts_text(enhanced_with_emphasis)
|
||||
|
||||
# Generate audio for enhanced version
|
||||
generate_audio(
|
||||
text=clean_text,
|
||||
model_path="prince-canuma/Kokoro-82M",
|
||||
voice="bm_george",
|
||||
speed=1.0,
|
||||
lang_code="b",
|
||||
file_prefix=f"{output_prefix}_enhanced",
|
||||
audio_format="wav",
|
||||
sample=24000,
|
||||
join_audio=True,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
print(f"\n✅ **AUDIO GENERATED:**")
|
||||
print(f" • Original: {output_prefix}_original.wav")
|
||||
print(f" • Enhanced: {output_prefix}_enhanced.wav")
|
||||
|
||||
def _extract_clean_tts_text(self, ssml_text: str) -> str:
|
||||
"""Extract clean text from SSML annotations for TTS generation"""
|
||||
|
||||
# Remove SSML tags but preserve the emphasis context in punctuation
|
||||
import re
|
||||
|
||||
# Handle strong emphasis - add period for dramatic pause
|
||||
ssml_text = re.sub(r'<emphasis level="strong">(.*?)</emphasis>',
|
||||
r'[\1](+2)', ssml_text)
|
||||
|
||||
# Handle moderate emphasis - add comma for slight pause
|
||||
ssml_text = re.sub(r'<emphasis>(.*?)</emphasis>',
|
||||
r'[\1](+1)', ssml_text)
|
||||
|
||||
# Remove remaining SSML tags
|
||||
clean = re.sub(r'<[^>]+>', '', ssml_text)
|
||||
|
||||
# Clean up spacing
|
||||
clean = re.sub(r'\.\s*\.\s*\.', '...', clean) # ... patterns
|
||||
clean = re.sub(r'\s+', ' ', clean).strip() # Extra spaces
|
||||
|
||||
return clean
|
||||
|
||||
def _preprocess_for_tts(self, raw_text: str) -> str:
|
||||
"""Your existing preprocessing logic"""
|
||||
|
||||
doc = self.nlp(raw_text)
|
||||
|
||||
enhanced_paragraphs = []
|
||||
for sentence in doc.sents:
|
||||
|
||||
# 1. Clean up common abbreviations
|
||||
text = sentence.text.strip()
|
||||
replacements = {
|
||||
"Dr.": "Doctor",
|
||||
"Mr.": "Mister",
|
||||
"Mrs.": "Misses"
|
||||
}
|
||||
|
||||
for abbrev, full_form in replacements.items():
|
||||
text = text.replace(abbrev, full_form)
|
||||
|
||||
# 2. Add natural pauses for better speech rhythm
|
||||
if "but" in text.lower() or "," in text:
|
||||
# Pause before conjunctions/arguments
|
||||
if "but" in text.lower():
|
||||
text = re.sub(r'\s+but\s+', ' ..., ... but ', text, flags=re.IGNORECASE)
|
||||
elif "," in text:
|
||||
# Natural pause at commas for complex phrases
|
||||
if "which" in text.lower():
|
||||
parts = text.split(",")
|
||||
if len(parts) >= 2:
|
||||
comma_enhanced = parts[0] + ",..." + "".join(parts[1:])
|
||||
text = comma_enhanced
|
||||
|
||||
enhanced_paragraphs.append(text)
|
||||
|
||||
return ". ".join(enhanced_paragraphs)
|
||||
|
||||
|
||||
# Usage example for your specific literary text
|
||||
if __name__ == "__main__":
|
||||
# Your literary passage
|
||||
literary_text = """White sparks cascaded onto the trembling wick. It was as if there were shooting stars in his hands, like the stars at the bottom of the grave to which Silk and Hyacinth had driven Orpine's body in a dream he recalled with uncanny clarity. Here we dig holes in the ground for our dead, he thought, to bring them nearer the Outsider; and on Blue we do the same because we did it here, though it takes them away from him."""
|
||||
literary_text_2 = """I had never seen war, or even talked of it at length with someone who had, but I was young and knew something of violence, and so believed that war would be no more than a new experience for me, as other things—the possession of authority in Thrax, say, or my escape from the House Absolute—had been new experiences. War is not a new experience; it is a new world. Its inhabitants are more different from human beings than [Famulimus](fa'mu'lie'mus) and her friends. Its laws are new, and even its geography is new, because it is a geography in which insignificant hills and hollows are lifted to the importance of cities. Just as our familiar Urth holds such monstrosities as Erebus, [Abaia](Ah-by-ya), and [Arioch](Ari-och), so the world of war is stalked by the monsters called battles, whose cells are individuals but who have a life and intelligence of their own, and whom one approaches through an ever-thickening array of portents."""
|
||||
literary_text_3 = """The executions I have seen performed and have performed myself so often are no more than a trade, a butchery of human beings who are for the most part less innocent and less valuable than cattle."""
|
||||
|
||||
# Initialize enhanced pipeline
|
||||
tts_pipeline = EnhancedTTSPipeline()
|
||||
|
||||
# Run the complete enhanced TTS pipeline
|
||||
tts_pipeline.run_enhanced_pipeline(literary_text, output_prefix="literary_passage")
|
||||
tts_pipeline.run_enhanced_pipeline(literary_text_2, output_prefix="literary_passage_2")
|
||||
tts_pipeline.run_enhanced_pipeline(literary_text_3, output_prefix="literary_passage_3")
|
||||
104
experimental/users/acmcarther/llm/tts/spacy_demo.py
Normal file
104
experimental/users/acmcarther/llm/tts/spacy_demo.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""
|
||||
SpaCy Feature Demonstration for TTS Preparation
|
||||
Shows specific NLP features and their TTS applications
|
||||
"""
|
||||
|
||||
import spacy
|
||||
|
||||
def demonstrate_spacy_features(text):
|
||||
"""Show different SpaCy features for TTS preparation"""
|
||||
|
||||
# Load the English model
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
doc = nlp(text)
|
||||
|
||||
print(f"📝 **Original Text**: {text}\n")
|
||||
|
||||
# 1. Tokenization - break text into manageable chunks
|
||||
print("🔤 **Token Analysis**")
|
||||
for i, token in enumerate(doc[:10]): # First 10 tokens
|
||||
print(f" {token.text:<12} | POS: {token.pos_:<6} | Lemma: {token.lemma_}")
|
||||
print()
|
||||
|
||||
# 2. Named Entity Recognition - identify important information
|
||||
print("🏷️ **Named Entities**")
|
||||
entities = [(ent.text, ent.label_, spacy.explain(ent.label_)) for ent in doc.ents]
|
||||
if entities:
|
||||
for text, label, description in entities[:5]: # First 5 entities
|
||||
print(f" {text:<15} | {label}: {description}")
|
||||
else:
|
||||
print(" No named entities found")
|
||||
print()
|
||||
|
||||
# 3. Part-of-Speech tagging - understand word types for pronunciation
|
||||
print("📊 **Part of Speech Tags**")
|
||||
pos_summary = {}
|
||||
for token in doc[:15]: # First 15 tokens
|
||||
tag_name = spacy.explain(token.pos_) or token.pos_
|
||||
pos_summary[tag_name] = pos_summary.get(tag_name, 0) + 1
|
||||
|
||||
for tag, count in list(pos_summary.items())[:5]:
|
||||
print(f" {tag}: {count} words")
|
||||
print()
|
||||
|
||||
# 4. Sentence boundary detection - natural pause points
|
||||
print("📄 **Sentence Structure**")
|
||||
for i, sent in enumerate(doc.sents):
|
||||
if len(sent.text.strip()) > 10: # Skip very short sentences
|
||||
print(f" Sentence {i+1}: {sent.text.strip()}")
|
||||
print()
|
||||
|
||||
def compare_preprocessing_approaches(text):
|
||||
"""Compare different TTS text preprocessing strategies"""
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
doc = nlp(text)
|
||||
|
||||
# Strategy 1: Abbreviation expansion only
|
||||
abbreviations_expanded = text.replace("Dr.", "Doctor").replace("Mr.", "Mister")
|
||||
|
||||
# Strategy 2: Sentence-based with pause optimization
|
||||
processed_sentences = []
|
||||
for sent in doc.sents:
|
||||
# Add natural pauses at conjunction boundaries
|
||||
processed_text = "".join([
|
||||
token.text_with_ws if not (token.pos_ == "CCONJ" and len(str(sent).strip()) > 20)
|
||||
else f", ... {token.text_with_ws}"
|
||||
for token in sent
|
||||
])
|
||||
processed_sentences.append(processed_text)
|
||||
|
||||
sentence_aware = "".join(processed_sentences)
|
||||
|
||||
# Strategy 3: Full semantic enhancement
|
||||
semantic_enhanced = text
|
||||
doc_processed = nlp(text)
|
||||
|
||||
# Identify potential pronunciation issues
|
||||
tricky_words = ["colonel", "choir", "restaurant"]
|
||||
enhanced_text = text
|
||||
for word in tricky_words:
|
||||
if word.lower() in text.lower():
|
||||
enhanced_text = enhanced_text.replace(word, f"<proper_pronunciation>{word}</proper_pronunciation>")
|
||||
|
||||
return {
|
||||
"original": text,
|
||||
"abbreviations_only": abbreviations_expanded,
|
||||
"sentence_aware": sentence_aware,
|
||||
"semantic_enhanced": enhanced_text
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Demo text with various TTS challenges
|
||||
demo_text = "Dr. Smith, who is a colonel in the USA army, went to his favorite restaurant with Mr. Johnson."
|
||||
|
||||
# Show detailed spaCy features
|
||||
demonstrate_spacy_features(demo_text)
|
||||
|
||||
# Compare different preprocessing strategies
|
||||
print("🔄 **Preprocessing Strategy Comparison**\n")
|
||||
results = compare_preprocessing_approaches(demo_text)
|
||||
|
||||
for approach, processed in results.items():
|
||||
print(f"**{approach.replace('_', ' ').title()}:**")
|
||||
print(f"{processed}\n")
|
||||
107
experimental/users/acmcarther/llm/tts/spacy_tts_pipeline.py
Normal file
107
experimental/users/acmcarther/llm/tts/spacy_tts_pipeline.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""
|
||||
SpaCy + Kokoro TTS Integration Pipeline
|
||||
Example demonstrating how to use SpaCy for text preprocessing before TTS conversion
|
||||
"""
|
||||
|
||||
import spacy
|
||||
from mlx_audio.tts.generate import generate_audio
|
||||
|
||||
def complete_tts_pipeline(raw_text):
|
||||
"""
|
||||
Complete pipeline: Raw text → SpaCy processing → Enhanced TTS-ready text
|
||||
"""
|
||||
|
||||
# Load spaCy model (assuming en_core_web_sm is installed)
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
# Process text with spaCy
|
||||
doc = nlp(raw_text)
|
||||
|
||||
enhanced_sentences = []
|
||||
|
||||
for sentence in doc.sents:
|
||||
# Process each sentence individually
|
||||
processed_sentence = process_single_sentence(sentence.text_with_ws, nlp)
|
||||
enhanced_sentences.append(processed_sentence)
|
||||
|
||||
# Join sentences with appropriate pauses
|
||||
final_text = ". ".join(enhanced_sentences)
|
||||
|
||||
return final_text
|
||||
|
||||
def process_single_sentence(sentence, nlp):
|
||||
"""Process a single sentence for TTS optimization"""
|
||||
|
||||
# 1. Clean up common abbreviations
|
||||
replacements = {
|
||||
"Dr.": "Doctor",
|
||||
"Mr.": "Mister",
|
||||
"Mrs.": "Misses",
|
||||
"vs": "versus",
|
||||
"&": "and"
|
||||
}
|
||||
|
||||
cleaned = sentence
|
||||
for abbrev, full_form in replacements.items():
|
||||
cleaned = cleaned.replace(abbrev, full_form)
|
||||
|
||||
# 2. Add natural pauses for better speech rhythm
|
||||
doc = nlp(cleaned)
|
||||
|
||||
# Identify clause boundaries (relative pronouns, conjunctions)
|
||||
pause_indicators = ["but", "however", "although", "while", "when", "where"]
|
||||
final_text = cleaned
|
||||
|
||||
for token in doc:
|
||||
if token.text.lower() in pause_indicators and token.pos_ == "CCONJ":
|
||||
# Add slight pause before conjunctions
|
||||
final_text = final_text.replace(f" {token.text} ", f",... {token.text} ")
|
||||
|
||||
return final_text
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Sample raw text that might be challenging for TTS
|
||||
sample_text = """
|
||||
Dr. Smith went to the store but forgot his wallet.
|
||||
The cat, which was sleeping on the mat, suddenly woke up.
|
||||
He vs his friend decided to go shopping & eat dinner.
|
||||
"""
|
||||
|
||||
processed = complete_tts_pipeline(sample_text)
|
||||
print("Original:")
|
||||
print(sample_text)
|
||||
print("\nProcessed for TTS:")
|
||||
print(processed)
|
||||
|
||||
generate_audio(
|
||||
text=(sample_text),
|
||||
model_path="prince-canuma/Kokoro-82M",
|
||||
#voice="af_heart",
|
||||
voice="am_santa",
|
||||
#voice="am_echo",
|
||||
speed=1.2,
|
||||
lang_code="a", # Kokoro: (a)f_heart, or comment out for auto
|
||||
file_prefix="original",
|
||||
audio_format="wav",
|
||||
sample_rate=24000,
|
||||
join_audio=True,
|
||||
verbose=True # Set to False to disable print messages
|
||||
)
|
||||
print("Original audio generated")
|
||||
|
||||
generate_audio(
|
||||
text=(processed),
|
||||
model_path="prince-canuma/Kokoro-82M",
|
||||
#voice="af_heart",
|
||||
voice="am_santa",
|
||||
#voice="am_echo",
|
||||
speed=1.2,
|
||||
lang_code="a", # Kokoro: (a)f_heart, or comment out for auto
|
||||
file_prefix="processed",
|
||||
audio_format="wav",
|
||||
sample_rate=24000,
|
||||
join_audio=True,
|
||||
verbose=True # Set to False to disable print messages
|
||||
)
|
||||
print("Original audio generated")
|
||||
141
experimental/users/acmcarther/llm/tts_grpc/BUILD.bazel
Normal file
141
experimental/users/acmcarther/llm/tts_grpc/BUILD.bazel
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_binary", "py_library", "py_pex_binary", "py_unpacked_wheel")
|
||||
load("@rules_go//go:def.bzl", "go_binary", "go_library")
|
||||
|
||||
# gazelle:proto disable
|
||||
load("@build_stack_rules_proto//rules:proto_compile.bzl", "proto_compile")
|
||||
load("@build_stack_rules_proto//rules/py:grpc_py_library.bzl", "grpc_py_library")
|
||||
load("@build_stack_rules_proto//rules/go:proto_go_library.bzl", "proto_go_library")
|
||||
load("@build_stack_rules_proto//rules/py:proto_py_library.bzl", "proto_py_library")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "tts_proto",
|
||||
srcs = ["tts.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
proto_compile(
|
||||
name = "tts_python_compile",
|
||||
outputs = [
|
||||
"tts_pb2.py",
|
||||
"tts_pb2.pyi",
|
||||
"tts_pb2_grpc.py",
|
||||
],
|
||||
plugins = [
|
||||
"@build_stack_rules_proto//plugin/builtin:pyi",
|
||||
"@build_stack_rules_proto//plugin/builtin:python",
|
||||
"@build_stack_rules_proto//plugin/grpc/grpc:protoc-gen-grpc-python",
|
||||
],
|
||||
proto = ":tts_proto",
|
||||
)
|
||||
|
||||
proto_py_library(
|
||||
name = "tts_proto_py_lib",
|
||||
srcs = ["tts_pb2.py"],
|
||||
deps = ["@com_google_protobuf//:protobuf_python"],
|
||||
)
|
||||
|
||||
grpc_py_library(
|
||||
name = "tts_grpc_py_library",
|
||||
srcs = ["tts_pb2_grpc.py"],
|
||||
deps = [
|
||||
":tts_py_library",
|
||||
"@pip_third_party//grpcio:pkg",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "tts_server_main",
|
||||
srcs = ["tts_server_main.py"],
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
deps = [
|
||||
":tts_grpc_py_library",
|
||||
":tts_proto_py_lib",
|
||||
"//third_party/python/spacy:en_core_web_sm",
|
||||
requirement("grpcio"),
|
||||
requirement("spacy"),
|
||||
requirement("asyncio"),
|
||||
requirement("mlx-audio"),
|
||||
# Transitive (mlx-audio tts)
|
||||
requirement("soundfile"),
|
||||
requirement("sounddevice"),
|
||||
requirement("kokoro"),
|
||||
requirement("num2words"),
|
||||
requirement("misaki"),
|
||||
requirement("espeakng-loader"),
|
||||
requirement("phonemizer-fork"),
|
||||
requirement("spacy-curated-transformers"),
|
||||
requirement("scipy"),
|
||||
],
|
||||
)
|
||||
|
||||
py_pex_binary(
|
||||
name = "tts_server",
|
||||
binary = ":tts_server_main",
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "tts_client_main",
|
||||
srcs = ["tts_client_main.py"],
|
||||
target_compatible_with = ["@platforms//os:macos"],
|
||||
deps = [
|
||||
":tts_grpc_py_library",
|
||||
":tts_proto_py_lib",
|
||||
requirement("asyncio"),
|
||||
requirement("sounddevice"),
|
||||
requirement("soundfile"),
|
||||
requirement("grpcio"),
|
||||
requirement("absl-py"),
|
||||
requirement("PyObjC"),
|
||||
],
|
||||
)
|
||||
|
||||
proto_compile(
|
||||
name = "tts_go_compile",
|
||||
output_mappings = [
|
||||
"tts.pb.go=forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/tts_grpc/tts.pb.go",
|
||||
"tts_grpc.pb.go=forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/tts_grpc/tts_grpc.pb.go",
|
||||
],
|
||||
outputs = [
|
||||
"tts.pb.go",
|
||||
"tts_grpc.pb.go",
|
||||
],
|
||||
plugins = [
|
||||
"@build_stack_rules_proto//plugin/golang/protobuf:protoc-gen-go",
|
||||
"@build_stack_rules_proto//plugin/grpc/grpc-go:protoc-gen-go-grpc",
|
||||
],
|
||||
proto = ":tts_proto",
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "tts_go_proto",
|
||||
srcs = [":tts_go_compile"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/tts_grpc",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//codes",
|
||||
"@org_golang_google_grpc//status",
|
||||
"@org_golang_google_protobuf//reflect/protoreflect",
|
||||
"@org_golang_google_protobuf//runtime/protoimpl",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "tts_client_go",
|
||||
srcs = ["main.go"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tts_go_proto",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//credentials/insecure",
|
||||
],
|
||||
)
|
||||
|
||||
proto_py_library(
|
||||
name = "tts_py_library",
|
||||
srcs = ["tts_pb2.py"],
|
||||
deps = ["@com_google_protobuf//:protobuf_python"],
|
||||
)
|
||||
198
experimental/users/acmcarther/llm/tts_grpc/main.go
Normal file
198
experimental/users/acmcarther/llm/tts_grpc/main.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
pb "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/tts_grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||
inputText = flag.String("text", "Hello: This is a test of the TTS Client", "Text to convert to speech")
|
||||
voiceModel = flag.String("voice", "bm_george", "Voice model to use")
|
||||
speakingRate = flag.Float64("rate", 1.0, "Speaking rate")
|
||||
useBytes = flag.Bool("use_bytes", false, "Use the bytes API (GenerateTTS) instead of local file API")
|
||||
stream = flag.Bool("stream", true, "Use the streaming API (GenerateTTSStream)")
|
||||
play = flag.Bool("play", true, "Play the generated audio using afplay")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
// Set up a connection to the server.
|
||||
conn, err := grpc.Dial(*addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewTTSServiceClient(conn)
|
||||
|
||||
// Contact the server and print out its response.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
log.Printf("Sending request: text=%q, voice=%q, rate=%f", *inputText, *voiceModel, *speakingRate)
|
||||
|
||||
var audioFile string
|
||||
|
||||
if *stream {
|
||||
audioFile, err = handleStream(ctx, c)
|
||||
if err != nil {
|
||||
log.Fatalf("streaming error: %v", err)
|
||||
}
|
||||
} else if *useBytes {
|
||||
resp, err := c.GenerateTTS(ctx, &pb.GenerateTTSRequest{
|
||||
InputText: inputText,
|
||||
VoiceModel: voiceModel,
|
||||
SpeakingRate: getFloat32Pointer(*speakingRate),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("could not generate TTS (bytes): %v", err)
|
||||
}
|
||||
log.Printf("TTS bytes received. Sample rate: %d, Size: %d bytes", resp.GetAudioSampleRate(), len(resp.GetAudioContent()))
|
||||
|
||||
audioFile, err = saveWav(resp.GetAudioContent(), int(resp.GetAudioSampleRate()))
|
||||
if err != nil {
|
||||
log.Fatalf("could not save WAV file: %v", err)
|
||||
}
|
||||
log.Printf("Saved audio to: %s", audioFile)
|
||||
|
||||
} else {
|
||||
r, err := c.GenerateTTSLocalFile(ctx, &pb.GenerateTTSLocalFileRequest{
|
||||
InputText: inputText,
|
||||
VoiceModel: voiceModel,
|
||||
SpeakingRate: getFloat32Pointer(*speakingRate),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("could not generate TTS (local file): %v", err)
|
||||
}
|
||||
audioFile = r.GetLocalTtsFilePath()
|
||||
log.Printf("TTS generated successfully. File path: %s", audioFile)
|
||||
}
|
||||
|
||||
if *play {
|
||||
log.Printf("Playing audio file: %s", audioFile)
|
||||
if err := playAudio(audioFile); err != nil {
|
||||
log.Printf("Failed to play audio: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleStream(ctx context.Context, c pb.TTSServiceClient) (string, error) {
|
||||
stream, err := c.GenerateTTSStream(ctx, &pb.GenerateTTSRequest{
|
||||
InputText: inputText,
|
||||
VoiceModel: voiceModel,
|
||||
SpeakingRate: getFloat32Pointer(*speakingRate),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("calling GenerateTTSStream: %w", err)
|
||||
}
|
||||
|
||||
var totalBytes []byte
|
||||
var sampleRate int32 = -1
|
||||
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("receiving stream: %w", err)
|
||||
}
|
||||
|
||||
if sampleRate == -1 {
|
||||
sampleRate = resp.GetAudioSampleRate()
|
||||
} else if sampleRate != resp.GetAudioSampleRate() {
|
||||
log.Printf("Warning: Sample rate changed mid-stream from %d to %d", sampleRate, resp.GetAudioSampleRate())
|
||||
}
|
||||
|
||||
chunk := resp.GetAudioContent()
|
||||
log.Printf("Received chunk: size=%d", len(chunk))
|
||||
totalBytes = append(totalBytes, chunk...)
|
||||
}
|
||||
|
||||
if sampleRate == -1 {
|
||||
return "", fmt.Errorf("received no data from stream")
|
||||
}
|
||||
|
||||
log.Printf("Stream finished. Total size: %d bytes, Sample rate: %d", len(totalBytes), sampleRate)
|
||||
return saveWav(totalBytes, int(sampleRate))
|
||||
}
|
||||
|
||||
func getFloat32Pointer(v float64) *float32 {
|
||||
f := float32(v)
|
||||
return &f
|
||||
}
|
||||
|
||||
func playAudio(filePath string) error {
|
||||
cmd := exec.Command("afplay", filePath)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// saveWav writes the raw float32 audio data to a WAV file with a temporary name.
|
||||
func saveWav(data []byte, sampleRate int) (string, error) {
|
||||
// Create a temporary file
|
||||
f, err := os.CreateTemp("", "tts_output_*.wav")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating temp file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// 1 channel, 32-bit float
|
||||
numChannels := 1
|
||||
bitsPerSample := 32
|
||||
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
|
||||
blockAlign := numChannels * (bitsPerSample / 8)
|
||||
audioFormat := 3 // IEEE Float
|
||||
|
||||
// Total data size
|
||||
dataSize := uint32(len(data))
|
||||
// File size - 8 bytes
|
||||
fileSize := 36 + dataSize
|
||||
|
||||
// Write WAV Header
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// RIFF header
|
||||
buf.WriteString("RIFF")
|
||||
binary.Write(buf, binary.LittleEndian, fileSize)
|
||||
buf.WriteString("WAVE")
|
||||
|
||||
// fmt subchunk
|
||||
buf.WriteString("fmt ")
|
||||
binary.Write(buf, binary.LittleEndian, uint32(16)) // Subchunk1Size
|
||||
binary.Write(buf, binary.LittleEndian, uint16(audioFormat))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(numChannels))
|
||||
binary.Write(buf, binary.LittleEndian, uint32(sampleRate))
|
||||
binary.Write(buf, binary.LittleEndian, uint32(byteRate))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(blockAlign))
|
||||
binary.Write(buf, binary.LittleEndian, uint16(bitsPerSample))
|
||||
|
||||
// data subchunk
|
||||
buf.WriteString("data")
|
||||
binary.Write(buf, binary.LittleEndian, dataSize)
|
||||
|
||||
// Write header to file
|
||||
if _, err := f.Write(buf.Bytes()); err != nil {
|
||||
return "", fmt.Errorf("writing header: %w", err)
|
||||
}
|
||||
|
||||
// Write audio data
|
||||
if _, err := f.Write(data); err != nil {
|
||||
return "", fmt.Errorf("writing data: %w", err)
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
33
experimental/users/acmcarther/llm/tts_grpc/tts.proto
Normal file
33
experimental/users/acmcarther/llm/tts_grpc/tts.proto
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package experimental.users.acmcarther.llm.tts_grpc;
|
||||
|
||||
option go_package = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/llm/tts_grpc";
|
||||
|
||||
// The request message for generating a TTS local file.
|
||||
message GenerateTTSLocalFileRequest {
|
||||
optional string input_text = 1;
|
||||
optional string voice_model = 2;
|
||||
optional float speaking_rate = 3;
|
||||
}
|
||||
|
||||
// The response message containing the TTS local file path and "enhanced" text.
|
||||
message GenerateTTSLocalFileResponse {
|
||||
optional string local_tts_file_path = 1;
|
||||
}
|
||||
|
||||
message GenerateTTSRequest {
|
||||
optional string input_text = 1;
|
||||
optional string voice_model = 2;
|
||||
optional float speaking_rate = 3;
|
||||
}
|
||||
|
||||
message GenerateTTSResponse {
|
||||
optional bytes audio_content = 1;
|
||||
optional int32 audio_sample_rate = 2;
|
||||
}
|
||||
|
||||
// TTSService defines a gRPC service for Text-to-Speech generation.
|
||||
service TTSService {
|
||||
rpc GenerateTTSLocalFile(GenerateTTSLocalFileRequest) returns (GenerateTTSLocalFileResponse) {}
|
||||
rpc GenerateTTS(GenerateTTSRequest) returns (GenerateTTSResponse) {}
|
||||
rpc GenerateTTSStream(GenerateTTSRequest) returns (stream GenerateTTSResponse) {}
|
||||
}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import asyncio
|
||||
from experimental.users.acmcarther.llm.tts_grpc import tts_pb2_grpc, tts_pb2
|
||||
import grpc
|
||||
from absl import app, flags, logging
|
||||
import soundfile # type: ignore
|
||||
import sounddevice # type: ignore
|
||||
import tempfile
|
||||
import os
|
||||
import io
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("text", None, "Input text to convert to speech.")
|
||||
flags.DEFINE_string("voice_model", "bm_george", "Voice model to use for TTS.")
|
||||
flags.DEFINE_float("speaking_rate", 1.0, "Speaking rate for TTS.")
|
||||
flags.DEFINE_bool("use_fileless_api", False, "Whether to use the fileless API.")
|
||||
|
||||
|
||||
def play_sound_from_bytes(audio_bytes: bytes, sample_rate: int):
|
||||
"""Play sound from bytes using sounddevice."""
|
||||
wav_buf = io.BytesIO(audio_bytes)
|
||||
wav_buf.name = "input.RAW"
|
||||
data, sample_rate = soundfile.read(
|
||||
wav_buf, channels=1, samplerate=sample_rate, subtype="FLOAT"
|
||||
)
|
||||
sounddevice.play(data, samplerate=sample_rate)
|
||||
sounddevice.wait()
|
||||
|
||||
|
||||
def play_sound_from_file(file_path: str):
|
||||
"""Play sound from file using sounddevice."""
|
||||
data, sample_rate = soundfile.read(file_path)
|
||||
sounddevice.play(data, samplerate=sample_rate)
|
||||
sounddevice.wait()
|
||||
|
||||
|
||||
async def run_tts_client():
|
||||
|
||||
async with grpc.aio.insecure_channel("localhost:50051") as channel:
|
||||
stub = tts_pb2_grpc.TTSServiceStub(channel)
|
||||
if FLAGS.use_fileless_api:
|
||||
response = await stub.GenerateTTS(
|
||||
tts_pb2.GenerateTTSRequest(
|
||||
input_text=FLAGS.text,
|
||||
voice_model=FLAGS.voice_model,
|
||||
speaking_rate=FLAGS.speaking_rate,
|
||||
)
|
||||
)
|
||||
# Play the audio
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: play_sound_from_bytes(
|
||||
response.audio_content, response.audio_sample_rate
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = await stub.GenerateTTSLocalFile(
|
||||
tts_pb2.GenerateTTSLocalFileRequest(
|
||||
input_text=FLAGS.text,
|
||||
voice_model=FLAGS.voice_model,
|
||||
speaking_rate=FLAGS.speaking_rate,
|
||||
)
|
||||
)
|
||||
print(
|
||||
"TTS client received local file path: " + response.local_tts_file_path
|
||||
)
|
||||
# Play the audio
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None, lambda: play_sound_from_file(response.local_tts_file_path)
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv # Unused
|
||||
if not FLAGS.text:
|
||||
logging.error("The --text flag is required.")
|
||||
return
|
||||
asyncio.run(run_tts_client())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
150
experimental/users/acmcarther/llm/tts_grpc/tts_server_main.py
Normal file
150
experimental/users/acmcarther/llm/tts_grpc/tts_server_main.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from experimental.users.acmcarther.llm.tts_grpc import tts_pb2_grpc, tts_pb2
|
||||
from mlx_audio.tts.generate import generate_audio
|
||||
from mlx_audio.tts.utils import load_model
|
||||
from typing import List, Dict, Tuple, AsyncGenerator
|
||||
import asyncio
|
||||
import grpc
|
||||
import re
|
||||
import spacy
|
||||
import uuid
|
||||
import soundfile # type: ignore
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineParams:
|
||||
text: str
|
||||
voice: str | None = None
|
||||
speaking_rate: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
audio_bytes: mx.array
|
||||
audio_sample_rate: int
|
||||
|
||||
|
||||
class TTSPipeline:
|
||||
nlp: spacy.Language
|
||||
tts_model: nn.Module
|
||||
|
||||
def __init__(self):
|
||||
# TODO: acmcarther@ - Perform some post-processing on text during pipeline execution.
|
||||
self.nlp = spacy.load("en_core_web_sm")
|
||||
self.tts_model = load_model("prince-canuma/Kokoro-82M")
|
||||
|
||||
async def run_pipeline(self, params: PipelineParams) -> PipelineResult:
|
||||
audio_list = []
|
||||
sample_rate = 24000 # Default fallback
|
||||
async for audio_chunk, sr in self.run_pipeline_stream(params):
|
||||
audio_list.append(audio_chunk)
|
||||
sample_rate = sr
|
||||
|
||||
if not audio_list:
|
||||
return PipelineResult(
|
||||
audio_bytes=mx.array([]), audio_sample_rate=sample_rate
|
||||
)
|
||||
|
||||
audio = mx.concatenate(audio_list, axis=0)
|
||||
return PipelineResult(audio_bytes=audio, audio_sample_rate=sample_rate)
|
||||
|
||||
async def run_pipeline_stream(
|
||||
self, params: PipelineParams
|
||||
) -> AsyncGenerator[Tuple[mx.array, int], None]:
|
||||
voice = params.voice or "bm_george"
|
||||
|
||||
# Note: self.tts_model.generate is a synchronous generator.
|
||||
# We wrap it or iterate it. Since MLX ops might be heavy, running in a thread might be better,
|
||||
# but for now we keep it simple.
|
||||
results = self.tts_model.generate(
|
||||
text=params.text,
|
||||
voice=voice,
|
||||
speed=params.speaking_rate or 1.0,
|
||||
lang_code=voice[0], # "am_santa" -> "a"
|
||||
audio_format="wav",
|
||||
sample=24000,
|
||||
)
|
||||
|
||||
for result in results:
|
||||
# Yield control to allow asyncio loop to run other tasks if needed,
|
||||
# though this loop itself is sync.
|
||||
# In a real heavy server, we'd run generation in an executor.
|
||||
yield result.audio, self.tts_model.sample_rate
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
class TTSService(tts_pb2_grpc.TTSServiceServicer):
|
||||
def __init__(self):
|
||||
self.tts_pipeline = TTSPipeline()
|
||||
|
||||
async def GenerateTTSLocalFile(self, request, context):
|
||||
params = PipelineParams(
|
||||
text=request.input_text,
|
||||
voice=request.voice_model,
|
||||
speaking_rate=request.speaking_rate,
|
||||
)
|
||||
result = await self.tts_pipeline.run_pipeline(params)
|
||||
|
||||
output_prefix = "/tmp/tts_output_" + str(uuid.uuid4())
|
||||
output_full_path = output_prefix + ".wav"
|
||||
soundfile.write(output_full_path, result.audio_bytes, result.audio_sample_rate)
|
||||
return tts_pb2.GenerateTTSLocalFileResponse(
|
||||
local_tts_file_path=output_full_path
|
||||
)
|
||||
|
||||
async def GenerateTTS(self, request, context):
|
||||
params = PipelineParams(
|
||||
text=request.input_text,
|
||||
voice=request.voice_model,
|
||||
speaking_rate=request.speaking_rate,
|
||||
)
|
||||
result = await self.tts_pipeline.run_pipeline(params)
|
||||
|
||||
return tts_pb2.GenerateTTSResponse(
|
||||
audio_content=memoryview(result.audio_bytes).tobytes(),
|
||||
audio_sample_rate=result.audio_sample_rate,
|
||||
)
|
||||
|
||||
async def GenerateTTSStream(self, request, context):
|
||||
params = PipelineParams(
|
||||
text=request.input_text,
|
||||
voice=request.voice_model,
|
||||
speaking_rate=request.speaking_rate,
|
||||
)
|
||||
|
||||
# 1MB Chunk size to safely stay under gRPC 4MB limit
|
||||
CHUNK_SIZE = 1024 * 1024
|
||||
|
||||
async for audio_chunk, sample_rate in self.tts_pipeline.run_pipeline_stream(
|
||||
params
|
||||
):
|
||||
# Convert mlx array to bytes
|
||||
audio_bytes = memoryview(audio_chunk).tobytes()
|
||||
|
||||
# Split into smaller chunks if necessary
|
||||
for i in range(0, len(audio_bytes), CHUNK_SIZE):
|
||||
chunk_data = audio_bytes[i : i + CHUNK_SIZE]
|
||||
yield tts_pb2.GenerateTTSResponse(
|
||||
audio_content=chunk_data, audio_sample_rate=sample_rate
|
||||
)
|
||||
|
||||
|
||||
async def serve():
|
||||
port = 50051
|
||||
server = grpc.aio.server()
|
||||
tts_pb2_grpc.add_TTSServiceServicer_to_server(TTSService(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
await server.start()
|
||||
print(f"gRPC server is running on port {port}...")
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
def main():
|
||||
asyncio.run(serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
36
experimental/users/acmcarther/temporal/BUILD.bazel
Normal file
36
experimental/users/acmcarther/temporal/BUILD.bazel
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
load("@gazelle//:def.bzl", "gazelle")
|
||||
load("@rules_go//go:def.bzl", "go_binary", "go_library")
|
||||
|
||||
gazelle(
|
||||
name = "gazelle",
|
||||
)
|
||||
|
||||
go_library(
|
||||
name = "helloworld",
|
||||
srcs = ["helloworld.go"],
|
||||
importpath = "forgejo.csbx.dev/acmcarther/yesod/experimental/users/acmcarther/temporal",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@io_temporal_go_sdk//activity",
|
||||
"@io_temporal_go_sdk//workflow",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "starter_main",
|
||||
srcs = ["starter_main.go"],
|
||||
deps = [
|
||||
":helloworld",
|
||||
"@io_temporal_go_sdk//client",
|
||||
],
|
||||
)
|
||||
|
||||
go_binary(
|
||||
name = "worker_main",
|
||||
srcs = ["worker_main.go"],
|
||||
deps = [
|
||||
":helloworld",
|
||||
"@io_temporal_go_sdk//client",
|
||||
"@io_temporal_go_sdk//worker",
|
||||
],
|
||||
)
|
||||
121
experimental/users/acmcarther/temporal/git_workflow/BUILD.bazel
Normal file
121
experimental/users/acmcarther/temporal/git_workflow/BUILD.bazel
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
load("@aspect_rules_py//py:defs.bzl", "py_image_layer")
|
||||
load("@pip_third_party//:requirements.bzl", "requirement")
|
||||
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_push")
|
||||
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")
|
||||
|
||||
py_library(
|
||||
name = "briefing",
|
||||
srcs = ["briefing.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "workspace",
|
||||
srcs = ["workspace.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [requirement("requests")],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "activities",
|
||||
srcs = ["activities.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":briefing",
|
||||
":workspace",
|
||||
requirement("temporalio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "workflow",
|
||||
srcs = ["workflow.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":activities",
|
||||
":briefing",
|
||||
requirement("temporalio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "worker",
|
||||
srcs = ["worker.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":activities",
|
||||
":workflow",
|
||||
requirement("temporalio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "trigger",
|
||||
srcs = ["trigger.py"],
|
||||
deps = [
|
||||
":briefing",
|
||||
":workflow",
|
||||
requirement("temporalio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "workspace_test",
|
||||
srcs = ["workspace_test.py"],
|
||||
deps = [":workspace"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "activities_test",
|
||||
srcs = ["activities_test.py"],
|
||||
deps = [
|
||||
":activities",
|
||||
":briefing",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "workflow_test",
|
||||
srcs = ["workflow_test.py"],
|
||||
deps = [
|
||||
":briefing",
|
||||
":workflow",
|
||||
requirement("pytest"),
|
||||
requirement("pytest-mock"),
|
||||
requirement("temporalio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "e2e_test",
|
||||
srcs = ["e2e_test.py"],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
requirement("requests"),
|
||||
requirement("temporalio"),
|
||||
],
|
||||
)
|
||||
|
||||
py_image_layer(
|
||||
name = "worker_binary_layer",
|
||||
binary = ":worker",
|
||||
)
|
||||
|
||||
oci_image(
|
||||
name = "image",
|
||||
base = "//k8s/container/coder-dev-base-image:noble",
|
||||
entrypoint = [
|
||||
"python3",
|
||||
"/experimental/users/acmcarther/temporal/git_workflow/worker",
|
||||
],
|
||||
tars = [
|
||||
":worker_binary_layer",
|
||||
],
|
||||
)
|
||||
|
||||
oci_push(
|
||||
name = "push",
|
||||
image = ":image",
|
||||
remote_tags = ["latest"],
|
||||
repository = "forgejo.csbx.dev/acmcarther/temporal-worker-image",
|
||||
)
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from temporalio import activity
|
||||
from experimental.users.acmcarther.temporal.git_workflow.workspace import Workspace
|
||||
from experimental.users.acmcarther.temporal.git_workflow.briefing import Briefing
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
@activity.defn(name="provision_workspace_activity")
|
||||
async def provision_workspace(briefing: Briefing) -> str:
|
||||
"""Provisions a workspace by cloning a repo and creating a branch."""
|
||||
if not briefing.repo_url or not briefing.branch_name:
|
||||
raise ValueError("repo_url and branch_name must be set in the briefing.")
|
||||
|
||||
workspace = Workspace(repo_url=briefing.repo_url, branch_name=briefing.branch_name)
|
||||
workspace_path = workspace.provision()
|
||||
return workspace_path
|
||||
|
||||
from typing import Dict
|
||||
|
||||
@activity.defn(name="apply_changes_in_workspace_activity")
|
||||
async def apply_changes_in_workspace(briefing: Briefing, files_to_create: Dict[str, str]):
|
||||
"""Applies changes to the workspace."""
|
||||
if not briefing.workspace_path or not briefing.branch_name:
|
||||
raise ValueError("workspace_path and branch_name must be set in the briefing.")
|
||||
|
||||
workspace = Workspace(branch_name=briefing.branch_name, path=Path(briefing.workspace_path))
|
||||
workspace.apply_changes(files_to_create)
|
||||
|
||||
@activity.defn(name="commit_and_push_changes_activity")
|
||||
async def commit_and_push_changes(briefing: Briefing, commit_message: str):
|
||||
"""Commits and pushes changes."""
|
||||
if not briefing.workspace_path or not briefing.branch_name:
|
||||
raise ValueError("workspace_path and branch_name must be set in the briefing.")
|
||||
|
||||
workspace = Workspace(branch_name=briefing.branch_name, path=Path(briefing.workspace_path))
|
||||
workspace.commit_and_push(commit_message)
|
||||
|
||||
@activity.defn(name="run_tests_activity")
|
||||
async def run_tests(briefing: Briefing):
|
||||
"""Runs tests in the workspace."""
|
||||
if not briefing.workspace_path or not briefing.branch_name or not briefing.tests_to_run:
|
||||
raise ValueError("workspace_path, branch_name, and tests_to_run must be set in the briefing.")
|
||||
|
||||
workspace = Workspace(branch_name=briefing.branch_name, path=Path(briefing.workspace_path))
|
||||
# For now, we assume a single test command. This could be extended to support multiple commands.
|
||||
workspace.run_tests(briefing.tests_to_run[0])
|
||||
|
||||
@activity.defn(name="create_pull_request_activity")
|
||||
async def create_pull_request(briefing: Briefing):
|
||||
"""Creates a pull request."""
|
||||
if not briefing.repo_url or not briefing.branch_name or not briefing.pr_title or not briefing.pr_body or not briefing.forgejo_token:
|
||||
raise ValueError("repo_url, branch_name, pr_title, pr_body, and forgejo_token must be set in the briefing.")
|
||||
|
||||
workspace = Workspace(repo_url=briefing.repo_url, branch_name=briefing.branch_name)
|
||||
return workspace.create_pull_request(briefing.pr_title, briefing.pr_body, briefing.forgejo_token)
|
||||
|
||||
@activity.defn(name="merge_pull_request_activity")
|
||||
async def merge_pull_request(briefing: Briefing, pr_number: int):
|
||||
"""Merges a pull request."""
|
||||
if not briefing.repo_url or not briefing.forgejo_token:
|
||||
raise ValueError("repo_url and forgejo_token must be set in the briefing.")
|
||||
|
||||
# We don't need the branch_name for this activity.
|
||||
workspace = Workspace(repo_url=briefing.repo_url, branch_name="dummy_branch")
|
||||
return workspace.merge_pull_request(pr_number, briefing.forgejo_token)
|
||||
|
||||
@activity.defn(name="cleanup_workspace_activity")
|
||||
async def cleanup_workspace(briefing: Briefing):
|
||||
"""Cleans up the workspace."""
|
||||
if not briefing.workspace_path or not briefing.branch_name:
|
||||
raise ValueError("workspace_path and branch_name must be set in the briefing.")
|
||||
|
||||
workspace = Workspace(branch_name=briefing.branch_name, path=Path(briefing.workspace_path))
|
||||
workspace.cleanup_workspace()
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
from experimental.users.acmcarther.temporal.git_workflow.briefing import Briefing
|
||||
from experimental.users.acmcarther.temporal.git_workflow.activities import (
|
||||
provision_workspace,
|
||||
apply_changes_in_workspace,
|
||||
commit_and_push_changes,
|
||||
run_tests,
|
||||
create_pull_request,
|
||||
merge_pull_request,
|
||||
cleanup_workspace,
|
||||
)
|
||||
|
||||
class ActivitiesTest(unittest.TestCase):
|
||||
|
||||
def test_provision_workspace_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
repo_url="https://test.com/repo.git",
|
||||
branch_name="feature/test",
|
||||
task_description="", pr_title="", pr_body="", forgejo_token="", tests_to_run=[], files_to_create={}
|
||||
)
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_instance.provision.return_value = "/workspace/feature_test"
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
result = asyncio.run(provision_workspace(briefing))
|
||||
|
||||
# Assert
|
||||
mock_workspace_class.assert_called_once_with(repo_url="https://test.com/repo.git", branch_name="feature/test")
|
||||
mock_workspace_instance.provision.assert_called_once()
|
||||
self.assertEqual(result, "/workspace/feature_test")
|
||||
|
||||
def test_apply_changes_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
workspace_path="/workspace/feature_test",
|
||||
branch_name="feature/test",
|
||||
files_to_create={"hello.txt": "world"},
|
||||
repo_url="", task_description="", pr_title="", pr_body="", forgejo_token="", tests_to_run=[]
|
||||
)
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
asyncio.run(apply_changes_in_workspace(briefing, briefing.files_to_create))
|
||||
|
||||
# Assert
|
||||
mock_workspace_instance.apply_changes.assert_called_once_with({"hello.txt": "world"})
|
||||
|
||||
def test_commit_and_push_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
workspace_path="/workspace/feature_test",
|
||||
branch_name="feature/test",
|
||||
repo_url="", task_description="", pr_title="", pr_body="", forgejo_token="", tests_to_run=[], files_to_create={}
|
||||
)
|
||||
commit_message = "A test commit"
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
asyncio.run(commit_and_push_changes(briefing, commit_message))
|
||||
|
||||
# Assert
|
||||
mock_workspace_instance.commit_and_push.assert_called_once_with(commit_message)
|
||||
|
||||
def test_run_tests_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
workspace_path="/workspace/feature_test",
|
||||
branch_name="feature/test",
|
||||
tests_to_run=["pytest ."],
|
||||
repo_url="", task_description="", pr_title="", pr_body="", forgejo_token="", files_to_create={}
|
||||
)
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
asyncio.run(run_tests(briefing))
|
||||
|
||||
# Assert
|
||||
mock_workspace_instance.run_tests.assert_called_once_with("pytest .")
|
||||
|
||||
def test_create_pull_request_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
repo_url="https://test.com/repo.git",
|
||||
branch_name="feature/test",
|
||||
pr_title="Test PR",
|
||||
pr_body="This is a test.",
|
||||
forgejo_token="fake-token",
|
||||
task_description="", tests_to_run=[], files_to_create={}
|
||||
)
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_instance.create_pull_request.return_value = {"html_url": "https://test.com/repo/pulls/1"}
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
result = asyncio.run(create_pull_request(briefing))
|
||||
|
||||
# Assert
|
||||
mock_workspace_instance.create_pull_request.assert_called_once_with("Test PR", "This is a test.", "fake-token")
|
||||
self.assertEqual(result, {"html_url": "https://test.com/repo/pulls/1"})
|
||||
|
||||
def test_merge_pull_request_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
repo_url="https://test.com/repo.git",
|
||||
forgejo_token="fake-token",
|
||||
branch_name="", task_description="", pr_title="", pr_body="", tests_to_run=[], files_to_create={}
|
||||
)
|
||||
pr_number = 123
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_instance.merge_pull_request.return_value = {"merged": True}
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
result = asyncio.run(merge_pull_request(briefing, pr_number))
|
||||
|
||||
# Assert
|
||||
mock_workspace_instance.merge_pull_request.assert_called_once_with(123, "fake-token")
|
||||
self.assertEqual(result, {"merged": True})
|
||||
|
||||
def test_cleanup_workspace_activity(self):
|
||||
# Arrange
|
||||
briefing = Briefing(
|
||||
workspace_path="/workspace/feature_test",
|
||||
branch_name="feature/test",
|
||||
repo_url="", task_description="", pr_title="", pr_body="", forgejo_token="", tests_to_run=[], files_to_create={}
|
||||
)
|
||||
|
||||
with patch("experimental.users.acmcarther.temporal.git_workflow.activities.Workspace") as mock_workspace_class:
|
||||
mock_workspace_instance = MagicMock()
|
||||
mock_workspace_class.return_value = mock_workspace_instance
|
||||
|
||||
# Act
|
||||
asyncio.run(cleanup_workspace(briefing))
|
||||
|
||||
# Assert
|
||||
mock_workspace_instance.cleanup_workspace.assert_called_once()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
@dataclass
|
||||
class Changeset:
|
||||
commit_message: str
|
||||
files_to_create: Dict[str, str]
|
||||
|
||||
@dataclass
|
||||
class Briefing:
|
||||
"""
|
||||
A dataclass to hold the state of the git workflow.
|
||||
"""
|
||||
repo_url: str
|
||||
branch_name: str
|
||||
task_description: str
|
||||
pr_title: str
|
||||
pr_body: str
|
||||
forgejo_token: str
|
||||
tests_to_run: list[str]
|
||||
# A list of changesets to be applied and committed.
|
||||
changesets: List[Changeset] = field(default_factory=list)
|
||||
# DEPRECATED: Use changesets instead.
|
||||
files_to_create: dict[str, str] = field(default_factory=dict)
|
||||
workspace_path: Optional[str] = None
|
||||
163
experimental/users/acmcarther/temporal/git_workflow/e2e_test.py
Normal file
163
experimental/users/acmcarther/temporal/git_workflow/e2e_test.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
|
||||
from temporalio.client import Client
|
||||
|
||||
# This is a placeholder for a more robust end-to-end test.
|
||||
# A full implementation would require a running Temporal worker and a way to
|
||||
# interact with the Forgejo API to verify the results of the workflow.
|
||||
|
||||
FORGEJO_URL = "https://forgejo.csbx.dev"
|
||||
REPO_OWNER = "gemini-thinker"
|
||||
API_URL = f"{FORGEJO_URL}/api/v1"
|
||||
LOG_FILE = "/home/coder/yesod/logs/e2e_test.log"
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
def get_forgejo_token():
|
||||
logging.info("Attempting to get Forgejo token for gemini-thinker...")
|
||||
# The e2e test runs as the 'gemini-thinker' agent to have the correct permissions.
|
||||
token_path = os.path.expanduser("/home/coder/yesod/ai/agents/gemini-thinker/.forgejo_token")
|
||||
if not os.path.exists(token_path):
|
||||
logging.error(f"Forgejo token not found at {token_path}")
|
||||
raise FileNotFoundError(f"Forgejo token not found at {token_path}")
|
||||
with open(token_path, "r") as f:
|
||||
token = f.read().strip()
|
||||
logging.info("Successfully retrieved Forgejo token.")
|
||||
return token
|
||||
|
||||
def create_test_repo(repo_name):
|
||||
"""Creates a new repository in Forgejo for testing."""
|
||||
logging.info(f"Creating test repository: {repo_name}")
|
||||
headers = {
|
||||
"Authorization": f"token {get_forgejo_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"name": repo_name,
|
||||
"private": False,
|
||||
"auto_init": True,
|
||||
"default_branch": "main",
|
||||
}
|
||||
url = f"{API_URL}/user/repos"
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
logging.info("Test repository created.")
|
||||
return response.json()
|
||||
|
||||
def delete_test_repo(repo_name):
|
||||
"""Deletes a repository from Forgejo."""
|
||||
logging.info(f"Deleting test repository: {repo_name}")
|
||||
headers = {"Authorization": f"token {get_forgejo_token()}"}
|
||||
url = f"{API_URL}/repos/{REPO_OWNER}/{repo_name}"
|
||||
response = requests.delete(url, headers=headers)
|
||||
if response.status_code != 204:
|
||||
logging.warning(f"Failed to delete repository {repo_name}: {response.text}")
|
||||
else:
|
||||
logging.info("Test repository deleted.")
|
||||
|
||||
async def test_git_workflow_end_to_end():
|
||||
"""
|
||||
Tests the full end-to-end GitWorkflow.
|
||||
"""
|
||||
repo_name = f"test-repo-e2e-{uuid.uuid4()}"
|
||||
logging.info(f"--- Starting E2E Git Workflow Test for repo: {repo_name} ---")
|
||||
|
||||
try:
|
||||
create_test_repo(repo_name)
|
||||
|
||||
# 1. Connect to Temporal
|
||||
logging.info("STEP 1: Connecting to Temporal...")
|
||||
client = await Client.connect("temporal-frontend.temporal.svc.cluster.local:7233", namespace="temporal-system")
|
||||
logging.info("STEP 1 COMPLETE: Connected to Temporal.")
|
||||
|
||||
# 2. Set up workflow parameters
|
||||
token = get_forgejo_token()
|
||||
remote_url_with_token = f"https://{REPO_OWNER}:{token}@{FORGEJO_URL.split('//')[1]}/{REPO_OWNER}/{repo_name}.git"
|
||||
branch_name = f"feature/test-e2e-{uuid.uuid4()}"
|
||||
pr_title = f"E2E Test PR {branch_name}"
|
||||
pr_body = "This is an end-to-end test PR."
|
||||
workflow_id = f"git-workflow-e2e-{uuid.uuid4()}"
|
||||
changesets = [
|
||||
{
|
||||
"commit_message": "Add new_file.txt",
|
||||
"files_to_create": {"new_file.txt": "Hello, World!"},
|
||||
},
|
||||
{
|
||||
"commit_message": "Add another_file.txt",
|
||||
"files_to_create": {"another_file.txt": "Hello, again!"},
|
||||
},
|
||||
]
|
||||
|
||||
briefing = {
|
||||
"task_description": "E2E Test",
|
||||
"repo_url": remote_url_with_token,
|
||||
"branch_name": branch_name,
|
||||
"pr_title": pr_title,
|
||||
"pr_body": pr_body,
|
||||
"forgejo_token": token,
|
||||
"changesets": changesets,
|
||||
"tests_to_run": [],
|
||||
}
|
||||
|
||||
# 3. Start the workflow
|
||||
logging.info(f"STEP 2: Starting workflow with ID: {workflow_id}")
|
||||
handle = await client.start_workflow(
|
||||
"GitWorkflow",
|
||||
briefing,
|
||||
id=workflow_id,
|
||||
task_queue="git-workflow-queue",
|
||||
)
|
||||
logging.info(f"STEP 2 COMPLETE: Workflow '{handle.id}' started.")
|
||||
|
||||
# 4. Wait for the workflow to be ready for approval (this is an approximation)
|
||||
logging.info("STEP 3: Waiting for workflow to create PR...")
|
||||
time.sleep(15) # Give it some time to create the PR
|
||||
|
||||
# 5. Send the approval signal
|
||||
logging.info(f"STEP 4: Sending 'approve' signal to workflow {handle.id}...")
|
||||
await handle.signal("approve")
|
||||
logging.info("STEP 4 COMPLETE: Signal sent.")
|
||||
|
||||
# 6. Wait for the workflow result
|
||||
logging.info("STEP 5: Waiting for workflow to complete...")
|
||||
result = await handle.result()
|
||||
logging.info(f"STEP 5 COMPLETE: Workflow finished with result: {result}")
|
||||
|
||||
# 7. Verify the files were created
|
||||
logging.info("STEP 6: Verifying file creation in the repository...")
|
||||
|
||||
# Verify first file
|
||||
file1_content_url = f"{FORGEJO_URL}/{REPO_OWNER}/{repo_name}/raw/branch/main/new_file.txt"
|
||||
response1 = requests.get(file1_content_url, timeout=5)
|
||||
response1.raise_for_status()
|
||||
assert response1.text == "Hello, World!"
|
||||
|
||||
# Verify second file
|
||||
file2_content_url = f"{FORGEJO_URL}/{REPO_OWNER}/{repo_name}/raw/branch/main/another_file.txt"
|
||||
response2 = requests.get(file2_content_url, timeout=5)
|
||||
response2.raise_for_status()
|
||||
assert response2.text == "Hello, again!"
|
||||
|
||||
logging.info("STEP 6 COMPLETE: Files verified.")
|
||||
|
||||
logging.info("--- E2E Test Completed Successfully ---")
|
||||
|
||||
finally:
|
||||
delete_test_repo(repo_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_git_workflow_end_to_end())
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
|
||||
from temporalio.client import Client
|
||||
|
||||
from experimental.users.acmcarther.temporal.git_workflow.briefing import Briefing, Changeset
|
||||
from experimental.users.acmcarther.temporal.git_workflow.workflow import GitWorkflow
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Connects to Temporal and starts the GitWorkflow."""
|
||||
parser = argparse.ArgumentParser(description="Trigger the GitWorkflow.")
|
||||
parser.add_argument("--task-description", required=True)
|
||||
parser.add_argument("--repo-url", required=True)
|
||||
parser.add_argument("--branch-name", required=True)
|
||||
parser.add_argument("--pr-title", required=True)
|
||||
parser.add_argument("--pr-body", required=True)
|
||||
parser.add_argument("--forgejo-token", required=True)
|
||||
parser.add_argument("--tests-to-run", nargs='*', default=[])
|
||||
parser.add_argument("--workflow-id", default=f"git-workflow-{uuid.uuid4()}")
|
||||
parser.add_argument("--changesets", type=str, default='[]', help='A JSON string of a list of changesets.')
|
||||
parser.add_argument("--signal-feedback", action="store_true", help="Send a feedback signal instead of starting a workflow.")
|
||||
parser.add_argument("--feedback-changeset", type=str, help="JSON string for the feedback changeset.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
client = await Client.connect("temporal-frontend.temporal.svc.cluster.local:7233", namespace="temporal-system")
|
||||
|
||||
if args.signal_feedback:
|
||||
if not args.workflow_id or not args.feedback_changeset:
|
||||
logging.error("--workflow-id and --feedback-changeset are required for signaling feedback.")
|
||||
return
|
||||
try:
|
||||
changeset_data = json.loads(args.feedback_changeset)
|
||||
changeset = Changeset(**changeset_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logging.error("Invalid JSON format for --feedback-changeset.")
|
||||
return
|
||||
|
||||
handle = client.get_workflow_handle(args.workflow_id)
|
||||
logging.info(f"Sending feedback to workflow {args.workflow_id}...")
|
||||
await handle.signal("incorporate_feedback", changeset)
|
||||
logging.info("Feedback signal sent.")
|
||||
return
|
||||
|
||||
logging.info(f"Workflow ID: {args.workflow_id}")
|
||||
|
||||
try:
|
||||
changesets_data = json.loads(args.changesets)
|
||||
changesets = [Changeset(**cs) for cs in changesets_data]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logging.error("Invalid JSON format for --changesets.")
|
||||
return
|
||||
|
||||
briefing = Briefing(
|
||||
task_description=args.task_description,
|
||||
repo_url=args.repo_url,
|
||||
branch_name=args.branch_name,
|
||||
pr_title=args.pr_title,
|
||||
pr_body=args.pr_body,
|
||||
forgejo_token=args.forgejo_token,
|
||||
tests_to_run=args.tests_to_run,
|
||||
changesets=changesets,
|
||||
)
|
||||
|
||||
workflow_id = args.workflow_id
|
||||
task_queue = "git-workflow-queue"
|
||||
|
||||
logging.info(f"Starting workflow {workflow_id} on task queue {task_queue}...")
|
||||
handle = await client.start_workflow(
|
||||
GitWorkflow.run,
|
||||
briefing,
|
||||
id=workflow_id,
|
||||
task_queue=task_queue,
|
||||
)
|
||||
logging.info(f"Workflow started with ID: {handle.id}")
|
||||
# Note: This script now only starts the workflow.
|
||||
# You can use other tools or scripts to signal it.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
|
||||
# Import the workflow and activities
|
||||
from experimental.users.acmcarther.temporal.git_workflow.workflow import GitWorkflow
|
||||
from experimental.users.acmcarther.temporal.git_workflow.activities import (
|
||||
provision_workspace,
|
||||
apply_changes_in_workspace,
|
||||
commit_and_push_changes,
|
||||
run_tests,
|
||||
create_pull_request,
|
||||
merge_pull_request,
|
||||
cleanup_workspace,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Creates and runs a Temporal worker for the GitWorkflow."""
|
||||
logging.info("Connecting to Temporal server...")
|
||||
client = await Client.connect(
|
||||
"temporal-frontend.temporal.svc.cluster.local:7233",
|
||||
namespace="temporal-system"
|
||||
)
|
||||
logging.info("Successfully connected to Temporal server.")
|
||||
|
||||
task_queue = "git-workflow-queue"
|
||||
logging.info(f"Creating worker for task queue: '{task_queue}'")
|
||||
|
||||
worker = Worker(
|
||||
client,
|
||||
task_queue=task_queue,
|
||||
workflows=[GitWorkflow],
|
||||
activities=[
|
||||
provision_workspace,
|
||||
apply_changes_in_workspace,
|
||||
commit_and_push_changes,
|
||||
run_tests,
|
||||
create_pull_request,
|
||||
merge_pull_request,
|
||||
cleanup_workspace,
|
||||
],
|
||||
)
|
||||
|
||||
logging.info(f"Worker created. Starting run...")
|
||||
await worker.run()
|
||||
logging.info("Worker stopped.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("Worker shutting down.")
|
||||
169
experimental/users/acmcarther/temporal/git_workflow/workflow.py
Normal file
169
experimental/users/acmcarther/temporal/git_workflow/workflow.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
from datetime import timedelta
|
||||
from temporalio import workflow
|
||||
from temporalio.exceptions import ActivityError
|
||||
import asyncio
|
||||
|
||||
# Import the activities
|
||||
from experimental.users.acmcarther.temporal.git_workflow.activities import (
|
||||
provision_workspace,
|
||||
apply_changes_in_workspace,
|
||||
commit_and_push_changes,
|
||||
run_tests,
|
||||
create_pull_request,
|
||||
merge_pull_request,
|
||||
cleanup_workspace,
|
||||
)
|
||||
from experimental.users.acmcarther.temporal.git_workflow.briefing import Briefing, Changeset
|
||||
from typing import Optional
|
||||
|
||||
@workflow.defn
|
||||
class GitWorkflow:
|
||||
def __init__(self):
|
||||
self._approved = False
|
||||
self._retry_tests_flag = False
|
||||
self._feedback_received = False
|
||||
self._feedback_changeset: Optional[Changeset] = None
|
||||
self._ready_for_approval = False
|
||||
|
||||
@workflow.signal
|
||||
def approve(self):
|
||||
self._approved = True
|
||||
|
||||
@workflow.signal
|
||||
def retry_tests(self):
|
||||
self._retry_tests_flag = True
|
||||
|
||||
@workflow.signal
|
||||
def incorporate_feedback(self, changeset: Changeset):
|
||||
self._feedback_received = True
|
||||
self._feedback_changeset = changeset
|
||||
|
||||
@workflow.query
|
||||
def is_ready_for_approval(self) -> bool:
|
||||
return self._ready_for_approval
|
||||
|
||||
@workflow.run
|
||||
async def run(self, briefing: Briefing) -> str:
|
||||
"""Executes the git workflow."""
|
||||
|
||||
try:
|
||||
# 1. Provision the workspace
|
||||
workspace_path = await workflow.execute_activity(
|
||||
provision_workspace,
|
||||
briefing,
|
||||
start_to_close_timeout=timedelta(minutes=15),
|
||||
)
|
||||
briefing.workspace_path = workspace_path
|
||||
|
||||
# 2. Apply changes and commit for each changeset
|
||||
if briefing.changesets:
|
||||
for changeset in briefing.changesets:
|
||||
await workflow.execute_activity(
|
||||
apply_changes_in_workspace,
|
||||
args=[briefing, changeset.files_to_create],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
await workflow.execute_activity(
|
||||
commit_and_push_changes,
|
||||
args=[briefing, changeset.commit_message],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
# DEPRECATED: Handle old-style files_to_create
|
||||
elif briefing.files_to_create:
|
||||
await workflow.execute_activity(
|
||||
apply_changes_in_workspace,
|
||||
args=[briefing, briefing.files_to_create],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
commit_message = f"Apply changes for: {briefing.task_description}"
|
||||
await workflow.execute_activity(
|
||||
commit_and_push_changes,
|
||||
args=[briefing, commit_message],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
|
||||
# 4. Run tests with retry logic
|
||||
while True:
|
||||
try:
|
||||
await workflow.execute_activity(
|
||||
run_tests,
|
||||
briefing,
|
||||
start_to_close_timeout=timedelta(minutes=10),
|
||||
)
|
||||
break # Tests passed, exit the loop
|
||||
except ActivityError as e:
|
||||
workflow.logger.warning(f"Tests failed: {e}. Waiting for retry signal.")
|
||||
await workflow.wait_condition(lambda: self._retry_tests_flag)
|
||||
self._retry_tests_flag = False # Reset for next potential failure
|
||||
|
||||
# 5. Create a pull request
|
||||
pr = await workflow.execute_activity(
|
||||
create_pull_request,
|
||||
briefing,
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
pr_number = pr.get("number")
|
||||
pr_url = pr.get("html_url")
|
||||
|
||||
self._ready_for_approval = True
|
||||
|
||||
# 6. Wait for feedback or approval
|
||||
while True:
|
||||
# Wait for either signal
|
||||
await workflow.wait_condition(
|
||||
lambda: self._approved or self._feedback_received
|
||||
)
|
||||
|
||||
if self._feedback_received and self._feedback_changeset:
|
||||
changeset = self._feedback_changeset
|
||||
self._feedback_received = False # Reset for the next feedback cycle
|
||||
self._feedback_changeset = None
|
||||
|
||||
# Apply and commit the feedback
|
||||
await workflow.execute_activity(
|
||||
apply_changes_in_workspace,
|
||||
args=[briefing, changeset.files_to_create],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
await workflow.execute_activity(
|
||||
commit_and_push_changes,
|
||||
args=[briefing, changeset.commit_message],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
|
||||
# Re-run tests after applying feedback
|
||||
while True:
|
||||
try:
|
||||
await workflow.execute_activity(
|
||||
run_tests,
|
||||
briefing,
|
||||
start_to_close_timeout=timedelta(minutes=10),
|
||||
)
|
||||
break # Tests passed
|
||||
except ActivityError as e:
|
||||
workflow.logger.warning(f"Tests failed after applying feedback: {e}. Waiting for retry signal.")
|
||||
await workflow.wait_condition(lambda: self._retry_tests_flag)
|
||||
self._retry_tests_flag = False
|
||||
|
||||
# Continue the loop to wait for more feedback or approval
|
||||
continue
|
||||
|
||||
if self._approved:
|
||||
break # Exit the loop to merge
|
||||
|
||||
# 7. Merge the pull request
|
||||
await workflow.execute_activity(
|
||||
merge_pull_request,
|
||||
args=[briefing, pr_number],
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
|
||||
return f"Pull request merged: {pr_url}"
|
||||
finally:
|
||||
# 8. Clean up the workspace
|
||||
if briefing.workspace_path:
|
||||
await workflow.execute_activity(
|
||||
cleanup_workspace,
|
||||
briefing,
|
||||
start_to_close_timeout=timedelta(minutes=2),
|
||||
)
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
from temporalio.worker import Worker
|
||||
from temporalio.exceptions import ActivityError
|
||||
|
||||
from experimental.users.acmcarther.temporal.git_workflow.workflow import GitWorkflow
|
||||
from experimental.users.acmcarther.temporal.git_workflow.briefing import Briefing, Changeset
|
||||
from experimental.users.acmcarther.temporal.git_workflow.activities import (
|
||||
provision_workspace,
|
||||
apply_changes_in_workspace,
|
||||
commit_and_push_changes,
|
||||
run_tests,
|
||||
create_pull_request,
|
||||
merge_pull_request,
|
||||
cleanup_workspace,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_workflow_success_path_single_commit(mocker):
|
||||
"""Tests the successful path of the GitWorkflow with a single commit."""
|
||||
|
||||
briefing = Briefing(
|
||||
repo_url="https://test.com/repo.git",
|
||||
branch_name="feature/test",
|
||||
task_description="Test the workflow",
|
||||
pr_title="Test PR",
|
||||
pr_body="This is a test.",
|
||||
forgejo_token="fake-token",
|
||||
tests_to_run=["pytest ."],
|
||||
files_to_create={"test.txt": "hello"},
|
||||
)
|
||||
|
||||
activities = [
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.provision_workspace", return_value="/mock/workspace/path"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.apply_changes_in_workspace"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.commit_and_push_changes"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.run_tests"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.create_pull_request", return_value={"number": 123, "html_url": "https://test.com/repo/pulls/1"}),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.merge_pull_request", return_value={"merged": True}),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.cleanup_workspace"),
|
||||
]
|
||||
|
||||
async with await WorkflowEnvironment.start_time_skipping() as env:
|
||||
async with Worker(
|
||||
env.client,
|
||||
task_queue="test-git-workflow",
|
||||
workflows=[GitWorkflow],
|
||||
activities=activities,
|
||||
):
|
||||
handle = await env.client.start_workflow(
|
||||
GitWorkflow.run,
|
||||
briefing,
|
||||
id="test-git-workflow-id",
|
||||
task_queue="test-git-workflow",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await handle.signal("approve")
|
||||
result = await handle.result()
|
||||
|
||||
assert result == "Pull request merged: https://test.com/repo/pulls/1"
|
||||
activities[1].assert_called_once()
|
||||
activities[2].assert_called_once()
|
||||
activities[-1].assert_called_once() # Check that cleanup was called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_workflow_success_path_multiple_commits(mocker):
|
||||
"""Tests the successful path of the GitWorkflow with multiple commits."""
|
||||
|
||||
changesets = [
|
||||
Changeset(commit_message="First commit", files_to_create={"a.txt": "1"}),
|
||||
Changeset(commit_message="Second commit", files_to_create={"b.txt": "2"}),
|
||||
]
|
||||
briefing = Briefing(
|
||||
repo_url="https://test.com/repo.git",
|
||||
branch_name="feature/test",
|
||||
task_description="Test the workflow",
|
||||
pr_title="Test PR",
|
||||
pr_body="This is a test.",
|
||||
forgejo_token="fake-token",
|
||||
tests_to_run=["pytest ."],
|
||||
changesets=changesets,
|
||||
)
|
||||
|
||||
activities = [
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.provision_workspace", return_value="/mock/workspace/path"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.apply_changes_in_workspace"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.commit_and_push_changes"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.run_tests"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.create_pull_request", return_value={"number": 123, "html_url": "https://test.com/repo/pulls/1"}),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.merge_pull_request", return_value={"merged": True}),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.cleanup_workspace"),
|
||||
]
|
||||
|
||||
async with await WorkflowEnvironment.start_time_skipping() as env:
|
||||
async with Worker(
|
||||
env.client,
|
||||
task_queue="test-git-workflow",
|
||||
workflows=[GitWorkflow],
|
||||
activities=activities,
|
||||
):
|
||||
handle = await env.client.start_workflow(
|
||||
GitWorkflow.run,
|
||||
briefing,
|
||||
id="test-git-workflow-id",
|
||||
task_queue="test-git-workflow",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await handle.signal("approve")
|
||||
result = await handle.result()
|
||||
|
||||
assert result == "Pull request merged: https://test.com/repo/pulls/1"
|
||||
assert activities[1].call_count == 2
|
||||
assert activities[2].call_count == 2
|
||||
activities[-1].assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_workflow_feedback_loop(mocker):
|
||||
"""Tests the feedback loop of the GitWorkflow."""
|
||||
|
||||
briefing = Briefing(
|
||||
repo_url="https://test.com/repo.git",
|
||||
branch_name="feature/test",
|
||||
task_description="Test the workflow",
|
||||
pr_title="Test PR",
|
||||
pr_body="This is a test.",
|
||||
forgejo_token="fake-token",
|
||||
tests_to_run=["pytest ."],
|
||||
files_to_create={"test.txt": "hello"},
|
||||
)
|
||||
|
||||
feedback_changeset = Changeset(
|
||||
commit_message="Incorporate feedback",
|
||||
files_to_create={"test.txt": "hello world"},
|
||||
)
|
||||
|
||||
run_tests_mock = mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.run_tests")
|
||||
|
||||
activities = [
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.provision_workspace", return_value="/mock/workspace/path"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.apply_changes_in_workspace"),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.commit_and_push_changes"),
|
||||
run_tests_mock,
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.create_pull_request", return_value={"number": 123, "html_url": "https://test.com/repo/pulls/1"}),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.merge_pull_request", return_value={"merged": True}),
|
||||
mocker.patch("experimental.users.acmcarther.temporal.git_workflow.activities.cleanup_workspace"),
|
||||
]
|
||||
|
||||
async with await WorkflowEnvironment.start_time_skipping() as env:
|
||||
async with Worker(
|
||||
env.client,
|
||||
task_queue="test-git-workflow",
|
||||
workflows=[GitWorkflow],
|
||||
activities=activities,
|
||||
):
|
||||
handle = await env.client.start_workflow(
|
||||
GitWorkflow.run,
|
||||
briefing,
|
||||
id="test-git-workflow-id",
|
||||
task_queue="test-git-workflow",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
await handle.signal("incorporate_feedback", feedback_changeset)
|
||||
await asyncio.sleep(0.1)
|
||||
await handle.signal("approve")
|
||||
|
||||
result = await handle.result()
|
||||
|
||||
assert result == "Pull request merged: https://test.com/repo/pulls/1"
|
||||
# apply_changes and commit_and_push are called once initially, and once for feedback
|
||||
assert activities[1].call_count == 2
|
||||
assert activities[2].call_count == 2
|
||||
# run_tests is called once initially, and once after feedback
|
||||
assert run_tests_mock.call_count == 2
|
||||
activities[-1].assert_called_once()
|
||||
166
experimental/users/acmcarther/temporal/git_workflow/workspace.py
Normal file
166
experimental/users/acmcarther/temporal/git_workflow/workspace.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import subprocess
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
import shutil
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
)
|
||||
|
||||
class Workspace:
|
||||
"""
|
||||
A class to encapsulate git operations for a workspace.
|
||||
"""
|
||||
|
||||
def __init__(self, branch_name: str, repo_url: Optional[str] = None, path: Optional[Path] = None):
|
||||
self.repo_url = repo_url
|
||||
self.branch_name = branch_name
|
||||
|
||||
if path:
|
||||
self.path = path
|
||||
else:
|
||||
workspace_base = "/workspace"
|
||||
# Sanitize branch name to be a valid directory name
|
||||
sanitized_branch_name = "".join(c if c.isalnum() else '_' for c in branch_name)
|
||||
self.path = Path(workspace_base) / sanitized_branch_name
|
||||
|
||||
def _run_command(self, *args, cwd=None):
|
||||
"""Runs a command and logs its output."""
|
||||
cwd = cwd or self.path
|
||||
logging.info(f"Running command: {' '.join(str(arg) for arg in args)}")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
args,
|
||||
cwd=cwd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
logging.info(result.stdout)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Command failed: {e.stderr}")
|
||||
raise
|
||||
|
||||
def provision(self):
|
||||
"""
|
||||
Clones a repository, creates a new branch, and pushes it to the remote.
|
||||
"""
|
||||
if not self.repo_url:
|
||||
raise ValueError("repo_url must be set to provision a workspace.")
|
||||
|
||||
logging.info(f"Provisioning workspace for branch '{self.branch_name}' from repo '{self.repo_url}'")
|
||||
|
||||
if self.path.exists():
|
||||
logging.info(f"Workspace path {self.path} already exists, removing it.")
|
||||
shutil.rmtree(self.path)
|
||||
|
||||
os.makedirs(self.path.parent, exist_ok=True)
|
||||
|
||||
# Clone the repository
|
||||
self._run_command(
|
||||
"git", "-c", "http.sslVerify=false", "clone", self.repo_url, str(self.path),
|
||||
cwd=self.path.parent
|
||||
)
|
||||
|
||||
# Create and check out the new branch
|
||||
self._run_command("git", "checkout", "-b", self.branch_name)
|
||||
|
||||
# Configure git user
|
||||
self._run_command("git", "config", "user.name", self.branch_name.split('/')[1])
|
||||
self._run_command("git", "config", "user.email", "gemini-prime@localhost")
|
||||
|
||||
# Push the new branch to the remote
|
||||
self._run_command("git", "-c", "http.sslVerify=false", "push", "-u", "origin", self.branch_name)
|
||||
|
||||
logging.info(f"Workspace created successfully at: {self.path}")
|
||||
return str(self.path)
|
||||
|
||||
def apply_changes(self, files_to_create: dict[str, str]):
|
||||
"""Creates or overwrites files in the workspace."""
|
||||
for filename, content in files_to_create.items():
|
||||
file_path = self.path / filename
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(file_path, "w") as f:
|
||||
f.write(content)
|
||||
logging.info(f"Created/updated file: {filename}")
|
||||
|
||||
def commit_and_push(self, commit_message: str):
|
||||
"""Commits all changes and pushes the branch to the remote."""
|
||||
self._run_command("git", "add", ".")
|
||||
self._run_command("git", "commit", "-m", commit_message)
|
||||
self._run_command("git", "-c", "http.sslVerify=false", "push", "origin", self.branch_name)
|
||||
logging.info("Committed and pushed changes.")
|
||||
|
||||
def run_tests(self, test_command: str):
|
||||
"""Runs a test command in the workspace."""
|
||||
logging.info(f"Running tests with command: {test_command}")
|
||||
# The command is expected to be a single string, so we split it.
|
||||
self._run_command(*test_command.split())
|
||||
logging.info("Tests passed.")
|
||||
|
||||
def create_pull_request(self, title: str, body: str, forgejo_token: str):
|
||||
"""Creates a pull request on Forgejo."""
|
||||
import requests
|
||||
if not self.repo_url:
|
||||
raise ValueError("repo_url must be set to create a pull request.")
|
||||
|
||||
# Extract owner and repo name from the repo_url
|
||||
# Example: https://forgejo.csbx.dev/gemini-thinker/test-repo.git
|
||||
parts = self.repo_url.split("/")
|
||||
owner = parts[-2]
|
||||
repo_name = parts[-1].replace(".git", "")
|
||||
|
||||
api_url = f"https://forgejo.csbx.dev/api/v1/repos/{owner}/{repo_name}/pulls"
|
||||
headers = {
|
||||
"Authorization": f"token {forgejo_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"head": self.branch_name,
|
||||
"base": "main", # Assuming 'main' is the default branch
|
||||
"title": title,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
logging.info(f"Creating pull request: {title}")
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
logging.info("Pull request created successfully.")
|
||||
return response.json()
|
||||
|
||||
def merge_pull_request(self, pr_number: int, forgejo_token: str):
|
||||
"""Merges a pull request on Forgejo."""
|
||||
import requests
|
||||
if not self.repo_url:
|
||||
raise ValueError("repo_url must be set to merge a pull request.")
|
||||
|
||||
parts = self.repo_url.split("/")
|
||||
owner = parts[-2]
|
||||
repo_name = parts[-1].replace(".git", "")
|
||||
|
||||
api_url = f"https://forgejo.csbx.dev/api/v1/repos/{owner}/{repo_name}/pulls/{pr_number}/merge"
|
||||
headers = {
|
||||
"Authorization": f"token {forgejo_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"Do": "merge",
|
||||
}
|
||||
|
||||
logging.info(f"Merging pull request #{pr_number}")
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
logging.info(f"Pull request #{pr_number} merged successfully.")
|
||||
return response.json()
|
||||
|
||||
def cleanup_workspace(self):
|
||||
"""Deletes the local workspace directory."""
|
||||
if self.path.exists():
|
||||
logging.info(f"Cleaning up workspace at {self.path}")
|
||||
shutil.rmtree(self.path)
|
||||
logging.info("Workspace cleaned up successfully.")
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
from experimental.users.acmcarther.temporal.git_workflow.workspace import Workspace
|
||||
|
||||
class WorkspaceTest(unittest.TestCase):
|
||||
|
||||
@patch("subprocess.run")
|
||||
@patch("os.makedirs")
|
||||
def test_provision_success(self, mock_makedirs, mock_subprocess_run):
|
||||
# Arrange
|
||||
repo_url = "https://forgejo.csbx.dev/gemini-thinker/test-repo.git"
|
||||
branch_name = "feature/new-thing"
|
||||
|
||||
# Mock the subprocess result
|
||||
mock_process = MagicMock()
|
||||
mock_process.returncode = 0
|
||||
mock_process.stdout = "Success"
|
||||
mock_process.stderr = ""
|
||||
mock_subprocess_run.return_value = mock_process
|
||||
|
||||
# Act
|
||||
workspace = Workspace(branch_name=branch_name, repo_url=repo_url)
|
||||
result_path = workspace.provision()
|
||||
|
||||
# Assert
|
||||
self.assertEqual(result_path, str(workspace.path))
|
||||
mock_makedirs.assert_called_once_with(workspace.path.parent, exist_ok=True)
|
||||
|
||||
self.assertEqual(mock_subprocess_run.call_count, 5)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "-c", "http.sslVerify=false", "clone", repo_url, str(workspace.path)),
|
||||
cwd=workspace.path.parent, check=True, capture_output=True, text=True
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "checkout", "-b", branch_name),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "config", "user.name", "new-thing"),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "config", "user.email", "gemini-prime@localhost"),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "-c", "http.sslVerify=false", "push", "-u", "origin", branch_name),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
@patch("pathlib.Path.mkdir")
|
||||
@patch("builtins.open")
|
||||
def test_apply_changes(self, mock_open, mock_mkdir):
|
||||
# Arrange
|
||||
workspace = Workspace(branch_name="branch")
|
||||
files_to_create = {"test.txt": "hello"}
|
||||
|
||||
# Act
|
||||
workspace.apply_changes(files_to_create)
|
||||
|
||||
# Assert
|
||||
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
||||
mock_open.assert_called_once_with(workspace.path / "test.txt", "w")
|
||||
mock_open.return_value.__enter__.return_value.write.assert_called_once_with("hello")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_commit_and_push(self, mock_subprocess_run):
|
||||
# Arrange
|
||||
workspace = Workspace(branch_name="branch")
|
||||
commit_message = "Test commit"
|
||||
|
||||
# Act
|
||||
workspace.commit_and_push(commit_message)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(mock_subprocess_run.call_count, 3)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "add", "."),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "commit", "-m", commit_message),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
("git", "-c", "http.sslVerify=false", "push", "origin", "branch"),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_run_tests(self, mock_subprocess_run):
|
||||
# Arrange
|
||||
workspace = Workspace(branch_name="branch")
|
||||
test_command = "pytest ."
|
||||
|
||||
# Act
|
||||
workspace.run_tests(test_command)
|
||||
|
||||
# Assert
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
("pytest", "."),
|
||||
cwd=workspace.path, check=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
@patch("requests.post")
|
||||
def test_create_pull_request(self, mock_post):
|
||||
# Arrange
|
||||
repo_url = "https://forgejo.csbx.dev/gemini-thinker/test-repo.git"
|
||||
workspace = Workspace(branch_name="feature/new-thing", repo_url=repo_url)
|
||||
title = "Test PR"
|
||||
body = "This is a test."
|
||||
token = "fake-token"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"id": 123}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = workspace.create_pull_request(title, body, token)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(result, {"id": 123})
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("requests.post")
|
||||
def test_merge_pull_request(self, mock_post):
|
||||
# Arrange
|
||||
repo_url = "https://forgejo.csbx.dev/gemini-thinker/test-repo.git"
|
||||
workspace = Workspace(branch_name="feature/new-thing", repo_url=repo_url)
|
||||
pr_number = 123
|
||||
token = "fake-token"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {"merged": True}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Act
|
||||
result = workspace.merge_pull_request(pr_number, token)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(result, {"merged": True})
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("shutil.rmtree")
|
||||
@patch("pathlib.Path.exists")
|
||||
def test_cleanup_workspace(self, mock_exists, mock_rmtree):
|
||||
# Arrange
|
||||
mock_exists.return_value = True
|
||||
workspace = Workspace(branch_name="branch")
|
||||
|
||||
# Act
|
||||
workspace.cleanup_workspace()
|
||||
|
||||
# Assert
|
||||
mock_exists.assert_called_once()
|
||||
mock_rmtree.assert_called_once_with(workspace.path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
37
experimental/users/acmcarther/temporal/helloworld.go
Normal file
37
experimental/users/acmcarther/temporal/helloworld.go
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
package helloworld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.temporal.io/sdk/activity"
|
||||
"go.temporal.io/sdk/workflow"
|
||||
)
|
||||
|
||||
// Workflow is a Hello World workflow definition.
|
||||
func Workflow(ctx workflow.Context, name string) (string, error) {
|
||||
ao := workflow.ActivityOptions{
|
||||
StartToCloseTimeout: 10 * time.Second,
|
||||
}
|
||||
ctx = workflow.WithActivityOptions(ctx, ao)
|
||||
|
||||
logger := workflow.GetLogger(ctx)
|
||||
logger.Info("HelloWorld workflow started", "name", name)
|
||||
|
||||
var result string
|
||||
err := workflow.ExecuteActivity(ctx, Activity, name).Get(ctx, &result)
|
||||
if err != nil {
|
||||
logger.Error("Activity failed.", "Error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
logger.Info("HelloWorld workflow completed.", "result", result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func Activity(ctx context.Context, name string) (string, error) {
|
||||
logger := activity.GetLogger(ctx)
|
||||
logger.Info("Activity", "name", name)
|
||||
return "Hello " + name + "!", nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue