Дообучить модель эмбеддингов
Примечание
Функциональность находится на стадии Preview.
Этот пример показывает, как дообучить модель эмбеддингов по методу LoRA в AI Studio. Ссылки на другие примеры доступны в разделе См. также.
После завершения операции дообучения сохраните URI дообученной модели вида emb://<идентификатор_каталога>/text-embeddings/latest@<суффикс_дообучения>. Используйте его в качестве пользовательской модели эмбеддингов, если это необходимо. Например, можно указать model_uri при построении поискового индекса.
Чтобы воспользоваться примером, вам понадобится сервисный аккаунт с ролями ai.assistants.editor и ai.languageModels.user, а также API-ключ с областью действия yc.ai.foundationModels.execute. API-ключ, который вы можете создать в AI Studio, имеет такие разрешения. Пример того, как настроить рабочее окружение, можно найти в разделе Начало работы.
Подготовьте данные
- Подготовьте данные в необходимом формате. Для дообучения модели эмбеддингов используйте датасеты пар
TextEmbeddingPairParamsили триплетовTextEmbeddingTripletParams. - Создайте датасет любым удобным способом. В интерфейсе AI Studio вы также сможете создать датасет позднее на этапе создания дообучения.
Запустите дообучение
import grpc
import os
import time
import re
from yandex.cloud.ai.tuning.v1 import tuning_service_pb2
from yandex.cloud.ai.tuning.v1 import tuning_service_pb2_grpc
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub
YANDEX_API_KEY = os.environ["YANDEX_API_KEY"]
YANDEX_FOLDER_ID = os.environ["YANDEX_FOLDER_ID"]
YANDEX_DATASET_ID = os.environ["YANDEX_DATASET_ID"]
# --- Вспомогательная функция для конвертации CamelCase в snake_case ---
def to_snake_case(name):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1)
return s2.lower()
# Получаем тип и конвертируем в snake_case
raw_type = os.environ.get("YANDEX_EMBEDDING_TYPE", "text_embedding_pair_params")
YANDEX_EMBEDDING_TYPE = to_snake_case(raw_type)
print(f"🎯 Используемый тип эмбеддингов (поле): {YANDEX_EMBEDDING_TYPE}")
# Настройка канала и стуба
credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel("ai.api.cloud.yandex.net:443", credentials)
tuning_stub = tuning_service_pb2_grpc.TuningServiceStub(channel)
# Канал для операций
op_channel = grpc.secure_channel("ai.api.cloud.yandex.net:443", credentials)
op_stub = OperationServiceStub(op_channel)
metadata = [("authorization", f"Api-Key {YANDEX_API_KEY}")]
# Подготовка аргументов запроса
tune_kwargs = {
"base_model_uri": f"emb://{YANDEX_FOLDER_ID}/text-embeddings/latest",
"train_datasets": [{"dataset_id": YANDEX_DATASET_ID, "weight": 1.0}],
"name": "train-embeddings",
}
# Добавляем поле типа эмбеддингов (теперь ключ точно в snake_case)
tune_kwargs[YANDEX_EMBEDDING_TYPE] = {}
try:
# ИСПРАВЛЕНИЕ 1: Используем TuningRequest вместо TuneRequest
request = tuning_service_pb2.TuningRequest(**tune_kwargs)
print("📤 Отправка запроса на дообучение эмбеддингов...")
operation = tuning_stub.Tune(request, metadata=metadata)
print(f"✅ Операция запущена: {operation.id}")
print("⏳ Ожидание завершения дообучения...")
while True:
op = op_stub.Get(GetOperationRequest(operation_id=operation.id), metadata=metadata)
if op.done:
if op.HasField("error"):
print(f"❌ Операция завершилась с ошибкой: {op.error.message}")
break
print("✅ Дообучение завершено!")
if op.HasField("response"):
print(f"📦 Ответ получен: {op.response}")
# Для получения чистого URI нужно делать unpack ответа,
# но для базовой проверки достаточно увидеть успех.
break
print("🔄 Статус: выполняется...")
time.sleep(30)
except Exception as e:
print(f"❌ Произошла ошибка: {e}")
import traceback
traceback.print_exc()
Где:
-
YANDEX_API_KEY— API-ключ для работы в AI Studio. -
YANDEX_FOLDER_ID— идентификатор каталога сервисного аккаунта. -
YANDEX_DATASET_ID— идентификатор датасета, сохраненный ранее. -
YANDEX_EMBEDDING_TYPE— тип созданного ранее датасета. Возможные значения:text_embedding_pair_params— датасет пар.text_embedding_triplet_params— датасет триплетов.
Пример ответа:
🎯 Используемый тип эмбеддингов (поле): text_embedding_triplet_params
📤 Отправка запроса на дообучение эмбеддингов...
✅ Операция запущена: ftnvcb6ifmjq********
⏳ Ожидание завершения дообучения...
🔄 Статус: выполняется...
import * as grpc from "@grpc/grpc-js";
import protobuf from "protobufjs";
import descriptor from "protobufjs/ext/descriptor/index.js";
// ============================================================================
// КОНФИГУРАЦИЯ
// ============================================================================
const YANDEX_API_KEY = process.env.YANDEX_API_KEY;
const YANDEX_FOLDER_ID = process.env.YANDEX_FOLDER_ID;
const YANDEX_DATASET_ID = process.env.YANDEX_DATASET_ID;
// Тип датасета: \"text_embedding_pair_params\" или \"text_embedding_triplet_params\"
const YANDEX_EMBEDDING_TYPE = process.env.YANDEX_EMBEDDING_TYPE;
const MODEL_NAME = process.env.MODEL_NAME || "train-embeddings";
const BASE_MODEL = process.env.BASE_MODEL || "text-embeddings/latest";
const GRPC_ADDRESS = "ai.api.cloud.yandex.net:443";
// Преобразуем тип embeddings в формат для API
// TextEmbeddingPair -> text_embedding_pair_params
// TextEmbeddingTriplet -> text_embedding_triplet_params
const EMBEDDING_TYPE_API = YANDEX_EMBEDDING_TYPE
.replace(/([A-Z])/g, "_$1")
.toLowerCase()
.replace(/^_/, "")
.replace(/_params$/, "") + "_params";
// ============================================================================
// ПРОВЕРКА ПЕРЕМЕННЫХ ОКРУЖЕНИЯ
// ============================================================================
if (!YANDEX_API_KEY) {
console.error("❌ YANDEX_API_KEY не установлен");
process.exit(1);
}
if (!YANDEX_FOLDER_ID) {
console.error("❌ YANDEX_FOLDER_ID не установлен");
process.exit(1);
}
if (!YANDEX_DATASET_ID) {
console.error("❌ YANDEX_DATASET_ID не установлен");
process.exit(1);
}
if (!YANDEX_EMBEDDING_TYPE) {
console.error("❌ YANDEX_EMBEDDING_TYPE не установлен");
process.exit(1);
}
// ============================================================================
// gRPC SERVER REFLECTION
// ============================================================================
const { FileDescriptorProto, FileDescriptorSet } = descriptor;
/** Сериализация ServerReflectionRequest */
function serializeReflectionRequest(msg) {
const writer = protobuf.Writer.create();
if (msg.list_services !== undefined) {
writer.uint32(58).string(msg.list_services);
}
if (msg.file_containing_symbol !== undefined) {
writer.uint32(34).string(msg.file_containing_symbol);
}
if (msg.file_by_filename !== undefined) {
writer.uint32(26).string(msg.file_by_filename);
}
return Buffer.from(writer.finish());
}
/** Десериализация ServerReflectionResponse */
function deserializeReflectionResponse(buf) {
const reader = protobuf.Reader.create(buf);
const result = {};
while (reader.pos < reader.len) {
const tag = reader.uint32();
const fieldNumber = tag >>> 3;
const wireType = tag & 7;
switch (fieldNumber) {
case 1:
result.validHost = reader.string();
break;
case 2:
reader.skipType(wireType);
break;
case 4: {
const end = reader.uint32() + reader.pos;
const fileDescriptorProto = [];
while (reader.pos < end) {
const innerTag = reader.uint32();
if ((innerTag >>> 3) === 1) {
fileDescriptorProto.push(Buffer.from(reader.bytes()));
} else {
reader.skipType(innerTag & 7);
}
}
result.fileDescriptorResponse = { fileDescriptorProto };
break;
}
case 6: {
const end = reader.uint32() + reader.pos;
const services = [];
while (reader.pos < end) {
const innerTag = reader.uint32();
if ((innerTag >>> 3) === 1) {
const svcEnd = reader.uint32() + reader.pos;
let name = "";
while (reader.pos < svcEnd) {
const svcTag = reader.uint32();
if ((svcTag >>> 3) === 1) {
name = reader.string();
} else {
reader.skipType(svcTag & 7);
}
}
services.push({ name });
} else {
reader.skipType(innerTag & 7);
}
}
result.listServicesResponse = { services };
break;
}
case 7: {
const end = reader.uint32() + reader.pos;
const err = {};
while (reader.pos < end) {
const innerTag = reader.uint32();
const innerField = innerTag >>> 3;
if (innerField === 1) err.errorCode = reader.int32();
else if (innerField === 2) err.errorMessage = reader.string();
else reader.skipType(innerTag & 7);
}
result.errorResponse = err;
break;
}
default:
reader.skipType(wireType);
}
}
return result;
}
/** Создает Reflection-клиент */
function createReflectionClient(address, credentials) {
const serviceDef = {
ServerReflectionInfo: {
path: "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
requestStream: true,
responseStream: true,
requestSerialize: serializeReflectionRequest,
requestDeserialize: (buf) => buf,
responseSerialize: (msg) => msg,
responseDeserialize: deserializeReflectionResponse,
},
};
const Client = grpc.makeGenericClientConstructor(serviceDef);
return new Client(address, credentials);
}
/** Один reflection-запрос через bidi stream */
function reflectionCall(client, metadata, request) {
return new Promise((resolve, reject) => {
const call = client.ServerReflectionInfo(metadata);
const results = [];
call.on("data", (resp) => results.push(resp));
call.on("end", () => resolve(results));
call.on("error", reject);
call.write(request);
call.end();
});
}
/**
* Получает все FileDescriptor для указанных сервисов, включая зависимости.
*/
async function resolveServices(client, metadata, serviceNames) {
const allFdBytes = new Map();
for (const svcName of serviceNames) {
const responses = await reflectionCall(client, metadata, {
file_containing_symbol: svcName,
});
for (const resp of responses) {
if (resp.errorResponse) {
throw new Error(
`Reflection error for "${svcName}": ${resp.errorResponse.errorMessage}`
);
}
if (resp.fileDescriptorResponse) {
for (const fdBytes of resp.fileDescriptorResponse.fileDescriptorProto) {
const fd = FileDescriptorProto.decode(fdBytes);
allFdBytes.set(fd.name, fdBytes);
}
}
}
}
// Рекурсивно подгружаем зависимости
const processed = new Set();
let hasNew = true;
while (hasNew) {
hasNew = false;
for (const [name, fdBytes] of allFdBytes) {
if (processed.has(name)) continue;
processed.add(name);
const fd = FileDescriptorProto.decode(fdBytes);
if (fd.dependency) {
for (const dep of fd.dependency) {
if (!allFdBytes.has(dep)) {
try {
const responses = await reflectionCall(client, metadata, {
file_by_filename: dep,
});
for (const resp of responses) {
if (resp.fileDescriptorResponse) {
for (const b of resp.fileDescriptorResponse
.fileDescriptorProto) {
const d = FileDescriptorProto.decode(b);
allFdBytes.set(d.name, b);
}
}
}
} catch {
// Зависимость недоступна — пропускаем
}
hasNew = true;
}
}
}
}
}
return allFdBytes;
}
/**
* Строит protobufjs Root из FileDescriptor bytes.
*/
function buildRoot(fdBytesMap) {
const files = [];
for (const fdBytes of fdBytesMap.values()) {
files.push(FileDescriptorProto.decode(fdBytes));
}
const fds = FileDescriptorSet.create({ file: files });
const root = protobuf.Root.fromDescriptor(fds);
root.resolveAll();
return root;
}
/**
* Создает gRPC-клиент из protobufjs Root.
*
* ВАЖНО: protobufjs при fromDescriptor() сохраняет оригинальные snake_case
* имена полей из proto. Поэтому в запросах нужно использовать snake_case
* (base_model_uri, train_datasets и т.д.), и в ответах поля тоже в snake_case.
*/
function createServiceClient(root, servicePath, address, credentials) {
const service = root.lookupService(servicePath);
service.resolveAll();
const serviceDef = {};
for (const method of service.methodsArray) {
method.resolve();
const reqType = method.resolvedRequestType;
const resType = method.resolvedResponseType;
serviceDef[method.name] = {
path: `/${service.fullName.replace(/^\./, "")}/${method.name}`,
requestStream: !!method.requestStream,
responseStream: !!method.responseStream,
requestSerialize: (msg) =>
Buffer.from(reqType.encode(reqType.create(msg)).finish()),
requestDeserialize: (buf) =>
reqType.toObject(reqType.decode(buf), {
longs: String,
enums: String,
defaults: true,
}),
responseSerialize: (msg) =>
Buffer.from(resType.encode(resType.create(msg)).finish()),
responseDeserialize: (buf) =>
resType.toObject(resType.decode(buf), {
longs: String,
enums: String,
defaults: true,
}),
};
}
const Client = grpc.makeGenericClientConstructor(serviceDef);
return new Client(address, credentials);
}
// ============================================================================
// УТИЛИТЫ
// ============================================================================
/** Промисифицированный unary gRPC-вызов */
function rpc(client, method, request, metadata) {
return new Promise((resolve, reject) => {
client[method](request, metadata, (err, response) => {
if (err) reject(err);
else resolve(response);
});
});
}
// ============================================================================
// ОСНОВНАЯ ЛОГИКА
// ============================================================================
async function tuneEmbeddings() {
console.log("🎓 Запуск дообучения embeddings...");
console.log(` Базовая модель: ${BASE_MODEL}`);
console.log(` Датасет: ${YANDEX_DATASET_ID}`);
console.log(` Название: ${MODEL_NAME}`);
console.log(` Тип embeddings: ${EMBEDDING_TYPE_API}\n`);
const sslCreds = grpc.credentials.createSsl();
const metadata = new grpc.Metadata();
metadata.add("authorization", `Api-Key ${YANDEX_API_KEY}`);
// ---------------------------------------------------------------
// Шаг 0: gRPC Server Reflection — получаем описание сервисов
// ---------------------------------------------------------------
console.log("🔍 Получение описания сервисов через gRPC Reflection...");
const reflectionClient = createReflectionClient(GRPC_ADDRESS, sslCreds);
const fdBytes = await resolveServices(reflectionClient, metadata, [
"yandex.cloud.ai.tuning.v1.TuningService",
"yandex.cloud.operation.OperationService",
]);
console.log(`✅ Получено ${fdBytes.size} proto-описаний с сервера\n`);
reflectionClient.close();
// Строим protobufjs Root
const root = buildRoot(fdBytes);
// Создаем gRPC-клиенты
const tuningClient = createServiceClient(
root,
"yandex.cloud.ai.tuning.v1.TuningService",
GRPC_ADDRESS,
sslCreds
);
const operationClient = createServiceClient(
root,
"yandex.cloud.operation.OperationService",
GRPC_ADDRESS,
sslCreds
);
// ---------------------------------------------------------------
// Шаг 1: Запускаем дообучение
// Имена полей — snake_case (как в proto)
// ---------------------------------------------------------------
console.log("📤 Отправка запроса на дообучение...");
const tuneRequest = {
base_model_uri: `emb://${YANDEX_FOLDER_ID}/${BASE_MODEL}`,
train_datasets: [{ dataset_id: YANDEX_DATASET_ID, weight: 1.0 }],
name: MODEL_NAME,
[EMBEDDING_TYPE_API]: {},
};
const operation = await rpc(tuningClient, "Tune", tuneRequest, metadata);
const operationId = operation.id;
console.log(`✅ Операция запущена: ${operationId}\n`);
// ---------------------------------------------------------------
// Шаг 2: Ожидаем завершения
// ---------------------------------------------------------------
console.log("⏳ Ожидание завершения дообучения...");
console.log(" (проверка статуса каждые 30 секунд)\n");
while (true) {
const op = await rpc(
operationClient,
"Get",
{ operation_id: operationId },
metadata
);
if (op.done) {
if (op.error) {
throw new Error(
`Ошибка дообучения: ${op.error.message || JSON.stringify(op.error)}`
);
}
console.log("✅ Дообучение завершено!\n");
// Извлекаем URI модели из ответа
if (op.response && op.response.value) {
try {
// Декодируем ответ TuneResponse
const TuneResponse = root.lookupType(
"yandex.cloud.ai.tuning.v1.TuneResponse"
);
const responseData = TuneResponse.toObject(
TuneResponse.decode(op.response.value),
{
longs: String,
enums: String,
defaults: true,
}
);
const modelUri =
responseData.target_model_uri || responseData.model_uri;
if (modelUri) {
console.log(`🎉 URI дообученной модели: ${modelUri}`);
return modelUri;
} else {
console.log("⚠️ Не удалось извлечь URI модели из ответа");
console.log("Ответ:", JSON.stringify(responseData, null, 2));
}
} catch (decodeError) {
console.log(
"⚠️ Ошибка декодирования ответа:",
decodeError.message
);
console.log("Сырой ответ:", JSON.stringify(op, null, 2));
}
} else {
console.log("⚠️ Не удалось получить URI модели из ответа");
console.log("Ответ:", JSON.stringify(op, null, 2));
}
break;
}
console.log("🔄 Статус: дообучение выполняется...");
await new Promise((r) => setTimeout(r, 30000));
}
}
// ============================================================================
// ЗАПУСК
// ============================================================================
tuneEmbeddings()
.then(() => {
console.log("\n✅ Процесс завершен успешно");
process.exit(0);
})
.catch((err) => {
console.error("\n❌ Процесс завершен с ошибкой:", err);
process.exit(1);
});
Где:
-
YANDEX_API_KEY— API-ключ для работы в AI Studio. -
YANDEX_FOLDER_ID— идентификатор каталога сервисного аккаунта. -
YANDEX_DATASET_ID— идентификатор датасета, сохраненный ранее. -
YANDEX_EMBEDDING_TYPE— тип созданного ранее датасета. Возможные значения:text_embedding_pair_params— датасет пар.text_embedding_triplet_params— датасет триплетов.
Пример ответа:
🎯 Используемый тип эмбеддингов (поле): text_embedding_triplet_params
📤 Отправка запроса на дообучение эмбеддингов...
✅ Операция запущена: ftnvcb6ifmjq********
⏳ Ожидание завершения дообучения...
🔄 Статус: выполняется...
package main
import (
"context"
"fmt"
"log"
"os"
"strings"
"time"
"github.com/jhump/protoreflect/grpcreflect"
"github.com/jhump/protoreflect/dynamic"
"github.com/jhump/protoreflect/dynamic/grpcdynamic"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
oppb "github.com/yandex-cloud/go-genproto/yandex/cloud/operation"
)
var (
YANDEX_API_KEY = os.Getenv("YANDEX_API_KEY")
YANDEX_FOLDER_ID = os.Getenv("YANDEX_FOLDER_ID")
YANDEX_DATASET_ID = os.Getenv("YANDEX_DATASET_ID")
// Тип датасета: "text_embedding_pair_params" или "text_embedding_triplet_params"
YANDEX_EMBEDDING_TYPE = os.Getenv("YANDEX_EMBEDDING_TYPE")
// Базовая модель: "text-embeddings/latest" или "text-embeddings/rc"
BASE_MODEL = getEnvOrDefault("BASE_MODEL", "text-embeddings/latest")
MODEL_NAME = getEnvOrDefault("MODEL_NAME", "train-embeddings")
)
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func main() {
// Проверка переменных окружения
if YANDEX_API_KEY == "" {
log.Fatal("❌ Ошибка: YANDEX_API_KEY не установлен")
}
if YANDEX_FOLDER_ID == "" {
log.Fatal("❌ Ошибка: YANDEX_FOLDER_ID не установлен")
}
if YANDEX_DATASET_ID == "" {
log.Fatal("❌ Ошибка: YANDEX_DATASET_ID не установлен")
}
// Добавляем тип embeddings по умолчанию
if YANDEX_EMBEDDING_TYPE == "" {
YANDEX_EMBEDDING_TYPE = "text_embedding_pair_params"
}
fmt.Println("🎓 Запуск дообучения embeddings...")
fmt.Printf(" Базовая модель: %s\n", BASE_MODEL)
fmt.Printf(" Датасет: %s\n", YANDEX_DATASET_ID)
fmt.Printf(" Название: %s\n", MODEL_NAME)
fmt.Printf(" Тип embeddings: %s\n", YANDEX_EMBEDDING_TYPE)
fmt.Println()
// Подключаемся к tuning service
creds := credentials.NewClientTLSFromCert(nil, "")
tuningConn, err := grpc.Dial("ai.api.cloud.yandex.net:443", grpc.WithTransportCredentials(creds))
if err != nil {
log.Fatalf("Ошибка подключения к tuning service: %v", err)
}
defer tuningConn.Close()
ctx := metadata.AppendToOutgoingContext(context.Background(),
"authorization", "Api-Key "+YANDEX_API_KEY)
// Используем server reflection для получения схемы
fmt.Println("🔍 Получение схемы сервиса через reflection...")
refClient := grpc_reflection_v1alpha.NewServerReflectionClient(tuningConn)
reflectCtx := grpcreflect.NewClient(ctx, refClient)
defer reflectCtx.Reset()
// Получаем описание сервиса TuningService
tuningService, err := reflectCtx.ResolveService("yandex.cloud.ai.tuning.v1.TuningService")
if err != nil {
log.Fatalf("❌ Ошибка получения TuningService: %v", err)
}
// Находим метод Tune
tuneMethod := tuningService.FindMethodByName("Tune")
if tuneMethod == nil {
log.Fatal("❌ Не удалось найти метод Tune")
}
// Создаем динамический stub
stub := grpcdynamic.NewStub(tuningConn)
// Создаем динамическое сообщение для запроса
requestMsg := dynamic.NewMessage(tuneMethod.GetInputType())
requestMsg.SetFieldByName("base_model_uri", fmt.Sprintf("emb://%s/%s", YANDEX_FOLDER_ID, BASE_MODEL))
requestMsg.SetFieldByName("name", MODEL_NAME)
// Создаем train_datasets
trainDatasetsField := tuneMethod.GetInputType().FindFieldByName("train_datasets")
datasetMsg := dynamic.NewMessage(trainDatasetsField.GetMessageType())
datasetMsg.SetFieldByName("dataset_id", YANDEX_DATASET_ID)
datasetMsg.SetFieldByName("weight", float64(1.0))
requestMsg.SetFieldByName("train_datasets", []interface{}{datasetMsg})
// Добавляем поле для типа embeddings (пустое сообщение)
embeddingField := tuneMethod.GetInputType().FindFieldByName(YANDEX_EMBEDDING_TYPE)
if embeddingField == nil {
log.Fatalf("❌ Не удалось найти поле %s в TuneRequest", YANDEX_EMBEDDING_TYPE)
}
embeddingMsg := dynamic.NewMessage(embeddingField.GetMessageType())
requestMsg.SetFieldByName(YANDEX_EMBEDDING_TYPE, embeddingMsg)
// Выполняем gRPC вызов
fmt.Println("📤 Отправка запроса на дообучение...")
responseMsg, err := stub.InvokeRpc(ctx, tuneMethod, requestMsg)
if err != nil {
// Проверяем на ошибку квоты
if strings.Contains(err.Error(), "Quota") {
log.Fatalf("❌ Превышена квота на количество дообучений: %v", err)
}
log.Fatalf("❌ Ошибка при запуске дообучения: %v", err)
}
// Преобразуем ответ в dynamic.Message
dynResponse, ok := responseMsg.(*dynamic.Message)
if !ok {
log.Fatal("❌ Не удалось преобразовать ответ в dynamic.Message")
}
// Извлекаем ID операции
operationID := dynResponse.GetFieldByName("id").(string)
fmt.Printf("✅ Операция запущена: %s\n", operationID)
fmt.Println()
// Подключаемся к operation service для проверки статуса
opConn, err := grpc.Dial("ai.api.cloud.yandex.net:443", grpc.WithTransportCredentials(creds))
if err != nil {
log.Fatalf("Ошибка подключения к operation service: %v", err)
}
defer opConn.Close()
opClient := oppb.NewOperationServiceClient(opConn)
// Ожидаем завершения
fmt.Println("⏳ Ожидание завершения дообучения...")
fmt.Println(" (проверка статуса каждые 30 секунд)")
fmt.Println()
for {
op, err := opClient.Get(ctx, &oppb.GetOperationRequest{OperationId: operationID})
if err != nil {
// Если ошибка связана с квотой, выводим понятное сообщение
if strings.Contains(err.Error(), "Quota") {
log.Fatalf("❌ Превышена квота на количество дообучений: %v", err)
}
log.Fatalf("Ошибка получения статуса: %v", err)
}
if op.Done {
if op.GetError() != nil {
errMsg := op.GetError().Message
if strings.Contains(errMsg, "Quota") {
log.Fatalf("❌ Превышена квота: %s", errMsg)
}
log.Fatalf("❌ Операция завершилась с ошибкой: %v", op.GetError())
}
fmt.Println("✅ Дообучение завершено!")
fmt.Println()
// Извлекаем URI модели через reflection
if opResponse := op.GetResponse(); opResponse != nil {
// Получаем описание TuneResponse через reflection
tuneResponseDesc, err := reflectCtx.ResolveMessage("yandex.cloud.ai.tuning.v1.TuneResponse")
if err != nil {
fmt.Printf("⚠️ Ошибка получения TuneResponse: %v\n", err)
break
}
// Декодируем ответ
tuneResponse := dynamic.NewMessage(tuneResponseDesc)
if err := tuneResponse.Unmarshal(opResponse.Value); err == nil {
// Пробуем разные варианты имени поля
modelURI := ""
if uri := tuneResponse.GetFieldByName("target_model_uri"); uri != nil {
if uriStr, ok := uri.(string); ok {
modelURI = uriStr
}
}
if modelURI == "" {
if uri := tuneResponse.GetFieldByName("targetModelUri"); uri != nil {
if uriStr, ok := uri.(string); ok {
modelURI = uriStr
}
}
}
if modelURI == "" {
if uri := tuneResponse.GetFieldByName("model_uri"); uri != nil {
if uriStr, ok := uri.(string); ok {
modelURI = uriStr
}
}
}
if modelURI != "" {
fmt.Printf("🎉 URI дообученной модели: %s\n", modelURI)
} else {
fmt.Println("⚠️ Не удалось извлечь URI модели")
fmt.Printf("Ответ: %+v\n", tuneResponse)
}
} else {
fmt.Printf("⚠️ Ошибка декодирования ответа: %v\n", err)
}
}
break
}
fmt.Println("🔄 Статус: дообучение выполняется...")
time.Sleep(30 * time.Second)
}
fmt.Println()
fmt.Println("✅ Процесс завершен успешно")
}
Где:
-
YANDEX_API_KEY— API-ключ для работы в AI Studio. -
YANDEX_FOLDER_ID— идентификатор каталога сервисного аккаунта. -
YANDEX_DATASET_ID— идентификатор датасета, сохраненный ранее. -
TuningTaskType— тип созданного ранее датасета. Возможные значения:TextEmbeddingPairParams— датасет пар.TextEmbeddingTripletParams— датасет триплетов.
Пример ответа:
🎯 Используемый тип эмбеддингов (поле): text_embedding_triplet_params
📤 Отправка запроса на дообучение эмбеддингов...
✅ Операция запущена: ftnvcb6ifmjq********
⏳ Ожидание завершения дообучения...
🔄 Статус: выполняется...
#!/bin/bash
# Переменные окружения:
# YANDEX_API_KEY — API-ключ
# YANDEX_FOLDER_ID — идентификатор каталога
# YANDEX_DATASET_ID — идентификатор готового датасета (статус Ready)
# YANDEX_EMBEDDING_TYPE — тип датасета:
# pair — датасет пар (по умолчанию)
# triplet — датасет триплетов
#
# Использование:
# export YANDEX_API_KEY="..."
# export YANDEX_FOLDER_ID="..."
# export YANDEX_DATASET_ID="..."
# export YANDEX_EMBEDDING_TYPE="triplet"
# bash embeddings-tuning-automation.sh
set -e
API_KEY="${YANDEX_API_KEY:?Укажите YANDEX_API_KEY}"
FOLDER_ID="${YANDEX_FOLDER_ID:?Укажите YANDEX_FOLDER_ID}"
DATASET_ID="${YANDEX_DATASET_ID:?Укажите YANDEX_DATASET_ID}"
# Автоопределение типа датасета через API
echo "Определение типа датасета..."
grpcurl \
-H "Authorization: Api-Key $API_KEY" \
-d "{\"dataset_id\": \"$DATASET_ID\"}" \
ai.api.cloud.yandex.net:443 \
yandex.cloud.ai.dataset.v1.DatasetService/Describe \
> /tmp/emb_dataset_info.json 2>/dev/null || true
DETECTED_TASK_TYPE=$(jq -r '.dataset.taskType // empty' /tmp/emb_dataset_info.json 2>/dev/null)
# Если автоопределение не сработало — используем переменную окружения
if [ -z "$DETECTED_TASK_TYPE" ]; then
EMBEDDING_TYPE="${YANDEX_EMBEDDING_TYPE:?Укажите YANDEX_EMBEDDING_TYPE: pair или triplet}"
else
echo "Тип датасета: $DETECTED_TASK_TYPE"
case "$DETECTED_TASK_TYPE" in
*Triplet*) EMBEDDING_TYPE="triplet" ;;
*) EMBEDDING_TYPE="pair" ;;
esac
fi
if [ "$EMBEDDING_TYPE" = "triplet" ]; then
TUNE_TASK_KEY="text_embedding_triplet_params"
else
TUNE_TASK_KEY="text_embedding_pair_params"
fi
# Шаг 1: Запустить дообучение
echo "=== Шаг 1: Запуск дообучения ($TUNE_TASK_KEY) ==="
grpcurl \
-H "Authorization: Api-Key $API_KEY" \
-d "{\"base_model_uri\": \"emb://$FOLDER_ID/text-embeddings/latest\", \"train_datasets\": [{\"dataset_id\": \"$DATASET_ID\", \"weight\": 1.0}], \"name\": \"train-embeddings\", \"$TUNE_TASK_KEY\": {}}" \
ai.api.cloud.yandex.net:443 \
yandex.cloud.ai.tuning.v1.TuningService/Tune \
> /tmp/emb_tuning_start.json
cat /tmp/emb_tuning_start.json
TUNING_OP_ID=$(jq -r '.id' /tmp/emb_tuning_start.json)
TUNING_TASK_ID=$(jq -r '.metadata.tuningTaskId' /tmp/emb_tuning_start.json)
echo "Операция дообучения: $TUNING_OP_ID"
# Шаг 2: Дождаться завершения дообучения
echo "=== Шаг 2: Ожидание завершения дообучения ==="
while true; do
grpcurl \
-H "Authorization: Api-Key $API_KEY" \
-d "{\"operation_id\": \"$TUNING_OP_ID\"}" \
ai.api.cloud.yandex.net:443 \
yandex.cloud.operation.OperationService/Get \
> /tmp/emb_tuning_status.json
DONE=$(jq -r '.done' /tmp/emb_tuning_status.json)
if [ "$DONE" = "true" ]; then
cat /tmp/emb_tuning_status.json
STATUS=$(jq -r '.response.status' /tmp/emb_tuning_status.json)
if [ "$STATUS" = "COMPLETED" ]; then
MODEL_URI=$(jq -r '.response.targetModelUri' /tmp/emb_tuning_status.json)
TUNING_TASK_ID=$(jq -r '.response.tuningTaskId' /tmp/emb_tuning_status.json)
echo "Дообучение завершено"
else
echo "Ошибка дообучения: $STATUS"
exit 1
fi
break
fi
echo "Дообучение выполняется..."
sleep 30
done
# Шаг 3: Получить ссылку на метрики TensorBoard
echo "=== Шаг 3: Ссылка на метрики ==="
grpcurl \
-H "Authorization: Api-Key $API_KEY" \
-d "{\"task_id\": \"$TUNING_TASK_ID\"}" \
ai.api.cloud.yandex.net:443 \
yandex.cloud.ai.tuning.v1.TuningService/GetMetricsUrl \
> /tmp/emb_tuning_metrics.json
cat /tmp/emb_tuning_metrics.json
echo ""
echo "URI дообученной модели: $MODEL_URI"
Где:
-
YANDEX_API_KEY— API-ключ для работы в AI Studio. -
YANDEX_FOLDER_ID— идентификатор каталога сервисного аккаунта. -
YANDEX_DATASET_ID— идентификатор датасета, сохраненный ранее. -
YANDEX_EMBEDDING_TYPE— тип созданного ранее датасета. Возможные значения:text_embedding_pair_params— датасет пар.text_embedding_triplet_params— датасет триплетов.
Пример ответа
Определение типа датасета...
Тип датасета: TextEmbeddingsTriplet
=== Шаг 1: Запуск дообучения (text_embedding_triplet_params) ===
{
"id": "ftn48a8ct********",
"createdAt": "2026-04-20T11:59:22Z",
"modifiedAt": "2026-04-20T11:59:22Z",
"metadata": {
"@type": "type.googleapis.com/yandex.cloud.ai.tuning.v1.TuningMetadata",
"status": "CREATED",
"tuningTaskId": "ftn48a8ct********","
}
}
Операция дообучения: ftn48a8ct********
=== Шаг 2: Ожидание завершения дообучения ===
Дообучение выполняется...
...
Дообучение выполняется...
{
"id": "ftn48a8ct********",
"createdAt": "2026-04-20T11:59:22Z",
"modifiedAt": "2026-04-20T12:02:37Z",
"done": true,
"metadata": {
"@type": "type.googleapis.com/yandex.cloud.ai.tuning.v1.TuningMetadata",
"status": "COMPLETED",
"tuningTaskId": "ftn48a8ct********"
},
"response": {
"@type": "type.googleapis.com/yandex.cloud.ai.tuning.v1.TuningResponse",
"status": "COMPLETED",
"targetModelUri": "emb://a1gtjhna53s********/text-embeddings/latest@tbmrjlutqnbs********",
"tuningTaskId": "ftn48a8ct********"
}
}
Дообучение завершено
=== Шаг 3: Ссылка на метрики ===
{
"loadUrl": "..."
}
URI дообученной модели: emb://a1gtjhna53s********/text-embeddings/latest@tbmrjlutqnbs********
#!/usr/bin/env python3
from __future__ import annotations
import uuid
from yandex_ai_studio_sdk import AIStudio
# Конфигурация
folder = "<идентификатор_каталога>"
token = "<API-ключ>"
type = "<тип_датасета>"
dataset_id = "<идентификатор_датасета>"
def main():
sdk = AIStudio(
folder_id=folder,
auth=token,
)
# Зададим датасет для обучения и базовую модель
train_dataset = sdk.datasets.get(dataset_id)
base_model = sdk.models.text_embeddings("text-embeddings")
# Запускаем дообучение
# Дообучение может длиться до нескольких часов
tuning_task = base_model.tune_deferred(
train_dataset, name=str(uuid.uuid4()), embeddings_tune_type=type
)
tuned_model = tuning_task.wait()
print(f"Resulting {tuned_model}")
if __name__ == "__main__":
main()
Где:
-
<идентификатор_каталога>— идентификатор каталога, в котором создан сервисный аккаунт. -
<API-ключ>— API-ключ сервисного аккаунта, полученный ранее и необходимый для аутентификации в API.В примерах используется аутентификация с помощью API-ключа. Yandex AI Studio SDK также поддерживает аутентификацию с помощью IAM-токена и OAuth-токена. Подробнее см. в разделе Аутентификация в Yandex AI Studio SDK.
-
<тип_датасета>— тип созданного ранее датасета. Возможные значения:pair(пара) иtriplet(триплет). -
<идентификатор_датасета>— идентификатор созданного ранее датасета для обучения.
Дообучение модели может занять до 1 суток в зависимости от объема датасета и загрузки системы.
Используйте полученный URI дообученной модели (значение поля uri) при обращении к ней.
Метрики дообучения доступны в формате TensorBoard. Загруженный файл можно открыть, например, в проекте Yandex DataSphere:
metrics_url = new_model.get_metrics_url()
download_tensorboard(metrics_url)
См. также
- Дообучение моделей
- Дообучить модель генерации текста
- Дообучить модель классификации текста
- Больше примеров SDK доступно в репозитории на GitHub