470 lines
12 KiB
Go
470 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
b64 "encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image/draw"
|
|
|
|
"image"
|
|
"image/jpeg"
|
|
"image/png"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/textproto"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/nfnt/resize"
|
|
"github.com/urfave/cli/v2"
|
|
)
|
|
|
|
func main() {
|
|
commands := []*cli.Command{
|
|
getJobResultCmd(),
|
|
resizeCmd(),
|
|
}
|
|
|
|
app := &cli.App{
|
|
Name: "i2v",
|
|
Usage: "A command line tool to convert images to videos, based on Stability.ai's Rest API.",
|
|
Version: "v1.0.0",
|
|
Description: "i2v (Image to Video) is a command line tool to convert images to videos, based on Stability.ai's Rest API.",
|
|
Commands: commands,
|
|
Flags: []cli.Flag{
|
|
&cli.StringFlag{
|
|
Name: "output",
|
|
Value: "./output.mp4",
|
|
Usage: "path to output file (should be a mp4)",
|
|
Aliases: []string{
|
|
"o",
|
|
},
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "format",
|
|
Aliases: []string{"f"},
|
|
Usage: "The format to resize the image to wide, tall, or square",
|
|
Value: "wide",
|
|
},
|
|
},
|
|
Authors: []*cli.Author{
|
|
{
|
|
Name: "Mason Payne",
|
|
Email: "mason@masonitestudios.com",
|
|
},
|
|
},
|
|
Copyright: "2023 Masonite Studios LLC",
|
|
UseShortOptionHandling: true,
|
|
Action: func(c *cli.Context) error {
|
|
var needsCleanup bool
|
|
// use the first argument as the file name
|
|
fileLocation, err := getFileLocation(c)
|
|
if err != nil {
|
|
return fmt.Errorf("error getting file location | %w", err)
|
|
}
|
|
|
|
// preprocess the image to make sure it is the right size
|
|
/*
|
|
Supported Dimensions:
|
|
1024x576
|
|
576x1024
|
|
768x768
|
|
*/
|
|
// get the format
|
|
format := c.String("format")
|
|
if format != "wide" && format != "tall" && format != "square" {
|
|
return fmt.Errorf("invalid format %s", format)
|
|
}
|
|
|
|
// make sure the image is either jpeg or png
|
|
fileMimeType := mime.TypeByExtension(path.Ext(fileLocation))
|
|
if fileMimeType != "image/jpeg" && fileMimeType != "image/png" {
|
|
return fmt.Errorf("unsupported file type | %v", fileMimeType)
|
|
}
|
|
|
|
// resize the image
|
|
tempLocation, err := resizeImage(fileLocation, format)
|
|
if err != nil {
|
|
return fmt.Errorf("error resizing image | %w", err)
|
|
}
|
|
|
|
if tempLocation != fileLocation {
|
|
fmt.Printf("Resized image to %s\n", tempLocation)
|
|
needsCleanup = true
|
|
fileLocation = tempLocation // use the resized image
|
|
}
|
|
|
|
id, err := initiateGeneratingAnimation(fileLocation)
|
|
if err != nil {
|
|
if needsCleanup {
|
|
// remove the temp file
|
|
err = cleanUpTempFile(tempLocation)
|
|
if err != nil {
|
|
return fmt.Errorf("error removing temp file | %w", err)
|
|
}
|
|
}
|
|
return fmt.Errorf("error making request | %w", err)
|
|
}
|
|
|
|
fmt.Println("Video is being rendered, this may take a while.")
|
|
fmt.Printf("Job ID: %v\n", id)
|
|
|
|
// wait for the job to finish
|
|
err = job(id, c.String("output"))
|
|
if err != nil {
|
|
if needsCleanup {
|
|
// remove the temp file
|
|
err = cleanUpTempFile(tempLocation)
|
|
if err != nil {
|
|
return fmt.Errorf("error removing temp file | %w", err)
|
|
}
|
|
}
|
|
return fmt.Errorf("error getting job result | %w", err)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
}
|
|
//fmt.Println("This package will be used for interacting with a running StormV2 service via a terminal or command line.")
|
|
err := app.Run(os.Args)
|
|
if err != nil {
|
|
fmt.Println(fmt.Errorf("error running app | %w", err))
|
|
return
|
|
}
|
|
}
|
|
|
|
func cleanUpTempFile(fileLocation string) error {
|
|
// remove the temp file
|
|
err := os.Remove(fileLocation)
|
|
if err != nil {
|
|
return fmt.Errorf("error removing temp file | %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func resizeImage(fileLocation string, format string) (string, error) {
|
|
// load the image
|
|
file, err := os.Open(fileLocation)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error opening file | %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
var img image.Image
|
|
fileMimeType := mime.TypeByExtension(path.Ext(fileLocation))
|
|
if fileMimeType == "image/jpeg" {
|
|
// decode the image
|
|
img, err = jpeg.Decode(file)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error decoding jpeg image | %w", err)
|
|
}
|
|
} else if fileMimeType == "image/png" {
|
|
// decode the image
|
|
img, err = png.Decode(file)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error decoding png image | %w", err)
|
|
}
|
|
} else {
|
|
return "", fmt.Errorf("unsupported file type | %v", fileMimeType)
|
|
}
|
|
|
|
// check if the image is already the correct size
|
|
rect := img.Bounds()
|
|
width := rect.Max.X - rect.Min.X
|
|
height := rect.Max.Y - rect.Min.Y
|
|
x1 := 1024
|
|
y1 := 576
|
|
if format == "wide" {
|
|
if width == 1024 && height == 576 {
|
|
fmt.Println("Image is already the correct size.")
|
|
return fileLocation, nil
|
|
}
|
|
}
|
|
if format == "tall" {
|
|
if width == 576 && height == 1024 {
|
|
fmt.Println("Image is already the correct size.")
|
|
return fileLocation, nil
|
|
}
|
|
x1 = 576
|
|
y1 = 1024
|
|
}
|
|
if format == "square" {
|
|
if width == 768 && height == 768 {
|
|
fmt.Println("Image is already the correct size.")
|
|
return fileLocation, nil
|
|
}
|
|
x1 = 768
|
|
y1 = 768
|
|
}
|
|
|
|
inFormat := "wide"
|
|
if width < height {
|
|
inFormat = "tall"
|
|
}
|
|
if width == height {
|
|
inFormat = "square"
|
|
}
|
|
|
|
// if not, resize the image
|
|
// scale the original image to the new size
|
|
var resizedImage image.Image
|
|
if format == "wide" {
|
|
resizedImage = resize.Resize(uint(x1), 0, img, resize.Lanczos3)
|
|
}
|
|
if format == "tall" {
|
|
resizedImage = resize.Resize(0, uint(y1), img, resize.Lanczos3)
|
|
}
|
|
if format == "square" {
|
|
if inFormat == "wide" {
|
|
resizedImage = resize.Resize(0, uint(y1), img, resize.Lanczos3)
|
|
}
|
|
if inFormat == "tall" {
|
|
resizedImage = resize.Resize(uint(x1), 0, img, resize.Lanczos3)
|
|
}
|
|
if inFormat == "square" {
|
|
resizedImage = resize.Resize(uint(x1), uint(y1), img, resize.Lanczos3)
|
|
}
|
|
}
|
|
|
|
// crop the image to the final correct size
|
|
|
|
// start by getting the center of the image
|
|
tempBounds := resizedImage.Bounds()
|
|
x0 := tempBounds.Max.X/2 - x1/2
|
|
y0 := tempBounds.Max.Y/2 - y1/2
|
|
xMax := x0 + x1
|
|
yMax := y0 + y1
|
|
|
|
croppedImageRect := image.Rect(x0, y0, xMax, yMax)
|
|
rgba := convertToRGBA(resizedImage)
|
|
croppedImage := rgba.SubImage(croppedImageRect)
|
|
|
|
// save the image to a temp file
|
|
tempFile, err := os.CreateTemp("", "i2v*"+path.Ext(fileLocation))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating temp file | %w", err)
|
|
}
|
|
defer tempFile.Close()
|
|
|
|
// encode the image to the temp file
|
|
err = jpeg.Encode(tempFile, croppedImage, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error encoding resized image | %w", err)
|
|
}
|
|
|
|
// return the temp file location
|
|
return tempFile.Name(), nil
|
|
}
|
|
|
|
func convertToRGBA(img image.Image) *image.RGBA {
|
|
bounds := img.Bounds()
|
|
rgba := image.NewRGBA(bounds)
|
|
draw.Draw(rgba, bounds, img, bounds.Min, draw.Src)
|
|
return rgba
|
|
}
|
|
|
|
func getFileLocation(c *cli.Context) (string, error) {
|
|
fileLocation := c.Args().Get(0)
|
|
// if no argument is provided, use stdin
|
|
if fileLocation == "" {
|
|
fmt.Println("No file provided, using stdin.")
|
|
// read from stdin
|
|
var err error
|
|
fileLocationBytes, err := io.ReadAll(os.Stdin)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error reading from stdin | %w", err)
|
|
}
|
|
fileLocation = string(fileLocationBytes)
|
|
}
|
|
return fileLocation, nil
|
|
}
|
|
|
|
func initiateGeneratingAnimation(fileLocation string) (string, error) {
|
|
// get base filename from file location
|
|
filename := path.Base(fileLocation)
|
|
|
|
// Create a buffer to store the request body
|
|
bodyBuf := &bytes.Buffer{}
|
|
|
|
// Create a multipart writer with the buffer
|
|
writer := multipart.NewWriter(bodyBuf)
|
|
|
|
// Add form fields
|
|
err := writer.WriteField("seed", "0")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing form field seed | %w", err)
|
|
}
|
|
err = writer.WriteField("cfg_scale", "2.5")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing form field cfg_scale | %w", err)
|
|
}
|
|
err = writer.WriteField("motion_bucket_id", "40")
|
|
if err != nil {
|
|
return "", fmt.Errorf("error writing form field motion_bucket_id | %w", err)
|
|
}
|
|
|
|
// Add a file
|
|
file, err := os.Open(fileLocation)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error opening file | %w", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
// Create a form file part
|
|
//filePart, err := writer.CreateFormFile("image", filename)
|
|
filePart, err := CreateImageFormFile(writer, filename)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating form file | %w", err)
|
|
}
|
|
|
|
// Copy the file content to the form file part
|
|
_, err = io.Copy(filePart, file)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error copying file content | %w", err)
|
|
}
|
|
|
|
// Close the multipart writer
|
|
writer.Close()
|
|
|
|
url := "https://api.stability.ai/v2alpha/generation/image-to-video"
|
|
|
|
// Create the HTTP request
|
|
req, err := http.NewRequest("POST", url, bodyBuf)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating request | %w", err)
|
|
}
|
|
|
|
// Set the Content-Type header
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
req.Header.Add("authorization", "Bearer "+os.Getenv("STABILITY_API_KEY"))
|
|
|
|
// Make the HTTP request
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error making request | %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Print the response status and body
|
|
fmt.Println("Status:", resp.Status)
|
|
fmt.Println("Body:", resp.Body)
|
|
|
|
//bodyBytes, err := io.ReadAll(resp.Body)
|
|
//if err != nil {
|
|
// return "", fmt.Errorf("error reading response body | %w", err)
|
|
//}
|
|
//bodyString := string(bodyBytes)
|
|
//fmt.Println(bodyString)
|
|
|
|
res := struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Errors []string `json:"errors"`
|
|
}{}
|
|
|
|
// decode response body
|
|
err = json.NewDecoder(resp.Body).Decode(&res)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error decoding response body | %w", err)
|
|
}
|
|
|
|
// print response body
|
|
if res.Name != "" {
|
|
return "", fmt.Errorf("error generating animation | %v: Errors: %v", res.Name, strings.Join(res.Errors, ", "))
|
|
}
|
|
|
|
return res.ID, nil
|
|
}
|
|
|
|
func CreateImageFormFile(w *multipart.Writer, filename string) (io.Writer, error) {
|
|
fileMimeType := mime.TypeByExtension(path.Ext(filename))
|
|
h := make(textproto.MIMEHeader)
|
|
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "image", filename))
|
|
h.Set("Content-Type", fileMimeType)
|
|
return w.CreatePart(h)
|
|
}
|
|
|
|
func getJobResult(jobID string) (string, bool, error) {
|
|
url := fmt.Sprintf("https://api.stability.ai/v2alpha/generation/image-to-video/result/%s", jobID)
|
|
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("error creating request | %w", err)
|
|
}
|
|
|
|
req.Header.Add("authorization", "Bearer "+os.Getenv("STABILITY_API_KEY"))
|
|
req.Header.Add("accept", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("error making request | %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
res := struct {
|
|
Video string `json:"video"`
|
|
FinishReason string `json:"finishReason"`
|
|
Seed int64 `json:"seed"`
|
|
ID string `json:"id"`
|
|
Status string `json:"status"`
|
|
Name string `json:"name"`
|
|
Errors []string `json:"errors"`
|
|
}{}
|
|
|
|
err = json.NewDecoder(resp.Body).Decode(&res)
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("error unmarshalling response body | %w", err)
|
|
}
|
|
|
|
if res.Name != "" {
|
|
return "", false, fmt.Errorf("error generating animation | %v: Errors: %v", res.Name, strings.Join(res.Errors, ", "))
|
|
}
|
|
|
|
if res.Status == "in-progress" {
|
|
return "", false, nil
|
|
}
|
|
|
|
return res.Video, true, nil
|
|
}
|
|
|
|
func job(id, outputLocation string) error {
|
|
var video string
|
|
var finished bool
|
|
var err error
|
|
// poll the job result until it is finished
|
|
for {
|
|
video, finished, err = getJobResult(id)
|
|
if err != nil {
|
|
return fmt.Errorf("error getting job result | %w", err)
|
|
}
|
|
|
|
if finished {
|
|
fmt.Println("Video has completed rendering.")
|
|
break
|
|
}
|
|
time.Sleep(5 * time.Second)
|
|
}
|
|
|
|
// decode the video, it is in base64 and is expected to be a mp4
|
|
decodedVideo, err := b64.StdEncoding.DecodeString(video)
|
|
if err != nil {
|
|
return fmt.Errorf("error decoding video | %w", err)
|
|
}
|
|
|
|
// write the video to the current directory
|
|
err = os.WriteFile(outputLocation, decodedVideo, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("error writing video to file | %w", err)
|
|
}
|
|
|
|
fmt.Println("Video has been saved to ", outputLocation)
|
|
return nil
|
|
}
|