Files
i2v/cmd/cli/cli.go

411 lines
11 KiB
Go

package main
import (
"bytes"
b64 "encoding/base64"
"encoding/json"
"fmt"
"image/draw"
"image"
"image/jpeg"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"os"
"path"
"strings"
"time"
"github.com/disintegration/imaging"
"github.com/urfave/cli/v2"
)
func main() {
commands := []*cli.Command{
getJobResultCmd(),
resizeCmd(),
}
app := &cli.App{
Name: "i2v",
Usage: "A command line tool to convert images to videos, based on Stability.ai's Rest API.",
Version: "v1.0.0",
Description: "i2v (Image to Video) is a command line tool to convert images to videos, based on Stability.ai's Rest API.",
Commands: commands,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "output",
Value: "./output.mp4",
Usage: "path to output file (should be a mp4)",
Aliases: []string{
"o",
},
},
&cli.StringFlag{
Name: "format",
Aliases: []string{"f"},
Usage: "The format to resize the image to wide, tall, or square",
Value: "wide",
},
},
Authors: []*cli.Author{
{
Name: "Mason Payne",
Email: "mason@masonitestudios.com",
},
},
Copyright: "2023 Masonite Studios LLC",
UseShortOptionHandling: true,
Action: func(c *cli.Context) error {
var needsCleanup bool
// use the first argument as the file name
fileLocation, err := getFileLocation(c)
if err != nil {
return fmt.Errorf("error getting file location | %w", err)
}
// preprocess the image to make sure it is the right size
/*
Supported Dimensions:
1024x576
576x1024
768x768
*/
// get the format
format := c.String("format")
if format != "wide" && format != "tall" && format != "square" {
return fmt.Errorf("invalid format %s", format)
}
// make sure the image is either jpeg or png
fileMimeType := mime.TypeByExtension(path.Ext(fileLocation))
if fileMimeType != "image/jpeg" && fileMimeType != "image/png" {
return fmt.Errorf("unsupported file type | %v", fileMimeType)
}
// resize the image
tempLocation, err := resizeImage(fileLocation, format)
if err != nil {
return fmt.Errorf("error resizing image | %w", err)
}
if tempLocation != fileLocation {
fmt.Printf("Resized image to %s\n", tempLocation)
needsCleanup = true
fileLocation = tempLocation // use the resized image
}
id, err := initiateGeneratingAnimation(fileLocation)
if err != nil {
if needsCleanup {
// remove the temp file
err = cleanUpTempFile(tempLocation)
if err != nil {
return fmt.Errorf("error removing temp file | %w", err)
}
}
return fmt.Errorf("error making request | %w", err)
}
fmt.Println("Video is being rendered, this may take a while.")
fmt.Printf("Job ID: %v\n", id)
// wait for the job to finish
err = job(id, c.String("output"))
if err != nil {
if needsCleanup {
// remove the temp file
err = cleanUpTempFile(tempLocation)
if err != nil {
return fmt.Errorf("error removing temp file | %w", err)
}
}
return fmt.Errorf("error getting job result | %w", err)
}
return nil
},
}
//fmt.Println("This package will be used for interacting with a running StormV2 service via a terminal or command line.")
err := app.Run(os.Args)
if err != nil {
fmt.Println(fmt.Errorf("error running app | %w", err))
return
}
}
func cleanUpTempFile(fileLocation string) error {
// remove the temp file
err := os.Remove(fileLocation)
if err != nil {
return fmt.Errorf("error removing temp file | %w", err)
}
return nil
}
func resizeImage(fileLocation string, format string) (string, error) {
// load the image
var img image.Image
img, err := imaging.Open(fileLocation, imaging.AutoOrientation(true))
if err != nil {
return "", fmt.Errorf("error opening image | %w", err)
}
// check if the image is already the correct size
rect := img.Bounds()
width := rect.Max.X - rect.Min.X
height := rect.Max.Y - rect.Min.Y
x1 := 1024
y1 := 576
if format == "wide" {
if width == 1024 && height == 576 {
fmt.Println("Image is already the correct size.")
return fileLocation, nil
}
}
if format == "tall" {
if width == 576 && height == 1024 {
fmt.Println("Image is already the correct size.")
return fileLocation, nil
}
x1 = 576
y1 = 1024
}
if format == "square" {
if width == 768 && height == 768 {
fmt.Println("Image is already the correct size.")
return fileLocation, nil
}
x1 = 768
y1 = 768
}
croppedImage := imaging.Fill(img, x1, y1, imaging.Center, imaging.Lanczos)
// save the image to a temp file
tempFile, err := os.CreateTemp("", "i2v*"+path.Ext(fileLocation))
if err != nil {
return "", fmt.Errorf("error creating temp file | %w", err)
}
defer tempFile.Close()
// encode the image to the temp file
err = jpeg.Encode(tempFile, croppedImage, nil)
if err != nil {
return "", fmt.Errorf("error encoding resized image | %w", err)
}
// return the temp file location
return tempFile.Name(), nil
}
func convertToRGBA(img image.Image) *image.RGBA {
bounds := img.Bounds()
rgba := image.NewRGBA(bounds)
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
return rgba
}
func getFileLocation(c *cli.Context) (string, error) {
fileLocation := c.Args().Get(0)
// if no argument is provided, use stdin
if fileLocation == "" {
fmt.Println("No file provided, using stdin.")
// read from stdin
var err error
fileLocationBytes, err := io.ReadAll(os.Stdin)
if err != nil {
return "", fmt.Errorf("error reading from stdin | %w", err)
}
fileLocation = string(fileLocationBytes)
}
return fileLocation, nil
}
func initiateGeneratingAnimation(fileLocation string) (string, error) {
// get base filename from file location
filename := path.Base(fileLocation)
// Create a buffer to store the request body
bodyBuf := &bytes.Buffer{}
// Create a multipart writer with the buffer
writer := multipart.NewWriter(bodyBuf)
// Add form fields
err := writer.WriteField("seed", "0")
if err != nil {
return "", fmt.Errorf("error writing form field seed | %w", err)
}
err = writer.WriteField("cfg_scale", "2.5")
if err != nil {
return "", fmt.Errorf("error writing form field cfg_scale | %w", err)
}
err = writer.WriteField("motion_bucket_id", "40")
if err != nil {
return "", fmt.Errorf("error writing form field motion_bucket_id | %w", err)
}
// Add a file
file, err := os.Open(fileLocation)
if err != nil {
return "", fmt.Errorf("error opening file | %w", err)
}
defer file.Close()
// Create a form file part
//filePart, err := writer.CreateFormFile("image", filename)
filePart, err := CreateImageFormFile(writer, filename)
if err != nil {
return "", fmt.Errorf("error creating form file | %w", err)
}
// Copy the file content to the form file part
_, err = io.Copy(filePart, file)
if err != nil {
return "", fmt.Errorf("error copying file content | %w", err)
}
// Close the multipart writer
writer.Close()
url := "https://api.stability.ai/v2alpha/generation/image-to-video"
// Create the HTTP request
req, err := http.NewRequest("POST", url, bodyBuf)
if err != nil {
return "", fmt.Errorf("error creating request | %w", err)
}
// Set the Content-Type header
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Add("authorization", "Bearer "+os.Getenv("STABILITY_API_KEY"))
// Make the HTTP request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error making request | %w", err)
}
defer resp.Body.Close()
// Print the response status and body
fmt.Println("Status:", resp.Status)
fmt.Println("Body:", resp.Body)
//bodyBytes, err := io.ReadAll(resp.Body)
//if err != nil {
// return "", fmt.Errorf("error reading response body | %w", err)
//}
//bodyString := string(bodyBytes)
//fmt.Println(bodyString)
res := struct {
ID string `json:"id"`
Name string `json:"name"`
Errors []string `json:"errors"`
}{}
// decode response body
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return "", fmt.Errorf("error decoding response body | %w", err)
}
// print response body
if res.Name != "" {
return "", fmt.Errorf("error generating animation | %v: Errors: %v", res.Name, strings.Join(res.Errors, ", "))
}
return res.ID, nil
}
func CreateImageFormFile(w *multipart.Writer, filename string) (io.Writer, error) {
fileMimeType := mime.TypeByExtension(path.Ext(filename))
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "image", filename))
h.Set("Content-Type", fileMimeType)
return w.CreatePart(h)
}
func getJobResult(jobID string) (string, bool, error) {
url := fmt.Sprintf("https://api.stability.ai/v2alpha/generation/image-to-video/result/%s", jobID)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", false, fmt.Errorf("error creating request | %w", err)
}
req.Header.Add("authorization", "Bearer "+os.Getenv("STABILITY_API_KEY"))
req.Header.Add("accept", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", false, fmt.Errorf("error making request | %w", err)
}
defer resp.Body.Close()
res := struct {
Video string `json:"video"`
FinishReason string `json:"finishReason"`
Seed int64 `json:"seed"`
ID string `json:"id"`
Status string `json:"status"`
Name string `json:"name"`
Errors []string `json:"errors"`
}{}
err = json.NewDecoder(resp.Body).Decode(&res)
if err != nil {
return "", false, fmt.Errorf("error unmarshalling response body | %w", err)
}
if res.Name != "" {
return "", false, fmt.Errorf("error generating animation | %v: Errors: %v", res.Name, strings.Join(res.Errors, ", "))
}
if res.Status == "in-progress" {
return "", false, nil
}
return res.Video, true, nil
}
func job(id, outputLocation string) error {
var video string
var finished bool
var err error
// poll the job result until it is finished
for {
video, finished, err = getJobResult(id)
if err != nil {
return fmt.Errorf("error getting job result | %w", err)
}
if finished {
fmt.Println("Video has completed rendering.")
break
}
time.Sleep(5 * time.Second)
}
// decode the video, it is in base64 and is expected to be a mp4
decodedVideo, err := b64.StdEncoding.DecodeString(video)
if err != nil {
return fmt.Errorf("error decoding video | %w", err)
}
// write the video to the current directory
err = os.WriteFile(outputLocation, decodedVideo, 0644)
if err != nil {
return fmt.Errorf("error writing video to file | %w", err)
}
fmt.Println("Video has been saved to ", outputLocation)
return nil
}