-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
111 lines (92 loc) · 3.78 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import os
from google.cloud import storage
import transformers
def download_model(model_name):
# Download the model from Hugging Face
model = transformers.AutoModel.from_pretrained(model_name)
return model
def upload_to_gcs(bucket_name, file_name, file_path):
# Initialize the client
storage_client = storage.Client()
# Get the bucket
bucket = storage_client.get_bucket(bucket_name)
# Create a new blob
blob = bucket.blob(file_name)
# Upload the file
blob.upload_from_filename(file_path)
print(f"File {file_name} has been uploaded to {bucket_name}.")
def download_from_gcs(bucket_name, file_name, file_path):
# Initialize the client
storage_client = storage.Client()
# Get the bucket
bucket = storage_client.get_bucket(bucket_name)
# Create a new blob
blob = bucket.blob(file_name)
# Download the file
blob.download_to_filename(file_path)
print(f"File {file_name} has been downloaded from {bucket_name}.")
def download_images_from_gcs(bucket_name, filenames):
# Initialize the client
storage_client = storage.Client()
# Get the bucket
bucket = storage_client.get_bucket(bucket_name)
images = []
for filename in filenames:
# Create a new blob
blob = bucket.blob(filename)
# Download the file
blob.download_to_filename(filename)
images.append(filename)
print(f"Files {', '.join(filenames)} have been downloaded from {bucket_name}.")
return images
def run_training(model, images, prompt):
# Code for running training on the model with the given images and prompt
pass
def run_inference(model, prompt):
# Code for running inference on the model with the given prompt
pass
def upload_images_to_gcs(bucket_name, filenames):
# Initialize the client
storage_client = storage.Client()
# Get the bucket
bucket = storage_client.get_bucket(bucket_name)
for filename in filenames:
# Create a new blob
blob = bucket.blob(filename)
# Upload the file
blob.upload_from_filename(filename)
print(f"Files {', '.join(filenames)} have been uploaded to {bucket_name}.")
def main(request):
model_name = 'runwayml/stable-diffusion-v1-5'
model_folder = 'models'
model_path = os.path.join(model_folder, model_name)
# Download the model from Hugging Face
model = download_model(model_name)
# Upload the model to GCS
upload_to_gcs('my-bucket-name', model_path, model_path)
# Delete the local files for the model
os.remove(model_path)
# Download the model from GCS
download_from_gcs('my-bucket-name', model_path, model_path)
# Download the images from GCS
image_filenames = ['photo1.jpg', 'photo2.jpg', 'photo3.jpg', 'photo4.jpg']
images = download_images_from_gcs('my-bucket-name', image_filenames)
# Run training on the model with the images and prompt
prompt = 'profile picture of xyz person'
run_training(model, images, prompt)
# Upload the trained model to GCS
upload_to_gcs('my-bucket-name', 'trained_models/' + model_name, model_path)
# Download the trained model from GCS
download_from_gcs('my-bucket-name', 'trained_models/' + model_name, model_path)
# Run inference with the prompt
prompt = 'photo of xyz person, xyz person on the beach close to the water, beautifull day, cristaline beautiful ocean water green transparent on a sunny day in indonesia'
generated_images = run_inference(model, prompt)
# Upload the generated images to GCS
upload_images_to_gcs('my-bucket-name', generated_images)
# Return a JSON response with the filenames and a success message
response = {
'files': generated_images,
'message': 'success'
}
return json.dumps(response)